naga-29.0.3/.cargo/config.toml000064400000000000000000000000471046102023000141410ustar 00000000000000[alias] xtask = "run -p naga-xtask --" naga-29.0.3/.cargo_vcs_info.json0000644000000001421046102023000120540ustar { "git": { "sha1": "4cbe6232b2d7c289b6e1a38416a6ae1461a22e81" }, "path_in_vcs": "naga" }naga-29.0.3/.gitattributes000064400000000000000000000000401046102023000135120ustar 00000000000000tests/naga/out/**/* text eol=lf naga-29.0.3/.gitignore000064400000000000000000000002431046102023000126140ustar 00000000000000/target **/*.rs.bk Cargo.lock .DS_Store .fuse_hidden* .idea .vscode *.swp /*.dot /*.metal /*.metallib /*.ron /*.spv /*.vert /*.frag /*.comp /*.wgsl /*.hlsl /*.txt naga-29.0.3/CHANGELOG.md000064400000000000000000001531211046102023000124410ustar 00000000000000# Change Log For changelogs after v0.14, see [the wgpu changelog](../CHANGELOG.md). ## v0.14 (2023-10-25) #### GENERAL - Add support for const-expressions. ([#2309](https://github.com/gfx-rs/naga/pull/2309)) **@teoxoy**, **@jimblandy** - Add support for the `rgb10a2uint` storage format. ([#2525](https://github.com/gfx-rs/naga/pull/2525)) **@teoxoy** - Implement module compaction for snapshot testing and the CLI. ([#2472](https://github.com/gfx-rs/naga/pull/2472)) **@jimblandy** - Fix validation and GLSL parsing of `ldexp`. ([#2449](https://github.com/gfx-rs/naga/pull/2449)) **@fornwall** - Add support for dual source blending. ([#2427](https://github.com/gfx-rs/naga/pull/2427)) **@freqmod** - Bump `indexmap` to v2. ([#2426](https://github.com/gfx-rs/naga/pull/2426)) **@daxpedda** - Bump MSRV to 1.65. ([#2420](https://github.com/gfx-rs/naga/pull/2420)) **@jimblandy** #### API - Split `UnaryOperator::Not` into `UnaryOperator::LogicalNot` & `UnaryOperator::BitwiseNot`. ([#2554](https://github.com/gfx-rs/naga/pull/2554)) **@teoxoy** - Remove `IsFinite` & `IsNormal` relational functions. ([#2532](https://github.com/gfx-rs/naga/pull/2532)) **@teoxoy** - Derive `PartialEq` on `Expression`. ([#2417](https://github.com/gfx-rs/naga/pull/2417)) **@robtfm** - Use `FastIndexMap` for `SpecialTypes::predeclared_types`. ([#2495](https://github.com/gfx-rs/naga/pull/2495)) **@jimblandy** #### CLI - Change `--generate-debug-symbols` from an `option` to a `switch`. ([#2472](https://github.com/gfx-rs/naga/pull/2472)) **@jimblandy** - Add support for `.{vert,frag,comp}.glsl` files. ([#2462](https://github.com/gfx-rs/naga/pull/2462)) **@eliemichel** #### VALIDATOR - Require `Capabilities::FLOAT64` for 64-bit floating-point literals. ([#2567](https://github.com/gfx-rs/naga/pull/2567)) **@jimblandy** - Add `Capabilities::CUBE_ARRAY_TEXTURES`. ([#2530](https://github.com/gfx-rs/naga/pull/2530)) **@teoxoy** - Disallow passing pointers to variables in the workgroup address space to functions. ([#2507](https://github.com/gfx-rs/naga/pull/2507)) **@teoxoy** - Avoid OOM with large sparse resource bindings. ([#2561](https://github.com/gfx-rs/naga/pull/2561)) **@teoxoy** - Require that `Function` and `Private` variables be `CONSTRUCTIBLE`. ([#2545](https://github.com/gfx-rs/naga/pull/2545)) **@jimblandy** - Disallow floating-point NaNs and infinities. ([#2508](https://github.com/gfx-rs/naga/pull/2508)) **@teoxoy** - Temporarily disable uniformity analysis for the fragment stage. ([#2515](https://github.com/gfx-rs/naga/pull/2515)) **@teoxoy** - Validate that `textureSampleBias` is only used in the fragment stage. ([#2515](https://github.com/gfx-rs/naga/pull/2515)) **@teoxoy** - Validate variable initializer for address spaces. ([#2513](https://github.com/gfx-rs/naga/pull/2513)) **@teoxoy** - Prevent using multiple push constant variables in one entry point. ([#2484](https://github.com/gfx-rs/naga/pull/2484)) **@andriyDev** - Validate `binding_array` variable address space. ([#2422](https://github.com/gfx-rs/naga/pull/2422)) **@teoxoy** - Validate storage buffer access. ([#2415](https://github.com/gfx-rs/naga/pull/2415)) **@teoxoy** #### WGSL-IN - Fix expected min arg count of `textureLoad`. ([#2584](https://github.com/gfx-rs/naga/pull/2584)) **@teoxoy** - Turn `Error::Other` into `Error::Internal`, to help devs. ([#2574](https://github.com/gfx-rs/naga/pull/2574)) **@jimblandy** - Fix OOB typifier indexing. ([#2570](https://github.com/gfx-rs/naga/pull/2570)) **@teoxoy** - Add support for the `bgra8unorm` storage format. ([#2542](https://github.com/gfx-rs/naga/pull/2542) & [#2550](https://github.com/gfx-rs/naga/pull/2550)) **@nical** - Remove the `outerProduct` built-in function. ([#2535](https://github.com/gfx-rs/naga/pull/2535)) **@teoxoy** - Add support for `i32` overload of the `sign` built-in function. ([#2463](https://github.com/gfx-rs/naga/pull/2463)) **@fornwall** - Properly implement `modf` and `frexp`. ([#2454](https://github.com/gfx-rs/naga/pull/2454)) **@fornwall** - Add support for scalar overloads of `all` & `any` built-in functions. ([#2445](https://github.com/gfx-rs/naga/pull/2445)) **@fornwall** - Don't splat the left hand operand of a binary operation if it's not a scalar. ([#2444](https://github.com/gfx-rs/naga/pull/2444)) **@fornwall** - Avoid splatting all binary operator expressions. ([#2440](https://github.com/gfx-rs/naga/pull/2440)) **@fornwall** - Error on repeated or missing `@workgroup_size()`. ([#2435](https://github.com/gfx-rs/naga/pull/2435)) **@fornwall** - Error on repeated attributes. ([#2428](https://github.com/gfx-rs/naga/pull/2428)) **@fornwall** - Fix error message for invalid `texture{Load,Store}()` on arrayed textures. ([#2432](https://github.com/gfx-rs/naga/pull/2432)) **@fornwall** #### SPV-IN - Disable `Modf` & `Frexp` and translate `ModfStruct` & `FrexpStruct` to their IR equivalents. ([#2527](https://github.com/gfx-rs/naga/pull/2527)) **@teoxoy** - Don't advertise support for `Capability::ImageMSArray` & `Capability::InterpolationFunction`. ([#2529](https://github.com/gfx-rs/naga/pull/2529)) **@teoxoy** - Fix `OpImageQueries` to allow Uints. ([#2404](https://github.com/gfx-rs/naga/pull/2404)) **@evahop** #### GLSL-IN - Disable `modf` & `frexp`. ([#2527](https://github.com/gfx-rs/naga/pull/2527)) **@teoxoy** #### SPV-OUT - Require `ClipDistance` & `CullDistance` capabilities if necessary. ([#2528](https://github.com/gfx-rs/naga/pull/2528)) **@teoxoy** - Change `naga::back::spv::DebugInfo::file_name` to a `&Path`. ([#2501](https://github.com/gfx-rs/naga/pull/2501)) **@jimblandy** - Always give structs with runtime arrays a `Block` decoration. ([#2455](https://github.com/gfx-rs/naga/pull/2455)) **@TheoDulka** - Decorate the result of the `OpLoad` with `NonUniform` (not the access chain) when loading images/samplers (resources in the Handle address space). ([#2422](https://github.com/gfx-rs/naga/pull/2422)) **@teoxoy** - Cache `OpConstantNull`. ([#2414](https://github.com/gfx-rs/naga/pull/2414)) **@evahop** #### MSL-OUT - Add and fix minimum Metal version checks for optional functionality. ([#2486](https://github.com/gfx-rs/naga/pull/2486)) **@teoxoy** - Make varyings' struct members unique. ([#2521](https://github.com/gfx-rs/naga/pull/2521)) **@evahop** - Add experimental vertex pulling transform flag. ([#5254](https://github.com/gfx-rs/wgpu/pull/5254)) **@bradwerth** - Fixup some generated MSL for vertex buffer unpack functions. ([#5829](https://github.com/gfx-rs/wgpu/pull/5829)) **@bradwerth** - Make vertex pulling transform on by default. ([#5773](https://github.com/gfx-rs/wgpu/pull/5773)) **@bradwerth** #### GLSL-OUT - Cull functions that should not be available for a given stage. ([#2531](https://github.com/gfx-rs/naga/pull/2531)) **@teoxoy** - Rename identifiers containing double underscores. ([#2510](https://github.com/gfx-rs/naga/pull/2510)) **@evahop** - Polyfill `frexp`. ([#2504](https://github.com/gfx-rs/naga/pull/2504)) **@evahop** - Add built-in functions to keywords. ([#2410](https://github.com/gfx-rs/naga/pull/2410)) **@fornwall** #### WGSL-OUT - Generate correct code for bit complement on integers. ([#2548](https://github.com/gfx-rs/naga/pull/2548)) **@jimblandy** - Don't include type parameter in splat expressions. ([#2469](https://github.com/gfx-rs/naga/pull/2469)) **@jimblandy** ## v0.13 (2023-07-21) #### GENERAL - Move from `make` to `cargo xtask` workflows. ([#2297](https://github.com/gfx-rs/naga/pull/2297)) **@ErichDonGubler** - Omit non referenced expressions from output. ([#2378](https://github.com/gfx-rs/naga/pull/2378)) **@teoxoy** - Bump `bitflags` to v2. ([#2358](https://github.com/gfx-rs/naga/pull/2358)) **@daxpedda** - Implement `workgroupUniformLoad`. ([#2201](https://github.com/gfx-rs/naga/pull/2201)) **@DJMcNab** #### API - Expose early depth test field. ([#2393](https://github.com/gfx-rs/naga/pull/2393)) **@Joeoc2001** - Split image bounds check policy. ([#2265](https://github.com/gfx-rs/naga/pull/2265)) **@teoxoy** - Change type of constant sized arrays to `NonZeroU32`. ([#2337](https://github.com/gfx-rs/naga/pull/2337)) **@teoxoy** - Introduce `GlobalCtx`. ([#2335](https://github.com/gfx-rs/naga/pull/2335)) **@teoxoy** - Introduce `Expression::Literal`. ([#2333](https://github.com/gfx-rs/naga/pull/2333)) **@teoxoy** - Introduce `Expression::ZeroValue`. ([#2332](https://github.com/gfx-rs/naga/pull/2332)) **@teoxoy** - Add support for const-expressions (only at the API level, functionality is still WIP). ([#2266](https://github.com/gfx-rs/naga/pull/2266)) **@teoxoy**, **@jimblandy** #### DOCS - Document which expressions are in scope for a `break_if` expression. ([#2326](https://github.com/gfx-rs/naga/pull/2326)) **@jimblandy** #### VALIDATOR - Don't `use std::opsIndex`, used only when `"validate"` is on. ([#2383](https://github.com/gfx-rs/naga/pull/2383)) **@jimblandy** - Remove unneeded `ConstantError::Unresolved{Component,Size}`. ([#2330](https://github.com/gfx-rs/naga/pull/2330)) **@ErichDonGubler** - Remove `TypeError::UnresolvedBase`. ([#2308](https://github.com/gfx-rs/naga/pull/2308)) **@ErichDonGubler** #### WGSL-IN - Error on param redefinition. ([#2342](https://github.com/gfx-rs/naga/pull/2342)) **@SparkyPotato** #### SPV-IN - Improve documentation for SPIR-V control flow parsing. ([#2324](https://github.com/gfx-rs/naga/pull/2324)) **@jimblandy** - Obey the `is_depth` field of `OpTypeImage`. ([#2341](https://github.com/gfx-rs/naga/pull/2341)) **@expenses** - Convert conditional backedges to `break if`. ([#2290](https://github.com/gfx-rs/naga/pull/2290)) **@eddyb** #### GLSL-IN - Support commas in structure definitions. ([#2400](https://github.com/gfx-rs/naga/pull/2400)) **@fornwall** #### SPV-OUT - Add debug info. ([#2379](https://github.com/gfx-rs/naga/pull/2379)) **@wicast** - Use `IndexSet` instead of `HashSet` for iterated sets (capabilities/extensions). ([#2389](https://github.com/gfx-rs/naga/pull/2389)) **@eddyb** - Support array bindings of buffers. ([#2282](https://github.com/gfx-rs/naga/pull/2282)) **@kvark** #### MSL-OUT - Rename `allow_point_size` to `allow_and_force_point_size`. ([#2280](https://github.com/gfx-rs/naga/pull/2280)) **@teoxoy** - Initialize arrays inline. ([#2331](https://github.com/gfx-rs/naga/pull/2331)) **@teoxoy** #### HLSL-OUT - Implement Pack/Unpack for HLSL. ([#2353](https://github.com/gfx-rs/naga/pull/2353)) **@Elabajaba** - Complete HLSL reserved symbols. ([#2367](https://github.com/gfx-rs/naga/pull/2367)) **@teoxoy** - Handle case insensitive FXC keywords. ([#2347](https://github.com/gfx-rs/naga/pull/2347)) **@PJB3005** - Fix return type for firstbitlow/high. ([#2315](https://github.com/gfx-rs/naga/pull/2315)) **@evahop** #### GLSL-OUT - `textureSize` level must be a signed integer. ([#2397](https://github.com/gfx-rs/naga/pull/2397)) **@nical** - Fix functions with array return type. ([#2382](https://github.com/gfx-rs/naga/pull/2382)) **@Gordon-F** #### WGSL-OUT - Output `@interpolate(flat)` attribute for integer locations. ([#2318](https://github.com/gfx-rs/naga/pull/2318)) **@expenses** ## v0.12.3 (2023-07-09) #### WGSL-OUT - (Backport) Output `@interpolate(flat)` attribute for integer locations. ([#2318](https://github.com/gfx-rs/naga/pull/2318)) **@expenses** ## v0.12.2 (2023-05-30) #### SPV-OUT - (Backport) Support array bindings of buffers. ([#2282](https://github.com/gfx-rs/naga/pull/2282)) **@kvark** ## v0.12.1 (2023-05-18) #### SPV-IN - (Backport) Convert conditional backedges to `break if`. ([#2290](https://github.com/gfx-rs/naga/pull/2290)) **@eddyb** ## v0.12 (2023-04-19) #### GENERAL - Allow `array_index` to be unsigned. ([#2298](https://github.com/gfx-rs/naga/pull/2298)) **@daxpedda** - Add ray query support. ([#2256](https://github.com/gfx-rs/naga/pull/2256)) **@kvark** - Add partial derivative builtins. ([#2277](https://github.com/gfx-rs/naga/pull/2277)) **@evahop** - Skip `gl_PerVertex` unused builtins in the SPIR-V frontend. ([#2272](https://github.com/gfx-rs/naga/pull/2272)) **@teoxoy** - Differentiate between `i32` and `u32` in switch statement cases. ([#2269](https://github.com/gfx-rs/naga/pull/2269)) **@evahop** - Fix zero initialization of workgroup memory. ([#2259](https://github.com/gfx-rs/naga/pull/2259)) **@teoxoy** - Add `countTrailingZeros`. ([#2243](https://github.com/gfx-rs/naga/pull/2243)) **@gents83** - Fix texture built-ins where u32 was expected. ([#2245](https://github.com/gfx-rs/naga/pull/2245)) **@evahop** - Add `countLeadingZeros`. ([#2226](https://github.com/gfx-rs/naga/pull/2226)) **@evahop** - [glsl/hlsl-out] Write sizes of arrays behind pointers in function arguments. ([#2250](https://github.com/gfx-rs/naga/pull/2250)) **@pluiedev** #### VALIDATOR - Validate vertex stage returns the position built-in. ([#2264](https://github.com/gfx-rs/naga/pull/2264)) **@teoxoy** - Enforce discard is only used in the fragment stage. ([#2262](https://github.com/gfx-rs/naga/pull/2262)) **@Uriopass** - Add `Capabilities::MULTISAMPLED_SHADING`. ([#2255](https://github.com/gfx-rs/naga/pull/2255)) **@teoxoy** - Add `Capabilities::EARLY_DEPTH_TEST`. ([#2255](https://github.com/gfx-rs/naga/pull/2255)) **@teoxoy** - Add `Capabilities::MULTIVIEW`. ([#2255](https://github.com/gfx-rs/naga/pull/2255)) **@teoxoy** - Improve forward declaration validation. ([#2232](https://github.com/gfx-rs/naga/pull/2232)) **@JCapucho** #### WGSL-IN - Use `alias` instead of `type` for type aliases. ([#2299](https://github.com/gfx-rs/naga/pull/2299)) **@FL33TW00D** - Add predeclared vector and matrix type aliases. ([#2251](https://github.com/gfx-rs/naga/pull/2251)) **@evahop** - Improve invalid assignment diagnostic. ([#2233](https://github.com/gfx-rs/naga/pull/2233)) **@SparkyPotato** - Expect semicolons wherever required. ([#2233](https://github.com/gfx-rs/naga/pull/2233)) **@SparkyPotato** - Fix panic on invalid zero array size. ([#2233](https://github.com/gfx-rs/naga/pull/2233)) **@SparkyPotato** - Check for leading `{` while parsing a block. ([#2233](https://github.com/gfx-rs/naga/pull/2233)) **@SparkyPotato** #### SPV-IN - Don't apply interpolation to fragment shaders outputs. ([#2239](https://github.com/gfx-rs/naga/pull/2239)) **@JCapucho** #### GLSL-IN - Add switch implicit type conversion. ([#2273](https://github.com/gfx-rs/naga/pull/2273)) **@evahop** - Document some fields of `naga::front::glsl::context::Context`. ([#2244](https://github.com/gfx-rs/naga/pull/2244)) **@jimblandy** - Perform output parameters implicit casts. ([#2063](https://github.com/gfx-rs/naga/pull/2063)) **@JCapucho** - Add `not` vector relational builtin. ([#2227](https://github.com/gfx-rs/naga/pull/2227)) **@JCapucho** - Add double overloads for relational vector builtins. ([#2227](https://github.com/gfx-rs/naga/pull/2227)) **@JCapucho** - Add bool overloads for relational vector builtins. ([#2227](https://github.com/gfx-rs/naga/pull/2227)) **@JCapucho** #### SPV-OUT - Fix invalid spirv being generated from integer dot products. ([#2291](https://github.com/gfx-rs/naga/pull/2291)) **@PyryM** - Fix adding illegal decorators on fragment outputs. ([#2286](https://github.com/gfx-rs/naga/pull/2286)) **@Wumpf** - Fix `countLeadingZeros` impl. ([#2258](https://github.com/gfx-rs/naga/pull/2258)) **@teoxoy** - Cache constant composites. ([#2257](https://github.com/gfx-rs/naga/pull/2257)) **@evahop** - Support SPIR-V version 1.4. ([#2230](https://github.com/gfx-rs/naga/pull/2230)) **@kvark** #### MSL-OUT - Replace `per_stage_map` with `per_entry_point_map` ([#2237](https://github.com/gfx-rs/naga/pull/2237)) **@armansito** - Update `firstLeadingBit` for signed integers ([#2235](https://github.com/gfx-rs/naga/pull/2235)) **@evahop** #### HLSL-OUT - Use `Interlocked` intrinsic for atomic integers (#2294) ([#2294](https://github.com/gfx-rs/naga/pull/2294)) **@ErichDonGubler** - Document storage access generation. ([#2295](https://github.com/gfx-rs/naga/pull/2295)) **@jimblandy** - Emit constructor functions for arrays. ([#2281](https://github.com/gfx-rs/naga/pull/2281)) **@ErichDonGubler** - Clear `named_expressions` inserted by duplicated blocks. ([#2116](https://github.com/gfx-rs/naga/pull/2116)) **@teoxoy** #### GLSL-OUT - Skip `invariant` for `gl_FragCoord` on WebGL2. ([#2254](https://github.com/gfx-rs/naga/pull/2254)) **@grovesNL** - Inject default `gl_PointSize = 1.0` in vertex shaders if `FORCE_POINT_SIZE` option was set. ([#2223](https://github.com/gfx-rs/naga/pull/2223)) **@REASY** ## v0.11.1 (2023-05-18) #### SPV-IN - (Backport) Convert conditional backedges to `break if`. ([#2290](https://github.com/gfx-rs/naga/pull/2290)) **@eddyb** ## v0.11 (2023-01-25) - Move to the Rust 2021 edition ([#2085](https://github.com/gfx-rs/naga/pull/2085)) **@ErichDonGubler** - Bump MSRV to 1.63 ([#2129](https://github.com/gfx-rs/naga/pull/2129)) **@teoxoy** #### API - Add handle validation pass to `Validator` ([#2090](https://github.com/gfx-rs/naga/pull/2090)) **@ErichDonGubler** - Add `Range::new_from_bounds` ([#2148](https://github.com/gfx-rs/naga/pull/2148)) **@robtfm** #### DOCS - Fix docs for `Emit` statements ([#2208](https://github.com/gfx-rs/naga/pull/2208)) **@jimblandy** - Fix invalid `<...>` URLs with code spans ([#2176](https://github.com/gfx-rs/naga/pull/2176)) **@ErichDonGubler** - Explain how case clauses with multiple selectors are supported ([#2126](https://github.com/gfx-rs/naga/pull/2126)) **@teoxoy** - Document `EarlyDepthTest` and `ConservativeDepth` syntax ([#2132](https://github.com/gfx-rs/naga/pull/2132)) **@coreh** #### VALIDATOR - Allow `u32` coordinates for `textureStore`/`textureLoad` ([#2172](https://github.com/gfx-rs/naga/pull/2172)) **@PENGUINLIONG** - Fix array being flagged as constructible when its base isn't ([#2111](https://github.com/gfx-rs/naga/pull/2111)) **@teoxoy** - Add `type_flags` to `ModuleInfo` ([#2111](https://github.com/gfx-rs/naga/pull/2111)) **@teoxoy** - Remove overly restrictive array stride check ([#2215](https://github.com/gfx-rs/naga/pull/2215)) **@fintelia** - Let the uniformity analysis trust the handle validation pass ([#2200](https://github.com/gfx-rs/naga/pull/2200)) **@jimblandy** - Fix warnings when building tests without validation ([#2177](https://github.com/gfx-rs/naga/pull/2177)) **@jimblandy** - Add `ValidationFlags::BINDINGS` ([#2156](https://github.com/gfx-rs/naga/pull/2156)) **@kvark** - Fix `textureGather` on `texture_2d` ([#2138](https://github.com/gfx-rs/naga/pull/2138)) **@JMS55** #### ALL (FRONTENDS/BACKENDS) - Support 16-bit unorm/snorm formats ([#2210](https://github.com/gfx-rs/naga/pull/2210)) **@fintelia** - Support `gl_PointCoord` ([#2180](https://github.com/gfx-rs/naga/pull/2180)) **@Neo-Zhixing** #### ALL BACKENDS - Add support for zero-initializing workgroup memory ([#2111](https://github.com/gfx-rs/naga/pull/2111)) **@teoxoy** #### WGSL-IN - Implement module-level scoping ([#2075](https://github.com/gfx-rs/naga/pull/2075)) **@SparkyPotato** - Remove `isFinite` and `isNormal` ([#2218](https://github.com/gfx-rs/naga/pull/2218)) **@evahop** - Update inverse hyperbolic built-ins ([#2218](https://github.com/gfx-rs/naga/pull/2218)) **@evahop** - Add `refract` built-in ([#2218](https://github.com/gfx-rs/naga/pull/2218)) **@evahop** - Update reserved keywords ([#2130](https://github.com/gfx-rs/naga/pull/2130)) **@teoxoy** - Remove non-32bit integers ([#2146](https://github.com/gfx-rs/naga/pull/2146)) **@teoxoy** - Remove `workgroup_size` builtin ([#2147](https://github.com/gfx-rs/naga/pull/2147)) **@teoxoy** - Remove fallthrough statement ([#2126](https://github.com/gfx-rs/naga/pull/2126)) **@teoxoy** #### SPV-IN - Support binding arrays ([#2199](https://github.com/gfx-rs/naga/pull/2199)) **@Patryk27** #### GLSL-IN - Fix position propagation in lowering ([#2079](https://github.com/gfx-rs/naga/pull/2079)) **@JCapucho** - Update initializer list type when parsing ([#2066](https://github.com/gfx-rs/naga/pull/2066)) **@JCapucho** - Parenthesize unary negations to avoid `--` ([#2087](https://github.com/gfx-rs/naga/pull/2087)) **@ErichDonGubler** #### SPV-OUT - Add support for `atomicCompareExchangeWeak` ([#2165](https://github.com/gfx-rs/naga/pull/2165)) **@aweinstock314** - Omit extra switch case blocks where possible ([#2126](https://github.com/gfx-rs/naga/pull/2126)) **@teoxoy** - Fix switch cases after default not being output ([#2126](https://github.com/gfx-rs/naga/pull/2126)) **@teoxoy** #### MSL-OUT - Don't panic on missing bindings ([#2175](https://github.com/gfx-rs/naga/pull/2175)) **@kvark** - Omit extra switch case blocks where possible ([#2126](https://github.com/gfx-rs/naga/pull/2126)) **@teoxoy** - Fix `textureGather` compatibility on macOS 10.13 ([#2104](https://github.com/gfx-rs/naga/pull/2104)) **@xiaopengli89** - Fix incorrect atomic bounds check on metal back-end ([#2099](https://github.com/gfx-rs/naga/pull/2099)) **@raphlinus** - Parenthesize unary negations to avoid `--` ([#2087](https://github.com/gfx-rs/naga/pull/2087)) **@ErichDonGubler** #### HLSL-OUT - Simplify `write_default_init` ([#2111](https://github.com/gfx-rs/naga/pull/2111)) **@teoxoy** - Omit extra switch case blocks where possible ([#2126](https://github.com/gfx-rs/naga/pull/2126)) **@teoxoy** - Properly implement bitcast ([#2097](https://github.com/gfx-rs/naga/pull/2097)) **@cwfitzgerald** - Fix storage access chain through a matrix ([#2097](https://github.com/gfx-rs/naga/pull/2097)) **@cwfitzgerald** - Workaround FXC Bug in Matrix Indexing ([#2096](https://github.com/gfx-rs/naga/pull/2096)) **@cwfitzgerald** - Parenthesize unary negations to avoid `--` ([#2087](https://github.com/gfx-rs/naga/pull/2087)) **@ErichDonGubler** #### GLSL-OUT - Introduce a flag to include unused items ([#2205](https://github.com/gfx-rs/naga/pull/2205)) **@robtfm** - Use `fma` polyfill for versions below gles 320 ([#2197](https://github.com/gfx-rs/naga/pull/2197)) **@teoxoy** - Emit reflection info for non-struct uniforms ([#2189](https://github.com/gfx-rs/naga/pull/2189)) **@Rainb0wCodes** - Introduce a new block for switch cases ([#2126](https://github.com/gfx-rs/naga/pull/2126)) **@teoxoy** #### WGSL-OUT - Write correct scalar kind when `width != 4` ([#1514](https://github.com/gfx-rs/naga/pull/1514)) **@fintelia** ## v0.10.1 (2023-06-21) SPV-OUT - Backport #2389 (Use `IndexSet` instead of `HashSet` for iterated sets (capabilities/extensions)) by @eddyb, @jimblandy in https://github.com/gfx-rs/naga/pull/2391 SPV-IN - Backport #2290 (Convert conditional backedges to `break if`) by @eddyb in https://github.com/gfx-rs/naga/pull/2387 ## v0.10 (2022-10-05) - Make termcolor dependency optional by @AldaronLau in https://github.com/gfx-rs/naga/pull/2014 - Fix clippy lints for 1.63 by @JCapucho in https://github.com/gfx-rs/naga/pull/2026 - Saturate by @evahop in https://github.com/gfx-rs/naga/pull/2025 - Use `Option::as_deref` as appropriate. by @jimblandy in https://github.com/gfx-rs/naga/pull/2040 - Explicitly enable std for indexmap by @maxammann in https://github.com/gfx-rs/naga/pull/2062 - Fix compiler warning by @Gordon-F in https://github.com/gfx-rs/naga/pull/2074 API - Implement `Clone` for `Module` by @daxpedda in https://github.com/gfx-rs/naga/pull/2013 - Remove the glsl-validate feature by @JCapucho in https://github.com/gfx-rs/naga/pull/2045 DOCS - Document arithmetic binary operation type rules. by @jimblandy in https://github.com/gfx-rs/naga/pull/2051 VALIDATOR - Add `emit_to_{stderr,string}` helpers to validation error by @nolanderc in https://github.com/gfx-rs/naga/pull/2012 - Check regular functions don't have bindings by @JCapucho in https://github.com/gfx-rs/naga/pull/2050 WGSL-IN - Update reserved WGSL keywords by @norepimorphism in https://github.com/gfx-rs/naga/pull/2009 - Implement lexical scopes by @JCapucho in https://github.com/gfx-rs/naga/pull/2024 - Rename `Scope` to `Rule`, since we now have lexical scope. by @jimblandy in https://github.com/gfx-rs/naga/pull/2042 - Splat on compound assignments by @JCapucho in https://github.com/gfx-rs/naga/pull/2049 - Fix bad span in assignment lhs error by @JCapucho in https://github.com/gfx-rs/naga/pull/2054 - Fix inclusion of trivia in spans by @SparkyPotato in https://github.com/gfx-rs/naga/pull/2055 - Improve assignment diagnostics by @SparkyPotato in https://github.com/gfx-rs/naga/pull/2056 - Break up long string, reformat rest of file. by @jimblandy in https://github.com/gfx-rs/naga/pull/2057 - Fix line endings on wgsl reserved words list. by @jimblandy in https://github.com/gfx-rs/naga/pull/2059 GLSL-IN - Add support for .length() by @SpaceCat-Chan in https://github.com/gfx-rs/naga/pull/2017 - Fix missing stores for local declarations by @adeline-sparks in https://github.com/gfx-rs/naga/pull/2029 - Migrate to `SymbolTable` by @JCapucho in https://github.com/gfx-rs/naga/pull/2044 - Update initializer list type when parsing by @JCapucho in https://github.com/gfx-rs/naga/pull/2066 SPV-OUT - Don't decorate varyings with interpolation modes at pipeline start/end by @nical in https://github.com/gfx-rs/naga/pull/2038 - Decorate integer builtins as Flat in the spirv writer by @nical in https://github.com/gfx-rs/naga/pull/2035 - Properly combine the fixes for #2035 and #2038. by @jimblandy in https://github.com/gfx-rs/naga/pull/2041 - Don't emit no-op `OpBitCast` instructions. by @jimblandy in https://github.com/gfx-rs/naga/pull/2043 HLSL-OUT - Use the namer to sanitise entrypoint input/output struct names by @expenses in https://github.com/gfx-rs/naga/pull/2001 - Handle Unpack2x16float in hlsl by @expenses in https://github.com/gfx-rs/naga/pull/2002 - Add support for push constants by @JCapucho in https://github.com/gfx-rs/naga/pull/2005 DOT-OUT - Improvements by @JCapucho in https://github.com/gfx-rs/naga/pull/1987 ## v0.9 (2022-06-30) - Fix minimal-versions of dependencies ([#1840](https://github.com/gfx-rs/naga/pull/1840)) **@teoxoy** - Update MSRV to 1.56 ([#1838](https://github.com/gfx-rs/naga/pull/1838)) **@teoxoy** API - Rename `TypeFlags` `INTERFACE`/`HOST_SHARED` to `IO_SHARED`/`HOST_SHAREABLE` ([#1872](https://github.com/gfx-rs/naga/pull/1872)) **@jimblandy** - Expose more error information ([#1827](https://github.com/gfx-rs/naga/pull/1827), [#1937](https://github.com/gfx-rs/naga/pull/1937)) **@jakobhellermann** **@nical** **@jimblandy** - Do not unconditionally make error output colorful ([#1707](https://github.com/gfx-rs/naga/pull/1707)) **@rhysd** - Rename `StorageClass` to `AddressSpace` ([#1699](https://github.com/gfx-rs/naga/pull/1699)) **@kvark** - Add a way to emit errors to a path ([#1640](https://github.com/gfx-rs/naga/pull/1640)) **@laptou** CLI - Add `bincode` representation ([#1729](https://github.com/gfx-rs/naga/pull/1729)) **@kvark** - Include file path in WGSL parse error ([#1708](https://github.com/gfx-rs/naga/pull/1708)) **@rhysd** - Add `--version` flag ([#1706](https://github.com/gfx-rs/naga/pull/1706)) **@rhysd** - Support reading input from stdin via `--stdin-file-path` ([#1701](https://github.com/gfx-rs/naga/pull/1701)) **@rhysd** - Use `panic = "abort"` ([#1597](https://github.com/gfx-rs/naga/pull/1597)) **@jrmuizel** DOCS - Standardize some docs ([#1660](https://github.com/gfx-rs/naga/pull/1660)) **@NoelTautges** - Document `TypeInner::BindingArray` ([#1859](https://github.com/gfx-rs/naga/pull/1859)) **@jimblandy** - Clarify accepted types for `Expression::AccessIndex` ([#1862](https://github.com/gfx-rs/naga/pull/1862)) **@NoelTautges** - Document `proc::layouter` ([#1693](https://github.com/gfx-rs/naga/pull/1693)) **@jimblandy** - Document Naga's promises around validation and panics ([#1828](https://github.com/gfx-rs/naga/pull/1828)) **@jimblandy** - `FunctionInfo` doc fixes ([#1726](https://github.com/gfx-rs/naga/pull/1726)) **@jimblandy** VALIDATOR - Forbid returning pointers and atomics from functions ([#911](https://github.com/gfx-rs/naga/pull/911)) **@jimblandy** - Let validation check for more unsupported builtins ([#1962](https://github.com/gfx-rs/naga/pull/1962)) **@jimblandy** - Fix `Capabilities::SAMPLER_NON_UNIFORM_INDEXING` bitflag ([#1915](https://github.com/gfx-rs/naga/pull/1915)) **@cwfitzgerald** - Properly check that user-defined IO uses IO-shareable types ([#912](https://github.com/gfx-rs/naga/pull/912)) **@jimblandy** - Validate `ValuePointer` exactly like a `Pointer` to a `Scalar` ([#1875](https://github.com/gfx-rs/naga/pull/1875)) **@jimblandy** - Reject empty structs ([#1826](https://github.com/gfx-rs/naga/pull/1826)) **@jimblandy** - Validate uniform address space layout constraints ([#1812](https://github.com/gfx-rs/naga/pull/1812)) **@teoxoy** - Improve `AddressSpace` related error messages ([#1710](https://github.com/gfx-rs/naga/pull/1710)) **@kvark** WGSL-IN Main breaking changes - Commas to separate struct members (comma after last member is optional) - `struct S { a: f32; b: i32; }` -> `struct S { a: f32, b: i32 }` - Attribute syntax - `[[binding(0), group(0)]]` -> `@binding(0) @group(0)` - Entry point stage attributes - `@stage(vertex)` -> `@vertex` - `@stage(fragment)` -> `@fragment` - `@stage(compute)` -> `@compute` - Function renames - `smoothStep` -> `smoothstep` - `findLsb` -> `firstTrailingBit` - `findMsb` -> `firstLeadingBit` Specification Changes (relevant changes have also been applied to the WGSL backend) - Add support for `break if` ([#1993](https://github.com/gfx-rs/naga/pull/1993)) **@JCapucho** - Update number literal format ([#1863](https://github.com/gfx-rs/naga/pull/1863)) **@teoxoy** - Allow non-ascii characters in identifiers ([#1849](https://github.com/gfx-rs/naga/pull/1849)) **@teoxoy** - Update reserved keywords ([#1847](https://github.com/gfx-rs/naga/pull/1847), [#1870](https://github.com/gfx-rs/naga/pull/1870), [#1905](https://github.com/gfx-rs/naga/pull/1905)) **@teoxoy** **@Gordon-F** - Update entry point stage attributes ([#1833](https://github.com/gfx-rs/naga/pull/1833)) **@Gordon-F** - Make colon in case optional ([#1801](https://github.com/gfx-rs/naga/pull/1801)) **@Gordon-F** - Rename `smoothStep` to `smoothstep` ([#1800](https://github.com/gfx-rs/naga/pull/1800)) **@Gordon-F** - Make semicolon after struct declaration optional ([#1791](https://github.com/gfx-rs/naga/pull/1791)) **@stshine** - Use commas to separate struct members instead of semicolons ([#1773](https://github.com/gfx-rs/naga/pull/1773)) **@Gordon-F** - Rename `findLsb`/`findMsb` to `firstTrailingBit`/`firstLeadingBit` ([#1735](https://github.com/gfx-rs/naga/pull/1735)) **@kvark** - Make parenthesis optional for `if` and `switch` statements ([#1725](https://github.com/gfx-rs/naga/pull/1725)) **@Gordon-F** - Declare attributes with `@attrib` instead of `[[attrib]]` ([#1676](https://github.com/gfx-rs/naga/pull/1676)) **@kvark** - Allow non-structure buffer types ([#1682](https://github.com/gfx-rs/naga/pull/1682)) **@kvark** - Remove `stride` attribute ([#1681](https://github.com/gfx-rs/naga/pull/1681)) **@kvark** Improvements - Implement complete validation for size and align attributes ([#1979](https://github.com/gfx-rs/naga/pull/1979)) **@teoxoy** - Implement `firstTrailingBit`/`firstLeadingBit` u32 overloads ([#1865](https://github.com/gfx-rs/naga/pull/1865)) **@teoxoy** - Add error for non-floating-point matrix ([#1917](https://github.com/gfx-rs/naga/pull/1917)) **@grovesNL** - Implement partial vector & matrix identity constructors ([#1916](https://github.com/gfx-rs/naga/pull/1916)) **@teoxoy** - Implement phony assignment ([#1866](https://github.com/gfx-rs/naga/pull/1866), [#1869](https://github.com/gfx-rs/naga/pull/1869)) **@teoxoy** - Fix being able to match `~=` as LogicalOperation ([#1849](https://github.com/gfx-rs/naga/pull/1849)) **@teoxoy** - Implement Binding Arrays ([#1845](https://github.com/gfx-rs/naga/pull/1845)) **@cwfitzgerald** - Implement unary vector operators ([#1820](https://github.com/gfx-rs/naga/pull/1820)) **@teoxoy** - Implement zero value constructors and constructors that infer their type from their parameters ([#1790](https://github.com/gfx-rs/naga/pull/1790)) **@teoxoy** - Implement invariant attribute ([#1789](https://github.com/gfx-rs/naga/pull/1789), [#1822](https://github.com/gfx-rs/naga/pull/1822)) **@teoxoy** **@jimblandy** - Implement increment and decrement statements ([#1788](https://github.com/gfx-rs/naga/pull/1788), [#1912](https://github.com/gfx-rs/naga/pull/1912)) **@teoxoy** - Implement `while` loop ([#1787](https://github.com/gfx-rs/naga/pull/1787)) **@teoxoy** - Fix array size on globals ([#1717](https://github.com/gfx-rs/naga/pull/1717)) **@jimblandy** - Implement integer vector overloads for `dot` function ([#1689](https://github.com/gfx-rs/naga/pull/1689)) **@francesco-cattoglio** - Implement block comments ([#1675](https://github.com/gfx-rs/naga/pull/1675)) **@kocsis1david** - Implement assignment binary operators ([#1662](https://github.com/gfx-rs/naga/pull/1662)) **@kvark** - Implement `radians`/`degrees` builtin functions ([#1627](https://github.com/gfx-rs/naga/pull/1627)) **@encounter** - Implement `findLsb`/`findMsb` builtin functions ([#1473](https://github.com/gfx-rs/naga/pull/1473)) **@fintelia** - Implement `textureGather`/`textureGatherCompare` builtin functions ([#1596](https://github.com/gfx-rs/naga/pull/1596)) **@kvark** SPV-IN - Implement `OpBitReverse` and `OpBitCount` ([#1954](https://github.com/gfx-rs/naga/pull/1954)) **@JCapucho** - Add `MultiView` to `SUPPORTED_CAPABILITIES` ([#1934](https://github.com/gfx-rs/naga/pull/1934)) **@expenses** - Translate `OpSMod` and `OpFMod` correctly ([#1867](https://github.com/gfx-rs/naga/pull/1867), [#1995](https://github.com/gfx-rs/naga/pull/1995)) **@teoxoy** **@JCapucho** - Error on unsupported `MatrixStride` ([#1805](https://github.com/gfx-rs/naga/pull/1805)) **@teoxoy** - Align array stride for undecorated arrays ([#1724](https://github.com/gfx-rs/naga/pull/1724)) **@JCapucho** GLSL-IN - Don't allow empty last case in switch ([#1981](https://github.com/gfx-rs/naga/pull/1981)) **@JCapucho** - Fix last case fallthrough and empty switch ([#1981](https://github.com/gfx-rs/naga/pull/1981)) **@JCapucho** - Splat inputs for smoothstep if needed ([#1976](https://github.com/gfx-rs/naga/pull/1976)) **@JCapucho** - Fix parameter not changing to depth ([#1967](https://github.com/gfx-rs/naga/pull/1967)) **@JCapucho** - Fix matrix multiplication check ([#1953](https://github.com/gfx-rs/naga/pull/1953)) **@JCapucho** - Fix panic (stop emitter in conditional) ([#1952](https://github.com/gfx-rs/naga/pull/1952)) **@JCapucho** - Translate `mod` fn correctly ([#1867](https://github.com/gfx-rs/naga/pull/1867)) **@teoxoy** - Make the ternary operator behave as an if ([#1877](https://github.com/gfx-rs/naga/pull/1877)) **@JCapucho** - Add support for `clamp` function ([#1502](https://github.com/gfx-rs/naga/pull/1502)) **@sjinno** - Better errors for bad constant expression ([#1501](https://github.com/gfx-rs/naga/pull/1501)) **@sjinno** - Error on a `matCx2` used with the `std140` layout ([#1806](https://github.com/gfx-rs/naga/pull/1806)) **@teoxoy** - Allow nested accesses in lhs positions ([#1794](https://github.com/gfx-rs/naga/pull/1794)) **@JCapucho** - Use forced conversions for vector/matrix constructors ([#1796](https://github.com/gfx-rs/naga/pull/1796)) **@JCapucho** - Add support for `barrier` function ([#1793](https://github.com/gfx-rs/naga/pull/1793)) **@fintelia** - Fix panic (resume expression emit after `imageStore`) ([#1795](https://github.com/gfx-rs/naga/pull/1795)) **@JCapucho** - Allow multiple array specifiers ([#1780](https://github.com/gfx-rs/naga/pull/1780)) **@JCapucho** - Fix memory qualifiers being inverted ([#1779](https://github.com/gfx-rs/naga/pull/1779)) **@JCapucho** - Support arrays as input/output types ([#1759](https://github.com/gfx-rs/naga/pull/1759)) **@JCapucho** - Fix freestanding constructor parsing ([#1758](https://github.com/gfx-rs/naga/pull/1758)) **@JCapucho** - Fix matrix - scalar operations ([#1757](https://github.com/gfx-rs/naga/pull/1757)) **@JCapucho** - Fix matrix - matrix division ([#1757](https://github.com/gfx-rs/naga/pull/1757)) **@JCapucho** - Fix matrix comparisons ([#1757](https://github.com/gfx-rs/naga/pull/1757)) **@JCapucho** - Add support for `texelFetchOffset` ([#1746](https://github.com/gfx-rs/naga/pull/1746)) **@JCapucho** - Inject `sampler2DMSArray` builtins on use ([#1737](https://github.com/gfx-rs/naga/pull/1737)) **@JCapucho** - Inject `samplerCubeArray` builtins on use ([#1736](https://github.com/gfx-rs/naga/pull/1736)) **@JCapucho** - Add support for image builtin functions ([#1723](https://github.com/gfx-rs/naga/pull/1723)) **@JCapucho** - Add support for image declarations ([#1723](https://github.com/gfx-rs/naga/pull/1723)) **@JCapucho** - Texture builtins fixes ([#1719](https://github.com/gfx-rs/naga/pull/1719)) **@JCapucho** - Type qualifiers rework ([#1713](https://github.com/gfx-rs/naga/pull/1713)) **@JCapucho** - `texelFetch` accept multisampled textures ([#1715](https://github.com/gfx-rs/naga/pull/1715)) **@JCapucho** - Fix panic when culling nested block ([#1714](https://github.com/gfx-rs/naga/pull/1714)) **@JCapucho** - Fix composite constructors ([#1631](https://github.com/gfx-rs/naga/pull/1631)) **@JCapucho** - Fix using swizzle as out arguments ([#1632](https://github.com/gfx-rs/naga/pull/1632)) **@JCapucho** SPV-OUT - Implement `reverseBits` and `countOneBits` ([#1897](https://github.com/gfx-rs/naga/pull/1897)) **@hasali19** - Use `OpCopyObject` for matrix identity casts ([#1916](https://github.com/gfx-rs/naga/pull/1916)) **@teoxoy** - Use `OpCopyObject` for bool - bool conversion due to `OpBitcast` not being feasible for booleans ([#1916](https://github.com/gfx-rs/naga/pull/1916)) **@teoxoy** - Zero init variables in function and private address spaces ([#1871](https://github.com/gfx-rs/naga/pull/1871)) **@teoxoy** - Use `SRem` instead of `SMod` ([#1867](https://github.com/gfx-rs/naga/pull/1867)) **@teoxoy** - Add support for integer vector - scalar multiplication ([#1820](https://github.com/gfx-rs/naga/pull/1820)) **@teoxoy** - Add support for matrix addition and subtraction ([#1820](https://github.com/gfx-rs/naga/pull/1820)) **@teoxoy** - Emit required decorations on wrapper struct types ([#1815](https://github.com/gfx-rs/naga/pull/1815)) **@jimblandy** - Decorate array and struct type layouts unconditionally ([#1815](https://github.com/gfx-rs/naga/pull/1815)) **@jimblandy** - Fix wrong `MatrixStride` for `matCx2` and `mat2xR` ([#1781](https://github.com/gfx-rs/naga/pull/1781)) **@teoxoy** - Use `OpImageQuerySize` for MS images ([#1742](https://github.com/gfx-rs/naga/pull/1742)) **@JCapucho** MSL-OUT - Insert padding initialization for global constants ([#1988](https://github.com/gfx-rs/naga/pull/1988)) **@teoxoy** - Don't rely on cached expressions ([#1975](https://github.com/gfx-rs/naga/pull/1975)) **@JCapucho** - Fix pointers to private or workgroup address spaces possibly being read only ([#1901](https://github.com/gfx-rs/naga/pull/1901)) **@teoxoy** - Zero init variables in function address space ([#1871](https://github.com/gfx-rs/naga/pull/1871)) **@teoxoy** - Make binding arrays play nice with bounds checks ([#1855](https://github.com/gfx-rs/naga/pull/1855)) **@cwfitzgerald** - Permit `invariant` qualifier on vertex shader outputs ([#1821](https://github.com/gfx-rs/naga/pull/1821)) **@jimblandy** - Fix packed `vec3` stores ([#1816](https://github.com/gfx-rs/naga/pull/1816)) **@teoxoy** - Actually test push constants to be used ([#1767](https://github.com/gfx-rs/naga/pull/1767)) **@kvark** - Properly rename entry point arguments for struct members ([#1766](https://github.com/gfx-rs/naga/pull/1766)) **@jimblandy** - Qualify read-only storage with const ([#1763](https://github.com/gfx-rs/naga/pull/1763)) **@kvark** - Fix not unary operator for integer scalars ([#1760](https://github.com/gfx-rs/naga/pull/1760)) **@vincentisambart** - Add bounds checks for `ImageLoad` and `ImageStore` ([#1730](https://github.com/gfx-rs/naga/pull/1730)) **@jimblandy** - Fix resource bindings for non-structures ([#1718](https://github.com/gfx-rs/naga/pull/1718)) **@kvark** - Always check whether _buffer_sizes arg is needed ([#1717](https://github.com/gfx-rs/naga/pull/1717)) **@jimblandy** - WGSL storage address space should always correspond to MSL device address space ([#1711](https://github.com/gfx-rs/naga/pull/1711)) **@wtholliday** - Mitigation for MSL atomic bounds check ([#1703](https://github.com/gfx-rs/naga/pull/1703)) **@glalonde** HLSL-OUT - More `matCx2` fixes (#1989) ([#1989](https://github.com/gfx-rs/naga/pull/1989)) **@teoxoy** - Fix fallthrough in switch statements ([#1920](https://github.com/gfx-rs/naga/pull/1920)) **@teoxoy** - Fix missing break statements ([#1919](https://github.com/gfx-rs/naga/pull/1919)) **@teoxoy** - Fix `countOneBits` and `reverseBits` for signed integers ([#1928](https://github.com/gfx-rs/naga/pull/1928)) **@hasali19** - Fix array constructor return type ([#1914](https://github.com/gfx-rs/naga/pull/1914)) **@teoxoy** - Fix hlsl output for writes to scalar/vector storage buffer ([#1903](https://github.com/gfx-rs/naga/pull/1903)) **@hasali19** - Use `fmod` instead of `%` ([#1867](https://github.com/gfx-rs/naga/pull/1867)) **@teoxoy** - Use wrapped constructors when loading from storage address space ([#1893](https://github.com/gfx-rs/naga/pull/1893)) **@teoxoy** - Zero init struct constructor ([#1890](https://github.com/gfx-rs/naga/pull/1890)) **@teoxoy** - Flesh out matrix handling documentation ([#1850](https://github.com/gfx-rs/naga/pull/1850)) **@jimblandy** - Emit `row_major` qualifier on matrix uniform globals ([#1846](https://github.com/gfx-rs/naga/pull/1846)) **@jimblandy** - Fix bool splat ([#1820](https://github.com/gfx-rs/naga/pull/1820)) **@teoxoy** - Add more padding when necessary ([#1814](https://github.com/gfx-rs/naga/pull/1814)) **@teoxoy** - Support multidimensional arrays ([#1814](https://github.com/gfx-rs/naga/pull/1814)) **@teoxoy** - Don't output interpolation modifier if it's the default ([#1809](https://github.com/gfx-rs/naga/pull/1809)) **@NoelTautges** - Fix `matCx2` translation for uniform buffers ([#1802](https://github.com/gfx-rs/naga/pull/1802)) **@teoxoy** - Fix modifiers not being written in the vertex output and fragment input structs ([#1789](https://github.com/gfx-rs/naga/pull/1789)) **@teoxoy** - Fix matrix not being declared as transposed ([#1784](https://github.com/gfx-rs/naga/pull/1784)) **@teoxoy** - Insert padding between struct members ([#1786](https://github.com/gfx-rs/naga/pull/1786)) **@teoxoy** - Fix not unary operator for integer scalars ([#1760](https://github.com/gfx-rs/naga/pull/1760)) **@vincentisambart** GLSL-OUT - Fix vector bitcasts (#1966) ([#1966](https://github.com/gfx-rs/naga/pull/1966)) **@expenses** - Perform casts in int only math functions ([#1978](https://github.com/gfx-rs/naga/pull/1978)) **@JCapucho** - Don't rely on cached expressions ([#1975](https://github.com/gfx-rs/naga/pull/1975)) **@JCapucho** - Fix type error for `countOneBits` implementation ([#1897](https://github.com/gfx-rs/naga/pull/1897)) **@hasali19** - Fix storage format for `Rgba8Unorm` ([#1955](https://github.com/gfx-rs/naga/pull/1955)) **@JCapucho** - Implement bounds checks for `ImageLoad` ([#1889](https://github.com/gfx-rs/naga/pull/1889)) **@JCapucho** - Fix feature search in expressions ([#1887](https://github.com/gfx-rs/naga/pull/1887)) **@JCapucho** - Emit globals of any type ([#1823](https://github.com/gfx-rs/naga/pull/1823)) **@jimblandy** - Add support for boolean vector `~`, `|` and `&` ops ([#1820](https://github.com/gfx-rs/naga/pull/1820)) **@teoxoy** - Fix array function arguments ([#1814](https://github.com/gfx-rs/naga/pull/1814)) **@teoxoy** - Write constant sized array type for uniform ([#1768](https://github.com/gfx-rs/naga/pull/1768)) **@hatoo** - Texture function fixes ([#1742](https://github.com/gfx-rs/naga/pull/1742)) **@JCapucho** - Push constants use anonymous uniforms ([#1683](https://github.com/gfx-rs/naga/pull/1683)) **@JCapucho** - Add support for push constant emulation ([#1672](https://github.com/gfx-rs/naga/pull/1672)) **@JCapucho** - Skip unsized types if unused ([#1649](https://github.com/gfx-rs/naga/pull/1649)) **@kvark** - Write struct and array initializers ([#1644](https://github.com/gfx-rs/naga/pull/1644)) **@JCapucho** ## v0.8.5 (2022-01-25) MSL-OUT - Make VS-output positions invariant on even more systems ([#1697](https://github.com/gfx-rs/naga/pull/1697)) **@cwfitzgerald** - Improve support for point primitives ([#1696](https://github.com/gfx-rs/naga/pull/1696)) **@kvark** ## v0.8.4 (2022-01-24) MSL-OUT - Make VS-output positions invariant if possible ([#1687](https://github.com/gfx-rs/naga/pull/1687)) **@kvark** GLSL-OUT - Fix `floatBitsToUint` spelling ([#1688](https://github.com/gfx-rs/naga/pull/1688)) **@cwfitzgerald** - Call proper memory barrier functions ([#1680](https://github.com/gfx-rs/naga/pull/1680)) **@francesco-cattoglio** ## v0.8.3 (2022-01-20) - Don't pin `indexmap` version ([#1666](https://github.com/gfx-rs/naga/pull/1666)) **@a1phyr** MSL-OUT - Fix support for point primitives ([#1674](https://github.com/gfx-rs/naga/pull/1674)) **@kvark** GLSL-OUT - Fix sampler association ([#1671](https://github.com/gfx-rs/naga/pull/1671)) **@JCapucho** ## v0.8.2 (2022-01-11) VALIDATOR - Check structure resource types ([#1639](https://github.com/gfx-rs/naga/pull/1639)) **@kvark** WGSL-IN - Improve type mismatch errors ([#1658](https://github.com/gfx-rs/naga/pull/1658)) **@Gordon-F** SPV-IN - Implement more sign agnostic operations ([#1651](https://github.com/gfx-rs/naga/pull/1651), [#1650](https://github.com/gfx-rs/naga/pull/1650)) **@JCapucho** SPV-OUT - Fix modulo operator (use `OpFRem` instead of `OpFMod`) ([#1653](https://github.com/gfx-rs/naga/pull/1653)) **@JCapucho** MSL-OUT - Fix `texture1d` accesses ([#1647](https://github.com/gfx-rs/naga/pull/1647)) **@jimblandy** - Fix data packing functions ([#1637](https://github.com/gfx-rs/naga/pull/1637)) **@phoekz** ## v0.8.1 (2021-12-29) API - Make `WithSpan` cloneable ([#1620](https://github.com/gfx-rs/naga/pull/1620)) **@jakobhellermann** MSL-OUT - Fix packed vec access ([#1634](https://github.com/gfx-rs/naga/pull/1634)) **@kvark** - Fix packed float support ([#1630](https://github.com/gfx-rs/naga/pull/1630)) **@kvark** HLSL-OUT - Support arrays of matrices ([#1629](https://github.com/gfx-rs/naga/pull/1629)) **@kvark** - Use `mad` instead of `fma` function ([#1580](https://github.com/gfx-rs/naga/pull/1580)) **@parasyte** GLSL-OUT - Fix conflicting names for globals ([#1616](https://github.com/gfx-rs/naga/pull/1616)) **@Gordon-F** - Fix `fma` function ([#1580](https://github.com/gfx-rs/naga/pull/1580)) **@parasyte** ## v0.8 (2021-12-18) - development release for wgpu-0.12 - lots of fixes in all parts - validator: - now gated by `validate` feature - nicely detailed error messages with spans - API: - image gather operations - WGSL-in: - remove `[[block]]` attribute - `elseif` is removed in favor of `else if` - MSL-out: - full out-of-bounds checking ## v0.7.3 (2021-12-14) - API: - `view_index` builtin - GLSL-out: - reflect textures without samplers - SPV-out: - fix incorrect pack/unpack ## v0.7.2 (2021-12-01) - validator: - check stores for proper pointer class - HLSL-out: - fix stores into `mat3` - respect array strides - SPV-out: - fix multi-word constants - WGSL-in: - permit names starting with underscores - SPV-in: - cull unused builtins - support empty debug labels - GLSL-in: - don't panic on invalid integer operations ## v0.7.1 (2021-10-12) - implement casts from and to booleans in the backends ## v0.7 (2021-10-07) - development release for wgpu-0.11 - API: - bit extraction and packing functions - hyperbolic trigonometry functions - validation is gated by a cargo feature - `view_index` builtin - separate bounds checking policies for locals/buffers/textures - IR: - types and constants are guaranteed to be unique - WGSL-in: - new hex literal parser - updated list of reserved words - rewritten logic for resolving references and pointers - `switch` can use unsigned selectors - GLSL-in: - better support for texture sampling - better logic for auto-splatting scalars - GLSL-out: - fixed storage buffer layout - fix module operator - HLSL-out: - fixed texture queries - SPV-in: - control flow handling is rewritten from scratch - SPV-out: - fully covered out-of-bounds checking - option to emit point size - option to clamp output depth ## v0.6.3 (2021-09-08) - Reduced heap allocations when generating WGSL, HLSL, and GLSL - WGSL-in: - support module-scope `let` type inference - SPV-in: - fix depth sampling with projection - HLSL-out: - fix local struct construction - GLSL-out: - fix `select()` order - SPV-out: - allow working around Adreno issue with `OpName` ## v0.6.2 (2021-09-01) - SPV-out fixes: - requested capabilities for 1D and cube images, storage formats - handling `break` and `continue` in a `switch` statement - avoid generating duplicate `OpTypeImage` types - HLSL-out fixes: - fix output struct member names - MSL-out fixes: - fix packing of fields in interface structs - GLSL-out fixes: - fix non-fallthrough `switch` cases - GLSL-in fixes: - avoid infinite loop on invalid statements ## v0.6.1 (2021-08-24) - HLSL-out fixes: - array arguments - pointers to array arguments - switch statement - rewritten interface matching - SPV-in fixes: - array storage texture stores - tracking sampling across function parameters - updated petgraph dependencies - MSL-out: - gradient sampling - GLSL-out: - modulo operator on floats ## v0.6 (2021-08-18) - development release for wgpu-0.10 - API: - atomic types and functions - storage access is moved from global variables to the storage class and storage texture type - new built-ins: `primitive_index` and `num_workgroups` - support for multi-sampled depth images - WGSL: - `select()` order of true/false is swapped - HLSL backend is vastly improved and now usable - GLSL frontend is heavily reworked ## v0.5 (2021-06-18) - development release for wgpu-0.9 - API: - barriers - dynamic indexing of matrices and arrays is only allowed on variables - validator now accepts a list of IR capabilities to allow - improved documentation - Infrastructure: - much richer test suite, focused around consuming or emitting WGSL - lazy testing on large shader corpuses - the binary is moved to a sub-crate "naga-cli" - Frontends: - GLSL frontend: - rewritten from scratch and effectively revived, no longer depends on `pomelo` - only supports 440/450/460 versions for now - has optional support for codespan messages - SPIRV frontend has improved CFG resolution (still with issues unresolved) - WGSL got better error messages, workgroup memory support - Backends: - general: better expression naming and emitting - new HLSL backend (in progress) - MSL: - support `ArraySize` expression - better texture sampling instructions - GLSL: - multisampling on GLES - WGSL is vastly improved and now usable ## v0.4.2 (2021-05-28) - SPIR-V frontend: - fix image stores - fix matrix stride check - SPIR-V backend: - fix auto-deriving the capabilities - GLSL backend: - support sample interpolation - write out swizzled vector accesses ## v0.4.1 (2021-05-14) - numerous additions and improvements to SPIR-V frontend: - int8, in16, int64 - null constant initializers for structs and matrices - `OpArrayLength`, `OpCopyMemory`, `OpInBoundsAccessChain`, `OpLogicalXxxEqual` - outer product - fix struct size alignment - initialize built-ins with default values - fix read-only decorations on struct members - fix struct size alignment in WGSL - fix `fwidth` in WGSL - fix scalars arrays in GLSL backend ## v0.4 (2021-04-29) - development release for wgpu-0.8 - API: - expressions are explicitly emitted with `Statement::Emit` - entry points have inputs in arguments and outputs in the result type - `input`/`output` storage classes are gone, but `push_constant` is added - `Interpolation` is moved into `Binding::Location` variant - real pointer semantics with required `Expression::Load` - `TypeInner::ValuePointer` is added - image query expressions are added - new `Statement::ImageStore` - all function calls are `Statement::Call` - `GlobalUse` is moved out into processing - `Header` is removed - entry points are an array instead of a map - new `Swizzle` and `Splat` expressions - interpolation qualifiers are extended and required - struct member layout is based on the byte offsets - Infrastructure: - control flow uniformity analysis - texture-sampler combination gathering - `CallGraph` processor is moved out into `glsl` backend - `Interface` is removed, instead the analysis produces `ModuleInfo` with all the derived info - validation of statement tree, expressions, and constants - code linting is more strict for matches - new GraphViz `dot` backend for pretty visualization of the IR - Metal support for inlined samplers - `convert` example is transformed into the default binary target named `naga` - lots of frontend and backend fixes ## v0.3.2 (2021-02-15) - fix logical expression types - fix _FragDepth_ semantics - spv-in: - derive block status of structures - spv-out: - add lots of missing math functions - implement discard ## v0.3.1 (2021-01-31) - wgsl: - support constant array sizes - spv-out: - fix block decorations on nested structures - fix fixed-size arrays - fix matrix decorations inside structures - implement read-only decorations ## v0.3 (2021-01-30) - development release for wgpu-0.7 - API: - math functions - type casts - updated storage classes - updated image sub-types - image sampling/loading options - storage images - interpolation qualifiers - early and conservative depth - Processors: - name manager - automatic layout - termination analysis - validation of types, constants, variables, and entry points ## v0.2 (2020-08-17) - development release for wgpu-0.6 ## v0.1 (2020-02-26) - initial release naga-29.0.3/Cargo.lock0000644000000322321046102023000100340ustar # This file is automatically @generated by Cargo. # It is not intended for manual editing. version = 3 [[package]] name = "arbitrary" version = "1.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c3d036a3c4ab069c7b410a2ce876bd74808d2d0888a82667669f8e783a898bf1" dependencies = [ "derive_arbitrary", ] [[package]] name = "arrayvec" version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" [[package]] name = "autocfg" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" [[package]] name = "bit-set" version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34ddef2995421ab6a5c779542c81ee77c115206f4ad9d5a8e05f4ff49716a3dd" dependencies = [ "bit-vec", ] [[package]] name = "bit-vec" version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b71798fca2c1fe1086445a7258a4bc81e6e49dcd24c8d0dd9a1e57395b603f51" [[package]] name = "bitflags" version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af" dependencies = [ "arbitrary", "serde_core", ] [[package]] name = "cfg-if" version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" [[package]] name = "cfg_aliases" version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" [[package]] name = "codespan-reporting" version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "af491d569909a7e4dee0ad7db7f5341fef5c614d5b8ec8cf765732aba3cff681" dependencies = [ "serde", "termcolor", "unicode-width", ] [[package]] name = "crunchy" version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" [[package]] name = "derive_arbitrary" version = "1.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e567bd82dcff979e4b03460c307b3cdc9e96fde3d73bed1496d2bc75d9dd62a" dependencies = [ "proc-macro2", "quote", "syn", ] [[package]] name = "diff" version = "0.1.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56254986775e3233ffa9c4d7d3faaf6d36a2c09d30b20687e9f88bc8bafc16c8" [[package]] name = "either" version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" [[package]] name = "env_filter" version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32e90c2accc4b07a8456ea0debdc2e7587bdd890680d71173a15d4ae604f6eef" dependencies = [ "log", ] [[package]] name = "env_logger" version = "0.11.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0621c04f2196ac3f488dd583365b9c09be011a4ab8b9f37248ffcc8f6198b56a" dependencies = [ "env_filter", "log", ] [[package]] name = "equivalent" version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" [[package]] name = "fixedbitset" version = "0.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" [[package]] name = "foldhash" version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" [[package]] name = "foldhash" version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb" [[package]] name = "half" version = "2.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6ea2d84b969582b4b1864a92dc5d27cd2b77b622a8d79306834f1be5ba20d84b" dependencies = [ "arbitrary", "cfg-if", "crunchy", "num-traits", "serde", "zerocopy", ] [[package]] name = "hashbrown" version = "0.15.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" dependencies = [ "foldhash 0.1.5", ] [[package]] name = "hashbrown" version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" dependencies = [ "foldhash 0.2.0", "serde", "serde_core", ] [[package]] name = "heck" version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" [[package]] name = "hexf-parse" version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dfa686283ad6dd069f105e5ab091b04c62850d3e4cf5d67debad1933f55023df" [[package]] name = "indexmap" version = "2.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" dependencies = [ "arbitrary", "equivalent", "hashbrown 0.16.1", "serde", "serde_core", ] [[package]] name = "itertools" version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285" dependencies = [ "either", ] [[package]] name = "libm" version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6d2cec3eae94f9f509c767b45932f1ada8350c4bdb85af2fcab4a3c14807981" [[package]] name = "log" version = "0.4.29" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" [[package]] name = "naga" version = "29.0.3" dependencies = [ "arbitrary", "arrayvec", "bit-set", "bitflags", "cfg-if", "cfg_aliases", "codespan-reporting", "diff", "env_logger", "half", "hashbrown 0.16.1", "hexf-parse", "indexmap", "itertools", "libm", "log", "num-traits", "once_cell", "petgraph", "pp-rs", "ron", "rspirv", "rustc-hash", "serde", "spirv", "strum", "thiserror", "unicode-ident", "walkdir", ] [[package]] name = "num-traits" version = "0.2.19" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" dependencies = [ "autocfg", "libm", ] [[package]] name = "once_cell" version = "1.21.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" [[package]] name = "petgraph" version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8701b58ea97060d5e5b155d383a69952a60943f0e6dfe30b04c287beb0b27455" dependencies = [ "fixedbitset", "hashbrown 0.15.5", "indexmap", ] [[package]] name = "pp-rs" version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bb458bb7f6e250e6eb79d5026badc10a3ebb8f9a15d1fff0f13d17c71f4d6dee" dependencies = [ "unicode-xid", ] [[package]] name = "proc-macro2" version = "1.0.106" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" dependencies = [ "unicode-ident", ] [[package]] name = "quote" version = "1.0.45" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" dependencies = [ "proc-macro2", ] [[package]] name = "ron" version = "0.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fd490c5b18261893f14449cbd28cb9c0b637aebf161cd77900bfdedaff21ec32" dependencies = [ "bitflags", "once_cell", "serde", "serde_derive", "typeid", "unicode-ident", ] [[package]] name = "rspirv" version = "0.13.0+sdk-1.4.341.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "091dca2e1d6fd3098417b5ec88e77e80d1ba5945750943419dc976858082c296" dependencies = [ "rustc-hash", "spirv", ] [[package]] name = "rustc-hash" version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" [[package]] name = "same-file" version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" dependencies = [ "winapi-util", ] [[package]] name = "serde" version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" dependencies = [ "serde_core", "serde_derive", ] [[package]] name = "serde_core" version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" dependencies = [ "proc-macro2", "quote", "syn", ] [[package]] name = "spirv" version = "0.4.0+sdk-1.4.341.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9571ea910ebd84c86af4b3ed27f9dbdc6ad06f17c5f96146b2b671e2976744f" dependencies = [ "bitflags", ] [[package]] name = "strum" version = "0.27.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "af23d6f6c1a224baef9d3f61e287d2761385a5b88fdab4eb4c6f11aeb54c4bcf" dependencies = [ "strum_macros", ] [[package]] name = "strum_macros" version = "0.27.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7695ce3845ea4b33927c055a39dc438a45b059f7c1b3d91d38d10355fb8cbca7" dependencies = [ "heck", "proc-macro2", "quote", "syn", ] [[package]] name = "syn" version = "2.0.117" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" dependencies = [ "proc-macro2", "quote", "unicode-ident", ] [[package]] name = "termcolor" version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755" dependencies = [ "winapi-util", ] [[package]] name = "thiserror" version = "2.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" version = "2.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" dependencies = [ "proc-macro2", "quote", "syn", ] [[package]] name = "typeid" version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bc7d623258602320d5c55d1bc22793b57daff0ec7efc270ea7d55ce1d5f5471c" [[package]] name = "unicode-ident" version = "1.0.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" [[package]] name = "unicode-width" version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254" [[package]] name = "unicode-xid" version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" [[package]] name = "walkdir" version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" dependencies = [ "same-file", "winapi-util", ] [[package]] name = "winapi-util" version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" dependencies = [ "windows-sys", ] [[package]] name = "windows-link" version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" [[package]] name = "windows-sys" version = "0.61.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" dependencies = [ "windows-link", ] [[package]] name = "zerocopy" version = "0.8.42" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f2578b716f8a7a858b7f02d5bd870c14bf4ddbbcf3a4c05414ba6503640505e3" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" version = "0.8.42" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7e6cc098ea4d3bd6246687de65af3f920c430e236bee1e3bf2e441463f08a02f" dependencies = [ "proc-macro2", "quote", "syn", ] naga-29.0.3/Cargo.toml0000644000000102241046102023000100540ustar # THIS FILE IS AUTOMATICALLY GENERATED BY CARGO # # When uploading crates to the registry Cargo will automatically # "normalize" Cargo.toml files for maximal compatibility # with all versions of Cargo and also rewrite `path` dependencies # to registry (e.g., crates.io) dependencies. # # If you are reading this file be aware that the original Cargo.toml # will likely look very different (and much more reasonable). # See Cargo.toml.orig for the original contents. [package] edition = "2021" rust-version = "1.87" name = "naga" version = "29.0.3" authors = ["gfx-rs developers"] build = "build.rs" exclude = [ "bin/**/*", "tests/**/*", "Cargo.lock", "target/**/*", ] autolib = false autobins = false autoexamples = false autotests = false autobenches = false description = "Shader translator and validator. Part of the wgpu project" readme = "README.md" keywords = [ "shader", "SPIR-V", "GLSL", "MSL", ] license = "MIT OR Apache-2.0" repository = "https://github.com/gfx-rs/wgpu" [package.metadata.docs.rs] all-features = true [features] arbitrary = [ "dep:arbitrary", "bitflags/arbitrary", "indexmap/arbitrary", "half/arbitrary", "half/std", ] default = [] deserialize = [ "dep:serde", "bitflags/serde", "half/serde", "hashbrown/serde", "indexmap/serde", ] dot-out = [] fs = [] glsl-in = ["dep:pp-rs"] glsl-out = [] hlsl-out = [] hlsl-out-if-target-windows = [] msl-out = [] msl-out-if-target-apple = [] serialize = [ "dep:serde", "bitflags/serde", "half/serde", "hashbrown/serde", "indexmap/serde", ] spv-in = [ "dep:petgraph", "petgraph/graphmap", "dep:spirv", ] spv-out = ["dep:spirv"] stderr = ["codespan-reporting/std"] termcolor = ["codespan-reporting/termcolor"] wgsl-in = [ "dep:hexf-parse", "dep:unicode-ident", ] wgsl-out = [] [lib] name = "naga" path = "src/lib.rs" [dependencies.arbitrary] version = "1.4.2" features = ["derive"] optional = true [dependencies.arrayvec] version = "0.7.1" default-features = false [dependencies.bit-set] version = "0.9" default-features = false [dependencies.bitflags] version = "2.9" [dependencies.cfg-if] version = "1" [dependencies.codespan-reporting] version = "0.13" default-features = false [dependencies.half] version = "2.5" features = ["num-traits"] default-features = false [dependencies.hashbrown] version = "0.16" features = [ "default-hasher", "inline-more", ] default-features = false [dependencies.hexf-parse] version = "0.2" optional = true [dependencies.indexmap] version = "2.11.4" default-features = false [dependencies.libm] version = "0.2.6" default-features = false [dependencies.log] version = "0.4.29" [dependencies.num-traits] version = "0.2.16" default-features = false [dependencies.once_cell] version = "1.21" features = [ "alloc", "race", ] default-features = false [dependencies.petgraph] version = "0.8" optional = true default-features = false [dependencies.pp-rs] version = "0.2.1" optional = true [dependencies.rustc-hash] version = "1.1" default-features = false [dependencies.serde] version = "1.0.225" features = [ "alloc", "derive", ] optional = true default-features = false [dependencies.spirv] version = "0.4" optional = true [dependencies.thiserror] version = "2.0.12" default-features = false [dependencies.unicode-ident] version = "1.0.5" optional = true [dev-dependencies.diff] version = "0.1" [dev-dependencies.env_logger] version = "0.11" default-features = false [dev-dependencies.hashbrown] version = "0.16" features = [ "default-hasher", "inline-more", "serde", ] default-features = false [dev-dependencies.itertools] version = "0.14" [dev-dependencies.ron] version = "0.12" [dev-dependencies.rspirv] version = "0.13" [dev-dependencies.serde] version = "1.0.225" features = [ "default", "derive", ] default-features = false [dev-dependencies.spirv] version = "0.4" [dev-dependencies.strum] version = "0.27.1" features = ["derive"] default-features = false [dev-dependencies.walkdir] version = "2.3" [build-dependencies.cfg_aliases] version = "0.2.1" [lints.clippy] alloc_instead_of_core = "warn" std_instead_of_alloc = "warn" std_instead_of_core = "warn" naga-29.0.3/Cargo.toml.orig000064400000000000000000000103071046102023000135150ustar 00000000000000[package] name = "naga" version.workspace = true authors.workspace = true edition.workspace = true description = "Shader translator and validator. Part of the wgpu project" repository.workspace = true keywords = ["shader", "SPIR-V", "GLSL", "MSL"] license.workspace = true exclude = ["bin/**/*", "tests/**/*", "Cargo.lock", "target/**/*"] # Override the workspace's `rust-version` key. Firefox uses `cargo vendor` to # copy the crates it actually uses out of the workspace, so it's meaningful for # them to have less restrictive MSRVs individually than the workspace as a # whole, if their code permits. See `../README.md` for details. rust-version = "1.87" [package.metadata.docs.rs] all-features = true [features] default = [] dot-out = [] glsl-in = ["dep:pp-rs"] glsl-out = [] ## Enables outputting to the Metal Shading Language (MSL). ## ## This enables MSL output regardless of the target platform. ## If you want to enable it only when targeting iOS/tvOS/watchOS/macOS, use `naga/msl-out-if-target-apple`. msl-out = [] ## Enables outputting to the Metal Shading Language (MSL) only if the target platform is iOS/tvOS/watchOS/macOS. ## ## If you want to enable MSL output it regardless of the target platform, use `naga/msl-out`. msl-out-if-target-apple = [] serialize = [ "dep:serde", "bitflags/serde", "half/serde", "hashbrown/serde", "indexmap/serde", ] deserialize = [ "dep:serde", "bitflags/serde", "half/serde", "hashbrown/serde", "indexmap/serde", ] arbitrary = [ "dep:arbitrary", "bitflags/arbitrary", "indexmap/arbitrary", "half/arbitrary", "half/std", ] spv-in = ["dep:petgraph", "petgraph/graphmap", "dep:spirv"] spv-out = ["dep:spirv"] wgsl-in = ["dep:hexf-parse", "dep:unicode-ident"] wgsl-out = [] ## Enables outputting to HLSL (Microsoft's High-Level Shader Language). ## ## This enables HLSL output regardless of the target platform. ## If you want to enable it only when targeting Windows, use `hlsl-out-if-target-windows`. hlsl-out = [] ## Enables outputting to HLSL (Microsoft's High-Level Shader Language) only if the target platform is Windows. ## ## If you want to enable HLSL output it regardless of the target platform, use `naga/hlsl-out`. hlsl-out-if-target-windows = [] ## Enables colored output through codespan-reporting and termcolor. termcolor = ["codespan-reporting/termcolor"] ## Enables writing output to stderr. stderr = ["codespan-reporting/std"] ## Enables integration with the underlying filesystem. fs = [] [dependencies] arbitrary = { workspace = true, features = ["derive"], optional = true } arrayvec.workspace = true bitflags.workspace = true bit-set.workspace = true cfg-if.workspace = true codespan-reporting = { workspace = true } hashbrown.workspace = true half = { workspace = true, features = ["num-traits"] } rustc-hash.workspace = true indexmap.workspace = true libm = { workspace = true, default-features = false } log.workspace = true num-traits.workspace = true once_cell = { workspace = true, features = ["alloc", "race"] } spirv = { workspace = true, optional = true } thiserror.workspace = true serde = { workspace = true, features = ["alloc", "derive"], optional = true } petgraph = { workspace = true, optional = true } pp-rs = { workspace = true, optional = true } hexf-parse = { workspace = true, optional = true } unicode-ident = { workspace = true, optional = true } [build-dependencies] cfg_aliases.workspace = true [dev-dependencies] diff.workspace = true env_logger.workspace = true hashbrown = { workspace = true, features = ["serde"] } hlsl-snapshots.workspace = true itertools.workspace = true naga-test.workspace = true ron.workspace = true rspirv.workspace = true # So we don't actually need this, however if we remove this, it # brakes calling `--features spirv` at the workspace level. I think # this is because there is a `dep:spirv` in the regular feature set, # so cargo tries to match the feature against that, fails as it's a optional dep, # and then refuses to build instead of ignoring it. spirv.workspace = true serde = { workspace = true, features = ["default", "derive"] } strum = { workspace = true } walkdir.workspace = true [lints.clippy] std_instead_of_alloc = "warn" std_instead_of_core = "warn" alloc_instead_of_core = "warn" naga-29.0.3/LICENSE.APACHE000064400000000000000000000236751046102023000125670ustar 00000000000000 Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS naga-29.0.3/LICENSE.MIT000064400000000000000000000020661046102023000122660ustar 00000000000000MIT License Copyright (c) 2025 The gfx-rs developers Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. naga-29.0.3/README.md000064400000000000000000000100071046102023000121020ustar 00000000000000# Naga [![Matrix](https://img.shields.io/badge/Matrix-%23naga%3Amatrix.org-blueviolet.svg)](https://matrix.to/#/#naga:matrix.org) [![Crates.io](https://img.shields.io/crates/v/naga.svg?label=naga)](https://crates.io/crates/naga) [![Docs.rs](https://docs.rs/naga/badge.svg)](https://docs.rs/naga) [![Build Status](https://github.com/gfx-rs/naga/workflows/pipeline/badge.svg)](https://github.com/gfx-rs/naga/actions) ![MSRV](https://img.shields.io/badge/rustc-1.90-blue.svg) [![codecov.io](https://codecov.io/gh/gfx-rs/naga/branch/master/graph/badge.svg?token=9VOKYO8BM2)](https://codecov.io/gh/gfx-rs/naga) The shader translation library for the needs of [wgpu](https://github.com/gfx-rs/wgpu). ## Supported end-points Front-end | Status | Feature | Notes | --------------- | ------------------ | ------- | ----- | SPIR-V (binary) | :white_check_mark: | spv-in | | WGSL | :white_check_mark: | wgsl-in | Fully validated | GLSL | :ok: | glsl-in | GLSL 440+ and Vulkan semantics only | Back-end | Status | Feature | Notes | --------------- | ------------------ | -------- | ----- | SPIR-V | :white_check_mark: | spv-out | | WGSL | :ok: | wgsl-out | | Metal | :white_check_mark: | msl-out | | HLSL | :white_check_mark: | hlsl-out | Shader Model 5.0+ (DirectX 11+) | GLSL | :ok: | glsl-out | GLSL 330+ and GLSL ES 300+ | AIR | | | | DXIL/DXIR | | | | DXBC | | | | DOT (GraphViz) | :ok: | dot-out | Not a shading language | :white_check_mark: = Primary support — :ok: = Secondary support — :construction: = Unsupported, but support in progress ## Conversion tool Naga can be used as a CLI, which allows testing the conversion of different code paths. First, install `naga-cli` from crates.io or directly from GitHub. ```bash # release version cargo install naga-cli # development version cargo install naga-cli --git https://github.com/gfx-rs/wgpu.git ``` Then, you can run `naga` command. ```bash naga my_shader.wgsl # validate only naga my_shader.spv my_shader.txt # dump the IR module into a file naga my_shader.spv my_shader.metal --flow-dir flow-dir # convert the SPV to Metal, also dump the SPIR-V flow graph to `flow-dir` naga my_shader.wgsl my_shader.vert --profile es310 # convert the WGSL to GLSL vertex stage under ES 3.20 profile ``` As naga includes a default binary target, you can also use `cargo run` without installation. This is useful when you develop naga itself or investigate the behavior of naga at a specific commit (e.g. [wgpu](https://github.com/gfx-rs/wgpu) might pin a different version of naga than the `HEAD` of this repository). ```bash cargo run my_shader.wgsl ``` ## Development workflow The main instrument aiding the development is the good old `cargo test --all-features --workspace`, which will run the unit tests and also update all the snapshots. You'll see these changes in git before committing the code. If working on a particular front-end or back-end, it may be convenient to enable the relevant features in `Cargo.toml`, e.g. ```toml default = ["spv-out"] #TEMP! ``` This allows IDE basic checks to report errors there unless your IDE is sufficiently configurable already. Finally, when changes to the snapshots are made, we should verify that the produced shaders are indeed valid for the target platforms they are compiled for: ```bash cargo xtask validate spv # for Vulkan shaders, requires SPIRV-Tools installed cargo xtask validate msl # for Metal shaders, requires XCode command-line tools installed cargo xtask validate glsl # for OpenGL shaders, requires GLSLang installed cargo xtask validate dot # for dot files, requires GraphViz installed cargo xtask validate wgsl # for WGSL shaders cargo xtask validate hlsl dxc # for HLSL shaders via DXC cargo xtask validate hlsl fxc # for HLSL shaders via FXC ``` naga-29.0.3/build.rs000064400000000000000000000010721046102023000122720ustar 00000000000000fn main() { cfg_aliases::cfg_aliases! { dot_out: { feature = "dot-out" }, glsl_out: { feature = "glsl-out" }, hlsl_out: { any(feature = "hlsl-out", all(target_os = "windows", feature = "hlsl-out-if-target-windows")) }, msl_out: { any(feature = "msl-out", all(target_vendor = "apple", feature = "msl-out-if-target-apple")) }, spv_out: { feature = "spv-out" }, wgsl_out: { feature = "wgsl-out" }, std: { any(test, feature = "wgsl-in", feature = "stderr", feature = "fs") }, no_std: { not(std) }, } } naga-29.0.3/src/arena/handle.rs000064400000000000000000000062511046102023000143070ustar 00000000000000//! Well-typed indices into [`Arena`]s and [`UniqueArena`]s. //! //! This module defines [`Handle`] and related types. //! //! [`Arena`]: super::Arena //! [`UniqueArena`]: super::UniqueArena use core::{cmp::Ordering, fmt, hash, marker::PhantomData}; /// An unique index in the arena array that a handle points to. /// The "non-max" part ensures that an `Option>` has /// the same size and representation as `Handle`. pub type Index = crate::non_max_u32::NonMaxU32; #[derive(Clone, Copy, Debug, thiserror::Error, PartialEq)] #[error("Handle {index} of {kind} is either not present, or inaccessible yet")] pub struct BadHandle { pub kind: &'static str, pub index: usize, } impl BadHandle { pub fn new(handle: Handle) -> Self { Self { kind: core::any::type_name::(), index: handle.index(), } } } /// A strongly typed reference to an arena item. /// /// A `Handle` value can be used as an index into an [`Arena`] or [`UniqueArena`]. /// /// [`Arena`]: super::Arena /// [`UniqueArena`]: super::UniqueArena #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] #[cfg_attr( any(feature = "serialize", feature = "deserialize"), serde(transparent) )] #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] pub struct Handle { index: Index, #[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(skip))] marker: PhantomData, } impl Clone for Handle { fn clone(&self) -> Self { *self } } impl Copy for Handle {} impl PartialEq for Handle { fn eq(&self, other: &Self) -> bool { self.index == other.index } } impl Eq for Handle {} impl PartialOrd for Handle { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } impl Ord for Handle { fn cmp(&self, other: &Self) -> Ordering { self.index.cmp(&other.index) } } impl fmt::Debug for Handle { fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { write!(formatter, "[{}]", self.index) } } impl hash::Hash for Handle { fn hash(&self, hasher: &mut H) { self.index.hash(hasher) } } impl Handle { pub(crate) const fn new(index: Index) -> Self { Handle { index, marker: PhantomData, } } /// Returns the index of this handle. pub const fn index(self) -> usize { self.index.get() as usize } /// Convert a `usize` index into a `Handle`. pub(super) fn from_usize(index: usize) -> Self { let handle_index = u32::try_from(index) .ok() .and_then(Index::new) .expect("Failed to insert into arena. Handle overflows"); Handle::new(handle_index) } /// Write this handle's index to `formatter`, preceded by `prefix`. pub fn write_prefixed( &self, formatter: &mut fmt::Formatter, prefix: &'static str, ) -> fmt::Result { formatter.write_str(prefix)?; ::fmt(&self.index(), formatter) } } naga-29.0.3/src/arena/handle_set.rs000064400000000000000000000071511046102023000151620ustar 00000000000000//! The [`HandleSet`] type and associated definitions. use crate::arena::{Arena, Handle, UniqueArena}; /// A set of `Handle` values. #[derive(Debug)] pub struct HandleSet { /// Bound on indexes of handles stored in this set. len: usize, /// `members[i]` is true if the handle with index `i` is a member. members: bit_set::BitSet, /// This type is indexed by values of type `T`. as_keys: core::marker::PhantomData, } impl HandleSet { /// Return a new, empty `HandleSet`. pub fn new() -> Self { Self { len: 0, members: bit_set::BitSet::new(), as_keys: core::marker::PhantomData, } } pub fn is_empty(&self) -> bool { self.members.is_empty() } /// Return a new, empty `HandleSet`, sized to hold handles from `arena`. pub fn for_arena(arena: &impl ArenaType) -> Self { let len = arena.len(); Self { len, members: bit_set::BitSet::with_capacity(len), as_keys: core::marker::PhantomData, } } /// Remove all members from `self`. pub fn clear(&mut self) { self.members.make_empty(); } /// Remove all members from `self`, and reserve space to hold handles from `arena`. pub fn clear_for_arena(&mut self, arena: &impl ArenaType) { self.members.make_empty(); self.members.reserve_len(arena.len()); } /// Return an iterator over all handles that could be made members /// of this set. pub fn all_possible(&self) -> impl Iterator> { super::Range::full_range_from_size(self.len) } /// Add `handle` to the set. /// /// Return `true` if `handle` was not already present in the set. pub fn insert(&mut self, handle: Handle) -> bool { self.members.insert(handle.index()) } /// Remove `handle` from the set. /// /// Returns `true` if `handle` was present in the set. pub fn remove(&mut self, handle: Handle) -> bool { self.members.remove(handle.index()) } /// Add handles from `iter` to the set. pub fn insert_iter(&mut self, iter: impl IntoIterator>) { for handle in iter { self.insert(handle); } } /// Add all of the handles that can be included in this set. pub fn add_all(&mut self) { self.members.get_mut().fill(true); } pub fn contains(&self, handle: Handle) -> bool { self.members.contains(handle.index()) } /// Return an iterator over all handles in `self`. pub fn iter(&self) -> impl '_ + Iterator> { self.members.iter().map(Handle::from_usize) } /// Removes and returns the numerically largest handle in the set, or `None` /// if the set is empty. pub fn pop(&mut self) -> Option> { let members = core::mem::take(&mut self.members); let mut vec = members.into_bit_vec(); let result = vec.iter_mut().enumerate().rev().find_map(|(i, mut bit)| { if *bit { *bit = false; Some(i) } else { None } }); self.members = bit_set::BitSet::from_bit_vec(vec); result.map(Handle::from_usize) } } impl Default for HandleSet { fn default() -> Self { Self::new() } } pub trait ArenaType { fn len(&self) -> usize; } impl ArenaType for Arena { fn len(&self) -> usize { self.len() } } impl ArenaType for UniqueArena { fn len(&self) -> usize { self.len() } } naga-29.0.3/src/arena/handlevec.rs000064400000000000000000000054251046102023000150070ustar 00000000000000//! The [`HandleVec`] type and associated definitions. use super::handle::Handle; use alloc::{vec, vec::Vec}; use core::marker::PhantomData; use core::ops; /// A [`Vec`] indexed by [`Handle`]s. /// /// A `HandleVec` is a [`Vec`] indexed by values of type `Handle`, /// rather than `usize`. /// /// Rather than a `push` method, `HandleVec` has an [`insert`] method, analogous /// to [`HashMap::insert`], that requires you to provide the handle at which the /// new value should appear. However, since `HandleVec` only supports insertion /// at the end, the given handle's index must be equal to the `HandleVec`'s /// current length; otherwise, the insertion will panic. /// /// [`insert`]: HandleVec::insert /// [`HashMap::insert`]: hashbrown::HashMap::insert #[derive(Debug)] pub(crate) struct HandleVec { inner: Vec, as_keys: PhantomData, } impl Default for HandleVec { fn default() -> Self { Self { inner: vec![], as_keys: PhantomData, } } } #[allow(dead_code)] impl HandleVec { pub(crate) const fn new() -> Self { Self { inner: vec![], as_keys: PhantomData, } } pub(crate) fn with_capacity(capacity: usize) -> Self { Self { inner: Vec::with_capacity(capacity), as_keys: PhantomData, } } pub(crate) const fn len(&self) -> usize { self.inner.len() } /// Insert a mapping from `handle` to `value`. /// /// Unlike a [`HashMap`], a `HandleVec` can only have new entries inserted at /// the end, like [`Vec::push`]. So the index of `handle` must equal /// [`self.len()`]. /// /// [`HashMap`]: hashbrown::HashMap /// [`self.len()`]: HandleVec::len pub(crate) fn insert(&mut self, handle: Handle, value: U) { assert_eq!(handle.index(), self.inner.len()); self.inner.push(value); } pub(crate) fn get(&self, handle: Handle) -> Option<&U> { self.inner.get(handle.index()) } pub(crate) fn clear(&mut self) { self.inner.clear() } pub(crate) fn resize(&mut self, len: usize, fill: U) where U: Clone, { self.inner.resize(len, fill); } pub(crate) fn iter(&self) -> impl Iterator { self.inner.iter() } pub(crate) fn iter_mut(&mut self) -> impl Iterator { self.inner.iter_mut() } } impl ops::Index> for HandleVec { type Output = U; fn index(&self, handle: Handle) -> &Self::Output { &self.inner[handle.index()] } } impl ops::IndexMut> for HandleVec { fn index_mut(&mut self, handle: Handle) -> &mut Self::Output { &mut self.inner[handle.index()] } } naga-29.0.3/src/arena/mod.rs000064400000000000000000000244311046102023000136330ustar 00000000000000/*! The [`Arena`], [`UniqueArena`], and [`Handle`] types. To improve translator performance and reduce memory usage, most structures are stored in an [`Arena`]. An `Arena` stores a series of `T` values, indexed by [`Handle`] values, which are just wrappers around integer indexes. For example, a `Function`'s expressions are stored in an `Arena`, and compound expressions refer to their sub-expressions via `Handle` values. A [`UniqueArena`] is just like an `Arena`, except that it stores only a single instance of each value. The value type must implement `Eq` and `Hash`. Like an `Arena`, inserting a value into a `UniqueArena` returns a `Handle` which can be used to efficiently access the value, without a hash lookup. Inserting a value multiple times returns the same `Handle`. If the `span` feature is enabled, both `Arena` and `UniqueArena` can associate a source code span with each element. [`Handle`]: Handle */ mod handle; mod handle_set; mod handlevec; mod range; mod unique_arena; pub use handle::{BadHandle, Handle}; pub(crate) use handle_set::HandleSet; pub(crate) use handlevec::HandleVec; pub use range::{BadRangeError, Range}; pub use unique_arena::UniqueArena; use alloc::vec::Vec; use core::{fmt, ops}; use crate::Span; use handle::Index; /// An arena holding some kind of component (e.g., type, constant, /// instruction, etc.) that can be referenced. /// /// Adding new items to the arena produces a strongly-typed [`Handle`]. /// The arena can be indexed using the given handle to obtain /// a reference to the stored item. #[derive(Clone)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "serialize", serde(transparent))] #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] #[cfg_attr(test, derive(PartialEq))] pub struct Arena { /// Values of this arena. data: Vec, #[cfg_attr(feature = "serialize", serde(skip))] span_info: Vec, } impl Default for Arena { fn default() -> Self { Self::new() } } impl fmt::Debug for Arena { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_map().entries(self.iter()).finish() } } impl Arena { /// Create a new arena with no initial capacity allocated. pub const fn new() -> Self { Arena { data: Vec::new(), span_info: Vec::new(), } } /// Extracts the inner vector. pub fn into_inner(self) -> Vec { self.data } /// Returns the current number of items stored in this arena. pub const fn len(&self) -> usize { self.data.len() } /// Returns `true` if the arena contains no elements. pub const fn is_empty(&self) -> bool { self.data.is_empty() } /// Returns an iterator over the items stored in this arena, returning both /// the item's handle and a reference to it. pub fn iter(&self) -> impl DoubleEndedIterator, &T)> + ExactSizeIterator { self.data .iter() .enumerate() .map(|(i, v)| (Handle::from_usize(i), v)) } /// Returns an iterator over the items stored in this arena, returning both /// the item's handle and a reference to it. pub fn iter_mut_span( &mut self, ) -> impl DoubleEndedIterator, &mut T, &Span)> + ExactSizeIterator { self.data .iter_mut() .zip(self.span_info.iter()) .enumerate() .map(|(i, (v, span))| (Handle::from_usize(i), v, span)) } /// Drains the arena, returning an iterator over the items stored. pub fn drain(&mut self) -> impl DoubleEndedIterator, T, Span)> { let arena = core::mem::take(self); arena .data .into_iter() .zip(arena.span_info) .enumerate() .map(|(i, (v, span))| (Handle::from_usize(i), v, span)) } /// Returns a iterator over the items stored in this arena, /// returning both the item's handle and a mutable reference to it. pub fn iter_mut(&mut self) -> impl DoubleEndedIterator, &mut T)> { self.data .iter_mut() .enumerate() .map(|(i, v)| (Handle::from_usize(i), v)) } /// Adds a new value to the arena, returning a typed handle. pub fn append(&mut self, value: T, span: Span) -> Handle { let index = self.data.len(); self.data.push(value); self.span_info.push(span); Handle::from_usize(index) } /// Fetch a handle to an existing type. pub fn fetch_if bool>(&self, fun: F) -> Option> { self.data .iter() .position(fun) .map(|index| Handle::from_usize(index)) } /// Adds a value with a custom check for uniqueness: /// returns a handle pointing to /// an existing element if the check succeeds, or adds a new /// element otherwise. pub fn fetch_if_or_append bool>( &mut self, value: T, span: Span, fun: F, ) -> Handle { if let Some(index) = self.data.iter().position(|d| fun(d, &value)) { Handle::from_usize(index) } else { self.append(value, span) } } /// Adds a value with a check for uniqueness, where the check is plain comparison. pub fn fetch_or_append(&mut self, value: T, span: Span) -> Handle where T: PartialEq, { self.fetch_if_or_append(value, span, T::eq) } pub fn try_get(&self, handle: Handle) -> Result<&T, BadHandle> { self.data .get(handle.index()) .ok_or_else(|| BadHandle::new(handle)) } /// Get a mutable reference to an element in the arena. pub fn get_mut(&mut self, handle: Handle) -> &mut T { self.data.get_mut(handle.index()).unwrap() } /// Get the range of handles from a particular number of elements to the end. pub fn range_from(&self, old_length: usize) -> Range { let range = old_length as u32..self.data.len() as u32; Range::from_index_range(range, self) } /// Clears the arena keeping all allocations pub fn clear(&mut self) { self.data.clear() } pub fn get_span(&self, handle: Handle) -> Span { *self .span_info .get(handle.index()) .unwrap_or(&Span::default()) } /// Assert that `handle` is valid for this arena. pub fn check_contains_handle(&self, handle: Handle) -> Result<(), BadHandle> { if handle.index() < self.data.len() { Ok(()) } else { Err(BadHandle::new(handle)) } } /// Assert that `range` is valid for this arena. pub fn check_contains_range(&self, range: &Range) -> Result<(), BadRangeError> { // Since `range.inner` is a `Range`, we only need to check that the // start precedes the end, and that the end is in range. if range.inner.start > range.inner.end { return Err(BadRangeError::new(range.clone())); } // Empty ranges are tolerated: they can be produced by compaction. if range.inner.start == range.inner.end { return Ok(()); } let last_handle = Handle::new(Index::new(range.inner.end - 1).unwrap()); if self.check_contains_handle(last_handle).is_err() { return Err(BadRangeError::new(range.clone())); } Ok(()) } pub(crate) fn retain_mut

(&mut self, mut predicate: P) where P: FnMut(Handle, &mut T) -> bool, { let mut index = 0; let mut retained = 0; self.data.retain_mut(|elt| { let handle = Handle::from_usize(index); let keep = predicate(handle, elt); // Since `predicate` needs mutable access to each element, // we can't feasibly call it twice, so we have to compact // spans by hand in parallel as part of this iteration. if keep { self.span_info[retained] = self.span_info[index]; retained += 1; } index += 1; keep }); self.span_info.truncate(retained); } } #[cfg(feature = "deserialize")] impl<'de, T> serde::Deserialize<'de> for Arena where T: serde::Deserialize<'de>, { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { let data = Vec::deserialize(deserializer)?; let span_info = core::iter::repeat_n(Span::default(), data.len()).collect(); Ok(Self { data, span_info }) } } impl ops::Index> for Arena { type Output = T; fn index(&self, handle: Handle) -> &T { &self.data[handle.index()] } } impl ops::IndexMut> for Arena { fn index_mut(&mut self, handle: Handle) -> &mut T { &mut self.data[handle.index()] } } impl ops::Index> for Arena { type Output = [T]; fn index(&self, range: Range) -> &[T] { &self.data[range.inner.start as usize..range.inner.end as usize] } } #[cfg(test)] mod tests { use super::*; #[test] fn append_non_unique() { let mut arena: Arena = Arena::new(); let t1 = arena.append(0, Default::default()); let t2 = arena.append(0, Default::default()); assert!(t1 != t2); assert!(arena[t1] == arena[t2]); } #[test] fn append_unique() { let mut arena: Arena = Arena::new(); let t1 = arena.append(0, Default::default()); let t2 = arena.append(1, Default::default()); assert!(t1 != t2); assert!(arena[t1] != arena[t2]); } #[test] fn fetch_or_append_non_unique() { let mut arena: Arena = Arena::new(); let t1 = arena.fetch_or_append(0, Default::default()); let t2 = arena.fetch_or_append(0, Default::default()); assert!(t1 == t2); assert!(arena[t1] == arena[t2]) } #[test] fn fetch_or_append_unique() { let mut arena: Arena = Arena::new(); let t1 = arena.fetch_or_append(0, Default::default()); let t2 = arena.fetch_or_append(1, Default::default()); assert!(t1 != t2); assert!(arena[t1] != arena[t2]); } } naga-29.0.3/src/arena/range.rs000064400000000000000000000101601046102023000141420ustar 00000000000000//! Well-typed ranges of [`Arena`]s. //! //! This module defines the [`Range`] type, representing a contiguous range of //! entries in an [`Arena`]. //! //! [`Arena`]: super::Arena use core::{fmt, marker::PhantomData, ops}; use super::{ handle::{Handle, Index}, Arena, }; /// A strongly typed range of handles. #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] #[cfg_attr( any(feature = "serialize", feature = "deserialize"), serde(transparent) )] #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] #[cfg_attr(test, derive(PartialEq))] pub struct Range { pub(super) inner: ops::Range, #[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(skip))] marker: PhantomData, } impl Range { pub(crate) const fn erase_type(self) -> Range<()> { let Self { inner, marker: _ } = self; Range { inner, marker: PhantomData, } } } // NOTE: Keep this diagnostic in sync with that of [`BadHandle`]. #[derive(Clone, Debug, thiserror::Error)] #[cfg_attr(test, derive(PartialEq))] #[error("Handle range {range:?} of {kind} is either not present, or inaccessible yet")] pub struct BadRangeError { // This error is used for many `Handle` types, but there's no point in making this generic, so // we just flatten them all to `Handle<()>` here. kind: &'static str, range: Range<()>, } impl BadRangeError { pub fn new(range: Range) -> Self { Self { kind: core::any::type_name::(), range: range.erase_type(), } } } impl Clone for Range { fn clone(&self) -> Self { Range { inner: self.inner.clone(), marker: self.marker, } } } impl fmt::Debug for Range { fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { write!(formatter, "[{}..{}]", self.inner.start, self.inner.end) } } impl Iterator for Range { type Item = Handle; fn next(&mut self) -> Option { if self.inner.start < self.inner.end { let next = self.inner.start; self.inner.start += 1; Some(Handle::new(Index::new(next).unwrap())) } else { None } } } impl Range { /// Return a range enclosing handles `first` through `last`, inclusive. pub fn new_from_bounds(first: Handle, last: Handle) -> Self { Self { inner: (first.index() as u32)..(last.index() as u32 + 1), marker: Default::default(), } } /// Return a range covering all handles with indices from `0` to `size`. pub(super) fn full_range_from_size(size: usize) -> Self { Self { inner: 0..size as u32, marker: Default::default(), } } /// return the first and last handles included in `self`. /// /// If `self` is an empty range, there are no handles included, so /// return `None`. pub const fn first_and_last(&self) -> Option<(Handle, Handle)> { if self.inner.start < self.inner.end { Some(( // `Range::new_from_bounds` expects a start- and end-inclusive // range, but `self.inner` is an end-exclusive range. Handle::new(Index::new(self.inner.start).unwrap()), Handle::new(Index::new(self.inner.end - 1).unwrap()), )) } else { None } } /// Return the index range covered by `self`. pub fn index_range(&self) -> ops::Range { self.inner.clone() } /// Construct a `Range` that covers the indices in `inner`. pub fn from_index_range(inner: ops::Range, arena: &Arena) -> Self { // Since `inner` is a `Range`, we only need to check that // the start and end are well-ordered, and that the end fits // within `arena`. assert!(inner.start <= inner.end); assert!(inner.end as usize <= arena.len()); Self { inner, marker: Default::default(), } } } naga-29.0.3/src/arena/unique_arena.rs000064400000000000000000000177321046102023000155360ustar 00000000000000//! The [`UniqueArena`] type and supporting definitions. use alloc::vec::Vec; use core::{fmt, hash, ops}; use super::handle::{BadHandle, Handle, Index}; use crate::{FastIndexSet, Span}; /// An arena whose elements are guaranteed to be unique. /// /// A `UniqueArena` holds a set of unique values of type `T`, each with an /// associated [`Span`]. Inserting a value returns a `Handle`, which can be /// used to index the `UniqueArena` and obtain shared access to the `T` element. /// Access via a `Handle` is an array lookup - no hash lookup is necessary. /// /// The element type must implement `Eq` and `Hash`. Insertions of equivalent /// elements, according to `Eq`, all return the same `Handle`. /// /// Once inserted, elements generally may not be mutated, although a `replace` /// method exists to support rare cases. /// /// `UniqueArena` is similar to [`Arena`]: If `Arena` is vector-like, /// `UniqueArena` is `HashSet`-like. /// /// [`Arena`]: super::Arena #[derive(Clone)] pub struct UniqueArena { set: FastIndexSet, /// Spans for the elements, indexed by handle. /// /// The length of this vector is always equal to `set.len()`. `FastIndexSet` /// promises that its elements "are indexed in a compact range, without /// holes in the range 0..set.len()", so we can always use the indices /// returned by insertion as indices into this vector. span_info: Vec, } impl UniqueArena { /// Create a new arena with no initial capacity allocated. pub fn new() -> Self { UniqueArena { set: FastIndexSet::default(), span_info: Vec::new(), } } /// Return the current number of items stored in this arena. pub fn len(&self) -> usize { self.set.len() } /// Return `true` if the arena contains no elements. pub fn is_empty(&self) -> bool { self.set.is_empty() } /// Clears the arena, keeping all allocations. pub fn clear(&mut self) { self.set.clear(); self.span_info.clear(); } /// Return the span associated with `handle`. /// /// If a value has been inserted multiple times, the span returned is the /// one provided with the first insertion. pub fn get_span(&self, handle: Handle) -> Span { *self .span_info .get(handle.index()) .unwrap_or(&Span::default()) } pub(crate) fn drain_all(&mut self) -> UniqueArenaDrain<'_, T> { UniqueArenaDrain { inner_elts: self.set.drain(..), inner_spans: self.span_info.drain(..), index: Index::new(0).unwrap(), } } } pub struct UniqueArenaDrain<'a, T> { inner_elts: indexmap::set::Drain<'a, T>, inner_spans: alloc::vec::Drain<'a, Span>, index: Index, } impl Iterator for UniqueArenaDrain<'_, T> { type Item = (Handle, T, Span); fn next(&mut self) -> Option { match self.inner_elts.next() { Some(elt) => { let handle = Handle::new(self.index); self.index = self.index.checked_add(1).unwrap(); let span = self.inner_spans.next().unwrap(); Some((handle, elt, span)) } None => None, } } } impl UniqueArena { /// Returns an iterator over the items stored in this arena, returning both /// the item's handle and a reference to it. pub fn iter(&self) -> impl DoubleEndedIterator, &T)> + ExactSizeIterator { self.set.iter().enumerate().map(|(i, v)| { let index = Index::new(i as u32).unwrap(); (Handle::new(index), v) }) } /// Insert a new value into the arena. /// /// Return a [`Handle`], which can be used to index this arena to get a /// shared reference to the element. /// /// If this arena already contains an element that is `Eq` to `value`, /// return a `Handle` to the existing element, and drop `value`. /// /// If `value` is inserted into the arena, associate `span` with /// it. An element's span can be retrieved with the [`get_span`] /// method. /// /// [`Handle`]: Handle /// [`get_span`]: UniqueArena::get_span pub fn insert(&mut self, value: T, span: Span) -> Handle { let (index, added) = self.set.insert_full(value); if added { debug_assert!(index == self.span_info.len()); self.span_info.push(span); } debug_assert!(self.set.len() == self.span_info.len()); Handle::from_usize(index) } /// Replace an old value with a new value. /// /// # Panics /// /// - if the old value is not in the arena /// - if the new value already exists in the arena pub fn replace(&mut self, old: Handle, new: T) { let (index, added) = self.set.insert_full(new); assert!(added && index == self.set.len() - 1); self.set.swap_remove_index(old.index()).unwrap(); } /// Return this arena's handle for `value`, if present. /// /// If this arena already contains an element equal to `value`, /// return its handle. Otherwise, return `None`. pub fn get(&self, value: &T) -> Option> { self.set .get_index_of(value) .map(|index| Handle::from_usize(index)) } /// Return this arena's value at `handle`, if that is a valid handle. pub fn get_handle(&self, handle: Handle) -> Result<&T, BadHandle> { self.set .get_index(handle.index()) .ok_or_else(|| BadHandle::new(handle)) } /// Assert that `handle` is valid for this arena. pub fn check_contains_handle(&self, handle: Handle) -> Result<(), BadHandle> { if handle.index() < self.set.len() { Ok(()) } else { Err(BadHandle::new(handle)) } } } impl Default for UniqueArena { fn default() -> Self { Self::new() } } impl fmt::Debug for UniqueArena { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_map().entries(self.iter()).finish() } } impl ops::Index> for UniqueArena { type Output = T; fn index(&self, handle: Handle) -> &T { &self.set[handle.index()] } } #[cfg(feature = "serialize")] impl serde::Serialize for UniqueArena where T: Eq + hash::Hash + serde::Serialize, { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer, { self.set.serialize(serializer) } } #[cfg(feature = "deserialize")] impl<'de, T> serde::Deserialize<'de> for UniqueArena where T: Eq + hash::Hash + serde::Deserialize<'de>, { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { let set = FastIndexSet::deserialize(deserializer)?; let span_info = core::iter::repeat_n(Span::default(), set.len()).collect(); Ok(Self { set, span_info }) } } //Note: largely borrowed from `HashSet` implementation #[cfg(feature = "arbitrary")] impl<'a, T> arbitrary::Arbitrary<'a> for UniqueArena where T: Eq + hash::Hash + arbitrary::Arbitrary<'a>, { fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result { let mut arena = Self::default(); for elem in u.arbitrary_iter()? { arena.set.insert(elem?); arena.span_info.push(Span::UNDEFINED); } Ok(arena) } fn arbitrary_take_rest(u: arbitrary::Unstructured<'a>) -> arbitrary::Result { let mut arena = Self::default(); for elem in u.arbitrary_take_rest_iter()? { arena.set.insert(elem?); arena.span_info.push(Span::UNDEFINED); } Ok(arena) } #[inline] fn size_hint(depth: usize) -> (usize, Option) { let depth_hint = ::size_hint(depth); arbitrary::size_hint::and(depth_hint, (0, None)) } } naga-29.0.3/src/back/continue_forward.rs000064400000000000000000000310011046102023000162250ustar 00000000000000//! Workarounds for platform bugs and limitations in switches and loops. //! //! In these docs, we use CamelCase links for Naga IR concepts, and ordinary //! `code` formatting for HLSL or GLSL concepts. //! //! ## Avoiding `continue` within `switch` //! //! As described in , the FXC HLSL //! compiler doesn't allow `continue` statements within `switch` statements, but //! Naga IR does. We work around this by introducing synthetic boolean local //! variables and branches. //! //! Specifically: //! //! - We generate code for [`Continue`] statements within [`SwitchCase`]s that //! sets an introduced `bool` local to `true` and does a `break`, jumping to //! immediately after the generated `switch`. //! //! - When generating code for a [`Switch`] statement, we conservatively assume //! it might contain such a [`Continue`] statement, so: //! //! - If it's the outermost such [`Switch`] within a [`Loop`], we declare the //! `bool` local ahead of the switch, initialized to `false`. Immediately //! after the `switch`, we check the local and do a `continue` if it's set. //! //! - If the [`Switch`] is nested within other [`Switch`]es, then after the //! generated `switch`, we check the local (which we know was declared //! before the surrounding `switch`) and do a `break` if it's set. //! //! - As an optimization, we only generate the check of the local if a //! [`Continue`] statement is encountered within the [`Switch`]. This may //! help drivers more easily identify that the `bool` is unused. //! //! So while we "weaken" the [`Continue`] statement by rendering it as a `break` //! statement, we also place checks immediately at the locations to which those //! `break` statements will jump, until we can be sure we've reached the //! intended target of the original [`Continue`]. //! //! In the case of nested [`Loop`] and [`Switch`] statements, there may be //! multiple introduced `bool` locals in scope, but there's no problem knowing //! which one to operate on. At any point, there is at most one [`Loop`] //! statement that could be targeted by a [`Continue`] statement, so the correct //! `bool` local to set and test is always the one introduced for the innermost //! enclosing [`Loop`]'s outermost [`Switch`]. //! //! # Avoiding single body `switch` statements //! //! As described in , some language //! front ends miscompile `switch` statements where all cases branch to the same //! body. Our HLSL and GLSL backends render [`Switch`] statements with a single //! [`SwitchCase`] as `do {} while(false);` loops. //! //! However, this rewriting introduces a new loop that could "capture" //! `continue` statements in its body. To avoid doing so, we apply the //! [`Continue`]-to-`break` transformation described above. //! //! [`Continue`]: crate::Statement::Continue //! [`Loop`]: crate::Statement::Loop //! [`Switch`]: crate::Statement::Switch //! [`SwitchCase`]: crate::SwitchCase use alloc::{rc::Rc, string::String, vec::Vec}; use crate::proc::Namer; /// A summary of the code surrounding a statement. enum Nesting { /// Currently nested in at least one [`Loop`] statement. /// /// [`Continue`] should apply to the innermost loop. /// /// When this entry is on the top of the stack: /// /// * When entering an inner [`Loop`] statement, push a [`Loop`][nl] state /// onto the stack. /// /// * When entering a nested [`Switch`] statement, push a [`Switch`][ns] /// state onto the stack with a new variable name. Before the generated /// `switch`, introduce a `bool` local with that name, initialized to /// `false`. /// /// When exiting the [`Loop`] for which this entry was pushed, pop it from /// the stack. /// /// [`Continue`]: crate::Statement::Continue /// [`Loop`]: crate::Statement::Loop /// [`Switch`]: crate::Statement::Switch /// [ns]: Nesting::Switch /// [nl]: Nesting::Loop Loop, /// Currently nested in at least one [`Switch`] that may need to forward /// [`Continue`]s. /// /// This includes [`Switch`]es rendered as `do {} while(false)` loops, but /// doesn't need to include regular [`Switch`]es in backends that can /// support `continue` within switches. /// /// [`Continue`] should be forwarded to the innermost surrounding [`Loop`]. /// /// When this entry is on the top of the stack: /// /// * When entering a nested [`Loop`], push a [`Loop`][nl] state onto the /// stack. /// /// * When entering a nested [`Switch`], push a [`Switch`][ns] state onto /// the stack with a clone of the introduced `bool` variable's name. /// /// * When encountering a [`Continue`] statement, render it as code to set /// the introduced `bool` local (whose name is held in [`variable`]) to /// `true`, and then `break`. Set [`continue_encountered`] to `true` to /// record that the [`Switch`] contains a [`Continue`]. /// /// * When exiting this [`Switch`], pop its entry from the stack. If /// [`continue_encountered`] is set, then we have rendered [`Continue`] /// statements as `break` statements that jump to this point. Generate /// code to check `variable`, and if it is `true`: /// /// * If there is another [`Switch`][ns] left on top of the stack, set /// its `continue_encountered` flag, and generate a `break`. (Both /// [`Switch`][ns]es are within the same [`Loop`] and share the same /// introduced variable, so there's no need to set another flag to /// continue to exit the `switch`es.) /// /// * Otherwise, `continue`. /// /// When we exit the [`Switch`] for which this entry was pushed, pop it. /// /// [`Continue`]: crate::Statement::Continue /// [`Loop`]: crate::Statement::Loop /// [`Switch`]: crate::Statement::Switch /// [`variable`]: Nesting::Switch::variable /// [`continue_encountered`]: Nesting::Switch::continue_encountered /// [ns]: Nesting::Switch /// [nl]: Nesting::Loop Switch { variable: Rc, /// Set if we've generated code for a [`Continue`] statement with this /// entry on the top of the stack. /// /// If this is still clear when we finish rendering the [`Switch`], then /// we know we don't need to generate branch forwarding code. Omitting /// that may make it easier for drivers to tell that the `bool` we /// introduced ahead of the [`Switch`] is actually unused. /// /// [`Continue`]: crate::Statement::Continue /// [`Switch`]: crate::Statement::Switch continue_encountered: bool, }, } /// A micro-IR for code a backend should generate after a [`Switch`]. /// /// [`Switch`]: crate::Statement::Switch pub(super) enum ExitControlFlow { None, /// Emit `if (continue_variable) { continue; }` Continue { variable: Rc, }, /// Emit `if (continue_variable) { break; }` /// /// Used after a [`Switch`] to exit from an enclosing [`Switch`]. /// /// After the enclosing switch, its associated check will consult this same /// variable, see that it is set, and exit early. /// /// [`Switch`]: crate::Statement::Switch Break { variable: Rc, }, } /// Utility for tracking nesting of loops and switches to orchestrate forwarding /// of continue statements inside of a switch to the enclosing loop. /// /// See [module docs](self) for why we need this. #[derive(Default)] pub(super) struct ContinueCtx { stack: Vec, } impl ContinueCtx { /// Resets internal state. /// /// Use this to reuse memory between writing sessions. #[allow(dead_code, reason = "only used by some backends")] pub fn clear(&mut self) { self.stack.clear(); } /// Updates internal state to record entering a [`Loop`] statement. /// /// [`Loop`]: crate::Statement::Loop pub fn enter_loop(&mut self) { self.stack.push(Nesting::Loop); } /// Updates internal state to record exiting a [`Loop`] statement. /// /// [`Loop`]: crate::Statement::Loop pub fn exit_loop(&mut self) { if !matches!(self.stack.pop(), Some(Nesting::Loop)) { unreachable!("ContinueCtx stack out of sync"); } } /// Updates internal state to record entering a [`Switch`] statement. /// /// Return `Some(variable)` if this [`Switch`] is nested within a [`Loop`], /// and the caller should introcue a new `bool` local variable named /// `variable` above the `switch`, for forwarding [`Continue`] statements. /// /// `variable` is guaranteed not to conflict with any names used by the /// program itself. /// /// [`Continue`]: crate::Statement::Continue /// [`Loop`]: crate::Statement::Loop /// [`Switch`]: crate::Statement::Switch pub fn enter_switch(&mut self, namer: &mut Namer) -> Option> { match self.stack.last() { // If the stack is empty, we are not in loop, so we don't need to // forward continue statements within this `Switch`. We can leave // the stack empty. None => None, Some(&Nesting::Loop) => { let variable = Rc::new(namer.call("should_continue")); self.stack.push(Nesting::Switch { variable: Rc::clone(&variable), continue_encountered: false, }); Some(variable) } Some(&Nesting::Switch { ref variable, .. }) => { self.stack.push(Nesting::Switch { variable: Rc::clone(variable), continue_encountered: false, }); // We have already declared the variable before some enclosing // `Switch`. None } } } /// Update internal state to record leaving a [`Switch`] statement. /// /// Return an [`ExitControlFlow`] value indicating what code should be /// introduced after the generated `switch` to forward continues. /// /// [`Switch`]: crate::Statement::Switch pub fn exit_switch(&mut self) -> ExitControlFlow { match self.stack.pop() { // This doesn't indicate a problem: we don't start pushing entries // for `Switch` statements unless we have an enclosing `Loop`. None => ExitControlFlow::None, Some(Nesting::Loop) => { unreachable!("Unexpected loop state when exiting switch"); } Some(Nesting::Switch { variable, continue_encountered: inner_continue, }) => { if !inner_continue { // No `Continue` statement was encountered, so we didn't // introduce any `break`s jumping to this point. ExitControlFlow::None } else if let Some(&mut Nesting::Switch { continue_encountered: ref mut outer_continue, .. }) = self.stack.last_mut() { // This is nested in another `Switch`. Propagate upwards // that there is a continue statement present. *outer_continue = true; ExitControlFlow::Break { variable } } else { ExitControlFlow::Continue { variable } } } } } /// Determine what to generate for a [`Continue`] statement. /// /// If we can generate an ordinary `continue` statement, return `None`. /// /// Otherwise, we're enclosed by a [`Switch`] that is itself enclosed by a /// [`Loop`]. Return `Some(variable)` to indicate that the [`Continue`] /// should be rendered as setting `variable` to `true`, and then doing a /// `break`. /// /// This also notes that we've encountered a [`Continue`] statement, so that /// we can generate the right code to forward the branch following the /// enclosing `switch`. /// /// [`Continue`]: crate::Statement::Continue /// [`Loop`]: crate::Statement::Loop /// [`Switch`]: crate::Statement::Switch pub fn continue_encountered(&mut self) -> Option<&str> { if let Some(&mut Nesting::Switch { ref variable, ref mut continue_encountered, }) = self.stack.last_mut() { *continue_encountered = true; Some(variable) } else { None } } } naga-29.0.3/src/back/dot/mod.rs000064400000000000000000001033161046102023000142330ustar 00000000000000/*! Backend for [DOT][dot] (Graphviz). This backend writes a graph in the DOT language, for the ease of IR inspection and debugging. [dot]: https://graphviz.org/doc/info/lang.html */ use alloc::{ borrow::Cow, format, string::{String, ToString}, vec::Vec, }; use core::fmt::{Error as FmtError, Write as _}; use crate::{ arena::Handle, valid::{FunctionInfo, ModuleInfo}, }; /// Configuration options for the dot backend #[derive(Clone, Default)] pub struct Options { /// Only emit function bodies pub cfg_only: bool, } /// Identifier used to address a graph node type NodeId = usize; /// Stores the target nodes for control flow statements #[derive(Default, Clone, Copy)] struct Targets { /// The node, if some, where continue operations will land continue_target: Option, /// The node, if some, where break operations will land break_target: Option, } /// Stores information about the graph of statements #[derive(Default)] struct StatementGraph { /// List of node names nodes: Vec<&'static str>, /// List of edges of the control flow, the items are defined as /// (from, to, label) flow: Vec<(NodeId, NodeId, &'static str)>, /// List of implicit edges of the control flow, used for jump /// operations such as continue or break, the items are defined as /// (from, to, label, color_id) jumps: Vec<(NodeId, NodeId, &'static str, usize)>, /// List of dependency relationships between a statement node and /// expressions dependencies: Vec<(NodeId, Handle, &'static str)>, /// List of expression emitted by statement node emits: Vec<(NodeId, Handle)>, /// List of function call by statement node calls: Vec<(NodeId, Handle)>, } impl StatementGraph { /// Adds a new block to the statement graph, returning the first and last node, respectively fn add(&mut self, block: &[crate::Statement], targets: Targets) -> (NodeId, NodeId) { use crate::Statement as S; // The first node of the block isn't a statement but a virtual node let root = self.nodes.len(); self.nodes.push(if root == 0 { "Root" } else { "Node" }); // Track the last placed node, this will be returned to the caller and // will also be used to generate the control flow edges let mut last_node = root; for statement in block { // Reserve a new node for the current statement and link it to the // node of the previous statement let id = self.nodes.len(); self.flow.push((last_node, id, "")); self.nodes.push(""); // reserve space // Track the node identifier for the merge node, the merge node is // the last node of a statement, normally this is the node itself, // but for control flow statements such as `if`s and `switch`s this // is a virtual node where all branches merge back. let mut merge_id = id; self.nodes[id] = match *statement { S::Emit(ref range) => { for handle in range.clone() { self.emits.push((id, handle)); } "Emit" } S::Kill => "Kill", //TODO: link to the beginning S::Break => { // Try to link to the break target, otherwise produce // a broken connection if let Some(target) = targets.break_target { self.jumps.push((id, target, "Break", 5)) } else { self.jumps.push((id, root, "Broken", 7)) } "Break" } S::Continue => { // Try to link to the continue target, otherwise produce // a broken connection if let Some(target) = targets.continue_target { self.jumps.push((id, target, "Continue", 5)) } else { self.jumps.push((id, root, "Broken", 7)) } "Continue" } S::ControlBarrier(_flags) => "ControlBarrier", S::MemoryBarrier(_flags) => "MemoryBarrier", S::Block(ref b) => { let (other, last) = self.add(b, targets); self.flow.push((id, other, "")); // All following nodes should connect to the end of the block // statement so change the merge id to it. merge_id = last; "Block" } S::If { condition, ref accept, ref reject, } => { self.dependencies.push((id, condition, "condition")); let (accept_id, accept_last) = self.add(accept, targets); self.flow.push((id, accept_id, "accept")); let (reject_id, reject_last) = self.add(reject, targets); self.flow.push((id, reject_id, "reject")); // Create a merge node, link the branches to it and set it // as the merge node to make the next statement node link to it merge_id = self.nodes.len(); self.nodes.push("Merge"); self.flow.push((accept_last, merge_id, "")); self.flow.push((reject_last, merge_id, "")); "If" } S::Switch { selector, ref cases, } => { self.dependencies.push((id, selector, "selector")); // Create a merge node and set it as the merge node to make // the next statement node link to it merge_id = self.nodes.len(); self.nodes.push("Merge"); // Create a new targets structure and set the break target // to the merge node let mut targets = targets; targets.break_target = Some(merge_id); for case in cases { let (case_id, case_last) = self.add(&case.body, targets); let label = match case.value { crate::SwitchValue::Default => "default", _ => "case", }; self.flow.push((id, case_id, label)); // Link the last node of the branch to the merge node self.flow.push((case_last, merge_id, "")); } "Switch" } S::Loop { ref body, ref continuing, break_if, } => { // Create a new targets structure and set the break target // to the merge node, this must happen before generating the // continuing block since it can break. let mut targets = targets; targets.break_target = Some(id); let (continuing_id, continuing_last) = self.add(continuing, targets); // Set the the continue target to the beginning // of the newly generated continuing block targets.continue_target = Some(continuing_id); let (body_id, body_last) = self.add(body, targets); self.flow.push((id, body_id, "body")); // Link the last node of the body to the continuing block self.flow.push((body_last, continuing_id, "continuing")); // Link the last node of the continuing block back to the // beginning of the loop body self.flow.push((continuing_last, body_id, "continuing")); if let Some(expr) = break_if { self.dependencies.push((continuing_id, expr, "break if")); } "Loop" } S::Return { value } => { if let Some(expr) = value { self.dependencies.push((id, expr, "value")); } "Return" } S::Store { pointer, value } => { self.dependencies.push((id, value, "value")); self.emits.push((id, pointer)); "Store" } S::ImageStore { image, coordinate, array_index, value, } => { self.dependencies.push((id, image, "image")); self.dependencies.push((id, coordinate, "coordinate")); if let Some(expr) = array_index { self.dependencies.push((id, expr, "array_index")); } self.dependencies.push((id, value, "value")); "ImageStore" } S::Call { function, ref arguments, result, } => { for &arg in arguments { self.dependencies.push((id, arg, "arg")); } if let Some(expr) = result { self.emits.push((id, expr)); } self.calls.push((id, function)); "Call" } S::Atomic { pointer, ref fun, value, result, } => { if let Some(result) = result { self.emits.push((id, result)); } self.dependencies.push((id, pointer, "pointer")); self.dependencies.push((id, value, "value")); if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun { self.dependencies.push((id, cmp, "cmp")); } "Atomic" } S::ImageAtomic { image, coordinate, array_index, fun: _, value, } => { self.dependencies.push((id, image, "image")); self.dependencies.push((id, coordinate, "coordinate")); if let Some(expr) = array_index { self.dependencies.push((id, expr, "array_index")); } self.dependencies.push((id, value, "value")); "ImageAtomic" } S::WorkGroupUniformLoad { pointer, result } => { self.emits.push((id, result)); self.dependencies.push((id, pointer, "pointer")); "WorkGroupUniformLoad" } S::RayQuery { query, ref fun } => { self.dependencies.push((id, query, "query")); match *fun { crate::RayQueryFunction::Initialize { acceleration_structure, descriptor, } => { self.dependencies.push(( id, acceleration_structure, "acceleration_structure", )); self.dependencies.push((id, descriptor, "descriptor")); "RayQueryInitialize" } crate::RayQueryFunction::Proceed { result } => { self.emits.push((id, result)); "RayQueryProceed" } crate::RayQueryFunction::GenerateIntersection { hit_t } => { self.dependencies.push((id, hit_t, "hit_t")); "RayQueryGenerateIntersection" } crate::RayQueryFunction::ConfirmIntersection => { "RayQueryConfirmIntersection" } crate::RayQueryFunction::Terminate => "RayQueryTerminate", } } S::SubgroupBallot { result, predicate } => { if let Some(predicate) = predicate { self.dependencies.push((id, predicate, "predicate")); } self.emits.push((id, result)); "SubgroupBallot" } S::SubgroupCollectiveOperation { op, collective_op, argument, result, } => { self.dependencies.push((id, argument, "arg")); self.emits.push((id, result)); match (collective_op, op) { (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => { "SubgroupAll" } (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => { "SubgroupAny" } (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => { "SubgroupAdd" } (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => { "SubgroupMul" } (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => { "SubgroupMax" } (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => { "SubgroupMin" } (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => { "SubgroupAnd" } (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => { "SubgroupOr" } (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => { "SubgroupXor" } ( crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add, ) => "SubgroupExclusiveAdd", ( crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul, ) => "SubgroupExclusiveMul", ( crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add, ) => "SubgroupInclusiveAdd", ( crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul, ) => "SubgroupInclusiveMul", _ => unimplemented!(), } } S::SubgroupGather { mode, argument, result, } => { match mode { crate::GatherMode::BroadcastFirst => {} crate::GatherMode::Broadcast(index) | crate::GatherMode::Shuffle(index) | crate::GatherMode::ShuffleDown(index) | crate::GatherMode::ShuffleUp(index) | crate::GatherMode::ShuffleXor(index) | crate::GatherMode::QuadBroadcast(index) => { self.dependencies.push((id, index, "index")) } crate::GatherMode::QuadSwap(_) => {} } self.dependencies.push((id, argument, "arg")); self.emits.push((id, result)); match mode { crate::GatherMode::BroadcastFirst => "SubgroupBroadcastFirst", crate::GatherMode::Broadcast(_) => "SubgroupBroadcast", crate::GatherMode::Shuffle(_) => "SubgroupShuffle", crate::GatherMode::ShuffleDown(_) => "SubgroupShuffleDown", crate::GatherMode::ShuffleUp(_) => "SubgroupShuffleUp", crate::GatherMode::ShuffleXor(_) => "SubgroupShuffleXor", crate::GatherMode::QuadBroadcast(_) => "SubgroupQuadBroadcast", crate::GatherMode::QuadSwap(direction) => match direction { crate::Direction::X => "SubgroupQuadSwapX", crate::Direction::Y => "SubgroupQuadSwapY", crate::Direction::Diagonal => "SubgroupQuadSwapDiagonal", }, } } S::CooperativeStore { target, data } => { self.dependencies.push((id, target, "target")); self.dependencies.push((id, data.pointer, "pointer")); self.dependencies.push((id, data.stride, "stride")); if data.row_major { "CoopStoreT" } else { "CoopStore" } } S::RayPipelineFunction(func) => match func { crate::RayPipelineFunction::TraceRay { acceleration_structure, descriptor, payload, } => { self.dependencies.push(( id, acceleration_structure, "acceleration_structure", )); self.dependencies.push((id, descriptor, "descriptor")); self.dependencies.push((id, payload, "payload")); "TraceRay" } }, }; // Set the last node to the merge node last_node = merge_id; } (root, last_node) } } fn name(option: &Option) -> &str { option.as_deref().unwrap_or_default() } /// set39 color scheme from const COLORS: &[&str] = &[ "white", // pattern starts at 1 "#8dd3c7", "#ffffb3", "#bebada", "#fb8072", "#80b1d3", "#fdb462", "#b3de69", "#fccde5", "#d9d9d9", ]; struct Prefixed(Handle); impl core::fmt::Display for Prefixed { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { self.0.write_prefixed(f, "e") } } impl core::fmt::Display for Prefixed { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { self.0.write_prefixed(f, "l") } } impl core::fmt::Display for Prefixed { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { self.0.write_prefixed(f, "g") } } impl core::fmt::Display for Prefixed { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { self.0.write_prefixed(f, "f") } } fn write_fun( output: &mut String, prefix: String, fun: &crate::Function, info: Option<&FunctionInfo>, options: &Options, ) -> Result<(), FmtError> { writeln!(output, "\t\tnode [ style=filled ]")?; if !options.cfg_only { for (handle, var) in fun.local_variables.iter() { writeln!( output, "\t\t{}_{} [ shape=hexagon label=\"{:?} '{}'\" ]", prefix, Prefixed(handle), handle, name(&var.name), )?; } write_function_expressions(output, &prefix, fun, info)?; } let mut sg = StatementGraph::default(); sg.add(&fun.body, Targets::default()); for (index, label) in sg.nodes.into_iter().enumerate() { writeln!( output, "\t\t{prefix}_s{index} [ shape=square label=\"{label}\" ]", )?; } for (from, to, label) in sg.flow { writeln!( output, "\t\t{prefix}_s{from} -> {prefix}_s{to} [ arrowhead=tee label=\"{label}\" ]", )?; } for (from, to, label, color_id) in sg.jumps { writeln!( output, "\t\t{}_s{} -> {}_s{} [ arrowhead=tee style=dashed color=\"{}\" label=\"{}\" ]", prefix, from, prefix, to, COLORS[color_id], label, )?; } if !options.cfg_only { for (to, expr, label) in sg.dependencies { writeln!( output, "\t\t{}_{} -> {}_s{} [ label=\"{}\" ]", prefix, Prefixed(expr), prefix, to, label, )?; } for (from, to) in sg.emits { writeln!( output, "\t\t{}_s{} -> {}_{} [ style=dotted ]", prefix, from, prefix, Prefixed(to), )?; } } assert!(sg.calls.is_empty()); for (from, function) in sg.calls { writeln!( output, "\t\t{}_s{} -> {}_s0", prefix, from, Prefixed(function), )?; } Ok(()) } fn write_function_expressions( output: &mut String, prefix: &str, fun: &crate::Function, info: Option<&FunctionInfo>, ) -> Result<(), FmtError> { enum Payload<'a> { Arguments(&'a [Handle]), Local(Handle), Global(Handle), } let mut edges = crate::FastHashMap::<&str, _>::default(); let mut payload = None; for (handle, expression) in fun.expressions.iter() { use crate::Expression as E; let (label, color_id) = match *expression { E::Literal(_) => ("Literal".into(), 2), E::Constant(_) => ("Constant".into(), 2), E::Override(_) => ("Override".into(), 2), E::ZeroValue(_) => ("ZeroValue".into(), 2), E::Compose { ref components, .. } => { payload = Some(Payload::Arguments(components)); ("Compose".into(), 3) } E::Access { base, index } => { edges.insert("base", base); edges.insert("index", index); ("Access".into(), 1) } E::AccessIndex { base, index } => { edges.insert("base", base); (format!("AccessIndex[{index}]").into(), 1) } E::Splat { size, value } => { edges.insert("value", value); (format!("Splat{size:?}").into(), 3) } E::Swizzle { size, vector, pattern, } => { edges.insert("vector", vector); (format!("Swizzle{:?}", &pattern[..size as usize]).into(), 3) } E::FunctionArgument(index) => (format!("Argument[{index}]").into(), 1), E::GlobalVariable(h) => { payload = Some(Payload::Global(h)); ("Global".into(), 2) } E::LocalVariable(h) => { payload = Some(Payload::Local(h)); ("Local".into(), 1) } E::Load { pointer } => { edges.insert("pointer", pointer); ("Load".into(), 4) } E::ImageSample { image, sampler, gather, coordinate, array_index, offset: _, level, depth_ref, clamp_to_edge: _, } => { edges.insert("image", image); edges.insert("sampler", sampler); edges.insert("coordinate", coordinate); if let Some(expr) = array_index { edges.insert("array_index", expr); } match level { crate::SampleLevel::Auto => {} crate::SampleLevel::Zero => {} crate::SampleLevel::Exact(expr) => { edges.insert("level", expr); } crate::SampleLevel::Bias(expr) => { edges.insert("bias", expr); } crate::SampleLevel::Gradient { x, y } => { edges.insert("grad_x", x); edges.insert("grad_y", y); } } if let Some(expr) = depth_ref { edges.insert("depth_ref", expr); } let string = match gather { Some(component) => Cow::Owned(format!("ImageGather{component:?}")), _ => Cow::Borrowed("ImageSample"), }; (string, 5) } E::ImageLoad { image, coordinate, array_index, sample, level, } => { edges.insert("image", image); edges.insert("coordinate", coordinate); if let Some(expr) = array_index { edges.insert("array_index", expr); } if let Some(sample) = sample { edges.insert("sample", sample); } if let Some(level) = level { edges.insert("level", level); } ("ImageLoad".into(), 5) } E::ImageQuery { image, query } => { edges.insert("image", image); let args = match query { crate::ImageQuery::Size { level } => { if let Some(expr) = level { edges.insert("level", expr); } Cow::from("ImageSize") } _ => Cow::Owned(format!("{query:?}")), }; (args, 7) } E::Unary { op, expr } => { edges.insert("expr", expr); (format!("{op:?}").into(), 6) } E::Binary { op, left, right } => { edges.insert("left", left); edges.insert("right", right); (format!("{op:?}").into(), 6) } E::Select { condition, accept, reject, } => { edges.insert("condition", condition); edges.insert("accept", accept); edges.insert("reject", reject); ("Select".into(), 3) } E::Derivative { axis, ctrl, expr } => { edges.insert("", expr); (format!("d{axis:?}{ctrl:?}").into(), 8) } E::Relational { fun, argument } => { edges.insert("arg", argument); (format!("{fun:?}").into(), 6) } E::Math { fun, arg, arg1, arg2, arg3, } => { edges.insert("arg", arg); if let Some(expr) = arg1 { edges.insert("arg1", expr); } if let Some(expr) = arg2 { edges.insert("arg2", expr); } if let Some(expr) = arg3 { edges.insert("arg3", expr); } (format!("{fun:?}").into(), 7) } E::As { kind, expr, convert, } => { edges.insert("", expr); let string = match convert { Some(width) => format!("Convert<{kind:?},{width}>"), None => format!("Bitcast<{kind:?}>"), }; (string.into(), 3) } E::CallResult(_function) => ("CallResult".into(), 4), E::AtomicResult { .. } => ("AtomicResult".into(), 4), E::WorkGroupUniformLoadResult { .. } => ("WorkGroupUniformLoadResult".into(), 4), E::ArrayLength(expr) => { edges.insert("", expr); ("ArrayLength".into(), 7) } E::RayQueryProceedResult => ("rayQueryProceedResult".into(), 4), E::RayQueryGetIntersection { query, committed } => { edges.insert("", query); let ty = if committed { "Committed" } else { "Candidate" }; (format!("rayQueryGet{ty}Intersection").into(), 4) } E::SubgroupBallotResult => ("SubgroupBallotResult".into(), 4), E::SubgroupOperationResult { .. } => ("SubgroupOperationResult".into(), 4), E::RayQueryVertexPositions { query, committed } => { edges.insert("", query); let ty = if committed { "Committed" } else { "Candidate" }; (format!("get{ty}HitVertexPositions").into(), 4) } E::CooperativeLoad { ref data, .. } => { edges.insert("pointer", data.pointer); edges.insert("stride", data.stride); let suffix = if data.row_major { "T " } else { "" }; (format!("coopLoad{suffix}").into(), 4) } E::CooperativeMultiplyAdd { a, b, c } => { edges.insert("a", a); edges.insert("b", b); edges.insert("c", c); ("cooperativeMultiplyAdd".into(), 4) } }; // give uniform expressions an outline let color_attr = match info { Some(info) if info[handle].uniformity.non_uniform_result.is_none() => "fillcolor", _ => "color", }; writeln!( output, "\t\t{}_{} [ {}=\"{}\" label=\"{:?} {}\" ]", prefix, Prefixed(handle), color_attr, COLORS[color_id], handle, label, )?; for (key, edge) in edges.drain() { writeln!( output, "\t\t{}_{} -> {}_{} [ label=\"{}\" ]", prefix, Prefixed(edge), prefix, Prefixed(handle), key, )?; } match payload.take() { Some(Payload::Arguments(list)) => { write!(output, "\t\t{{")?; for &comp in list { write!(output, " {}_{}", prefix, Prefixed(comp))?; } writeln!(output, " }} -> {}_{}", prefix, Prefixed(handle))?; } Some(Payload::Local(h)) => { writeln!( output, "\t\t{}_{} -> {}_{}", prefix, Prefixed(h), prefix, Prefixed(handle), )?; } Some(Payload::Global(h)) => { writeln!( output, "\t\t{} -> {}_{} [fillcolor=gray]", Prefixed(h), prefix, Prefixed(handle), )?; } None => {} } } Ok(()) } /// Write shader module to a [`String`]. pub fn write( module: &crate::Module, mod_info: Option<&ModuleInfo>, options: Options, ) -> Result { use core::fmt::Write as _; let mut output = String::new(); output += "digraph Module {\n"; if !options.cfg_only { writeln!(output, "\tsubgraph cluster_globals {{")?; writeln!(output, "\t\tlabel=\"Globals\"")?; for (handle, var) in module.global_variables.iter() { writeln!( output, "\t\t{} [ shape=hexagon label=\"{:?} {:?}/'{}'\" ]", Prefixed(handle), handle, var.space, name(&var.name), )?; } writeln!(output, "\t}}")?; } for (handle, fun) in module.functions.iter() { let prefix = Prefixed(handle).to_string(); writeln!(output, "\tsubgraph cluster_{prefix} {{")?; writeln!( output, "\t\tlabel=\"Function{:?}/'{}'\"", handle, name(&fun.name) )?; let info = mod_info.map(|a| &a[handle]); write_fun(&mut output, prefix, fun, info, &options)?; writeln!(output, "\t}}")?; } for (ep_index, ep) in module.entry_points.iter().enumerate() { let prefix = format!("ep{ep_index}"); writeln!(output, "\tsubgraph cluster_{prefix} {{")?; writeln!(output, "\t\tlabel=\"{:?}/'{}'\"", ep.stage, ep.name)?; let info = mod_info.map(|a| a.get_entry_point(ep_index)); write_fun(&mut output, prefix, &ep.function, info, &options)?; writeln!(output, "\t}}")?; } output += "}\n"; Ok(output) } naga-29.0.3/src/back/glsl/conv.rs000064400000000000000000000202331046102023000145700ustar 00000000000000use crate::back::glsl::{BackendResult, Error, VaryingOptions}; /// Structure returned by [`glsl_scalar`] /// /// It contains both a prefix used in other types and the full type name pub(in crate::back::glsl) struct ScalarString<'a> { /// The prefix used to compose other types pub prefix: &'a str, /// The name of the scalar type pub full: &'a str, } /// Helper function that returns scalar related strings /// /// Check [`ScalarString`] for the information provided /// /// # Errors /// If a [`Float`](crate::ScalarKind::Float) with an width that isn't 4 or 8 pub(in crate::back::glsl) const fn glsl_scalar( scalar: crate::Scalar, ) -> Result, Error> { use crate::ScalarKind as Sk; Ok(match scalar.kind { Sk::Sint => ScalarString { prefix: "i", full: "int", }, Sk::Uint => ScalarString { prefix: "u", full: "uint", }, Sk::Float => match scalar.width { 4 => ScalarString { prefix: "", full: "float", }, 8 => ScalarString { prefix: "d", full: "double", }, _ => return Err(Error::UnsupportedScalar(scalar)), }, Sk::Bool => ScalarString { prefix: "b", full: "bool", }, Sk::AbstractInt | Sk::AbstractFloat => { return Err(Error::UnsupportedScalar(scalar)); } }) } /// Helper function that returns the glsl variable name for a builtin pub(in crate::back::glsl) const fn glsl_built_in( built_in: crate::BuiltIn, options: VaryingOptions, ) -> &'static str { use crate::BuiltIn as Bi; match built_in { Bi::Position { .. } => { if options.output { "gl_Position" } else { "gl_FragCoord" } } Bi::ViewIndex => { if options.targeting_webgl { "gl_ViewID_OVR" } else { "uint(gl_ViewIndex)" } } // vertex Bi::BaseInstance => "uint(gl_BaseInstance)", Bi::BaseVertex => "uint(gl_BaseVertex)", Bi::ClipDistance => "gl_ClipDistance", Bi::CullDistance => "gl_CullDistance", Bi::InstanceIndex => { if options.draw_parameters { "(uint(gl_InstanceID) + uint(gl_BaseInstanceARB))" } else { // Must match FIRST_INSTANCE_BINDING "(uint(gl_InstanceID) + naga_vs_first_instance)" } } Bi::PointSize => "gl_PointSize", Bi::VertexIndex => "uint(gl_VertexID)", Bi::DrawIndex => "gl_DrawID", // fragment Bi::FragDepth => "gl_FragDepth", Bi::PointCoord => "gl_PointCoord", Bi::FrontFacing => "gl_FrontFacing", Bi::PrimitiveIndex => "uint(gl_PrimitiveID)", Bi::Barycentric { perspective: true } => "gl_BaryCoordEXT", Bi::Barycentric { perspective: false } => "gl_BaryCoordNoPerspEXT", Bi::SampleIndex => "gl_SampleID", Bi::SampleMask => { if options.output { "gl_SampleMask" } else { "gl_SampleMaskIn" } } // compute Bi::GlobalInvocationId => "gl_GlobalInvocationID", Bi::LocalInvocationId => "gl_LocalInvocationID", Bi::LocalInvocationIndex => "gl_LocalInvocationIndex", Bi::WorkGroupId => "gl_WorkGroupID", Bi::WorkGroupSize => "gl_WorkGroupSize", Bi::NumWorkGroups => "gl_NumWorkGroups", // subgroup Bi::NumSubgroups => "gl_NumSubgroups", Bi::SubgroupId => "gl_SubgroupID", Bi::SubgroupSize => "gl_SubgroupSize", Bi::SubgroupInvocationId => "gl_SubgroupInvocationID", // mesh // TODO: figure out how to map these to glsl things as glsl treats them as arrays Bi::CullPrimitive | Bi::PointIndex | Bi::LineIndices | Bi::TriangleIndices | Bi::MeshTaskSize | Bi::VertexCount | Bi::PrimitiveCount | Bi::Vertices | Bi::Primitives | Bi::RayInvocationId | Bi::NumRayInvocations | Bi::InstanceCustomData | Bi::GeometryIndex | Bi::WorldRayOrigin | Bi::WorldRayDirection | Bi::ObjectRayOrigin | Bi::ObjectRayDirection | Bi::RayTmin | Bi::RayTCurrentMax | Bi::ObjectToWorld | Bi::WorldToObject | Bi::HitKind => { unimplemented!() } } } /// Helper function that returns the string corresponding to the address space pub(in crate::back::glsl) const fn glsl_storage_qualifier( space: crate::AddressSpace, ) -> Option<&'static str> { use crate::AddressSpace as As; match space { As::Function => None, As::Private => None, As::Storage { .. } => Some("buffer"), As::Uniform => Some("uniform"), As::Handle => Some("uniform"), As::WorkGroup => Some("shared"), As::Immediate => Some("uniform"), As::TaskPayload | As::RayPayload | As::IncomingRayPayload => unreachable!(), } } /// Helper function that returns the string corresponding to the glsl interpolation qualifier pub(in crate::back::glsl) const fn glsl_interpolation( interpolation: crate::Interpolation, ) -> &'static str { use crate::Interpolation as I; match interpolation { I::Perspective => "smooth", I::Linear => "noperspective", I::Flat => "flat", I::PerVertex => unreachable!(), } } /// Return the GLSL auxiliary qualifier for the given sampling value. pub(in crate::back::glsl) const fn glsl_sampling( sampling: crate::Sampling, ) -> BackendResult> { use crate::Sampling as S; Ok(match sampling { S::First => return Err(Error::FirstSamplingNotSupported), S::Center | S::Either => None, S::Centroid => Some("centroid"), S::Sample => Some("sample"), }) } /// Helper function that returns the glsl dimension string of [`ImageDimension`](crate::ImageDimension) pub(in crate::back::glsl) const fn glsl_dimension(dim: crate::ImageDimension) -> &'static str { use crate::ImageDimension as IDim; match dim { IDim::D1 => "1D", IDim::D2 => "2D", IDim::D3 => "3D", IDim::Cube => "Cube", } } /// Helper function that returns the glsl storage format string of [`StorageFormat`](crate::StorageFormat) pub(in crate::back::glsl) fn glsl_storage_format( format: crate::StorageFormat, ) -> Result<&'static str, Error> { use crate::StorageFormat as Sf; Ok(match format { Sf::R8Unorm => "r8", Sf::R8Snorm => "r8_snorm", Sf::R8Uint => "r8ui", Sf::R8Sint => "r8i", Sf::R16Uint => "r16ui", Sf::R16Sint => "r16i", Sf::R16Float => "r16f", Sf::Rg8Unorm => "rg8", Sf::Rg8Snorm => "rg8_snorm", Sf::Rg8Uint => "rg8ui", Sf::Rg8Sint => "rg8i", Sf::R32Uint => "r32ui", Sf::R32Sint => "r32i", Sf::R32Float => "r32f", Sf::Rg16Uint => "rg16ui", Sf::Rg16Sint => "rg16i", Sf::Rg16Float => "rg16f", Sf::Rgba8Unorm => "rgba8", Sf::Rgba8Snorm => "rgba8_snorm", Sf::Rgba8Uint => "rgba8ui", Sf::Rgba8Sint => "rgba8i", Sf::Rgb10a2Uint => "rgb10_a2ui", Sf::Rgb10a2Unorm => "rgb10_a2", Sf::Rg11b10Ufloat => "r11f_g11f_b10f", Sf::R64Uint => "r64ui", Sf::Rg32Uint => "rg32ui", Sf::Rg32Sint => "rg32i", Sf::Rg32Float => "rg32f", Sf::Rgba16Uint => "rgba16ui", Sf::Rgba16Sint => "rgba16i", Sf::Rgba16Float => "rgba16f", Sf::Rgba32Uint => "rgba32ui", Sf::Rgba32Sint => "rgba32i", Sf::Rgba32Float => "rgba32f", Sf::R16Unorm => "r16", Sf::R16Snorm => "r16_snorm", Sf::Rg16Unorm => "rg16", Sf::Rg16Snorm => "rg16_snorm", Sf::Rgba16Unorm => "rgba16", Sf::Rgba16Snorm => "rgba16_snorm", Sf::Bgra8Unorm => { return Err(Error::Custom( "Support format BGRA8 is not implemented".into(), )) } }) } naga-29.0.3/src/back/glsl/features.rs000064400000000000000000000722071046102023000154510ustar 00000000000000use core::fmt::Write; use super::{BackendResult, Error, Version, Writer}; use crate::{ back::glsl::{Options, WriterFlags}, AddressSpace, Binding, Expression, Handle, ImageClass, ImageDimension, Interpolation, SampleLevel, Sampling, Scalar, ScalarKind, ShaderStage, StorageFormat, Type, TypeInner, }; bitflags::bitflags! { /// Structure used to encode additions to GLSL that aren't supported by all versions. #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub struct Features: u32 { /// Buffer address space support. const BUFFER_STORAGE = 1; const ARRAY_OF_ARRAYS = 1 << 1; /// 8 byte floats. const DOUBLE_TYPE = 1 << 2; /// More image formats. const FULL_IMAGE_FORMATS = 1 << 3; const MULTISAMPLED_TEXTURES = 1 << 4; const MULTISAMPLED_TEXTURE_ARRAYS = 1 << 5; const CUBE_TEXTURES_ARRAY = 1 << 6; const COMPUTE_SHADER = 1 << 7; /// Image load and early depth tests. const IMAGE_LOAD_STORE = 1 << 8; const CONSERVATIVE_DEPTH = 1 << 9; /// Interpolation and auxiliary qualifiers. /// /// Perspective, Flat, and Centroid are available in all GLSL versions we support. const NOPERSPECTIVE_QUALIFIER = 1 << 11; const SAMPLE_QUALIFIER = 1 << 12; const CLIP_DISTANCE = 1 << 13; const CULL_DISTANCE = 1 << 14; /// Sample ID. const SAMPLE_VARIABLES = 1 << 15; /// Arrays with a dynamic length. const DYNAMIC_ARRAY_SIZE = 1 << 16; const MULTI_VIEW = 1 << 17; /// Texture samples query const TEXTURE_SAMPLES = 1 << 18; /// Texture levels query const TEXTURE_LEVELS = 1 << 19; /// Image size query const IMAGE_SIZE = 1 << 20; /// Dual source blending const DUAL_SOURCE_BLENDING = 1 << 21; /// Instance index /// /// We can always support this, either through the language or a polyfill const INSTANCE_INDEX = 1 << 22; /// Sample specific LODs of cube / array shadow textures const TEXTURE_SHADOW_LOD = 1 << 23; /// Subgroup operations const SUBGROUP_OPERATIONS = 1 << 24; /// Image atomics const TEXTURE_ATOMICS = 1 << 25; /// Image atomics const SHADER_BARYCENTRICS = 1 << 26; /// Primitive index builtin const PRIMITIVE_INDEX = 1 << 27; } } /// Helper structure used to store the required [`Features`] needed to output a /// [`Module`](crate::Module) /// /// Provides helper methods to check for availability and writing required extensions pub(crate) struct FeaturesManager(Features); impl FeaturesManager { /// Creates a new [`FeaturesManager`] instance pub const fn new() -> Self { Self(Features::empty()) } /// Adds to the list of required [`Features`] pub fn request(&mut self, features: Features) { self.0 |= features } /// Checks if the list of features [`Features`] contains the specified [`Features`] pub const fn contains(&mut self, features: Features) -> bool { self.0.contains(features) } /// Checks that all required [`Features`] are available for the specified /// [`Version`] otherwise returns an [`Error::MissingFeatures`]. pub fn check_availability(&self, version: Version) -> BackendResult { // Will store all the features that are unavailable let mut missing = Features::empty(); // Helper macro to check for feature availability macro_rules! check_feature { // Used when only core glsl supports the feature ($feature:ident, $core:literal) => { if self.0.contains(Features::$feature) && (version < Version::Desktop($core) || version.is_es()) { missing |= Features::$feature; } }; // Used when both core and es support the feature ($feature:ident, $core:literal, $es:literal) => { if self.0.contains(Features::$feature) && (version < Version::Desktop($core) || version < Version::new_gles($es)) { missing |= Features::$feature; } }; } check_feature!(COMPUTE_SHADER, 420, 310); check_feature!(BUFFER_STORAGE, 400, 310); check_feature!(DOUBLE_TYPE, 150); check_feature!(CUBE_TEXTURES_ARRAY, 130, 310); check_feature!(MULTISAMPLED_TEXTURES, 150, 300); check_feature!(MULTISAMPLED_TEXTURE_ARRAYS, 150, 310); check_feature!(ARRAY_OF_ARRAYS, 120, 310); check_feature!(IMAGE_LOAD_STORE, 130, 310); check_feature!(CONSERVATIVE_DEPTH, 130, 300); check_feature!(NOPERSPECTIVE_QUALIFIER, 130); check_feature!(SAMPLE_QUALIFIER, 400, 320); check_feature!(CLIP_DISTANCE, 130, 300 /* with extension */); check_feature!(CULL_DISTANCE, 450, 300 /* with extension */); check_feature!(SAMPLE_VARIABLES, 400, 300); check_feature!(DYNAMIC_ARRAY_SIZE, 400 /* with extension */, 310); check_feature!(DUAL_SOURCE_BLENDING, 330, 300 /* with extension */); check_feature!(SUBGROUP_OPERATIONS, 430, 310); check_feature!(TEXTURE_ATOMICS, 420, 310); match version { Version::Embedded { is_webgl: true, .. } => check_feature!(MULTI_VIEW, 140, 300), _ => check_feature!(MULTI_VIEW, 140, 310), }; // Only available on glsl core, this means that opengl es can't query the number // of samples nor levels in a image and neither do bound checks on the sample nor // the level argument of texelFecth check_feature!(TEXTURE_SAMPLES, 150); check_feature!(TEXTURE_LEVELS, 130); check_feature!(IMAGE_SIZE, 430, 310); check_feature!(TEXTURE_SHADOW_LOD, 200, 300); // Return an error if there are missing features if missing.is_empty() { Ok(()) } else { Err(Error::MissingFeatures(missing)) } } /// Helper method used to write all needed extensions /// /// # Notes /// This won't check for feature availability so it might output extensions that aren't even /// supported.[`check_availability`](Self::check_availability) will check feature availability pub fn write(&self, options: &Options, mut out: impl Write) -> BackendResult { if self.0.contains(Features::COMPUTE_SHADER) && !options.version.is_es() { // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_compute_shader.txt writeln!(out, "#extension GL_ARB_compute_shader : require")?; } if self.0.contains(Features::BUFFER_STORAGE) && !options.version.is_es() { // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_shader_storage_buffer_object.txt writeln!( out, "#extension GL_ARB_shader_storage_buffer_object : require" )?; } if self.0.contains(Features::DOUBLE_TYPE) && options.version < Version::Desktop(400) { // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_gpu_shader_fp64.txt writeln!(out, "#extension GL_ARB_gpu_shader_fp64 : require")?; } if self.0.contains(Features::CUBE_TEXTURES_ARRAY) { if options.version.is_es() { // https://www.khronos.org/registry/OpenGL/extensions/EXT/EXT_texture_cube_map_array.txt writeln!(out, "#extension GL_EXT_texture_cube_map_array : require")?; } else if options.version < Version::Desktop(400) { // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_texture_cube_map_array.txt writeln!(out, "#extension GL_ARB_texture_cube_map_array : require")?; } } if self.0.contains(Features::MULTISAMPLED_TEXTURE_ARRAYS) && options.version.is_es() { // https://www.khronos.org/registry/OpenGL/extensions/OES/OES_texture_storage_multisample_2d_array.txt writeln!( out, "#extension GL_OES_texture_storage_multisample_2d_array : require" )?; } if self.0.contains(Features::ARRAY_OF_ARRAYS) && options.version < Version::Desktop(430) { // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_arrays_of_arrays.txt writeln!(out, "#extension ARB_arrays_of_arrays : require")?; } if self.0.contains(Features::IMAGE_LOAD_STORE) { if self.0.contains(Features::FULL_IMAGE_FORMATS) && options.version.is_es() { // https://www.khronos.org/registry/OpenGL/extensions/NV/NV_image_formats.txt writeln!(out, "#extension GL_NV_image_formats : require")?; } if options.version < Version::Desktop(420) { // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_shader_image_load_store.txt writeln!(out, "#extension GL_ARB_shader_image_load_store : require")?; } } if self.0.contains(Features::CONSERVATIVE_DEPTH) { if options.version.is_es() { // https://www.khronos.org/registry/OpenGL/extensions/EXT/EXT_conservative_depth.txt writeln!(out, "#extension GL_EXT_conservative_depth : require")?; } if options.version < Version::Desktop(420) { // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_conservative_depth.txt writeln!(out, "#extension GL_ARB_conservative_depth : require")?; } } if (self.0.contains(Features::CLIP_DISTANCE) || self.0.contains(Features::CULL_DISTANCE)) && options.version.is_es() { // https://www.khronos.org/registry/OpenGL/extensions/EXT/EXT_clip_cull_distance.txt writeln!(out, "#extension GL_EXT_clip_cull_distance : require")?; } if self.0.contains(Features::SAMPLE_VARIABLES) && options.version.is_es() { // https://www.khronos.org/registry/OpenGL/extensions/OES/OES_sample_variables.txt writeln!(out, "#extension GL_OES_sample_variables : require")?; } if self.0.contains(Features::MULTI_VIEW) { if let Version::Embedded { is_webgl: true, .. } = options.version { // https://www.khronos.org/registry/OpenGL/extensions/OVR/OVR_multiview2.txt writeln!(out, "#extension GL_OVR_multiview2 : require")?; } else { // https://github.com/KhronosGroup/GLSL/blob/master/extensions/ext/GL_EXT_multiview.txt writeln!(out, "#extension GL_EXT_multiview : require")?; } } if self.0.contains(Features::TEXTURE_SAMPLES) { // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_shader_texture_image_samples.txt writeln!( out, "#extension GL_ARB_shader_texture_image_samples : require" )?; } if self.0.contains(Features::TEXTURE_LEVELS) && options.version < Version::Desktop(430) { // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_texture_query_levels.txt writeln!(out, "#extension GL_ARB_texture_query_levels : require")?; } if self.0.contains(Features::DUAL_SOURCE_BLENDING) && options.version.is_es() { // https://registry.khronos.org/OpenGL/extensions/EXT/EXT_blend_func_extended.txt writeln!(out, "#extension GL_EXT_blend_func_extended : require")?; } if self.0.contains(Features::INSTANCE_INDEX) { if options.writer_flags.contains(WriterFlags::DRAW_PARAMETERS) { // https://registry.khronos.org/OpenGL/extensions/ARB/ARB_shader_draw_parameters.txt writeln!(out, "#extension GL_ARB_shader_draw_parameters : require")?; } } if self.0.contains(Features::TEXTURE_SHADOW_LOD) { // https://registry.khronos.org/OpenGL/extensions/EXT/EXT_texture_shadow_lod.txt writeln!(out, "#extension GL_EXT_texture_shadow_lod : require")?; } if self.0.contains(Features::SUBGROUP_OPERATIONS) { // https://registry.khronos.org/OpenGL/extensions/KHR/KHR_shader_subgroup.txt writeln!(out, "#extension GL_KHR_shader_subgroup_basic : require")?; writeln!(out, "#extension GL_KHR_shader_subgroup_vote : require")?; writeln!( out, "#extension GL_KHR_shader_subgroup_arithmetic : require" )?; writeln!(out, "#extension GL_KHR_shader_subgroup_ballot : require")?; writeln!(out, "#extension GL_KHR_shader_subgroup_shuffle : require")?; writeln!( out, "#extension GL_KHR_shader_subgroup_shuffle_relative : require" )?; writeln!(out, "#extension GL_KHR_shader_subgroup_quad : require")?; } if self.0.contains(Features::TEXTURE_ATOMICS) { // https://www.khronos.org/registry/OpenGL/extensions/OES/OES_shader_image_atomic.txt writeln!(out, "#extension GL_OES_shader_image_atomic : require")?; } if self.0.contains(Features::SHADER_BARYCENTRICS) { // https://github.com/KhronosGroup/GLSL/blob/main/extensions/ext/GLSL_EXT_fragment_shader_barycentric.txt writeln!( out, "#extension GL_EXT_fragment_shader_barycentric : require" )?; } if self.0.contains(Features::PRIMITIVE_INDEX) { match options.version { Version::Embedded { version, .. } if version < 320 => { writeln!(out, "#extension GL_OES_geometry_shader : require")?; } Version::Desktop(version) if version < 150 => { writeln!(out, "#extension GL_ARB_geometry_shader4 : require")?; } _ => (), } } Ok(()) } } impl Writer<'_, W> { /// Helper method that searches the module for all the needed [`Features`] /// /// # Errors /// If the version doesn't support any of the needed [`Features`] a /// [`Error::MissingFeatures`] will be returned pub(super) fn collect_required_features(&mut self) -> BackendResult { let ep_info = self.info.get_entry_point(self.entry_point_idx as usize); if let Some(early_depth_test) = self.entry_point.early_depth_test { match early_depth_test { crate::EarlyDepthTest::Force => { if self.options.version.supports_early_depth_test() { self.features.request(Features::IMAGE_LOAD_STORE); } } crate::EarlyDepthTest::Allow { .. } => { self.features.request(Features::CONSERVATIVE_DEPTH); } } } for arg in self.entry_point.function.arguments.iter() { self.varying_required_features(arg.binding.as_ref(), arg.ty); } if let Some(ref result) = self.entry_point.function.result { self.varying_required_features(result.binding.as_ref(), result.ty); } if let ShaderStage::Compute = self.entry_point.stage { self.features.request(Features::COMPUTE_SHADER) } if self.multiview.is_some() { self.features.request(Features::MULTI_VIEW); } for (ty_handle, ty) in self.module.types.iter() { match ty.inner { TypeInner::Scalar(scalar) | TypeInner::Vector { scalar, .. } | TypeInner::Matrix { scalar, .. } => self.scalar_required_features(scalar), TypeInner::Array { base, size, .. } => { if let TypeInner::Array { .. } = self.module.types[base].inner { self.features.request(Features::ARRAY_OF_ARRAYS) } // If the array is dynamically sized if size == crate::ArraySize::Dynamic { let mut is_used = false; // Check if this type is used in a global that is needed by the current entrypoint for (global_handle, global) in self.module.global_variables.iter() { // Skip unused globals if ep_info[global_handle].is_empty() { continue; } // If this array is the type of a global, then this array is used if global.ty == ty_handle { is_used = true; break; } // If the type of this global is a struct if let TypeInner::Struct { ref members, .. } = self.module.types[global.ty].inner { // Check the last element of the struct to see if it's type uses // this array if let Some(last) = members.last() { if last.ty == ty_handle { is_used = true; break; } } } } // If this dynamically size array is used, we need dynamic array size support if is_used { self.features.request(Features::DYNAMIC_ARRAY_SIZE); } } } TypeInner::Image { dim, arrayed, class, } => { if arrayed && dim == ImageDimension::Cube { self.features.request(Features::CUBE_TEXTURES_ARRAY) } match class { ImageClass::Sampled { multi: true, .. } | ImageClass::Depth { multi: true } => { self.features.request(Features::MULTISAMPLED_TEXTURES); if arrayed { self.features.request(Features::MULTISAMPLED_TEXTURE_ARRAYS); } } ImageClass::Storage { format, .. } => match format { StorageFormat::R8Unorm | StorageFormat::R8Snorm | StorageFormat::R8Uint | StorageFormat::R8Sint | StorageFormat::R16Uint | StorageFormat::R16Sint | StorageFormat::R16Float | StorageFormat::Rg8Unorm | StorageFormat::Rg8Snorm | StorageFormat::Rg8Uint | StorageFormat::Rg8Sint | StorageFormat::Rg16Uint | StorageFormat::Rg16Sint | StorageFormat::Rg16Float | StorageFormat::Rgb10a2Uint | StorageFormat::Rgb10a2Unorm | StorageFormat::Rg11b10Ufloat | StorageFormat::R64Uint | StorageFormat::Rg32Uint | StorageFormat::Rg32Sint | StorageFormat::Rg32Float => { self.features.request(Features::FULL_IMAGE_FORMATS) } _ => {} }, ImageClass::Sampled { multi: false, .. } | ImageClass::Depth { multi: false } | ImageClass::External => {} } } _ => {} } } let mut immediates_used = false; for (handle, global) in self.module.global_variables.iter() { if ep_info[handle].is_empty() { continue; } match global.space { AddressSpace::WorkGroup => self.features.request(Features::COMPUTE_SHADER), AddressSpace::Storage { .. } => self.features.request(Features::BUFFER_STORAGE), AddressSpace::Immediate => { if immediates_used { return Err(Error::MultipleImmediateData); } immediates_used = true; } _ => {} } } // We will need to pass some of the members to a closure, so we need // to separate them otherwise the borrow checker will complain, this // shouldn't be needed in rust 2021 let &mut Self { module, info, ref mut features, entry_point, entry_point_idx, ref policies, .. } = self; // Loop through all expressions in both functions and the entry point // to check for needed features for (expressions, info) in module .functions .iter() .map(|(h, f)| (&f.expressions, &info[h])) .chain(core::iter::once(( &entry_point.function.expressions, info.get_entry_point(entry_point_idx as usize), ))) { for (_, expr) in expressions.iter() { match *expr { // Check for queries that need aditonal features Expression::ImageQuery { image, query, .. } => match query { // Storage images use `imageSize` which is only available // in glsl > 420 // // layers queries are also implemented as size queries crate::ImageQuery::Size { .. } | crate::ImageQuery::NumLayers => { if let TypeInner::Image { class: ImageClass::Storage { .. }, .. } = *info[image].ty.inner_with(&module.types) { features.request(Features::IMAGE_SIZE) } }, crate::ImageQuery::NumLevels => features.request(Features::TEXTURE_LEVELS), crate::ImageQuery::NumSamples => features.request(Features::TEXTURE_SAMPLES), } , // Check for image loads that needs bound checking on the sample // or level argument since this requires a feature Expression::ImageLoad { sample, level, .. } => { if policies.image_load != crate::proc::BoundsCheckPolicy::Unchecked { if sample.is_some() { features.request(Features::TEXTURE_SAMPLES) } if level.is_some() { features.request(Features::TEXTURE_LEVELS) } } } Expression::ImageSample { image, level, offset, .. } => { if let TypeInner::Image { dim, arrayed, class: ImageClass::Depth { .. }, } = *info[image].ty.inner_with(&module.types) { let lod = matches!(level, SampleLevel::Zero | SampleLevel::Exact(_)); let bias = matches!(level, SampleLevel::Bias(_)); let auto = matches!(level, SampleLevel::Auto); let cube = dim == ImageDimension::Cube; let array2d = dim == ImageDimension::D2 && arrayed; let gles = self.options.version.is_es(); // We have a workaround of using `textureGrad` instead of `textureLod` if the LOD is zero, // so we don't *need* this extension for those cases. // But if we're explicitly allowed to use the extension (`WriterFlags::TEXTURE_SHADOW_LOD`), // we always use it instead of the workaround. let grad_workaround_applicable = (array2d || (cube && !arrayed)) && level == SampleLevel::Zero; let prefer_grad_workaround = grad_workaround_applicable && !self.options.writer_flags.contains(WriterFlags::TEXTURE_SHADOW_LOD); let mut ext_used = false; // float texture(sampler2DArrayShadow sampler, vec4 P [, float bias]) // float texture(samplerCubeArrayShadow sampler, vec4 P, float compare [, float bias]) ext_used |= (array2d || cube && arrayed) && bias; // The non `bias` version of this was standardized in GL 4.3, but never in GLES. // float textureOffset(sampler2DArrayShadow sampler, vec4 P, ivec2 offset [, float bias]) ext_used |= array2d && (bias || (gles && auto)) && offset.is_some(); // float textureLod(sampler2DArrayShadow sampler, vec4 P, float lod) // float textureLodOffset(sampler2DArrayShadow sampler, vec4 P, float lod, ivec2 offset) // float textureLod(samplerCubeShadow sampler, vec4 P, float lod) // float textureLod(samplerCubeArrayShadow sampler, vec4 P, float compare, float lod) ext_used |= (cube || array2d) && lod && !prefer_grad_workaround; if ext_used { features.request(Features::TEXTURE_SHADOW_LOD); } } } Expression::SubgroupBallotResult | Expression::SubgroupOperationResult { .. } => { features.request(Features::SUBGROUP_OPERATIONS) } _ => {} } } } for blocks in module .functions .iter() .map(|(_, f)| &f.body) .chain(core::iter::once(&entry_point.function.body)) { for (stmt, _) in blocks.span_iter() { match *stmt { crate::Statement::ImageAtomic { .. } => { features.request(Features::TEXTURE_ATOMICS) } _ => {} } } } self.features.check_availability(self.options.version) } /// Helper method that checks the [`Features`] needed by a scalar fn scalar_required_features(&mut self, scalar: Scalar) { if scalar.kind == ScalarKind::Float && scalar.width == 8 { self.features.request(Features::DOUBLE_TYPE); } } fn varying_required_features(&mut self, binding: Option<&Binding>, ty: Handle) { if let TypeInner::Struct { ref members, .. } = self.module.types[ty].inner { for member in members { self.varying_required_features(member.binding.as_ref(), member.ty); } } else if let Some(binding) = binding { match *binding { Binding::BuiltIn(built_in) => match built_in { crate::BuiltIn::ClipDistance => self.features.request(Features::CLIP_DISTANCE), crate::BuiltIn::CullDistance => self.features.request(Features::CULL_DISTANCE), crate::BuiltIn::SampleIndex => { self.features.request(Features::SAMPLE_VARIABLES) } crate::BuiltIn::ViewIndex => self.features.request(Features::MULTI_VIEW), crate::BuiltIn::InstanceIndex | crate::BuiltIn::DrawIndex => { self.features.request(Features::INSTANCE_INDEX) } crate::BuiltIn::Barycentric { .. } => { self.features.request(Features::SHADER_BARYCENTRICS) } crate::BuiltIn::PrimitiveIndex => { self.features.request(Features::PRIMITIVE_INDEX) } _ => {} }, Binding::Location { location: _, interpolation, sampling, blend_src, per_primitive: _, } => { if interpolation == Some(Interpolation::Linear) { self.features.request(Features::NOPERSPECTIVE_QUALIFIER); } if sampling == Some(Sampling::Sample) { self.features.request(Features::SAMPLE_QUALIFIER); } if blend_src.is_some() { self.features.request(Features::DUAL_SOURCE_BLENDING); } } } } } } naga-29.0.3/src/back/glsl/keywords.rs000064400000000000000000000326611046102023000155020ustar 00000000000000use crate::proc::KeywordSet; use crate::racy_lock::RacyLock; pub const RESERVED_KEYWORDS: &[&str] = &[ // // GLSL 4.6 keywords, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L2004-L2322 // GLSL ES 3.2 keywords, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/es/3.2/GLSL_ES_Specification_3.20.html#L2166-L2478 // // Note: The GLSL ES 3.2 keywords are the same as GLSL 4.6 keywords with some residing in the reserved section. // The only exception are the missing Vulkan keywords which I think is an oversight (see https://github.com/KhronosGroup/OpenGL-Registry/issues/585). // "const", "uniform", "buffer", "shared", "attribute", "varying", "coherent", "volatile", "restrict", "readonly", "writeonly", "atomic_uint", "layout", "centroid", "flat", "smooth", "noperspective", "patch", "sample", "invariant", "precise", "break", "continue", "do", "for", "while", "switch", "case", "default", "if", "else", "subroutine", "in", "out", "inout", "int", "void", "bool", "true", "false", "float", "double", "discard", "return", "vec2", "vec3", "vec4", "ivec2", "ivec3", "ivec4", "bvec2", "bvec3", "bvec4", "uint", "uvec2", "uvec3", "uvec4", "dvec2", "dvec3", "dvec4", "mat2", "mat3", "mat4", "mat2x2", "mat2x3", "mat2x4", "mat3x2", "mat3x3", "mat3x4", "mat4x2", "mat4x3", "mat4x4", "dmat2", "dmat3", "dmat4", "dmat2x2", "dmat2x3", "dmat2x4", "dmat3x2", "dmat3x3", "dmat3x4", "dmat4x2", "dmat4x3", "dmat4x4", "lowp", "mediump", "highp", "precision", "sampler1D", "sampler1DShadow", "sampler1DArray", "sampler1DArrayShadow", "isampler1D", "isampler1DArray", "usampler1D", "usampler1DArray", "sampler2D", "sampler2DShadow", "sampler2DArray", "sampler2DArrayShadow", "isampler2D", "isampler2DArray", "usampler2D", "usampler2DArray", "sampler2DRect", "sampler2DRectShadow", "isampler2DRect", "usampler2DRect", "sampler2DMS", "isampler2DMS", "usampler2DMS", "sampler2DMSArray", "isampler2DMSArray", "usampler2DMSArray", "sampler3D", "isampler3D", "usampler3D", "samplerCube", "samplerCubeShadow", "isamplerCube", "usamplerCube", "samplerCubeArray", "samplerCubeArrayShadow", "isamplerCubeArray", "usamplerCubeArray", "samplerBuffer", "isamplerBuffer", "usamplerBuffer", "image1D", "iimage1D", "uimage1D", "image1DArray", "iimage1DArray", "uimage1DArray", "image2D", "iimage2D", "uimage2D", "image2DArray", "iimage2DArray", "uimage2DArray", "image2DRect", "iimage2DRect", "uimage2DRect", "image2DMS", "iimage2DMS", "uimage2DMS", "image2DMSArray", "iimage2DMSArray", "uimage2DMSArray", "image3D", "iimage3D", "uimage3D", "imageCube", "iimageCube", "uimageCube", "imageCubeArray", "iimageCubeArray", "uimageCubeArray", "imageBuffer", "iimageBuffer", "uimageBuffer", "struct", // Vulkan keywords "texture1D", "texture1DArray", "itexture1D", "itexture1DArray", "utexture1D", "utexture1DArray", "texture2D", "texture2DArray", "itexture2D", "itexture2DArray", "utexture2D", "utexture2DArray", "texture2DRect", "itexture2DRect", "utexture2DRect", "texture2DMS", "itexture2DMS", "utexture2DMS", "texture2DMSArray", "itexture2DMSArray", "utexture2DMSArray", "texture3D", "itexture3D", "utexture3D", "textureCube", "itextureCube", "utextureCube", "textureCubeArray", "itextureCubeArray", "utextureCubeArray", "textureBuffer", "itextureBuffer", "utextureBuffer", "sampler", "samplerShadow", "subpassInput", "isubpassInput", "usubpassInput", "subpassInputMS", "isubpassInputMS", "usubpassInputMS", // Reserved keywords "common", "partition", "active", "asm", "class", "union", "enum", "typedef", "template", "this", "resource", "goto", "inline", "noinline", "public", "static", "extern", "external", "interface", "long", "short", "half", "fixed", "unsigned", "superp", "input", "output", "hvec2", "hvec3", "hvec4", "fvec2", "fvec3", "fvec4", "filter", "sizeof", "cast", "namespace", "using", "sampler3DRect", // Reserved keywords that were unreserved in GLSL 4.2 "image1DArrayShadow", "image1DShadow", "image2DArrayShadow", "image2DShadow", // Reserved keywords that were unreserved in GLSL 4.4 "packed", "row_major", // // GLSL 4.6 Built-In Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L13314 // // Angle and Trigonometry Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L13469-L13561C5 "radians", "degrees", "sin", "cos", "tan", "asin", "acos", "atan", "sinh", "cosh", "tanh", "asinh", "acosh", "atanh", // Exponential Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L13569-L13620 "pow", "exp", "log", "exp2", "log2", "sqrt", "inversesqrt", // Common Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L13628-L13908 "abs", "sign", "floor", "trunc", "round", "roundEven", "ceil", "fract", "mod", "modf", "min", "max", "clamp", "mix", "step", "smoothstep", "isnan", "isinf", "floatBitsToInt", "floatBitsToUint", "intBitsToFloat", "uintBitsToFloat", "fma", "frexp", "ldexp", // Floating-Point Pack and Unpack Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L13916-L14007 "packUnorm2x16", "packSnorm2x16", "packUnorm4x8", "packSnorm4x8", "unpackUnorm2x16", "unpackSnorm2x16", "unpackUnorm4x8", "unpackSnorm4x8", "packHalf2x16", "unpackHalf2x16", "packDouble2x32", "unpackDouble2x32", // Geometric Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L14014-L14121 "length", "distance", "dot", "cross", "normalize", "ftransform", "faceforward", "reflect", "refract", // Matrix Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L14151-L14215 "matrixCompMult", "outerProduct", "transpose", "determinant", "inverse", // Vector Relational Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L14259-L14322 "lessThan", "lessThanEqual", "greaterThan", "greaterThanEqual", "equal", "notEqual", "any", "all", "not", // Integer Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L14335-L14432 "uaddCarry", "usubBorrow", "umulExtended", "imulExtended", "bitfieldExtract", "bitfieldInsert", "bitfieldReverse", "bitCount", "findLSB", "findMSB", // Texture Query Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L14645-L14732 "textureSize", "textureQueryLod", "textureQueryLevels", "textureSamples", // Texel Lookup Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L14736-L14997 "texture", "textureProj", "textureLod", "textureOffset", "texelFetch", "texelFetchOffset", "textureProjOffset", "textureLodOffset", "textureProjLod", "textureProjLodOffset", "textureGrad", "textureGradOffset", "textureProjGrad", "textureProjGradOffset", // Texture Gather Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L15077-L15154 "textureGather", "textureGatherOffset", "textureGatherOffsets", // Compatibility Profile Texture Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L15161-L15220 "texture1D", "texture1DProj", "texture1DLod", "texture1DProjLod", "texture2D", "texture2DProj", "texture2DLod", "texture2DProjLod", "texture3D", "texture3DProj", "texture3DLod", "texture3DProjLod", "textureCube", "textureCubeLod", "shadow1D", "shadow2D", "shadow1DProj", "shadow2DProj", "shadow1DLod", "shadow2DLod", "shadow1DProjLod", "shadow2DProjLod", // Atomic Counter Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L15241-L15531 "atomicCounterIncrement", "atomicCounterDecrement", "atomicCounter", "atomicCounterAdd", "atomicCounterSubtract", "atomicCounterMin", "atomicCounterMax", "atomicCounterAnd", "atomicCounterOr", "atomicCounterXor", "atomicCounterExchange", "atomicCounterCompSwap", // Atomic Memory Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L15563-L15624 "atomicAdd", "atomicMin", "atomicMax", "atomicAnd", "atomicOr", "atomicXor", "atomicExchange", "atomicCompSwap", // Image Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L15763-L15878 "imageSize", "imageSamples", "imageLoad", "imageStore", "imageAtomicAdd", "imageAtomicMin", "imageAtomicMax", "imageAtomicAnd", "imageAtomicOr", "imageAtomicXor", "imageAtomicExchange", "imageAtomicCompSwap", // Geometry Shader Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L15886-L15932 "EmitStreamVertex", "EndStreamPrimitive", "EmitVertex", "EndPrimitive", // Fragment Processing Functions, Derivative Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L16041-L16114 "dFdx", "dFdy", "dFdxFine", "dFdyFine", "dFdxCoarse", "dFdyCoarse", "fwidth", "fwidthFine", "fwidthCoarse", // Fragment Processing Functions, Interpolation Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L16150-L16198 "interpolateAtCentroid", "interpolateAtSample", "interpolateAtOffset", // Noise Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L16214-L16243 "noise1", "noise2", "noise3", "noise4", // Shader Invocation Control Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L16255-L16276 "barrier", // Shader Memory Control Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L16336-L16382 "memoryBarrier", "memoryBarrierAtomicCounter", "memoryBarrierBuffer", "memoryBarrierShared", "memoryBarrierImage", "groupMemoryBarrier", // Subpass-Input Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L16451-L16470 "subpassLoad", // Shader Invocation Group Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L16483-L16511 "anyInvocation", "allInvocations", "allInvocationsEqual", // // entry point name (should not be shadowed) // "main", // Naga utilities: super::MODF_FUNCTION, super::FREXP_FUNCTION, super::FIRST_INSTANCE_BINDING, ]; /// The above set of reserved keywords, turned into a cached HashSet. This saves /// significant time during [`Namer::reset`](crate::proc::Namer::reset). /// /// See for benchmarks. pub static RESERVED_KEYWORD_SET: RacyLock = RacyLock::new(|| KeywordSet::from_iter(RESERVED_KEYWORDS)); naga-29.0.3/src/back/glsl/mod.rs000064400000000000000000000560441046102023000144130ustar 00000000000000/*! Backend for [GLSL][glsl] (OpenGL Shading Language). The main structure is [`Writer`], it maintains internal state that is used to output a [`Module`](crate::Module) into glsl # Supported versions ### Core - 330 - 400 - 410 - 420 - 430 - 450 ### ES - 300 - 310 [glsl]: https://www.khronos.org/registry/OpenGL/index_gl.php */ // GLSL is mostly a superset of C but it also removes some parts of it this is a list of relevant // aspects for this backend. // // The most notable change is the introduction of the version preprocessor directive that must // always be the first line of a glsl file and is written as // `#version number profile` // `number` is the version itself (i.e. 300) and `profile` is the // shader profile we only support "core" and "es", the former is used in desktop applications and // the later is used in embedded contexts, mobile devices and browsers. Each one as it's own // versions (at the time of writing this the latest version for "core" is 460 and for "es" is 320) // // Other important preprocessor addition is the extension directive which is written as // `#extension name: behaviour` // Extensions provide increased features in a plugin fashion but they aren't required to be // supported hence why they are called extensions, that's why `behaviour` is used it specifies // whether the extension is strictly required or if it should only be enabled if needed. In our case // when we use extensions we set behaviour to `require` always. // // The only thing that glsl removes that makes a difference are pointers. // // Additions that are relevant for the backend are the discard keyword, the introduction of // vector, matrices, samplers, image types and functions that provide common shader operations pub use features::Features; pub use writer::Writer; use alloc::{ borrow::ToOwned, format, string::{String, ToString}, vec, vec::Vec, }; use core::{ cmp::Ordering, fmt::{self, Error as FmtError, Write}, mem, }; use hashbrown::hash_map; use thiserror::Error; use crate::{ back::{self, Baked}, common, proc::{self, NameKey}, valid, Handle, ShaderStage, TypeInner, }; use conv::*; use features::FeaturesManager; /// Contains simple 1:1 conversion functions. mod conv; /// Contains the features related code and the features querying method mod features; /// Contains a constant with a slice of all the reserved keywords RESERVED_KEYWORDS mod keywords; /// Contains the [`Writer`] type. mod writer; /// List of supported `core` GLSL versions. pub const SUPPORTED_CORE_VERSIONS: &[u16] = &[140, 150, 330, 400, 410, 420, 430, 440, 450, 460]; /// List of supported `es` GLSL versions. pub const SUPPORTED_ES_VERSIONS: &[u16] = &[300, 310, 320]; /// The suffix of the variable that will hold the calculated clamped level /// of detail for bounds checking in `ImageLoad` const CLAMPED_LOD_SUFFIX: &str = "_clamped_lod"; pub(crate) const MODF_FUNCTION: &str = "naga_modf"; pub(crate) const FREXP_FUNCTION: &str = "naga_frexp"; // Must match code in glsl_built_in pub const FIRST_INSTANCE_BINDING: &str = "naga_vs_first_instance"; #[cfg(feature = "deserialize")] #[derive(serde::Deserialize)] struct BindingMapSerialization { resource_binding: crate::ResourceBinding, bind_target: u8, } #[cfg(feature = "deserialize")] fn deserialize_binding_map<'de, D>(deserializer: D) -> Result where D: serde::Deserializer<'de>, { use serde::Deserialize; let vec = Vec::::deserialize(deserializer)?; let mut map = BindingMap::default(); for item in vec { map.insert(item.resource_binding, item.bind_target); } Ok(map) } /// Mapping between resources and bindings. pub type BindingMap = alloc::collections::BTreeMap; impl crate::AtomicFunction { const fn to_glsl(self) -> &'static str { match self { Self::Add | Self::Subtract => "Add", Self::And => "And", Self::InclusiveOr => "Or", Self::ExclusiveOr => "Xor", Self::Min => "Min", Self::Max => "Max", Self::Exchange { compare: None } => "Exchange", Self::Exchange { compare: Some(_) } => "", //TODO } } } impl crate::AddressSpace { /// Whether a variable with this address space can be initialized const fn initializable(&self) -> bool { match *self { crate::AddressSpace::Function | crate::AddressSpace::Private => true, crate::AddressSpace::WorkGroup | crate::AddressSpace::Uniform | crate::AddressSpace::Storage { .. } | crate::AddressSpace::Handle | crate::AddressSpace::Immediate | crate::AddressSpace::TaskPayload => false, crate::AddressSpace::RayPayload | crate::AddressSpace::IncomingRayPayload => { unreachable!() } } } } /// A GLSL version. #[derive(Debug, Copy, Clone, PartialEq)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] pub enum Version { /// `core` GLSL. Desktop(u16), /// `es` GLSL. Embedded { version: u16, is_webgl: bool }, } impl Version { /// Create a new gles version pub const fn new_gles(version: u16) -> Self { Self::Embedded { version, is_webgl: false, } } /// Returns true if self is `Version::Embedded` (i.e. is a es version) const fn is_es(&self) -> bool { match *self { Version::Desktop(_) => false, Version::Embedded { .. } => true, } } /// Returns true if targeting WebGL const fn is_webgl(&self) -> bool { match *self { Version::Desktop(_) => false, Version::Embedded { is_webgl, .. } => is_webgl, } } /// Checks the list of currently supported versions and returns true if it contains the /// specified version /// /// # Notes /// As an invalid version number will never be added to the supported version list /// so this also checks for version validity fn is_supported(&self) -> bool { match *self { Version::Desktop(v) => SUPPORTED_CORE_VERSIONS.contains(&v), Version::Embedded { version: v, .. } => SUPPORTED_ES_VERSIONS.contains(&v), } } fn supports_io_locations(&self) -> bool { *self >= Version::Desktop(330) || *self >= Version::new_gles(300) } /// Checks if the version supports all of the explicit layouts: /// - `location=` qualifiers for bindings /// - `binding=` qualifiers for resources /// /// Note: `location=` for vertex inputs and fragment outputs is supported /// unconditionally for GLES 300. fn supports_explicit_locations(&self) -> bool { *self >= Version::Desktop(420) || *self >= Version::new_gles(310) } fn supports_early_depth_test(&self) -> bool { *self >= Version::Desktop(130) || *self >= Version::new_gles(310) } fn supports_std140_layout(&self) -> bool { *self >= Version::Desktop(140) || *self >= Version::new_gles(300) } fn supports_std430_layout(&self) -> bool { // std430 is available from 400 via GL_ARB_shader_storage_buffer_object. *self >= Version::Desktop(400) || *self >= Version::new_gles(310) } fn supports_fma_function(&self) -> bool { *self >= Version::Desktop(400) || *self >= Version::new_gles(320) } fn supports_integer_functions(&self) -> bool { *self >= Version::Desktop(400) || *self >= Version::new_gles(310) } fn supports_frexp_function(&self) -> bool { *self >= Version::Desktop(400) || *self >= Version::new_gles(310) } fn supports_derivative_control(&self) -> bool { *self >= Version::Desktop(450) } // For supports_pack_unpack_4x8, supports_pack_unpack_snorm_2x16, supports_pack_unpack_unorm_2x16 // see: // https://registry.khronos.org/OpenGL-Refpages/gl4/html/unpackUnorm.xhtml // https://registry.khronos.org/OpenGL-Refpages/es3/html/unpackUnorm.xhtml // https://registry.khronos.org/OpenGL-Refpages/gl4/html/packUnorm.xhtml // https://registry.khronos.org/OpenGL-Refpages/es3/html/packUnorm.xhtml fn supports_pack_unpack_4x8(&self) -> bool { *self >= Version::Desktop(400) || *self >= Version::new_gles(310) } fn supports_pack_unpack_snorm_2x16(&self) -> bool { *self >= Version::Desktop(420) || *self >= Version::new_gles(300) } fn supports_pack_unpack_unorm_2x16(&self) -> bool { *self >= Version::Desktop(400) || *self >= Version::new_gles(300) } // https://registry.khronos.org/OpenGL-Refpages/gl4/html/unpackHalf2x16.xhtml // https://registry.khronos.org/OpenGL-Refpages/gl4/html/packHalf2x16.xhtml // https://registry.khronos.org/OpenGL-Refpages/es3/html/unpackHalf2x16.xhtml // https://registry.khronos.org/OpenGL-Refpages/es3/html/packHalf2x16.xhtml fn supports_pack_unpack_half_2x16(&self) -> bool { *self >= Version::Desktop(420) || *self >= Version::new_gles(300) } } impl PartialOrd for Version { fn partial_cmp(&self, other: &Self) -> Option { match (*self, *other) { (Version::Desktop(x), Version::Desktop(y)) => Some(x.cmp(&y)), (Version::Embedded { version: x, .. }, Version::Embedded { version: y, .. }) => { Some(x.cmp(&y)) } _ => None, } } } impl fmt::Display for Version { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match *self { Version::Desktop(v) => write!(f, "{v} core"), Version::Embedded { version: v, .. } => write!(f, "{v} es"), } } } bitflags::bitflags! { /// Configuration flags for the [`Writer`]. #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub struct WriterFlags: u32 { /// Flip output Y and extend Z from (0, 1) to (-1, 1). const ADJUST_COORDINATE_SPACE = 0x1; /// Supports GL_EXT_texture_shadow_lod on the host, which provides /// additional functions on shadows and arrays of shadows. const TEXTURE_SHADOW_LOD = 0x2; /// Supports ARB_shader_draw_parameters on the host, which provides /// support for `gl_BaseInstanceARB`, `gl_BaseVertexARB`, `gl_DrawIDARB`, and `gl_DrawID`. const DRAW_PARAMETERS = 0x4; /// Include unused global variables, constants and functions. By default the output will exclude /// global variables that are not used in the specified entrypoint (including indirect use), /// all constant declarations, and functions that use excluded global variables. const INCLUDE_UNUSED_ITEMS = 0x10; /// Emit `PointSize` output builtin to vertex shaders, which is /// required for drawing with `PointList` topology. /// /// https://registry.khronos.org/OpenGL/specs/es/3.2/GLSL_ES_Specification_3.20.html#built-in-language-variables /// The variable gl_PointSize is intended for a shader to write the size of the point to be rasterized. It is measured in pixels. /// If gl_PointSize is not written to, its value is undefined in subsequent pipe stages. const FORCE_POINT_SIZE = 0x20; } } /// Configuration used in the [`Writer`]. #[derive(Debug, Clone)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] #[cfg_attr(feature = "deserialize", serde(default))] pub struct Options { /// The GLSL version to be used. pub version: Version, /// Configuration flags for the [`Writer`]. pub writer_flags: WriterFlags, /// Map of resources association to binding locations. #[cfg_attr( feature = "deserialize", serde(deserialize_with = "deserialize_binding_map") )] pub binding_map: BindingMap, /// Should workgroup variables be zero initialized (by polyfilling)? pub zero_initialize_workgroup_memory: bool, } impl Default for Options { fn default() -> Self { Options { version: Version::new_gles(310), writer_flags: WriterFlags::ADJUST_COORDINATE_SPACE, binding_map: BindingMap::default(), zero_initialize_workgroup_memory: true, } } } /// A subset of options meant to be changed per pipeline. #[derive(Debug, Clone)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] pub struct PipelineOptions { /// The stage of the entry point. pub shader_stage: ShaderStage, /// The name of the entry point. /// /// If no entry point that matches is found while creating a [`Writer`], an /// error will be thrown. pub entry_point: String, /// How many views to render to, if doing multiview rendering. pub multiview: Option, } #[derive(Debug)] pub struct VaryingLocation { /// The location of the global. /// This corresponds to `layout(location = ..)` in GLSL. pub location: u32, /// The index which can be used for dual source blending. /// This corresponds to `layout(index = ..)` in GLSL. pub index: u32, } /// Reflection info for texture mappings and uniforms. #[derive(Debug)] pub struct ReflectionInfo { /// Mapping between texture names and variables/samplers. pub texture_mapping: crate::FastHashMap, /// Mapping between uniform variables and names. pub uniforms: crate::FastHashMap, String>, /// Mapping between names and attribute locations. pub varying: crate::FastHashMap, /// List of immediate data items in the shader. pub immediates_items: Vec, /// Number of user-defined clip planes. Only applicable to vertex shaders. pub clip_distance_count: u32, } /// Mapping between a texture and its sampler, if it exists. /// /// GLSL pre-Vulkan has no concept of separate textures and samplers. Instead, everything is a /// `gsamplerN` where `g` is the scalar type and `N` is the dimension. But naga uses separate textures /// and samplers in the IR, so the backend produces a [`FastHashMap`](crate::FastHashMap) with the texture name /// as a key and a [`TextureMapping`] as a value. This way, the user knows where to bind. /// /// [`Storage`](crate::ImageClass::Storage) images produce `gimageN` and don't have an associated sampler, /// so the [`sampler`](Self::sampler) field will be [`None`]. #[derive(Debug, Clone)] pub struct TextureMapping { /// Handle to the image global variable. pub texture: Handle, /// Handle to the associated sampler global variable, if it exists. pub sampler: Option>, } /// All information to bind a single uniform value to the shader. /// /// Immediates are emulated using traditional uniforms in OpenGL. /// /// These are composed of a set of primitives (scalar, vector, matrix) that /// are given names. Because they are not backed by the concept of a buffer, /// we must do the work of calculating the offset of each primitive in the /// immediate data block. #[derive(Debug, Clone)] pub struct ImmediateItem { /// GL uniform name for the item. This name is the same as if you were /// to access it directly from a GLSL shader. /// /// The with the following example, the following names will be generated, /// one name per GLSL uniform. /// /// ```glsl /// struct InnerStruct { /// value: f32, /// } /// /// struct ImmediateData { /// InnerStruct inner; /// vec4 array[2]; /// } /// /// uniform ImmediateData _immediates_binding_cs; /// ``` /// /// ```text /// - _immediates_binding_cs.inner.value /// - _immediates_binding_cs.array[0] /// - _immediates_binding_cs.array[1] /// ``` /// pub access_path: String, /// Type of the uniform. This will only ever be a scalar, vector, or matrix. pub ty: Handle, /// The offset in the immediate data memory block this uniform maps to. /// /// The size of the uniform can be derived from the type. pub offset: u32, } /// Helper structure that generates a number #[derive(Default)] struct IdGenerator(u32); impl IdGenerator { /// Generates a number that's guaranteed to be unique for this `IdGenerator` const fn generate(&mut self) -> u32 { // It's just an increasing number but it does the job let ret = self.0; self.0 += 1; ret } } /// Assorted options needed for generating varyings. #[derive(Clone, Copy)] struct VaryingOptions { output: bool, targeting_webgl: bool, draw_parameters: bool, } impl VaryingOptions { const fn from_writer_options(options: &Options, output: bool) -> Self { Self { output, targeting_webgl: options.version.is_webgl(), draw_parameters: options.writer_flags.contains(WriterFlags::DRAW_PARAMETERS), } } } /// Helper wrapper used to get a name for a varying /// /// Varying have different naming schemes depending on their binding: /// - Varyings with builtin bindings get their name from [`glsl_built_in`]. /// - Varyings with location bindings are named `_S_location_X` where `S` is a /// prefix identifying which pipeline stage the varying connects, and `X` is /// the location. struct VaryingName<'a> { binding: &'a crate::Binding, stage: ShaderStage, options: VaryingOptions, } impl fmt::Display for VaryingName<'_> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self.binding { crate::Binding::Location { blend_src: Some(1), .. } => { write!(f, "_fs2p_location1",) } crate::Binding::Location { location, .. } => { let prefix = match (self.stage, self.options.output) { (ShaderStage::Compute, _) => unreachable!(), // pipeline to vertex (ShaderStage::Vertex, false) => "p2vs", // vertex to fragment (ShaderStage::Vertex, true) | (ShaderStage::Fragment, false) => "vs2fs", // fragment to pipeline (ShaderStage::Fragment, true) => "fs2p", ( ShaderStage::Task | ShaderStage::Mesh | ShaderStage::RayGeneration | ShaderStage::AnyHit | ShaderStage::ClosestHit | ShaderStage::Miss, _, ) => unreachable!(), }; write!(f, "_{prefix}_location{location}",) } crate::Binding::BuiltIn(built_in) => { write!(f, "{}", glsl_built_in(built_in, self.options)) } } } } impl ShaderStage { const fn to_str(self) -> &'static str { match self { ShaderStage::Compute => "cs", ShaderStage::Fragment => "fs", ShaderStage::Vertex => "vs", ShaderStage::Task | ShaderStage::Mesh | ShaderStage::RayGeneration | ShaderStage::AnyHit | ShaderStage::ClosestHit | ShaderStage::Miss => unreachable!(), } } } /// Shorthand result used internally by the backend type BackendResult = Result; /// A GLSL compilation error. #[derive(Debug, Error)] pub enum Error { /// A error occurred while writing to the output. #[error("Format error")] FmtError(#[from] FmtError), /// The specified [`Version`] doesn't have all required [`Features`]. /// /// Contains the missing [`Features`]. #[error("The selected version doesn't support {0:?}")] MissingFeatures(Features), /// [`AddressSpace::Immediate`](crate::AddressSpace::Immediate) was used more than /// once in the entry point, which isn't supported. #[error("Multiple immediates aren't supported")] MultipleImmediateData, /// The specified [`Version`] isn't supported. #[error("The specified version isn't supported")] VersionNotSupported, /// The entry point couldn't be found. #[error("The requested entry point couldn't be found")] EntryPointNotFound, /// A call was made to an unsupported external. #[error("A call was made to an unsupported external: {0}")] UnsupportedExternal(String), /// A scalar with an unsupported width was requested. #[error("A scalar with an unsupported width was requested: {0:?}")] UnsupportedScalar(crate::Scalar), /// A image was used with multiple samplers, which isn't supported. #[error("A image was used with multiple samplers")] ImageMultipleSamplers, #[error("{0}")] Custom(String), #[error("overrides should not be present at this stage")] Override, /// [`crate::Sampling::First`] is unsupported. #[error("`{:?}` sampling is unsupported", crate::Sampling::First)] FirstSamplingNotSupported, #[error(transparent)] ResolveArraySizeError(#[from] proc::ResolveArraySizeError), } /// Binary operation with a different logic on the GLSL side. enum BinaryOperation { /// Vector comparison should use the function like `greaterThan()`, etc. VectorCompare, /// Vector component wise operation; used to polyfill unsupported ops like `|` and `&` for `bvecN`'s VectorComponentWise, /// GLSL `%` is SPIR-V `OpUMod/OpSMod` and `mod()` is `OpFMod`, but [`BinaryOperator::Modulo`](crate::BinaryOperator::Modulo) is `OpFRem`. Modulo, /// Any plain operation. No additional logic required. Other, } fn is_value_init_supported(module: &crate::Module, ty: Handle) -> bool { match module.types[ty].inner { TypeInner::Scalar { .. } | TypeInner::Vector { .. } | TypeInner::Matrix { .. } => true, TypeInner::Array { base, size, .. } => { size != crate::ArraySize::Dynamic && is_value_init_supported(module, base) } TypeInner::Struct { ref members, .. } => members .iter() .all(|member| is_value_init_supported(module, member.ty)), _ => false, } } pub fn supported_capabilities() -> valid::Capabilities { use valid::Capabilities as Caps; // Lots of these aren't supported on GLES in general, but naga is able to write them without panicking. Caps::IMMEDIATES | Caps::FLOAT64 | Caps::PRIMITIVE_INDEX | Caps::CLIP_DISTANCE | Caps::MULTIVIEW | Caps::EARLY_DEPTH_TEST | Caps::MULTISAMPLED_SHADING | Caps::DUAL_SOURCE_BLENDING | Caps::CUBE_ARRAY_TEXTURES | Caps::SHADER_INT64 | Caps::SHADER_INT64_ATOMIC_ALL_OPS | Caps::TEXTURE_ATOMIC | Caps::TEXTURE_INT64_ATOMIC | Caps::SUBGROUP | Caps::SUBGROUP_BARRIER | Caps::SHADER_FLOAT16 | Caps::SHADER_FLOAT16_IN_FLOAT32 | Caps::SHADER_BARYCENTRICS | Caps::DRAW_INDEX | Caps::MEMORY_DECORATION_COHERENT | Caps::MEMORY_DECORATION_VOLATILE } naga-29.0.3/src/back/glsl/writer.rs000064400000000000000000006106201046102023000151440ustar 00000000000000use super::*; /// Writer responsible for all code generation. pub struct Writer<'a, W> { // Inputs /// The module being written. pub(in crate::back::glsl) module: &'a crate::Module, /// The module analysis. pub(in crate::back::glsl) info: &'a valid::ModuleInfo, /// The output writer. out: W, /// User defined configuration to be used. pub(in crate::back::glsl) options: &'a Options, /// The bound checking policies to be used pub(in crate::back::glsl) policies: proc::BoundsCheckPolicies, // Internal State /// Features manager used to store all the needed features and write them. pub(in crate::back::glsl) features: FeaturesManager, namer: proc::Namer, /// A map with all the names needed for writing the module /// (generated by a [`Namer`](crate::proc::Namer)). names: crate::FastHashMap, /// A map with the names of global variables needed for reflections. reflection_names_globals: crate::FastHashMap, String>, /// The selected entry point. pub(in crate::back::glsl) entry_point: &'a crate::EntryPoint, /// The index of the selected entry point. pub(in crate::back::glsl) entry_point_idx: proc::EntryPointIndex, /// A generator for unique block numbers. block_id: IdGenerator, /// Set of expressions that have associated temporary variables. named_expressions: crate::NamedExpressions, /// Set of expressions that need to be baked to avoid unnecessary repetition in output need_bake_expressions: back::NeedBakeExpressions, /// Information about nesting of loops and switches. /// /// Used for forwarding continue statements in switches that have been /// transformed to `do {} while(false);` loops. continue_ctx: back::continue_forward::ContinueCtx, /// How many views to render to, if doing multiview rendering. pub(in crate::back::glsl) multiview: Option, /// Mapping of varying variables to their location. Needed for reflections. varying: crate::FastHashMap, /// Number of user-defined clip planes. Only non-zero for vertex shaders. clip_distance_count: u32, } impl<'a, W: Write> Writer<'a, W> { /// Creates a new [`Writer`] instance. /// /// # Errors /// - If the version specified is invalid or supported. /// - If the entry point couldn't be found in the module. /// - If the version specified doesn't support some used features. pub fn new( out: W, module: &'a crate::Module, info: &'a valid::ModuleInfo, options: &'a Options, pipeline_options: &'a PipelineOptions, policies: proc::BoundsCheckPolicies, ) -> Result { // Check if the requested version is supported if !options.version.is_supported() { log::error!("Version {}", options.version); return Err(Error::VersionNotSupported); } // Try to find the entry point and corresponding index let ep_idx = module .entry_points .iter() .position(|ep| { pipeline_options.shader_stage == ep.stage && pipeline_options.entry_point == ep.name }) .ok_or(Error::EntryPointNotFound)?; // Generate a map with names required to write the module let mut names = crate::FastHashMap::default(); let mut namer = proc::Namer::default(); namer.reset( module, &keywords::RESERVED_KEYWORD_SET, proc::KeywordSet::empty(), proc::CaseInsensitiveKeywordSet::empty(), &[ "gl_", // all GL built-in variables "_group", // all normal bindings "_immediates_binding_", // all immediate data bindings ], &mut names, ); // Build the instance let mut this = Self { module, info, out, options, policies, namer, features: FeaturesManager::new(), names, reflection_names_globals: crate::FastHashMap::default(), entry_point: &module.entry_points[ep_idx], entry_point_idx: ep_idx as u16, multiview: pipeline_options.multiview, block_id: IdGenerator::default(), named_expressions: Default::default(), need_bake_expressions: Default::default(), continue_ctx: back::continue_forward::ContinueCtx::default(), varying: Default::default(), clip_distance_count: 0, }; // Find all features required to print this module this.collect_required_features()?; Ok(this) } /// Writes the [`Module`](crate::Module) as glsl to the output /// /// # Notes /// If an error occurs while writing, the output might have been written partially /// /// # Panics /// Might panic if the module is invalid pub fn write(&mut self) -> Result { // We use `writeln!(self.out)` throughout the write to add newlines // to make the output more readable let es = self.options.version.is_es(); // Write the version (It must be the first thing or it isn't a valid glsl output) writeln!(self.out, "#version {}", self.options.version)?; // Write all the needed extensions // // This used to be the last thing being written as it allowed to search for features while // writing the module saving some loops but some older versions (420 or less) required the // extensions to appear before being used, even though extensions are part of the // preprocessor not the processor ¯\_(ツ)_/¯ self.features.write(self.options, &mut self.out)?; // glsl es requires a precision to be specified for floats and ints // TODO: Should this be user configurable? if es { writeln!(self.out)?; writeln!(self.out, "precision highp float;")?; writeln!(self.out, "precision highp int;")?; writeln!(self.out)?; } if self.entry_point.stage == ShaderStage::Compute { let workgroup_size = self.entry_point.workgroup_size; writeln!( self.out, "layout(local_size_x = {}, local_size_y = {}, local_size_z = {}) in;", workgroup_size[0], workgroup_size[1], workgroup_size[2] )?; writeln!(self.out)?; } if self.entry_point.stage == ShaderStage::Vertex && !self .options .writer_flags .contains(WriterFlags::DRAW_PARAMETERS) && self.features.contains(Features::INSTANCE_INDEX) { writeln!(self.out, "uniform uint {FIRST_INSTANCE_BINDING};")?; writeln!(self.out)?; } // Enable early depth tests if needed if let Some(early_depth_test) = self.entry_point.early_depth_test { // If early depth test is supported for this version of GLSL if self.options.version.supports_early_depth_test() { match early_depth_test { crate::EarlyDepthTest::Force => { writeln!(self.out, "layout(early_fragment_tests) in;")?; } crate::EarlyDepthTest::Allow { conservative, .. } => { use crate::ConservativeDepth as Cd; let depth = match conservative { Cd::GreaterEqual => "greater", Cd::LessEqual => "less", Cd::Unchanged => "unchanged", }; writeln!(self.out, "layout (depth_{depth}) out float gl_FragDepth;")?; } } } else { log::warn!( "Early depth testing is not supported for this version of GLSL: {}", self.options.version ); } } if self.entry_point.stage == ShaderStage::Vertex && self.options.version.is_webgl() { if let Some(multiview) = self.multiview.as_ref() { writeln!(self.out, "layout(num_views = {multiview}) in;")?; writeln!(self.out)?; } } // Write struct types. // // This are always ordered because the IR is structured in a way that // you can't make a struct without adding all of its members first. for (handle, ty) in self.module.types.iter() { if let TypeInner::Struct { ref members, .. } = ty.inner { let struct_name = &self.names[&NameKey::Type(handle)]; // Structures ending with runtime-sized arrays can only be // rendered as shader storage blocks in GLSL, not stand-alone // struct types. if !self.module.types[members.last().unwrap().ty] .inner .is_dynamically_sized(&self.module.types) { write!(self.out, "struct {struct_name} ")?; self.write_struct_body(handle, members)?; writeln!(self.out, ";")?; } } } // Write functions for special types. for (type_key, struct_ty) in self.module.special_types.predeclared_types.iter() { match type_key { &crate::PredeclaredType::ModfResult { size, scalar } | &crate::PredeclaredType::FrexpResult { size, scalar } => { let struct_name = &self.names[&NameKey::Type(*struct_ty)]; let arg_type_name_owner; let arg_type_name = if let Some(size) = size { arg_type_name_owner = format!( "{}vec{}", if scalar.width == 8 { "d" } else { "" }, size as u8 ); &arg_type_name_owner } else if scalar.width == 8 { "double" } else { "float" }; let other_type_name_owner; let (defined_func_name, called_func_name, other_type_name) = if matches!(type_key, &crate::PredeclaredType::ModfResult { .. }) { (MODF_FUNCTION, "modf", arg_type_name) } else { let other_type_name = if let Some(size) = size { other_type_name_owner = format!("ivec{}", size as u8); &other_type_name_owner } else { "int" }; (FREXP_FUNCTION, "frexp", other_type_name) }; writeln!(self.out)?; if !self.options.version.supports_frexp_function() && matches!(type_key, &crate::PredeclaredType::FrexpResult { .. }) { writeln!( self.out, "{struct_name} {defined_func_name}({arg_type_name} arg) {{ {other_type_name} other = arg == {arg_type_name}(0) ? {other_type_name}(0) : {other_type_name}({arg_type_name}(1) + log2(arg)); {arg_type_name} fract = arg * exp2({arg_type_name}(-other)); return {struct_name}(fract, other); }}", )?; } else { writeln!( self.out, "{struct_name} {defined_func_name}({arg_type_name} arg) {{ {other_type_name} other; {arg_type_name} fract = {called_func_name}(arg, other); return {struct_name}(fract, other); }}", )?; } } &crate::PredeclaredType::AtomicCompareExchangeWeakResult(_) => { // Handled by the general struct writing loop earlier. } } } // Write all named constants let mut constants = self .module .constants .iter() .filter(|&(_, c)| c.name.is_some()) .peekable(); while let Some((handle, _)) = constants.next() { self.write_global_constant(handle)?; // Add extra newline for readability on last iteration if constants.peek().is_none() { writeln!(self.out)?; } } let ep_info = self.info.get_entry_point(self.entry_point_idx as usize); // Write the globals // // Unless explicitly disabled with WriterFlags::INCLUDE_UNUSED_ITEMS, // we filter all globals that aren't used by the selected entry point as they might be // interfere with each other (i.e. two globals with the same location but different with // different classes) let include_unused = self .options .writer_flags .contains(WriterFlags::INCLUDE_UNUSED_ITEMS); for (handle, global) in self.module.global_variables.iter() { let is_unused = ep_info[handle].is_empty(); if !include_unused && is_unused { continue; } match self.module.types[global.ty].inner { // We treat images separately because they might require // writing the storage format TypeInner::Image { mut dim, arrayed, class, } => { // Gather the storage format if needed let storage_format_access = match self.module.types[global.ty].inner { TypeInner::Image { class: crate::ImageClass::Storage { format, access }, .. } => Some((format, access)), _ => None, }; if dim == crate::ImageDimension::D1 && es { dim = crate::ImageDimension::D2 } // Gether the location if needed let layout_binding = if self.options.version.supports_explicit_locations() { let br = global.binding.as_ref().unwrap(); self.options.binding_map.get(br).cloned() } else { None }; // Write all the layout qualifiers if layout_binding.is_some() || storage_format_access.is_some() { write!(self.out, "layout(")?; if let Some(binding) = layout_binding { write!(self.out, "binding = {binding}")?; } if let Some((format, _)) = storage_format_access { let format_str = glsl_storage_format(format)?; let separator = match layout_binding { Some(_) => ",", None => "", }; write!(self.out, "{separator}{format_str}")?; } write!(self.out, ") ")?; } if let Some((_, access)) = storage_format_access { self.write_storage_access(access)?; } // All images in glsl are `uniform` // The trailing space is important write!(self.out, "uniform ")?; // write the type // // This is way we need the leading space because `write_image_type` doesn't add // any spaces at the beginning or end self.write_image_type(dim, arrayed, class)?; // Finally write the name and end the global with a `;` // The leading space is important let global_name = self.get_global_name(handle, global); writeln!(self.out, " {global_name};")?; writeln!(self.out)?; self.reflection_names_globals.insert(handle, global_name); } // glsl has no concept of samplers so we just ignore it TypeInner::Sampler { .. } => continue, // All other globals are written by `write_global` _ => { self.write_global(handle, global)?; // Add a newline (only for readability) writeln!(self.out)?; } } } for arg in self.entry_point.function.arguments.iter() { self.write_varying(arg.binding.as_ref(), arg.ty, false)?; } if let Some(ref result) = self.entry_point.function.result { self.write_varying(result.binding.as_ref(), result.ty, true)?; } writeln!(self.out)?; // Write all regular functions for (handle, function) in self.module.functions.iter() { // Check that the function doesn't use globals that aren't supported // by the current entry point if !include_unused && !ep_info.dominates_global_use(&self.info[handle]) { continue; } let fun_info = &self.info[handle]; // Skip functions that that are not compatible with this entry point's stage. // // When validation is enabled, it rejects modules whose entry points try to call // incompatible functions, so if we got this far, then any functions incompatible // with our selected entry point must not be used. // // When validation is disabled, `fun_info.available_stages` is always just // `ShaderStages::all()`, so this will write all functions in the module, and // the downstream GLSL compiler will catch any problems. if !fun_info.available_stages.contains(ep_info.available_stages) { continue; } // Write the function self.write_function(back::FunctionType::Function(handle), function, fun_info)?; writeln!(self.out)?; } self.write_function( back::FunctionType::EntryPoint(self.entry_point_idx), &self.entry_point.function, ep_info, )?; // Add newline at the end of file writeln!(self.out)?; // Collect all reflection info and return it to the user self.collect_reflection_info() } fn write_array_size( &mut self, base: Handle, size: crate::ArraySize, ) -> BackendResult { write!(self.out, "[")?; // Write the array size // Writes nothing if `IndexableLength::Dynamic` match size.resolve(self.module.to_ctx())? { proc::IndexableLength::Known(size) => { write!(self.out, "{size}")?; } proc::IndexableLength::Dynamic => (), } write!(self.out, "]")?; if let TypeInner::Array { base: next_base, size: next_size, .. } = self.module.types[base].inner { self.write_array_size(next_base, next_size)?; } Ok(()) } /// Helper method used to write value types /// /// # Notes /// Adds no trailing or leading whitespace fn write_value_type(&mut self, inner: &TypeInner) -> BackendResult { match *inner { // Scalars are simple we just get the full name from `glsl_scalar` TypeInner::Scalar(scalar) | TypeInner::Atomic(scalar) | TypeInner::ValuePointer { size: None, scalar, space: _, } => write!(self.out, "{}", glsl_scalar(scalar)?.full)?, // Vectors are just `gvecN` where `g` is the scalar prefix and `N` is the vector size TypeInner::Vector { size, scalar } | TypeInner::ValuePointer { size: Some(size), scalar, space: _, } => write!(self.out, "{}vec{}", glsl_scalar(scalar)?.prefix, size as u8)?, // Matrices are written with `gmatMxN` where `g` is the scalar prefix (only floats and // doubles are allowed), `M` is the columns count and `N` is the rows count // // glsl supports a matrix shorthand `gmatN` where `N` = `M` but it doesn't justify the // extra branch to write matrices this way TypeInner::Matrix { columns, rows, scalar, } => write!( self.out, "{}mat{}x{}", glsl_scalar(scalar)?.prefix, columns as u8, rows as u8 )?, // GLSL arrays are written as `type name[size]` // Here we only write the size of the array i.e. `[size]` // Base `type` and `name` should be written outside TypeInner::Array { base, size, .. } => self.write_array_size(base, size)?, // Write all variants instead of `_` so that if new variants are added a // no exhaustiveness error is thrown TypeInner::Pointer { .. } | TypeInner::Struct { .. } | TypeInner::Image { .. } | TypeInner::Sampler { .. } | TypeInner::AccelerationStructure { .. } | TypeInner::RayQuery { .. } | TypeInner::BindingArray { .. } | TypeInner::CooperativeMatrix { .. } => { return Err(Error::Custom(format!("Unable to write type {inner:?}"))) } } Ok(()) } /// Helper method used to write non image/sampler types /// /// # Notes /// Adds no trailing or leading whitespace fn write_type(&mut self, ty: Handle) -> BackendResult { match self.module.types[ty].inner { // glsl has no pointer types so just write types as normal and loads are skipped TypeInner::Pointer { base, .. } => self.write_type(base), // glsl structs are written as just the struct name TypeInner::Struct { .. } => { // Get the struct name let name = &self.names[&NameKey::Type(ty)]; write!(self.out, "{name}")?; Ok(()) } // glsl array has the size separated from the base type TypeInner::Array { base, .. } => self.write_type(base), ref other => self.write_value_type(other), } } /// Helper method to write a image type /// /// # Notes /// Adds no leading or trailing whitespace fn write_image_type( &mut self, dim: crate::ImageDimension, arrayed: bool, class: crate::ImageClass, ) -> BackendResult { // glsl images consist of four parts the scalar prefix, the image "type", the dimensions // and modifiers // // There exists two image types // - sampler - for sampled images // - image - for storage images // // There are three possible modifiers that can be used together and must be written in // this order to be valid // - MS - used if it's a multisampled image // - Array - used if it's an image array // - Shadow - used if it's a depth image use crate::ImageClass as Ic; use crate::Scalar as S; let float = S { kind: crate::ScalarKind::Float, width: 4, }; let (base, scalar, ms, comparison) = match class { Ic::Sampled { kind, multi: true } => ("sampler", S { kind, width: 4 }, "MS", ""), Ic::Sampled { kind, multi: false } => ("sampler", S { kind, width: 4 }, "", ""), Ic::Depth { multi: true } => ("sampler", float, "MS", ""), Ic::Depth { multi: false } => ("sampler", float, "", "Shadow"), Ic::Storage { format, .. } => ("image", format.into(), "", ""), Ic::External => unimplemented!(), }; let precision = if self.options.version.is_es() { "highp " } else { "" }; write!( self.out, "{}{}{}{}{}{}{}", precision, glsl_scalar(scalar)?.prefix, base, glsl_dimension(dim), ms, if arrayed { "Array" } else { "" }, comparison )?; Ok(()) } /// Helper method used by [Self::write_global] to write just the layout part of /// a non image/sampler global variable, if applicable. /// /// # Notes /// /// Adds trailing whitespace if any layout qualifier is written fn write_global_layout(&mut self, global: &crate::GlobalVariable) -> BackendResult { // Determine which (if any) explicit memory layout to use, and whether we support it let layout = match global.space { crate::AddressSpace::Uniform => { if !self.options.version.supports_std140_layout() { return Err(Error::Custom( "Uniform address space requires std140 layout support".to_string(), )); } Some("std140") } crate::AddressSpace::Storage { .. } => { if !self.options.version.supports_std430_layout() { return Err(Error::Custom( "Storage address space requires std430 layout support".to_string(), )); } Some("std430") } _ => None, }; // If our version supports explicit layouts, we can also output the explicit binding // if we have it if self.options.version.supports_explicit_locations() { if let Some(ref br) = global.binding { match self.options.binding_map.get(br) { Some(binding) => { write!(self.out, "layout(")?; if let Some(layout) = layout { write!(self.out, "{layout}, ")?; } write!(self.out, "binding = {binding}) ")?; return Ok(()); } None => { log::debug!("unassigned binding for {:?}", global.name); } } } } // Either no explicit bindings are supported or we didn't have any. // Write just the memory layout. if let Some(layout) = layout { write!(self.out, "layout({layout}) ")?; } Ok(()) } /// Helper method used to write non images/sampler globals /// /// # Notes /// Adds a newline /// /// # Panics /// If the global has type sampler fn write_global( &mut self, handle: Handle, global: &crate::GlobalVariable, ) -> BackendResult { self.write_global_layout(global)?; if let crate::AddressSpace::Storage { access } = global.space { self.write_storage_access(access)?; if global .memory_decorations .contains(crate::MemoryDecorations::COHERENT) { write!(self.out, "coherent ")?; } if global .memory_decorations .contains(crate::MemoryDecorations::VOLATILE) { write!(self.out, "volatile ")?; } } if let Some(storage_qualifier) = glsl_storage_qualifier(global.space) { write!(self.out, "{storage_qualifier} ")?; } match global.space { crate::AddressSpace::Private => { self.write_simple_global(handle, global)?; } crate::AddressSpace::WorkGroup => { self.write_simple_global(handle, global)?; } crate::AddressSpace::Immediate => { self.write_simple_global(handle, global)?; } crate::AddressSpace::Uniform => { self.write_interface_block(handle, global)?; } crate::AddressSpace::Storage { .. } => { self.write_interface_block(handle, global)?; } crate::AddressSpace::TaskPayload => { self.write_interface_block(handle, global)?; } // A global variable in the `Function` address space is a // contradiction in terms. crate::AddressSpace::Function => unreachable!(), // Textures and samplers are handled directly in `Writer::write`. crate::AddressSpace::Handle => unreachable!(), // ray tracing pipelines unsupported crate::AddressSpace::RayPayload | crate::AddressSpace::IncomingRayPayload => { unreachable!() } } Ok(()) } fn write_simple_global( &mut self, handle: Handle, global: &crate::GlobalVariable, ) -> BackendResult { self.write_type(global.ty)?; write!(self.out, " ")?; self.write_global_name(handle, global)?; if let TypeInner::Array { base, size, .. } = self.module.types[global.ty].inner { self.write_array_size(base, size)?; } if global.space.initializable() && is_value_init_supported(self.module, global.ty) { write!(self.out, " = ")?; if let Some(init) = global.init { self.write_const_expr(init, &self.module.global_expressions)?; } else { self.write_zero_init_value(global.ty)?; } } writeln!(self.out, ";")?; if let crate::AddressSpace::Immediate = global.space { let global_name = self.get_global_name(handle, global); self.reflection_names_globals.insert(handle, global_name); } Ok(()) } /// Write an interface block for a single Naga global. /// /// Write `block_name { members }`. Since `block_name` must be unique /// between blocks and structs, we add `_block_ID` where `ID` is a /// `IdGenerator` generated number. Write `members` in the same way we write /// a struct's members. fn write_interface_block( &mut self, handle: Handle, global: &crate::GlobalVariable, ) -> BackendResult { // Write the block name, it's just the struct name appended with `_block_ID` let ty_name = &self.names[&NameKey::Type(global.ty)]; let block_name = format!( "{}_block_{}{:?}", // avoid double underscores as they are reserved in GLSL ty_name.trim_end_matches('_'), self.block_id.generate(), self.entry_point.stage, ); write!(self.out, "{block_name} ")?; self.reflection_names_globals.insert(handle, block_name); match self.module.types[global.ty].inner { TypeInner::Struct { ref members, .. } if self.module.types[members.last().unwrap().ty] .inner .is_dynamically_sized(&self.module.types) => { // Structs with dynamically sized arrays must have their // members lifted up as members of the interface block. GLSL // can't write such struct types anyway. self.write_struct_body(global.ty, members)?; write!(self.out, " ")?; self.write_global_name(handle, global)?; } _ => { // A global of any other type is written as the sole member // of the interface block. Since the interface block is // anonymous, this becomes visible in the global scope. write!(self.out, "{{ ")?; self.write_type(global.ty)?; write!(self.out, " ")?; self.write_global_name(handle, global)?; if let TypeInner::Array { base, size, .. } = self.module.types[global.ty].inner { self.write_array_size(base, size)?; } write!(self.out, "; }}")?; } } writeln!(self.out, ";")?; Ok(()) } /// Helper method used to find which expressions of a given function require baking /// /// # Notes /// Clears `need_bake_expressions` set before adding to it fn update_expressions_to_bake(&mut self, func: &crate::Function, info: &valid::FunctionInfo) { use crate::Expression; self.need_bake_expressions.clear(); for (fun_handle, expr) in func.expressions.iter() { let expr_info = &info[fun_handle]; let min_ref_count = func.expressions[fun_handle].bake_ref_count(); if min_ref_count <= expr_info.ref_count { self.need_bake_expressions.insert(fun_handle); } let inner = expr_info.ty.inner_with(&self.module.types); if let Expression::Math { fun, arg, arg1, arg2, .. } = *expr { match fun { crate::MathFunction::Dot => { // if the expression is a Dot product with integer arguments, // then the args needs baking as well if let TypeInner::Scalar(crate::Scalar { kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint, .. }) = *inner { self.need_bake_expressions.insert(arg); self.need_bake_expressions.insert(arg1.unwrap()); } } crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => { self.need_bake_expressions.insert(arg); self.need_bake_expressions.insert(arg1.unwrap()); } crate::MathFunction::Pack4xI8 | crate::MathFunction::Pack4xU8 | crate::MathFunction::Pack4xI8Clamp | crate::MathFunction::Pack4xU8Clamp | crate::MathFunction::Unpack4xI8 | crate::MathFunction::Unpack4xU8 | crate::MathFunction::QuantizeToF16 => { self.need_bake_expressions.insert(arg); } /* crate::MathFunction::Pack4x8unorm | */ crate::MathFunction::Unpack4x8snorm if !self.options.version.supports_pack_unpack_4x8() => { // We have a fallback if the platform doesn't natively support these self.need_bake_expressions.insert(arg); } /* crate::MathFunction::Pack4x8unorm | */ crate::MathFunction::Unpack4x8unorm if !self.options.version.supports_pack_unpack_4x8() => { self.need_bake_expressions.insert(arg); } /* crate::MathFunction::Pack2x16snorm | */ crate::MathFunction::Unpack2x16snorm if !self.options.version.supports_pack_unpack_snorm_2x16() => { self.need_bake_expressions.insert(arg); } /* crate::MathFunction::Pack2x16unorm | */ crate::MathFunction::Unpack2x16unorm if !self.options.version.supports_pack_unpack_unorm_2x16() => { self.need_bake_expressions.insert(arg); } crate::MathFunction::ExtractBits => { // Only argument 1 is re-used. self.need_bake_expressions.insert(arg1.unwrap()); } crate::MathFunction::InsertBits => { // Only argument 2 is re-used. self.need_bake_expressions.insert(arg2.unwrap()); } crate::MathFunction::CountLeadingZeros => { if let Some(crate::ScalarKind::Sint) = inner.scalar_kind() { self.need_bake_expressions.insert(arg); } } _ => {} } } } for statement in func.body.iter() { match *statement { crate::Statement::Atomic { fun: crate::AtomicFunction::Exchange { compare: Some(cmp) }, .. } => { self.need_bake_expressions.insert(cmp); } _ => {} } } } /// Helper method used to get a name for a global /// /// Globals have different naming schemes depending on their binding: /// - Globals without bindings use the name from the [`Namer`](crate::proc::Namer) /// - Globals with resource binding are named `_group_X_binding_Y` where `X` /// is the group and `Y` is the binding fn get_global_name( &self, handle: Handle, global: &crate::GlobalVariable, ) -> String { match (&global.binding, global.space) { (&Some(ref br), _) => { format!( "_group_{}_binding_{}_{}", br.group, br.binding, self.entry_point.stage.to_str() ) } (&None, crate::AddressSpace::Immediate) => { format!("_immediates_binding_{}", self.entry_point.stage.to_str()) } (&None, _) => self.names[&NameKey::GlobalVariable(handle)].clone(), } } /// Helper method used to write a name for a global without additional heap allocation fn write_global_name( &mut self, handle: Handle, global: &crate::GlobalVariable, ) -> BackendResult { match (&global.binding, global.space) { (&Some(ref br), _) => write!( self.out, "_group_{}_binding_{}_{}", br.group, br.binding, self.entry_point.stage.to_str() )?, (&None, crate::AddressSpace::Immediate) => write!( self.out, "_immediates_binding_{}", self.entry_point.stage.to_str() )?, (&None, _) => write!( self.out, "{}", &self.names[&NameKey::GlobalVariable(handle)] )?, } Ok(()) } /// Write a GLSL global that will carry a Naga entry point's argument or return value. /// /// A Naga entry point's arguments and return value are rendered in GLSL as /// variables at global scope with the `in` and `out` storage qualifiers. /// The code we generate for `main` loads from all the `in` globals into /// appropriately named locals. Before it returns, `main` assigns the /// components of its return value into all the `out` globals. /// /// This function writes a declaration for one such GLSL global, /// representing a value passed into or returned from [`self.entry_point`] /// that has a [`Location`] binding. The global's name is generated based on /// the location index and the shader stages being connected; see /// [`VaryingName`]. This means we don't need to know the names of /// arguments, just their types and bindings. /// /// Emit nothing for entry point arguments or return values with [`BuiltIn`] /// bindings; `main` will read from or assign to the appropriate GLSL /// special variable; these are pre-declared. As an exception, we do declare /// `gl_Position` or `gl_FragCoord` with the `invariant` qualifier if /// needed. /// /// Use `output` together with [`self.entry_point.stage`] to determine which /// shader stages are being connected, and choose the `in` or `out` storage /// qualifier. /// /// [`self.entry_point`]: Writer::entry_point /// [`self.entry_point.stage`]: crate::EntryPoint::stage /// [`Location`]: crate::Binding::Location /// [`BuiltIn`]: crate::Binding::BuiltIn fn write_varying( &mut self, binding: Option<&crate::Binding>, ty: Handle, output: bool, ) -> Result<(), Error> { // For a struct, emit a separate global for each member with a binding. if let TypeInner::Struct { ref members, .. } = self.module.types[ty].inner { for member in members { self.write_varying(member.binding.as_ref(), member.ty, output)?; } return Ok(()); } let binding = match binding { None => return Ok(()), Some(binding) => binding, }; let (location, interpolation, sampling, blend_src) = match *binding { crate::Binding::Location { location, interpolation, sampling, blend_src, per_primitive: _, } => (location, interpolation, sampling, blend_src), crate::Binding::BuiltIn(built_in) => { match built_in { crate::BuiltIn::Position { invariant: true } => { match (self.options.version, self.entry_point.stage) { ( Version::Embedded { version: 300, is_webgl: true, }, ShaderStage::Fragment, ) => { // `invariant gl_FragCoord` is not allowed in WebGL2 and possibly // OpenGL ES in general (waiting on confirmation). // // See https://github.com/KhronosGroup/WebGL/issues/3518 } _ => { writeln!( self.out, "invariant {};", glsl_built_in( built_in, VaryingOptions::from_writer_options(self.options, output) ) )?; } } } crate::BuiltIn::ClipDistance => { // Re-declare `gl_ClipDistance` with number of clip planes. let TypeInner::Array { size, .. } = self.module.types[ty].inner else { unreachable!(); }; let proc::IndexableLength::Known(size) = size.resolve(self.module.to_ctx())? else { unreachable!(); }; self.clip_distance_count = size; writeln!(self.out, "out float gl_ClipDistance[{size}];")?; } _ => {} } return Ok(()); } }; // Write the interpolation modifier if needed // // We ignore all interpolation and auxiliary modifiers that aren't used in fragment // shaders' input globals or vertex shaders' output globals. let emit_interpolation_and_auxiliary = match self.entry_point.stage { ShaderStage::Vertex => output, ShaderStage::Fragment => !output, ShaderStage::Compute => false, ShaderStage::Task | ShaderStage::Mesh | ShaderStage::RayGeneration | ShaderStage::AnyHit | ShaderStage::ClosestHit | ShaderStage::Miss => unreachable!(), }; // Write the I/O locations, if allowed let io_location = if self.options.version.supports_explicit_locations() || !emit_interpolation_and_auxiliary { if self.options.version.supports_io_locations() { if let Some(blend_src) = blend_src { write!( self.out, "layout(location = {location}, index = {blend_src}) " )?; } else { write!(self.out, "layout(location = {location}) ")?; } None } else { Some(VaryingLocation { location, index: blend_src.unwrap_or(0), }) } } else { None }; // Write the interpolation qualifier. if let Some(interp) = interpolation { if emit_interpolation_and_auxiliary { write!(self.out, "{} ", glsl_interpolation(interp))?; } } // Write the sampling auxiliary qualifier. // // Before GLSL 4.2, the `centroid` and `sample` qualifiers were required to appear // immediately before the `in` / `out` qualifier, so we'll just follow that rule // here, regardless of the version. if let Some(sampling) = sampling { if emit_interpolation_and_auxiliary { if let Some(qualifier) = glsl_sampling(sampling)? { write!(self.out, "{qualifier} ")?; } } } // Write the input/output qualifier. write!(self.out, "{} ", if output { "out" } else { "in" })?; // Write the type // `write_type` adds no leading or trailing spaces self.write_type(ty)?; // Finally write the global name and end the global with a `;` and a newline // Leading space is important let vname = VaryingName { binding: &crate::Binding::Location { location, interpolation: None, sampling: None, blend_src, per_primitive: false, }, stage: self.entry_point.stage, options: VaryingOptions::from_writer_options(self.options, output), }; writeln!(self.out, " {vname};")?; if let Some(location) = io_location { self.varying.insert(vname.to_string(), location); } Ok(()) } /// Helper method used to write functions (both entry points and regular functions) /// /// # Notes /// Adds a newline fn write_function( &mut self, ty: back::FunctionType, func: &crate::Function, info: &valid::FunctionInfo, ) -> BackendResult { // Create a function context for the function being written let ctx = back::FunctionCtx { ty, info, expressions: &func.expressions, named_expressions: &func.named_expressions, }; self.named_expressions.clear(); self.update_expressions_to_bake(func, info); // Write the function header // // glsl headers are the same as in c: // `ret_type name(args)` // `ret_type` is the return type // `name` is the function name // `args` is a comma separated list of `type name` // | - `type` is the argument type // | - `name` is the argument name // Start by writing the return type if any otherwise write void // This is the only place where `void` is a valid type // (though it's more a keyword than a type) if let back::FunctionType::EntryPoint(_) = ctx.ty { write!(self.out, "void")?; } else if let Some(ref result) = func.result { self.write_type(result.ty)?; if let TypeInner::Array { base, size, .. } = self.module.types[result.ty].inner { self.write_array_size(base, size)? } } else { write!(self.out, "void")?; } // Write the function name and open parentheses for the argument list let function_name = match ctx.ty { back::FunctionType::Function(handle) => &self.names[&NameKey::Function(handle)], back::FunctionType::EntryPoint(_) => "main", }; write!(self.out, " {function_name}(")?; // Write the comma separated argument list // // We need access to `Self` here so we use the reference passed to the closure as an // argument instead of capturing as that would cause a borrow checker error let arguments = match ctx.ty { back::FunctionType::EntryPoint(_) => &[][..], back::FunctionType::Function(_) => &func.arguments, }; let arguments: Vec<_> = arguments .iter() .enumerate() .filter(|&(_, arg)| match self.module.types[arg.ty].inner { TypeInner::Sampler { .. } => false, _ => true, }) .collect(); self.write_slice(&arguments, |this, _, &(i, arg)| { // Write the argument type match this.module.types[arg.ty].inner { // We treat images separately because they might require // writing the storage format TypeInner::Image { dim, arrayed, class, } => { // Write the storage format if needed if let TypeInner::Image { class: crate::ImageClass::Storage { format, .. }, .. } = this.module.types[arg.ty].inner { write!(this.out, "layout({}) ", glsl_storage_format(format)?)?; } // write the type // // This is way we need the leading space because `write_image_type` doesn't add // any spaces at the beginning or end this.write_image_type(dim, arrayed, class)?; } TypeInner::Pointer { base, .. } => { // write parameter qualifiers write!(this.out, "inout ")?; this.write_type(base)?; } // All other types are written by `write_type` _ => { this.write_type(arg.ty)?; } } // Write the argument name // The leading space is important write!(this.out, " {}", &this.names[&ctx.argument_key(i as u32)])?; // Write array size match this.module.types[arg.ty].inner { TypeInner::Array { base, size, .. } => { this.write_array_size(base, size)?; } TypeInner::Pointer { base, .. } => { if let TypeInner::Array { base, size, .. } = this.module.types[base].inner { this.write_array_size(base, size)?; } } _ => {} } Ok(()) })?; // Close the parentheses and open braces to start the function body writeln!(self.out, ") {{")?; if self.options.zero_initialize_workgroup_memory && ctx.ty.is_compute_like_entry_point(self.module) { self.write_workgroup_variables_initialization(&ctx)?; } // Compose the function arguments from globals, in case of an entry point. if let back::FunctionType::EntryPoint(ep_index) = ctx.ty { let stage = self.module.entry_points[ep_index as usize].stage; for (index, arg) in func.arguments.iter().enumerate() { write!(self.out, "{}", back::INDENT)?; self.write_type(arg.ty)?; let name = &self.names[&NameKey::EntryPointArgument(ep_index, index as u32)]; write!(self.out, " {name}")?; write!(self.out, " = ")?; match self.module.types[arg.ty].inner { TypeInner::Struct { ref members, .. } => { self.write_type(arg.ty)?; write!(self.out, "(")?; for (index, member) in members.iter().enumerate() { let varying_name = VaryingName { binding: member.binding.as_ref().unwrap(), stage, options: VaryingOptions::from_writer_options(self.options, false), }; if index != 0 { write!(self.out, ", ")?; } write!(self.out, "{varying_name}")?; } writeln!(self.out, ");")?; } _ => { let varying_name = VaryingName { binding: arg.binding.as_ref().unwrap(), stage, options: VaryingOptions::from_writer_options(self.options, false), }; writeln!(self.out, "{varying_name};")?; } } } } // Write all function locals // Locals are `type name (= init)?;` where the init part (including the =) are optional // // Always adds a newline for (handle, local) in func.local_variables.iter() { // Write indentation (only for readability) and the type // `write_type` adds no trailing space write!(self.out, "{}", back::INDENT)?; self.write_type(local.ty)?; // Write the local name // The leading space is important write!(self.out, " {}", self.names[&ctx.name_key(handle)])?; // Write size for array type if let TypeInner::Array { base, size, .. } = self.module.types[local.ty].inner { self.write_array_size(base, size)?; } // Write the local initializer if needed if let Some(init) = local.init { // Put the equal signal only if there's a initializer // The leading and trailing spaces aren't needed but help with readability write!(self.out, " = ")?; // Write the constant // `write_constant` adds no trailing or leading space/newline self.write_expr(init, &ctx)?; } else if is_value_init_supported(self.module, local.ty) { write!(self.out, " = ")?; self.write_zero_init_value(local.ty)?; } // Finish the local with `;` and add a newline (only for readability) writeln!(self.out, ";")? } // Write the function body (statement list) for sta in func.body.iter() { // Write a statement, the indentation should always be 1 when writing the function body // `write_stmt` adds a newline self.write_stmt(sta, &ctx, back::Level(1))?; } // Close braces and add a newline writeln!(self.out, "}}")?; Ok(()) } fn write_workgroup_variables_initialization( &mut self, ctx: &back::FunctionCtx, ) -> BackendResult { let mut vars = self .module .global_variables .iter() .filter(|&(handle, var)| { !ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup }) .peekable(); if vars.peek().is_some() { let level = back::Level(1); writeln!(self.out, "{level}if (gl_LocalInvocationID == uvec3(0u)) {{")?; for (handle, var) in vars { let name = &self.names[&NameKey::GlobalVariable(handle)]; write!(self.out, "{}{} = ", level.next(), name)?; self.write_zero_init_value(var.ty)?; writeln!(self.out, ";")?; } writeln!(self.out, "{level}}}")?; self.write_control_barrier(crate::Barrier::WORK_GROUP, level)?; } Ok(()) } /// Write a list of comma separated `T` values using a writer function `F`. /// /// The writer function `F` receives a mutable reference to `self` that if needed won't cause /// borrow checker issues (using for example a closure with `self` will cause issues), the /// second argument is the 0 based index of the element on the list, and the last element is /// a reference to the element `T` being written /// /// # Notes /// - Adds no newlines or leading/trailing whitespace /// - The last element won't have a trailing `,` fn write_slice BackendResult>( &mut self, data: &[T], mut f: F, ) -> BackendResult { // Loop through `data` invoking `f` for each element for (index, item) in data.iter().enumerate() { if index != 0 { write!(self.out, ", ")?; } f(self, index as u32, item)?; } Ok(()) } /// Helper method used to write global constants fn write_global_constant(&mut self, handle: Handle) -> BackendResult { write!(self.out, "const ")?; let constant = &self.module.constants[handle]; self.write_type(constant.ty)?; let name = &self.names[&NameKey::Constant(handle)]; write!(self.out, " {name}")?; if let TypeInner::Array { base, size, .. } = self.module.types[constant.ty].inner { self.write_array_size(base, size)?; } write!(self.out, " = ")?; self.write_const_expr(constant.init, &self.module.global_expressions)?; writeln!(self.out, ";")?; Ok(()) } /// Helper method used to output a dot product as an arithmetic expression /// fn write_dot_product( &mut self, arg: Handle, arg1: Handle, size: usize, ctx: &back::FunctionCtx, ) -> BackendResult { // Write parentheses around the dot product expression to prevent operators // with different precedences from applying earlier. write!(self.out, "(")?; // Cycle through all the components of the vector for index in 0..size { let component = back::COMPONENTS[index]; // Write the addition to the previous product // This will print an extra '+' at the beginning but that is fine in glsl write!(self.out, " + ")?; // Write the first vector expression, this expression is marked to be // cached so unless it can't be cached (for example, it's a Constant) // it shouldn't produce large expressions. self.write_expr(arg, ctx)?; // Access the current component on the first vector write!(self.out, ".{component} * ")?; // Write the second vector expression, this expression is marked to be // cached so unless it can't be cached (for example, it's a Constant) // it shouldn't produce large expressions. self.write_expr(arg1, ctx)?; // Access the current component on the second vector write!(self.out, ".{component}")?; } write!(self.out, ")")?; Ok(()) } /// Helper method used to write structs /// /// # Notes /// Ends in a newline fn write_struct_body( &mut self, handle: Handle, members: &[crate::StructMember], ) -> BackendResult { // glsl structs are written as in C // `struct name() { members };` // | `struct` is a keyword // | `name` is the struct name // | `members` is a semicolon separated list of `type name` // | `type` is the member type // | `name` is the member name writeln!(self.out, "{{")?; for (idx, member) in members.iter().enumerate() { // The indentation is only for readability write!(self.out, "{}", back::INDENT)?; match self.module.types[member.ty].inner { TypeInner::Array { base, size, stride: _, } => { self.write_type(base)?; write!( self.out, " {}", &self.names[&NameKey::StructMember(handle, idx as u32)] )?; // Write [size] self.write_array_size(base, size)?; // Newline is important writeln!(self.out, ";")?; } _ => { // Write the member type // Adds no trailing space self.write_type(member.ty)?; // Write the member name and put a semicolon // The leading space is important // All members must have a semicolon even the last one writeln!( self.out, " {};", &self.names[&NameKey::StructMember(handle, idx as u32)] )?; } } } write!(self.out, "}}")?; Ok(()) } /// Helper method used to write statements /// /// # Notes /// Always adds a newline fn write_stmt( &mut self, sta: &crate::Statement, ctx: &back::FunctionCtx, level: back::Level, ) -> BackendResult { use crate::Statement; match *sta { // This is where we can generate intermediate constants for some expression types. Statement::Emit(ref range) => { for handle in range.clone() { let ptr_class = ctx.resolve_type(handle, &self.module.types).pointer_space(); let expr_name = if ptr_class.is_some() { // GLSL can't save a pointer-valued expression in a variable, // but we shouldn't ever need to: they should never be named expressions, // and none of the expression types flagged by bake_ref_count can be pointer-valued. None } else if let Some(name) = ctx.named_expressions.get(&handle) { // Front end provides names for all variables at the start of writing. // But we write them to step by step. We need to recache them // Otherwise, we could accidentally write variable name instead of full expression. // Also, we use sanitized names! It defense backend from generating variable with name from reserved keywords. Some(self.namer.call(name)) } else if self.need_bake_expressions.contains(&handle) { Some(Baked(handle).to_string()) } else { None }; // If we are going to write an `ImageLoad` next and the target image // is sampled and we are using the `Restrict` policy for bounds // checking images we need to write a local holding the clamped lod. if let crate::Expression::ImageLoad { image, level: Some(level_expr), .. } = ctx.expressions[handle] { if let TypeInner::Image { class: crate::ImageClass::Sampled { .. }, .. } = *ctx.resolve_type(image, &self.module.types) { if let proc::BoundsCheckPolicy::Restrict = self.policies.image_load { write!(self.out, "{level}")?; self.write_clamped_lod(ctx, handle, image, level_expr)? } } } if let Some(name) = expr_name { write!(self.out, "{level}")?; self.write_named_expr(handle, name, handle, ctx)?; } } } // Blocks are simple we just need to write the block statements between braces // We could also just print the statements but this is more readable and maps more // closely to the IR Statement::Block(ref block) => { write!(self.out, "{level}")?; writeln!(self.out, "{{")?; for sta in block.iter() { // Increase the indentation to help with readability self.write_stmt(sta, ctx, level.next())? } writeln!(self.out, "{level}}}")? } // Ifs are written as in C: // ``` // if(condition) { // accept // } else { // reject // } // ``` Statement::If { condition, ref accept, ref reject, } => { write!(self.out, "{level}")?; write!(self.out, "if (")?; self.write_expr(condition, ctx)?; writeln!(self.out, ") {{")?; for sta in accept { // Increase indentation to help with readability self.write_stmt(sta, ctx, level.next())?; } // If there are no statements in the reject block we skip writing it // This is only for readability if !reject.is_empty() { writeln!(self.out, "{level}}} else {{")?; for sta in reject { // Increase indentation to help with readability self.write_stmt(sta, ctx, level.next())?; } } writeln!(self.out, "{level}}}")? } // Switch are written as in C: // ``` // switch (selector) { // // Fallthrough // case label: // block // // Non fallthrough // case label: // block // break; // default: // block // } // ``` // Where the `default` case happens isn't important but we put it last // so that we don't need to print a `break` for it Statement::Switch { selector, ref cases, } => { let l2 = level.next(); // Some GLSL consumers may not handle switches with a single // body correctly: See wgpu#4514. Write such switch statements // as a `do {} while(false);` loop instead. // // Since doing so may inadvertently capture `continue` // statements in the switch body, we must apply continue // forwarding. See the `naga::back::continue_forward` module // docs for details. let one_body = cases .iter() .rev() .skip(1) .all(|case| case.fall_through && case.body.is_empty()); if one_body { // Unlike HLSL, in GLSL `continue_ctx` only needs to know // about [`Switch`] statements that are being rendered as // `do-while` loops. if let Some(variable) = self.continue_ctx.enter_switch(&mut self.namer) { writeln!(self.out, "{level}bool {variable} = false;",)?; }; writeln!(self.out, "{level}do {{")?; // Note: Expressions have no side-effects so we don't need to emit selector expression. // Body if let Some(case) = cases.last() { for sta in case.body.iter() { self.write_stmt(sta, ctx, l2)?; } } // End do-while writeln!(self.out, "{level}}} while(false);")?; // Handle any forwarded continue statements. use back::continue_forward::ExitControlFlow; let op = match self.continue_ctx.exit_switch() { ExitControlFlow::None => None, ExitControlFlow::Continue { variable } => Some(("continue", variable)), ExitControlFlow::Break { variable } => Some(("break", variable)), }; if let Some((control_flow, variable)) = op { writeln!(self.out, "{level}if ({variable}) {{")?; writeln!(self.out, "{l2}{control_flow};")?; writeln!(self.out, "{level}}}")?; } } else { // Start the switch write!(self.out, "{level}")?; write!(self.out, "switch(")?; self.write_expr(selector, ctx)?; writeln!(self.out, ") {{")?; // Write all cases for case in cases { match case.value { crate::SwitchValue::I32(value) => { write!(self.out, "{l2}case {value}:")? } crate::SwitchValue::U32(value) => { write!(self.out, "{l2}case {value}u:")? } crate::SwitchValue::Default => write!(self.out, "{l2}default:")?, } let write_block_braces = !(case.fall_through && case.body.is_empty()); if write_block_braces { writeln!(self.out, " {{")?; } else { writeln!(self.out)?; } for sta in case.body.iter() { self.write_stmt(sta, ctx, l2.next())?; } if !case.fall_through && case.body.last().is_none_or(|s| !s.is_terminator()) { writeln!(self.out, "{}break;", l2.next())?; } if write_block_braces { writeln!(self.out, "{l2}}}")?; } } writeln!(self.out, "{level}}}")? } } // Loops in naga IR are based on wgsl loops, glsl can emulate the behaviour by using a // while true loop and appending the continuing block to the body resulting on: // ``` // bool loop_init = true; // while(true) { // if (!loop_init) { } // loop_init = false; // // } // ``` Statement::Loop { ref body, ref continuing, break_if, } => { self.continue_ctx.enter_loop(); if !continuing.is_empty() || break_if.is_some() { let gate_name = self.namer.call("loop_init"); writeln!(self.out, "{level}bool {gate_name} = true;")?; writeln!(self.out, "{level}while(true) {{")?; let l2 = level.next(); let l3 = l2.next(); writeln!(self.out, "{l2}if (!{gate_name}) {{")?; for sta in continuing { self.write_stmt(sta, ctx, l3)?; } if let Some(condition) = break_if { write!(self.out, "{l3}if (")?; self.write_expr(condition, ctx)?; writeln!(self.out, ") {{")?; writeln!(self.out, "{}break;", l3.next())?; writeln!(self.out, "{l3}}}")?; } writeln!(self.out, "{l2}}}")?; writeln!(self.out, "{}{} = false;", level.next(), gate_name)?; } else { writeln!(self.out, "{level}while(true) {{")?; } for sta in body { self.write_stmt(sta, ctx, level.next())?; } writeln!(self.out, "{level}}}")?; self.continue_ctx.exit_loop(); } // Break, continue and return as written as in C // `break;` Statement::Break => { write!(self.out, "{level}")?; writeln!(self.out, "break;")? } // `continue;` Statement::Continue => { // Sometimes we must render a `Continue` statement as a `break`. // See the docs for the `back::continue_forward` module. if let Some(variable) = self.continue_ctx.continue_encountered() { writeln!(self.out, "{level}{variable} = true;",)?; writeln!(self.out, "{level}break;")? } else { writeln!(self.out, "{level}continue;")? } } // `return expr;`, `expr` is optional Statement::Return { value } => { write!(self.out, "{level}")?; match ctx.ty { back::FunctionType::Function(_) => { write!(self.out, "return")?; // Write the expression to be returned if needed if let Some(expr) = value { write!(self.out, " ")?; self.write_expr(expr, ctx)?; } writeln!(self.out, ";")?; } back::FunctionType::EntryPoint(ep_index) => { let mut has_point_size = false; let ep = &self.module.entry_points[ep_index as usize]; if let Some(ref result) = ep.function.result { let value = value.unwrap(); match self.module.types[result.ty].inner { TypeInner::Struct { ref members, .. } => { let temp_struct_name = match ctx.expressions[value] { crate::Expression::Compose { .. } => { let return_struct = "_tmp_return"; write!( self.out, "{} {} = ", &self.names[&NameKey::Type(result.ty)], return_struct )?; self.write_expr(value, ctx)?; writeln!(self.out, ";")?; write!(self.out, "{level}")?; Some(return_struct) } _ => None, }; for (index, member) in members.iter().enumerate() { if let Some(crate::Binding::BuiltIn( crate::BuiltIn::PointSize, )) = member.binding { has_point_size = true; } let varying_name = VaryingName { binding: member.binding.as_ref().unwrap(), stage: ep.stage, options: VaryingOptions::from_writer_options( self.options, true, ), }; write!(self.out, "{varying_name} = ")?; if let Some(struct_name) = temp_struct_name { write!(self.out, "{struct_name}")?; } else { self.write_expr(value, ctx)?; } // Write field name writeln!( self.out, ".{};", &self.names [&NameKey::StructMember(result.ty, index as u32)] )?; write!(self.out, "{level}")?; } } _ => { let name = VaryingName { binding: result.binding.as_ref().unwrap(), stage: ep.stage, options: VaryingOptions::from_writer_options( self.options, true, ), }; write!(self.out, "{name} = ")?; self.write_expr(value, ctx)?; writeln!(self.out, ";")?; write!(self.out, "{level}")?; } } } let is_vertex_stage = self.module.entry_points[ep_index as usize].stage == ShaderStage::Vertex; if is_vertex_stage && self .options .writer_flags .contains(WriterFlags::ADJUST_COORDINATE_SPACE) { writeln!( self.out, "gl_Position.yz = vec2(-gl_Position.y, gl_Position.z * 2.0 - gl_Position.w);", )?; write!(self.out, "{level}")?; } if is_vertex_stage && self .options .writer_flags .contains(WriterFlags::FORCE_POINT_SIZE) && !has_point_size { writeln!(self.out, "gl_PointSize = 1.0;")?; write!(self.out, "{level}")?; } writeln!(self.out, "return;")?; } } } // This is one of the places were glsl adds to the syntax of C in this case the discard // keyword which ceases all further processing in a fragment shader, it's called OpKill // in spir-v that's why it's called `Statement::Kill` Statement::Kill => writeln!(self.out, "{level}discard;")?, Statement::ControlBarrier(flags) => { self.write_control_barrier(flags, level)?; } Statement::MemoryBarrier(flags) => { self.write_memory_barrier(flags, level)?; } // Stores in glsl are just variable assignments written as `pointer = value;` Statement::Store { pointer, value } => { write!(self.out, "{level}")?; self.write_expr(pointer, ctx)?; write!(self.out, " = ")?; self.write_expr(value, ctx)?; writeln!(self.out, ";")? } Statement::WorkGroupUniformLoad { pointer, result } => { // GLSL doesn't have pointers, which means that this backend needs to ensure that // the actual "loading" is happening between the two barriers. // This is done in `Emit` by never emitting a variable name for pointer variables self.write_control_barrier(crate::Barrier::WORK_GROUP, level)?; let result_name = Baked(result).to_string(); write!(self.out, "{level}")?; // Expressions cannot have side effects, so just writing the expression here is fine. self.write_named_expr(pointer, result_name, result, ctx)?; self.write_control_barrier(crate::Barrier::WORK_GROUP, level)?; } // Stores a value into an image. Statement::ImageStore { image, coordinate, array_index, value, } => { write!(self.out, "{level}")?; self.write_image_store(ctx, image, coordinate, array_index, value)? } // A `Call` is written `name(arguments)` where `arguments` is a comma separated expressions list Statement::Call { function, ref arguments, result, } => { write!(self.out, "{level}")?; if let Some(expr) = result { let name = Baked(expr).to_string(); let result = self.module.functions[function].result.as_ref().unwrap(); self.write_type(result.ty)?; write!(self.out, " {name}")?; if let TypeInner::Array { base, size, .. } = self.module.types[result.ty].inner { self.write_array_size(base, size)? } write!(self.out, " = ")?; self.named_expressions.insert(expr, name); } write!(self.out, "{}(", &self.names[&NameKey::Function(function)])?; let arguments: Vec<_> = arguments .iter() .enumerate() .filter_map(|(i, arg)| { let arg_ty = self.module.functions[function].arguments[i].ty; match self.module.types[arg_ty].inner { TypeInner::Sampler { .. } => None, _ => Some(*arg), } }) .collect(); self.write_slice(&arguments, |this, _, arg| this.write_expr(*arg, ctx))?; writeln!(self.out, ");")? } Statement::Atomic { pointer, ref fun, value, result, } => { write!(self.out, "{level}")?; match *fun { crate::AtomicFunction::Exchange { compare: Some(compare_expr), } => { let result_handle = result.expect("CompareExchange must have a result"); let res_name = Baked(result_handle).to_string(); self.write_type(ctx.info[result_handle].ty.handle().unwrap())?; write!(self.out, " {res_name};")?; write!(self.out, " {res_name}.old_value = atomicCompSwap(")?; self.write_expr(pointer, ctx)?; write!(self.out, ", ")?; self.write_expr(compare_expr, ctx)?; write!(self.out, ", ")?; self.write_expr(value, ctx)?; writeln!(self.out, ");")?; write!( self.out, "{level}{res_name}.exchanged = ({res_name}.old_value == " )?; self.write_expr(compare_expr, ctx)?; writeln!(self.out, ");")?; self.named_expressions.insert(result_handle, res_name); } _ => { if let Some(result) = result { let res_name = Baked(result).to_string(); self.write_type(ctx.info[result].ty.handle().unwrap())?; write!(self.out, " {res_name} = ")?; self.named_expressions.insert(result, res_name); } let fun_str = fun.to_glsl(); write!(self.out, "atomic{fun_str}(")?; self.write_expr(pointer, ctx)?; write!(self.out, ", ")?; if let crate::AtomicFunction::Subtract = *fun { // Emulate `atomicSub` with `atomicAdd` by negating the value. write!(self.out, "-")?; } self.write_expr(value, ctx)?; writeln!(self.out, ");")?; } } } // Stores a value into an image. Statement::ImageAtomic { image, coordinate, array_index, fun, value, } => { write!(self.out, "{level}")?; self.write_image_atomic(ctx, image, coordinate, array_index, fun, value)? } Statement::RayQuery { .. } => unreachable!(), Statement::SubgroupBallot { result, predicate } => { write!(self.out, "{level}")?; let res_name = Baked(result).to_string(); let res_ty = ctx.info[result].ty.inner_with(&self.module.types); self.write_value_type(res_ty)?; write!(self.out, " {res_name} = ")?; self.named_expressions.insert(result, res_name); write!(self.out, "subgroupBallot(")?; match predicate { Some(predicate) => self.write_expr(predicate, ctx)?, None => write!(self.out, "true")?, } writeln!(self.out, ");")?; } Statement::SubgroupCollectiveOperation { op, collective_op, argument, result, } => { write!(self.out, "{level}")?; let res_name = Baked(result).to_string(); let res_ty = ctx.info[result].ty.inner_with(&self.module.types); self.write_value_type(res_ty)?; write!(self.out, " {res_name} = ")?; self.named_expressions.insert(result, res_name); match (collective_op, op) { (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => { write!(self.out, "subgroupAll(")? } (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => { write!(self.out, "subgroupAny(")? } (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => { write!(self.out, "subgroupAdd(")? } (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => { write!(self.out, "subgroupMul(")? } (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => { write!(self.out, "subgroupMax(")? } (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => { write!(self.out, "subgroupMin(")? } (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => { write!(self.out, "subgroupAnd(")? } (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => { write!(self.out, "subgroupOr(")? } (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => { write!(self.out, "subgroupXor(")? } (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => { write!(self.out, "subgroupExclusiveAdd(")? } (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => { write!(self.out, "subgroupExclusiveMul(")? } (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => { write!(self.out, "subgroupInclusiveAdd(")? } (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => { write!(self.out, "subgroupInclusiveMul(")? } _ => unimplemented!(), } self.write_expr(argument, ctx)?; writeln!(self.out, ");")?; } Statement::SubgroupGather { mode, argument, result, } => { write!(self.out, "{level}")?; let res_name = Baked(result).to_string(); let res_ty = ctx.info[result].ty.inner_with(&self.module.types); self.write_value_type(res_ty)?; write!(self.out, " {res_name} = ")?; self.named_expressions.insert(result, res_name); match mode { crate::GatherMode::BroadcastFirst => { write!(self.out, "subgroupBroadcastFirst(")?; } crate::GatherMode::Broadcast(_) => { write!(self.out, "subgroupBroadcast(")?; } crate::GatherMode::Shuffle(_) => { write!(self.out, "subgroupShuffle(")?; } crate::GatherMode::ShuffleDown(_) => { write!(self.out, "subgroupShuffleDown(")?; } crate::GatherMode::ShuffleUp(_) => { write!(self.out, "subgroupShuffleUp(")?; } crate::GatherMode::ShuffleXor(_) => { write!(self.out, "subgroupShuffleXor(")?; } crate::GatherMode::QuadBroadcast(_) => { write!(self.out, "subgroupQuadBroadcast(")?; } crate::GatherMode::QuadSwap(direction) => match direction { crate::Direction::X => { write!(self.out, "subgroupQuadSwapHorizontal(")?; } crate::Direction::Y => { write!(self.out, "subgroupQuadSwapVertical(")?; } crate::Direction::Diagonal => { write!(self.out, "subgroupQuadSwapDiagonal(")?; } }, } self.write_expr(argument, ctx)?; match mode { crate::GatherMode::BroadcastFirst => {} crate::GatherMode::Broadcast(index) | crate::GatherMode::Shuffle(index) | crate::GatherMode::ShuffleDown(index) | crate::GatherMode::ShuffleUp(index) | crate::GatherMode::ShuffleXor(index) | crate::GatherMode::QuadBroadcast(index) => { write!(self.out, ", ")?; self.write_expr(index, ctx)?; } crate::GatherMode::QuadSwap(_) => {} } writeln!(self.out, ");")?; } Statement::CooperativeStore { .. } => unimplemented!(), Statement::RayPipelineFunction(_) => unimplemented!(), } Ok(()) } /// Write a const expression. /// /// Write `expr`, a handle to an [`Expression`] in the current [`Module`]'s /// constant expression arena, as GLSL expression. /// /// # Notes /// Adds no newlines or leading/trailing whitespace /// /// [`Expression`]: crate::Expression /// [`Module`]: crate::Module fn write_const_expr( &mut self, expr: Handle, arena: &crate::Arena, ) -> BackendResult { self.write_possibly_const_expr( expr, arena, |expr| &self.info[expr], |writer, expr| writer.write_const_expr(expr, arena), ) } /// Write [`Expression`] variants that can occur in both runtime and const expressions. /// /// Write `expr`, a handle to an [`Expression`] in the arena `expressions`, /// as as GLSL expression. This must be one of the [`Expression`] variants /// that is allowed to occur in constant expressions. /// /// Use `write_expression` to write subexpressions. /// /// This is the common code for `write_expr`, which handles arbitrary /// runtime expressions, and `write_const_expr`, which only handles /// const-expressions. Each of those callers passes itself (essentially) as /// the `write_expression` callback, so that subexpressions are restricted /// to the appropriate variants. /// /// # Notes /// Adds no newlines or leading/trailing whitespace /// /// [`Expression`]: crate::Expression fn write_possibly_const_expr<'w, I, E>( &'w mut self, expr: Handle, expressions: &crate::Arena, info: I, write_expression: E, ) -> BackendResult where I: Fn(Handle) -> &'w proc::TypeResolution, E: Fn(&mut Self, Handle) -> BackendResult, { use crate::Expression; match expressions[expr] { Expression::Literal(literal) => { match literal { // Floats are written using `Debug` instead of `Display` because it always appends the // decimal part even it's zero which is needed for a valid glsl float constant crate::Literal::F64(value) => write!(self.out, "{value:?}LF")?, crate::Literal::F32(value) => write!(self.out, "{value:?}")?, crate::Literal::F16(_) => { return Err(Error::Custom("GLSL has no 16-bit float type".into())); } // Unsigned integers need a `u` at the end // // While `core` doesn't necessarily need it, it's allowed and since `es` needs it we // always write it as the extra branch wouldn't have any benefit in readability crate::Literal::U32(value) => write!(self.out, "{value}u")?, crate::Literal::I32(value) => write!(self.out, "{value}")?, crate::Literal::Bool(value) => write!(self.out, "{value}")?, crate::Literal::I64(_) => { return Err(Error::Custom("GLSL has no 64-bit integer type".into())); } crate::Literal::U64(_) => { return Err(Error::Custom("GLSL has no 64-bit integer type".into())); } crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => { return Err(Error::Custom( "Abstract types should not appear in IR presented to backends".into(), )); } } } Expression::Constant(handle) => { let constant = &self.module.constants[handle]; if constant.name.is_some() { write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?; } else { self.write_const_expr(constant.init, &self.module.global_expressions)?; } } Expression::ZeroValue(ty) => { self.write_zero_init_value(ty)?; } Expression::Compose { ty, ref components } => { self.write_type(ty)?; if let TypeInner::Array { base, size, .. } = self.module.types[ty].inner { self.write_array_size(base, size)?; } write!(self.out, "(")?; for (index, component) in components.iter().enumerate() { if index != 0 { write!(self.out, ", ")?; } write_expression(self, *component)?; } write!(self.out, ")")? } // `Splat` needs to actually write down a vector, it's not always inferred in GLSL. Expression::Splat { size: _, value } => { let resolved = info(expr).inner_with(&self.module.types); self.write_value_type(resolved)?; write!(self.out, "(")?; write_expression(self, value)?; write!(self.out, ")")? } _ => { return Err(Error::Override); } } Ok(()) } /// Helper method to write expressions /// /// # Notes /// Doesn't add any newlines or leading/trailing spaces fn write_expr( &mut self, expr: Handle, ctx: &back::FunctionCtx, ) -> BackendResult { use crate::Expression; if let Some(name) = self.named_expressions.get(&expr) { write!(self.out, "{name}")?; return Ok(()); } match ctx.expressions[expr] { Expression::Literal(_) | Expression::Constant(_) | Expression::ZeroValue(_) | Expression::Compose { .. } | Expression::Splat { .. } => { self.write_possibly_const_expr( expr, ctx.expressions, |expr| &ctx.info[expr].ty, |writer, expr| writer.write_expr(expr, ctx), )?; } Expression::Override(_) => return Err(Error::Override), // `Access` is applied to arrays, vectors and matrices and is written as indexing Expression::Access { base, index } => { self.write_expr(base, ctx)?; write!(self.out, "[")?; self.write_expr(index, ctx)?; write!(self.out, "]")? } // `AccessIndex` is the same as `Access` except that the index is a constant and it can // be applied to structs, in this case we need to find the name of the field at that // index and write `base.field_name` Expression::AccessIndex { base, index } => { self.write_expr(base, ctx)?; let base_ty_res = &ctx.info[base].ty; let mut resolved = base_ty_res.inner_with(&self.module.types); let base_ty_handle = match *resolved { TypeInner::Pointer { base, space: _ } => { resolved = &self.module.types[base].inner; Some(base) } _ => base_ty_res.handle(), }; match *resolved { TypeInner::Vector { .. } => { // Write vector access as a swizzle write!(self.out, ".{}", back::COMPONENTS[index as usize])? } TypeInner::Matrix { .. } | TypeInner::Array { .. } | TypeInner::ValuePointer { .. } => write!(self.out, "[{index}]")?, TypeInner::Struct { .. } => { // This will never panic in case the type is a `Struct`, this is not true // for other types so we can only check while inside this match arm let ty = base_ty_handle.unwrap(); write!( self.out, ".{}", &self.names[&NameKey::StructMember(ty, index)] )? } ref other => return Err(Error::Custom(format!("Cannot index {other:?}"))), } } // `Swizzle` adds a few letters behind the dot. Expression::Swizzle { size, vector, pattern, } => { self.write_expr(vector, ctx)?; write!(self.out, ".")?; for &sc in pattern[..size as usize].iter() { self.out.write_char(back::COMPONENTS[sc as usize])?; } } // Function arguments are written as the argument name Expression::FunctionArgument(pos) => { write!(self.out, "{}", &self.names[&ctx.argument_key(pos)])? } // Global variables need some special work for their name but // `get_global_name` does the work for us Expression::GlobalVariable(handle) => { let global = &self.module.global_variables[handle]; self.write_global_name(handle, global)? } // A local is written as it's name Expression::LocalVariable(handle) => { write!(self.out, "{}", self.names[&ctx.name_key(handle)])? } // glsl has no pointers so there's no load operation, just write the pointer expression Expression::Load { pointer } => self.write_expr(pointer, ctx)?, // `ImageSample` is a bit complicated compared to the rest of the IR. // // First there are three variations depending whether the sample level is explicitly set, // if it's automatic or it it's bias: // `texture(image, coordinate)` - Automatic sample level // `texture(image, coordinate, bias)` - Bias sample level // `textureLod(image, coordinate, level)` - Zero or Exact sample level // // Furthermore if `depth_ref` is some we need to append it to the coordinate vector Expression::ImageSample { image, sampler: _, //TODO? gather, coordinate, array_index, offset, level, depth_ref, clamp_to_edge: _, } => { let (dim, class, arrayed) = match *ctx.resolve_type(image, &self.module.types) { TypeInner::Image { dim, class, arrayed, .. } => (dim, class, arrayed), _ => unreachable!(), }; let mut err = None; if dim == crate::ImageDimension::Cube { if offset.is_some() { err = Some("gsamplerCube[Array][Shadow] doesn't support texture sampling with offsets"); } if arrayed && matches!(class, crate::ImageClass::Depth { .. }) && matches!(level, crate::SampleLevel::Gradient { .. }) { err = Some("samplerCubeArrayShadow don't support textureGrad"); } } if gather.is_some() && level != crate::SampleLevel::Zero { err = Some("textureGather doesn't support LOD parameters"); } if let Some(err) = err { return Err(Error::Custom(String::from(err))); } // `textureLod[Offset]` on `sampler2DArrayShadow` and `samplerCubeShadow` does not exist in GLSL, // unless `GL_EXT_texture_shadow_lod` is present. // But if the target LOD is zero, we can emulate that by using `textureGrad[Offset]` with a constant gradient of 0. let workaround_lod_with_grad = ((dim == crate::ImageDimension::Cube && !arrayed) || (dim == crate::ImageDimension::D2 && arrayed)) && level == crate::SampleLevel::Zero && matches!(class, crate::ImageClass::Depth { .. }) && !self.features.contains(Features::TEXTURE_SHADOW_LOD); // Write the function to be used depending on the sample level let fun_name = match level { crate::SampleLevel::Zero if gather.is_some() => "textureGather", crate::SampleLevel::Zero if workaround_lod_with_grad => "textureGrad", crate::SampleLevel::Auto | crate::SampleLevel::Bias(_) => "texture", crate::SampleLevel::Zero | crate::SampleLevel::Exact(_) => "textureLod", crate::SampleLevel::Gradient { .. } => "textureGrad", }; let offset_name = match offset { Some(_) => "Offset", None => "", }; write!(self.out, "{fun_name}{offset_name}(")?; // Write the image that will be used self.write_expr(image, ctx)?; // The space here isn't required but it helps with readability write!(self.out, ", ")?; // TODO: handle clamp_to_edge // https://github.com/gfx-rs/wgpu/issues/7791 // We need to get the coordinates vector size to later build a vector that's `size + 1` // if `depth_ref` is some, if it isn't a vector we panic as that's not a valid expression let mut coord_dim = match *ctx.resolve_type(coordinate, &self.module.types) { TypeInner::Vector { size, .. } => size as u8, TypeInner::Scalar { .. } => 1, _ => unreachable!(), }; if array_index.is_some() { coord_dim += 1; } let merge_depth_ref = depth_ref.is_some() && gather.is_none() && coord_dim < 4; if merge_depth_ref { coord_dim += 1; } let tex_1d_hack = dim == crate::ImageDimension::D1 && self.options.version.is_es(); let is_vec = tex_1d_hack || coord_dim != 1; // Compose a new texture coordinates vector if is_vec { write!(self.out, "vec{}(", coord_dim + tex_1d_hack as u8)?; } self.write_expr(coordinate, ctx)?; if tex_1d_hack { write!(self.out, ", 0.0")?; } if let Some(expr) = array_index { write!(self.out, ", ")?; self.write_expr(expr, ctx)?; } if merge_depth_ref { write!(self.out, ", ")?; self.write_expr(depth_ref.unwrap(), ctx)?; } if is_vec { write!(self.out, ")")?; } if let (Some(expr), false) = (depth_ref, merge_depth_ref) { write!(self.out, ", ")?; self.write_expr(expr, ctx)?; } match level { // Auto needs no more arguments crate::SampleLevel::Auto => (), // Zero needs level set to 0 crate::SampleLevel::Zero => { if workaround_lod_with_grad { let vec_dim = match dim { crate::ImageDimension::Cube => 3, _ => 2, }; write!(self.out, ", vec{vec_dim}(0.0), vec{vec_dim}(0.0)")?; } else if gather.is_none() { write!(self.out, ", 0.0")?; } } // Exact and bias require another argument crate::SampleLevel::Exact(expr) => { write!(self.out, ", ")?; self.write_expr(expr, ctx)?; } crate::SampleLevel::Bias(_) => { // This needs to be done after the offset writing } crate::SampleLevel::Gradient { x, y } => { // If we are using sampler2D to replace sampler1D, we also // need to make sure to use vec2 gradients if tex_1d_hack { write!(self.out, ", vec2(")?; self.write_expr(x, ctx)?; write!(self.out, ", 0.0)")?; write!(self.out, ", vec2(")?; self.write_expr(y, ctx)?; write!(self.out, ", 0.0)")?; } else { write!(self.out, ", ")?; self.write_expr(x, ctx)?; write!(self.out, ", ")?; self.write_expr(y, ctx)?; } } } if let Some(constant) = offset { write!(self.out, ", ")?; if tex_1d_hack { write!(self.out, "ivec2(")?; } self.write_const_expr(constant, ctx.expressions)?; if tex_1d_hack { write!(self.out, ", 0)")?; } } // Bias is always the last argument if let crate::SampleLevel::Bias(expr) = level { write!(self.out, ", ")?; self.write_expr(expr, ctx)?; } if let (Some(component), None) = (gather, depth_ref) { write!(self.out, ", {}", component as usize)?; } // End the function write!(self.out, ")")? } Expression::ImageLoad { image, coordinate, array_index, sample, level, } => self.write_image_load(expr, ctx, image, coordinate, array_index, sample, level)?, // Query translates into one of the: // - textureSize/imageSize // - textureQueryLevels // - textureSamples/imageSamples Expression::ImageQuery { image, query } => { use crate::ImageClass; // This will only panic if the module is invalid let (dim, class) = match *ctx.resolve_type(image, &self.module.types) { TypeInner::Image { dim, arrayed: _, class, } => (dim, class), _ => unreachable!(), }; let components = match dim { crate::ImageDimension::D1 => 1, crate::ImageDimension::D2 => 2, crate::ImageDimension::D3 => 3, crate::ImageDimension::Cube => 2, }; if let crate::ImageQuery::Size { .. } = query { match components { 1 => write!(self.out, "uint(")?, _ => write!(self.out, "uvec{components}(")?, } } else { write!(self.out, "uint(")?; } match query { crate::ImageQuery::Size { level } => { match class { ImageClass::Sampled { multi, .. } | ImageClass::Depth { multi } => { write!(self.out, "textureSize(")?; self.write_expr(image, ctx)?; if let Some(expr) = level { let cast_to_int = matches!( *ctx.resolve_type(expr, &self.module.types), TypeInner::Scalar(crate::Scalar { kind: crate::ScalarKind::Uint, .. }) ); write!(self.out, ", ")?; if cast_to_int { write!(self.out, "int(")?; } self.write_expr(expr, ctx)?; if cast_to_int { write!(self.out, ")")?; } } else if !multi { // All textureSize calls requires an lod argument // except for multisampled samplers write!(self.out, ", 0")?; } } ImageClass::Storage { .. } => { write!(self.out, "imageSize(")?; self.write_expr(image, ctx)?; } ImageClass::External => unimplemented!(), } write!(self.out, ")")?; if components != 1 || self.options.version.is_es() { write!(self.out, ".{}", &"xyz"[..components])?; } } crate::ImageQuery::NumLevels => { write!(self.out, "textureQueryLevels(",)?; self.write_expr(image, ctx)?; write!(self.out, ")",)?; } crate::ImageQuery::NumLayers => { let fun_name = match class { ImageClass::Sampled { .. } | ImageClass::Depth { .. } => "textureSize", ImageClass::Storage { .. } => "imageSize", ImageClass::External => unimplemented!(), }; write!(self.out, "{fun_name}(")?; self.write_expr(image, ctx)?; // All textureSize calls requires an lod argument // except for multisampled samplers if !class.is_multisampled() { write!(self.out, ", 0")?; } write!(self.out, ")")?; if components != 1 || self.options.version.is_es() { write!(self.out, ".{}", back::COMPONENTS[components])?; } } crate::ImageQuery::NumSamples => { let fun_name = match class { ImageClass::Sampled { .. } | ImageClass::Depth { .. } => { "textureSamples" } ImageClass::Storage { .. } => "imageSamples", ImageClass::External => unimplemented!(), }; write!(self.out, "{fun_name}(")?; self.write_expr(image, ctx)?; write!(self.out, ")",)?; } } write!(self.out, ")")?; } Expression::Unary { op, expr } => { let operator_or_fn = match op { crate::UnaryOperator::Negate => "-", crate::UnaryOperator::LogicalNot => { match *ctx.resolve_type(expr, &self.module.types) { TypeInner::Vector { .. } => "not", _ => "!", } } crate::UnaryOperator::BitwiseNot => "~", }; write!(self.out, "{operator_or_fn}(")?; self.write_expr(expr, ctx)?; write!(self.out, ")")? } // `Binary` we just write `left op right`, except when dealing with // comparison operations on vectors as they are implemented with // builtin functions. // Once again we wrap everything in parentheses to avoid precedence issues Expression::Binary { mut op, left, right, } => { // Holds `Some(function_name)` if the binary operation is // implemented as a function call use crate::{BinaryOperator as Bo, ScalarKind as Sk, TypeInner as Ti}; let left_inner = ctx.resolve_type(left, &self.module.types); let right_inner = ctx.resolve_type(right, &self.module.types); let function = match (left_inner, right_inner) { (&Ti::Vector { scalar, .. }, &Ti::Vector { .. }) => match op { Bo::Less | Bo::LessEqual | Bo::Greater | Bo::GreaterEqual | Bo::Equal | Bo::NotEqual => BinaryOperation::VectorCompare, Bo::Modulo if scalar.kind == Sk::Float => BinaryOperation::Modulo, Bo::And if scalar.kind == Sk::Bool => { op = crate::BinaryOperator::LogicalAnd; BinaryOperation::VectorComponentWise } Bo::InclusiveOr if scalar.kind == Sk::Bool => { op = crate::BinaryOperator::LogicalOr; BinaryOperation::VectorComponentWise } _ => BinaryOperation::Other, }, _ => match (left_inner.scalar_kind(), right_inner.scalar_kind()) { (Some(Sk::Float), _) | (_, Some(Sk::Float)) => match op { Bo::Modulo => BinaryOperation::Modulo, _ => BinaryOperation::Other, }, (Some(Sk::Bool), Some(Sk::Bool)) => match op { Bo::InclusiveOr => { op = crate::BinaryOperator::LogicalOr; BinaryOperation::Other } Bo::And => { op = crate::BinaryOperator::LogicalAnd; BinaryOperation::Other } _ => BinaryOperation::Other, }, _ => BinaryOperation::Other, }, }; match function { BinaryOperation::VectorCompare => { let op_str = match op { Bo::Less => "lessThan(", Bo::LessEqual => "lessThanEqual(", Bo::Greater => "greaterThan(", Bo::GreaterEqual => "greaterThanEqual(", Bo::Equal => "equal(", Bo::NotEqual => "notEqual(", _ => unreachable!(), }; write!(self.out, "{op_str}")?; self.write_expr(left, ctx)?; write!(self.out, ", ")?; self.write_expr(right, ctx)?; write!(self.out, ")")?; } BinaryOperation::VectorComponentWise => { self.write_value_type(left_inner)?; write!(self.out, "(")?; let size = match *left_inner { Ti::Vector { size, .. } => size, _ => unreachable!(), }; for i in 0..size as usize { if i != 0 { write!(self.out, ", ")?; } self.write_expr(left, ctx)?; write!(self.out, ".{}", back::COMPONENTS[i])?; write!(self.out, " {} ", back::binary_operation_str(op))?; self.write_expr(right, ctx)?; write!(self.out, ".{}", back::COMPONENTS[i])?; } write!(self.out, ")")?; } // TODO: handle undefined behavior of BinaryOperator::Modulo // // sint: // if right == 0 return 0 // if left == min(type_of(left)) && right == -1 return 0 // if sign(left) == -1 || sign(right) == -1 return result as defined by WGSL // // uint: // if right == 0 return 0 // // float: // if right == 0 return ? see https://github.com/gpuweb/gpuweb/issues/2798 BinaryOperation::Modulo => { write!(self.out, "(")?; // write `e1 - e2 * trunc(e1 / e2)` self.write_expr(left, ctx)?; write!(self.out, " - ")?; self.write_expr(right, ctx)?; write!(self.out, " * ")?; write!(self.out, "trunc(")?; self.write_expr(left, ctx)?; write!(self.out, " / ")?; self.write_expr(right, ctx)?; write!(self.out, ")")?; write!(self.out, ")")?; } BinaryOperation::Other => { write!(self.out, "(")?; self.write_expr(left, ctx)?; write!(self.out, " {} ", back::binary_operation_str(op))?; self.write_expr(right, ctx)?; write!(self.out, ")")?; } } } // `Select` is written as `condition ? accept : reject` // We wrap everything in parentheses to avoid precedence issues Expression::Select { condition, accept, reject, } => { let cond_ty = ctx.resolve_type(condition, &self.module.types); let vec_select = if let TypeInner::Vector { .. } = *cond_ty { true } else { false }; // TODO: Boolean mix on desktop required GL_EXT_shader_integer_mix if vec_select { // Glsl defines that for mix when the condition is a boolean the first element // is picked if condition is false and the second if condition is true write!(self.out, "mix(")?; self.write_expr(reject, ctx)?; write!(self.out, ", ")?; self.write_expr(accept, ctx)?; write!(self.out, ", ")?; self.write_expr(condition, ctx)?; } else { write!(self.out, "(")?; self.write_expr(condition, ctx)?; write!(self.out, " ? ")?; self.write_expr(accept, ctx)?; write!(self.out, " : ")?; self.write_expr(reject, ctx)?; } write!(self.out, ")")? } // `Derivative` is a function call to a glsl provided function Expression::Derivative { axis, ctrl, expr } => { use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl}; let fun_name = if self.options.version.supports_derivative_control() { match (axis, ctrl) { (Axis::X, Ctrl::Coarse) => "dFdxCoarse", (Axis::X, Ctrl::Fine) => "dFdxFine", (Axis::X, Ctrl::None) => "dFdx", (Axis::Y, Ctrl::Coarse) => "dFdyCoarse", (Axis::Y, Ctrl::Fine) => "dFdyFine", (Axis::Y, Ctrl::None) => "dFdy", (Axis::Width, Ctrl::Coarse) => "fwidthCoarse", (Axis::Width, Ctrl::Fine) => "fwidthFine", (Axis::Width, Ctrl::None) => "fwidth", } } else { match axis { Axis::X => "dFdx", Axis::Y => "dFdy", Axis::Width => "fwidth", } }; write!(self.out, "{fun_name}(")?; self.write_expr(expr, ctx)?; write!(self.out, ")")? } // `Relational` is a normal function call to some glsl provided functions Expression::Relational { fun, argument } => { use crate::RelationalFunction as Rf; let fun_name = match fun { Rf::IsInf => "isinf", Rf::IsNan => "isnan", Rf::All => "all", Rf::Any => "any", }; write!(self.out, "{fun_name}(")?; self.write_expr(argument, ctx)?; write!(self.out, ")")? } Expression::Math { fun, arg, arg1, arg2, arg3, } => { use crate::MathFunction as Mf; let fun_name = match fun { // comparison Mf::Abs => "abs", Mf::Min => "min", Mf::Max => "max", Mf::Clamp => { let scalar_kind = ctx .resolve_type(arg, &self.module.types) .scalar_kind() .unwrap(); match scalar_kind { crate::ScalarKind::Float => "clamp", // Clamp is undefined if min > max. In practice this means it can use a median-of-three // instruction to determine the value. This is fine according to the WGSL spec for float // clamp, but integer clamp _must_ use min-max. As such we write out min/max. _ => { write!(self.out, "min(max(")?; self.write_expr(arg, ctx)?; write!(self.out, ", ")?; self.write_expr(arg1.unwrap(), ctx)?; write!(self.out, "), ")?; self.write_expr(arg2.unwrap(), ctx)?; write!(self.out, ")")?; return Ok(()); } } } Mf::Saturate => { write!(self.out, "clamp(")?; self.write_expr(arg, ctx)?; match *ctx.resolve_type(arg, &self.module.types) { TypeInner::Vector { size, .. } => write!( self.out, ", vec{}(0.0), vec{0}(1.0)", common::vector_size_str(size) )?, _ => write!(self.out, ", 0.0, 1.0")?, } write!(self.out, ")")?; return Ok(()); } // trigonometry Mf::Cos => "cos", Mf::Cosh => "cosh", Mf::Sin => "sin", Mf::Sinh => "sinh", Mf::Tan => "tan", Mf::Tanh => "tanh", Mf::Acos => "acos", Mf::Asin => "asin", Mf::Atan => "atan", Mf::Asinh => "asinh", Mf::Acosh => "acosh", Mf::Atanh => "atanh", Mf::Radians => "radians", Mf::Degrees => "degrees", // glsl doesn't have atan2 function // use two-argument variation of the atan function Mf::Atan2 => "atan", // decomposition Mf::Ceil => "ceil", Mf::Floor => "floor", Mf::Round => "roundEven", Mf::Fract => "fract", Mf::Trunc => "trunc", Mf::Modf => MODF_FUNCTION, Mf::Frexp => FREXP_FUNCTION, Mf::Ldexp => "ldexp", // exponent Mf::Exp => "exp", Mf::Exp2 => "exp2", Mf::Log => "log", Mf::Log2 => "log2", Mf::Pow => "pow", // geometry Mf::Dot => match *ctx.resolve_type(arg, &self.module.types) { TypeInner::Vector { scalar: crate::Scalar { kind: crate::ScalarKind::Float, .. }, .. } => "dot", TypeInner::Vector { size, .. } => { return self.write_dot_product(arg, arg1.unwrap(), size as usize, ctx) } _ => unreachable!( "Correct TypeInner for dot product should be already validated" ), }, fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => { let conversion = match fun { Mf::Dot4I8Packed => "int", Mf::Dot4U8Packed => "", _ => unreachable!(), }; let arg1 = arg1.unwrap(); // Write parentheses around the dot product expression to prevent operators // with different precedences from applying earlier. write!(self.out, "(")?; for i in 0..4 { // Since `bitfieldExtract` only sign extends if the value is signed, we // need to convert the inputs to `int` in case of `Dot4I8Packed`. For // `Dot4U8Packed`, the code below only introduces parenthesis around // each factor, which aren't strictly needed because both operands are // baked, but which don't hurt either. write!(self.out, "bitfieldExtract({conversion}(")?; self.write_expr(arg, ctx)?; write!(self.out, "), {}, 8)", i * 8)?; write!(self.out, " * bitfieldExtract({conversion}(")?; self.write_expr(arg1, ctx)?; write!(self.out, "), {}, 8)", i * 8)?; if i != 3 { write!(self.out, " + ")?; } } write!(self.out, ")")?; return Ok(()); } Mf::Outer => "outerProduct", Mf::Cross => "cross", Mf::Distance => "distance", Mf::Length => "length", Mf::Normalize => "normalize", Mf::FaceForward => "faceforward", Mf::Reflect => "reflect", Mf::Refract => "refract", // computational Mf::Sign => "sign", Mf::Fma => { if self.options.version.supports_fma_function() { // Use the fma function when available "fma" } else { // No fma support. Transform the function call into an arithmetic expression write!(self.out, "(")?; self.write_expr(arg, ctx)?; write!(self.out, " * ")?; let arg1 = arg1.ok_or_else(|| Error::Custom("Missing fma arg1".to_owned()))?; self.write_expr(arg1, ctx)?; write!(self.out, " + ")?; let arg2 = arg2.ok_or_else(|| Error::Custom("Missing fma arg2".to_owned()))?; self.write_expr(arg2, ctx)?; write!(self.out, ")")?; return Ok(()); } } Mf::Mix => "mix", Mf::Step => "step", Mf::SmoothStep => "smoothstep", Mf::Sqrt => "sqrt", Mf::InverseSqrt => "inversesqrt", Mf::Inverse => "inverse", Mf::Transpose => "transpose", Mf::Determinant => "determinant", Mf::QuantizeToF16 => match *ctx.resolve_type(arg, &self.module.types) { TypeInner::Scalar { .. } => { write!(self.out, "unpackHalf2x16(packHalf2x16(vec2(")?; self.write_expr(arg, ctx)?; write!(self.out, "))).x")?; return Ok(()); } TypeInner::Vector { size: crate::VectorSize::Bi, .. } => { write!(self.out, "unpackHalf2x16(packHalf2x16(")?; self.write_expr(arg, ctx)?; write!(self.out, "))")?; return Ok(()); } TypeInner::Vector { size: crate::VectorSize::Tri, .. } => { write!(self.out, "vec3(unpackHalf2x16(packHalf2x16(")?; self.write_expr(arg, ctx)?; write!(self.out, ".xy)), unpackHalf2x16(packHalf2x16(")?; self.write_expr(arg, ctx)?; write!(self.out, ".zz)).x)")?; return Ok(()); } TypeInner::Vector { size: crate::VectorSize::Quad, .. } => { write!(self.out, "vec4(unpackHalf2x16(packHalf2x16(")?; self.write_expr(arg, ctx)?; write!(self.out, ".xy)), unpackHalf2x16(packHalf2x16(")?; self.write_expr(arg, ctx)?; write!(self.out, ".zw)))")?; return Ok(()); } _ => unreachable!( "Correct TypeInner for QuantizeToF16 should be already validated" ), }, // bits Mf::CountTrailingZeros => { match *ctx.resolve_type(arg, &self.module.types) { TypeInner::Vector { size, scalar, .. } => { let s = common::vector_size_str(size); if let crate::ScalarKind::Uint = scalar.kind { write!(self.out, "min(uvec{s}(findLSB(")?; self.write_expr(arg, ctx)?; write!(self.out, ")), uvec{s}(32u))")?; } else { write!(self.out, "ivec{s}(min(uvec{s}(findLSB(")?; self.write_expr(arg, ctx)?; write!(self.out, ")), uvec{s}(32u)))")?; } } TypeInner::Scalar(scalar) => { if let crate::ScalarKind::Uint = scalar.kind { write!(self.out, "min(uint(findLSB(")?; self.write_expr(arg, ctx)?; write!(self.out, ")), 32u)")?; } else { write!(self.out, "int(min(uint(findLSB(")?; self.write_expr(arg, ctx)?; write!(self.out, ")), 32u))")?; } } _ => unreachable!(), }; return Ok(()); } Mf::CountLeadingZeros => { if self.options.version.supports_integer_functions() { match *ctx.resolve_type(arg, &self.module.types) { TypeInner::Vector { size, scalar } => { let s = common::vector_size_str(size); if let crate::ScalarKind::Uint = scalar.kind { write!(self.out, "uvec{s}(ivec{s}(31) - findMSB(")?; self.write_expr(arg, ctx)?; write!(self.out, "))")?; } else { write!(self.out, "mix(ivec{s}(31) - findMSB(")?; self.write_expr(arg, ctx)?; write!(self.out, "), ivec{s}(0), lessThan(")?; self.write_expr(arg, ctx)?; write!(self.out, ", ivec{s}(0)))")?; } } TypeInner::Scalar(scalar) => { if let crate::ScalarKind::Uint = scalar.kind { write!(self.out, "uint(31 - findMSB(")?; } else { write!(self.out, "(")?; self.write_expr(arg, ctx)?; write!(self.out, " < 0 ? 0 : 31 - findMSB(")?; } self.write_expr(arg, ctx)?; write!(self.out, "))")?; } _ => unreachable!(), }; } else { match *ctx.resolve_type(arg, &self.module.types) { TypeInner::Vector { size, scalar } => { let s = common::vector_size_str(size); if let crate::ScalarKind::Uint = scalar.kind { write!(self.out, "uvec{s}(")?; write!(self.out, "vec{s}(31.0) - floor(log2(vec{s}(")?; self.write_expr(arg, ctx)?; write!(self.out, ") + 0.5)))")?; } else { write!(self.out, "ivec{s}(")?; write!(self.out, "mix(vec{s}(31.0) - floor(log2(vec{s}(")?; self.write_expr(arg, ctx)?; write!(self.out, ") + 0.5)), ")?; write!(self.out, "vec{s}(0.0), lessThan(")?; self.write_expr(arg, ctx)?; write!(self.out, ", ivec{s}(0u))))")?; } } TypeInner::Scalar(scalar) => { if let crate::ScalarKind::Uint = scalar.kind { write!(self.out, "uint(31.0 - floor(log2(float(")?; self.write_expr(arg, ctx)?; write!(self.out, ") + 0.5)))")?; } else { write!(self.out, "(")?; self.write_expr(arg, ctx)?; write!(self.out, " < 0 ? 0 : int(")?; write!(self.out, "31.0 - floor(log2(float(")?; self.write_expr(arg, ctx)?; write!(self.out, ") + 0.5))))")?; } } _ => unreachable!(), }; } return Ok(()); } Mf::CountOneBits => "bitCount", Mf::ReverseBits => "bitfieldReverse", Mf::ExtractBits => { // The behavior of ExtractBits is undefined when offset + count > bit_width. We need // to first sanitize the offset and count first. If we don't do this, AMD and Intel chips // will return out-of-spec values if the extracted range is not within the bit width. // // This encodes the exact formula specified by the wgsl spec, without temporary values: // https://gpuweb.github.io/gpuweb/wgsl/#extractBits-unsigned-builtin // // w = sizeof(x) * 8 // o = min(offset, w) // c = min(count, w - o) // // bitfieldExtract(x, o, c) // // extract_bits(e, min(offset, w), min(count, w - min(offset, w)))) let scalar_bits = ctx .resolve_type(arg, &self.module.types) .scalar_width() .unwrap() * 8; write!(self.out, "bitfieldExtract(")?; self.write_expr(arg, ctx)?; write!(self.out, ", int(min(")?; self.write_expr(arg1.unwrap(), ctx)?; write!(self.out, ", {scalar_bits}u)), int(min(",)?; self.write_expr(arg2.unwrap(), ctx)?; write!(self.out, ", {scalar_bits}u - min(")?; self.write_expr(arg1.unwrap(), ctx)?; write!(self.out, ", {scalar_bits}u))))")?; return Ok(()); } Mf::InsertBits => { // InsertBits has the same considerations as ExtractBits above let scalar_bits = ctx .resolve_type(arg, &self.module.types) .scalar_width() .unwrap() * 8; write!(self.out, "bitfieldInsert(")?; self.write_expr(arg, ctx)?; write!(self.out, ", ")?; self.write_expr(arg1.unwrap(), ctx)?; write!(self.out, ", int(min(")?; self.write_expr(arg2.unwrap(), ctx)?; write!(self.out, ", {scalar_bits}u)), int(min(",)?; self.write_expr(arg3.unwrap(), ctx)?; write!(self.out, ", {scalar_bits}u - min(")?; self.write_expr(arg2.unwrap(), ctx)?; write!(self.out, ", {scalar_bits}u))))")?; return Ok(()); } Mf::FirstTrailingBit => "findLSB", Mf::FirstLeadingBit => "findMSB", // data packing Mf::Pack4x8snorm => { if self.options.version.supports_pack_unpack_4x8() { "packSnorm4x8" } else { // polyfill should go here. Needs a corresponding entry in `need_bake_expression` return Err(Error::UnsupportedExternal("packSnorm4x8".into())); } } Mf::Pack4x8unorm => { if self.options.version.supports_pack_unpack_4x8() { "packUnorm4x8" } else { return Err(Error::UnsupportedExternal("packUnorm4x8".to_owned())); } } Mf::Pack2x16snorm => { if self.options.version.supports_pack_unpack_snorm_2x16() { "packSnorm2x16" } else { return Err(Error::UnsupportedExternal("packSnorm2x16".to_owned())); } } Mf::Pack2x16unorm => { if self.options.version.supports_pack_unpack_unorm_2x16() { "packUnorm2x16" } else { return Err(Error::UnsupportedExternal("packUnorm2x16".to_owned())); } } Mf::Pack2x16float => { if self.options.version.supports_pack_unpack_half_2x16() { "packHalf2x16" } else { return Err(Error::UnsupportedExternal("packHalf2x16".to_owned())); } } fun @ (Mf::Pack4xI8 | Mf::Pack4xU8 | Mf::Pack4xI8Clamp | Mf::Pack4xU8Clamp) => { let was_signed = matches!(fun, Mf::Pack4xI8 | Mf::Pack4xI8Clamp); let clamp_bounds = match fun { Mf::Pack4xI8Clamp => Some(("-128", "127")), Mf::Pack4xU8Clamp => Some(("0", "255")), _ => None, }; let const_suffix = if was_signed { "" } else { "u" }; if was_signed { write!(self.out, "uint(")?; } let write_arg = |this: &mut Self| -> BackendResult { if let Some((min, max)) = clamp_bounds { write!(this.out, "clamp(")?; this.write_expr(arg, ctx)?; write!(this.out, ", {min}{const_suffix}, {max}{const_suffix})")?; } else { this.write_expr(arg, ctx)?; } Ok(()) }; write!(self.out, "(")?; write_arg(self)?; write!(self.out, "[0] & 0xFF{const_suffix}) | ((")?; write_arg(self)?; write!(self.out, "[1] & 0xFF{const_suffix}) << 8) | ((")?; write_arg(self)?; write!(self.out, "[2] & 0xFF{const_suffix}) << 16) | ((")?; write_arg(self)?; write!(self.out, "[3] & 0xFF{const_suffix}) << 24)")?; if was_signed { write!(self.out, ")")?; } return Ok(()); } // data unpacking Mf::Unpack2x16float => { if self.options.version.supports_pack_unpack_half_2x16() { "unpackHalf2x16" } else { return Err(Error::UnsupportedExternal("unpackHalf2x16".into())); } } Mf::Unpack2x16snorm => { if self.options.version.supports_pack_unpack_snorm_2x16() { "unpackSnorm2x16" } else { let scale = 32767; write!(self.out, "(vec2(ivec2(")?; self.write_expr(arg, ctx)?; write!(self.out, " << 16, ")?; self.write_expr(arg, ctx)?; write!(self.out, ") >> 16) / {scale}.0)")?; return Ok(()); } } Mf::Unpack2x16unorm => { if self.options.version.supports_pack_unpack_unorm_2x16() { "unpackUnorm2x16" } else { let scale = 65535; write!(self.out, "(vec2(")?; self.write_expr(arg, ctx)?; write!(self.out, " & 0xFFFFu, ")?; self.write_expr(arg, ctx)?; write!(self.out, " >> 16) / {scale}.0)")?; return Ok(()); } } Mf::Unpack4x8snorm => { if self.options.version.supports_pack_unpack_4x8() { "unpackSnorm4x8" } else { let scale = 127; write!(self.out, "(vec4(ivec4(")?; self.write_expr(arg, ctx)?; write!(self.out, " << 24, ")?; self.write_expr(arg, ctx)?; write!(self.out, " << 16, ")?; self.write_expr(arg, ctx)?; write!(self.out, " << 8, ")?; self.write_expr(arg, ctx)?; write!(self.out, ") >> 24) / {scale}.0)")?; return Ok(()); } } Mf::Unpack4x8unorm => { if self.options.version.supports_pack_unpack_4x8() { "unpackUnorm4x8" } else { let scale = 255; write!(self.out, "(vec4(")?; self.write_expr(arg, ctx)?; write!(self.out, " & 0xFFu, ")?; self.write_expr(arg, ctx)?; write!(self.out, " >> 8 & 0xFFu, ")?; self.write_expr(arg, ctx)?; write!(self.out, " >> 16 & 0xFFu, ")?; self.write_expr(arg, ctx)?; write!(self.out, " >> 24) / {scale}.0)")?; return Ok(()); } } fun @ (Mf::Unpack4xI8 | Mf::Unpack4xU8) => { let sign_prefix = match fun { Mf::Unpack4xI8 => 'i', Mf::Unpack4xU8 => 'u', _ => unreachable!(), }; write!(self.out, "{sign_prefix}vec4(")?; for i in 0..4 { write!(self.out, "bitfieldExtract(")?; // Since bitfieldExtract only sign extends if the value is signed, this // cast is needed match fun { Mf::Unpack4xI8 => { write!(self.out, "int(")?; self.write_expr(arg, ctx)?; write!(self.out, ")")?; } Mf::Unpack4xU8 => self.write_expr(arg, ctx)?, _ => unreachable!(), }; write!(self.out, ", {}, 8)", i * 8)?; if i != 3 { write!(self.out, ", ")?; } } write!(self.out, ")")?; return Ok(()); } }; let extract_bits = fun == Mf::ExtractBits; let insert_bits = fun == Mf::InsertBits; // Some GLSL functions always return signed integers (like findMSB), // so they need to be cast to uint if the argument is also an uint. let ret_might_need_int_to_uint = matches!( fun, Mf::FirstTrailingBit | Mf::FirstLeadingBit | Mf::CountOneBits | Mf::Abs ); // Some GLSL functions only accept signed integers (like abs), // so they need their argument cast from uint to int. let arg_might_need_uint_to_int = matches!(fun, Mf::Abs); // Check if the argument is an unsigned integer and return the vector size // in case it's a vector let maybe_uint_size = match *ctx.resolve_type(arg, &self.module.types) { TypeInner::Scalar(crate::Scalar { kind: crate::ScalarKind::Uint, .. }) => Some(None), TypeInner::Vector { scalar: crate::Scalar { kind: crate::ScalarKind::Uint, .. }, size, } => Some(Some(size)), _ => None, }; // Cast to uint if the function needs it if ret_might_need_int_to_uint { if let Some(maybe_size) = maybe_uint_size { match maybe_size { Some(size) => write!(self.out, "uvec{}(", size as u8)?, None => write!(self.out, "uint(")?, } } } write!(self.out, "{fun_name}(")?; // Cast to int if the function needs it if arg_might_need_uint_to_int { if let Some(maybe_size) = maybe_uint_size { match maybe_size { Some(size) => write!(self.out, "ivec{}(", size as u8)?, None => write!(self.out, "int(")?, } } } self.write_expr(arg, ctx)?; // Close the cast from uint to int if arg_might_need_uint_to_int && maybe_uint_size.is_some() { write!(self.out, ")")? } if let Some(arg) = arg1 { write!(self.out, ", ")?; if extract_bits { write!(self.out, "int(")?; self.write_expr(arg, ctx)?; write!(self.out, ")")?; } else { self.write_expr(arg, ctx)?; } } if let Some(arg) = arg2 { write!(self.out, ", ")?; if extract_bits || insert_bits { write!(self.out, "int(")?; self.write_expr(arg, ctx)?; write!(self.out, ")")?; } else { self.write_expr(arg, ctx)?; } } if let Some(arg) = arg3 { write!(self.out, ", ")?; if insert_bits { write!(self.out, "int(")?; self.write_expr(arg, ctx)?; write!(self.out, ")")?; } else { self.write_expr(arg, ctx)?; } } write!(self.out, ")")?; // Close the cast from int to uint if ret_might_need_int_to_uint && maybe_uint_size.is_some() { write!(self.out, ")")? } } // `As` is always a call. // If `convert` is true the function name is the type // Else the function name is one of the glsl provided bitcast functions Expression::As { expr, kind: target_kind, convert, } => { let inner = ctx.resolve_type(expr, &self.module.types); match convert { Some(width) => { // this is similar to `write_type`, but with the target kind let scalar = glsl_scalar(crate::Scalar { kind: target_kind, width, })?; match *inner { TypeInner::Matrix { columns, rows, .. } => write!( self.out, "{}mat{}x{}", scalar.prefix, columns as u8, rows as u8 )?, TypeInner::Vector { size, .. } => { write!(self.out, "{}vec{}", scalar.prefix, size as u8)? } _ => write!(self.out, "{}", scalar.full)?, } write!(self.out, "(")?; self.write_expr(expr, ctx)?; write!(self.out, ")")? } None => { use crate::ScalarKind as Sk; let target_vector_type = match *inner { TypeInner::Vector { size, scalar } => Some(TypeInner::Vector { size, scalar: crate::Scalar { kind: target_kind, width: scalar.width, }, }), _ => None, }; let source_kind = inner.scalar_kind().unwrap(); match (source_kind, target_kind, target_vector_type) { // No conversion needed (Sk::Sint, Sk::Sint, _) | (Sk::Uint, Sk::Uint, _) | (Sk::Float, Sk::Float, _) | (Sk::Bool, Sk::Bool, _) => { self.write_expr(expr, ctx)?; return Ok(()); } // Cast to/from floats (Sk::Float, Sk::Sint, _) => write!(self.out, "floatBitsToInt")?, (Sk::Float, Sk::Uint, _) => write!(self.out, "floatBitsToUint")?, (Sk::Sint, Sk::Float, _) => write!(self.out, "intBitsToFloat")?, (Sk::Uint, Sk::Float, _) => write!(self.out, "uintBitsToFloat")?, // Cast between vector types (_, _, Some(vector)) => { self.write_value_type(&vector)?; } // There is no way to bitcast between Uint/Sint in glsl. Use constructor conversion (Sk::Uint | Sk::Bool, Sk::Sint, None) => write!(self.out, "int")?, (Sk::Sint | Sk::Bool, Sk::Uint, None) => write!(self.out, "uint")?, (Sk::Bool, Sk::Float, None) => write!(self.out, "float")?, (Sk::Sint | Sk::Uint | Sk::Float, Sk::Bool, None) => { write!(self.out, "bool")? } (Sk::AbstractInt | Sk::AbstractFloat, _, _) | (_, Sk::AbstractInt | Sk::AbstractFloat, _) => unreachable!(), }; write!(self.out, "(")?; self.write_expr(expr, ctx)?; write!(self.out, ")")?; } } } // These expressions never show up in `Emit`. Expression::CallResult(_) | Expression::AtomicResult { .. } | Expression::RayQueryProceedResult | Expression::WorkGroupUniformLoadResult { .. } | Expression::SubgroupOperationResult { .. } | Expression::SubgroupBallotResult => unreachable!(), // `ArrayLength` is written as `expr.length()` and we convert it to a uint Expression::ArrayLength(expr) => { write!(self.out, "uint(")?; self.write_expr(expr, ctx)?; write!(self.out, ".length())")? } // not supported yet Expression::RayQueryGetIntersection { .. } | Expression::RayQueryVertexPositions { .. } | Expression::CooperativeLoad { .. } | Expression::CooperativeMultiplyAdd { .. } => unreachable!(), } Ok(()) } /// Helper function to write the local holding the clamped lod fn write_clamped_lod( &mut self, ctx: &back::FunctionCtx, expr: Handle, image: Handle, level_expr: Handle, ) -> Result<(), Error> { // Define our local and start a call to `clamp` write!( self.out, "int {}{} = clamp(", Baked(expr), CLAMPED_LOD_SUFFIX )?; // Write the lod that will be clamped self.write_expr(level_expr, ctx)?; // Set the min value to 0 and start a call to `textureQueryLevels` to get // the maximum value write!(self.out, ", 0, textureQueryLevels(")?; // Write the target image as an argument to `textureQueryLevels` self.write_expr(image, ctx)?; // Close the call to `textureQueryLevels` subtract 1 from it since // the lod argument is 0 based, close the `clamp` call and end the // local declaration statement. writeln!(self.out, ") - 1);")?; Ok(()) } // Helper method used to retrieve how many elements a coordinate vector // for the images operations need. fn get_coordinate_vector_size(&self, dim: crate::ImageDimension, arrayed: bool) -> u8 { // openGL es doesn't have 1D images so we need workaround it let tex_1d_hack = dim == crate::ImageDimension::D1 && self.options.version.is_es(); // Get how many components the coordinate vector needs for the dimensions only let tex_coord_size = match dim { crate::ImageDimension::D1 => 1, crate::ImageDimension::D2 => 2, crate::ImageDimension::D3 => 3, crate::ImageDimension::Cube => 2, }; // Calculate the true size of the coordinate vector by adding 1 for arrayed images // and another 1 if we need to workaround 1D images by making them 2D tex_coord_size + tex_1d_hack as u8 + arrayed as u8 } /// Helper method to write the coordinate vector for image operations fn write_texture_coord( &mut self, ctx: &back::FunctionCtx, vector_size: u8, coordinate: Handle, array_index: Option>, // Emulate 1D images as 2D for profiles that don't support it (glsl es) tex_1d_hack: bool, ) -> Result<(), Error> { match array_index { // If the image needs an array indice we need to add it to the end of our // coordinate vector, to do so we will use the `ivec(ivec, scalar)` // constructor notation (NOTE: the inner `ivec` can also be a scalar, this // is important for 1D arrayed images). Some(layer_expr) => { write!(self.out, "ivec{vector_size}(")?; self.write_expr(coordinate, ctx)?; write!(self.out, ", ")?; // If we are replacing sampler1D with sampler2D we also need // to add another zero to the coordinates vector for the y component if tex_1d_hack { write!(self.out, "0, ")?; } self.write_expr(layer_expr, ctx)?; write!(self.out, ")")?; } // Otherwise write just the expression (and the 1D hack if needed) None => { let uvec_size = match *ctx.resolve_type(coordinate, &self.module.types) { TypeInner::Scalar(crate::Scalar { kind: crate::ScalarKind::Uint, .. }) => Some(None), TypeInner::Vector { size, scalar: crate::Scalar { kind: crate::ScalarKind::Uint, .. }, } => Some(Some(size as u32)), _ => None, }; if tex_1d_hack { write!(self.out, "ivec2(")?; } else if uvec_size.is_some() { match uvec_size { Some(None) => write!(self.out, "int(")?, Some(Some(size)) => write!(self.out, "ivec{size}(")?, _ => {} } } self.write_expr(coordinate, ctx)?; if tex_1d_hack { write!(self.out, ", 0)")?; } else if uvec_size.is_some() { write!(self.out, ")")?; } } } Ok(()) } /// Helper method to write the `ImageStore` statement fn write_image_store( &mut self, ctx: &back::FunctionCtx, image: Handle, coordinate: Handle, array_index: Option>, value: Handle, ) -> Result<(), Error> { use crate::ImageDimension as IDim; // NOTE: openGL requires that `imageStore`s have no effects when the texel is invalid // so we don't need to generate bounds checks (OpenGL 4.2 Core §3.9.20) // This will only panic if the module is invalid let dim = match *ctx.resolve_type(image, &self.module.types) { TypeInner::Image { dim, .. } => dim, _ => unreachable!(), }; // Begin our call to `imageStore` write!(self.out, "imageStore(")?; self.write_expr(image, ctx)?; // Separate the image argument from the coordinates write!(self.out, ", ")?; // openGL es doesn't have 1D images so we need workaround it let tex_1d_hack = dim == IDim::D1 && self.options.version.is_es(); // Write the coordinate vector self.write_texture_coord( ctx, // Get the size of the coordinate vector self.get_coordinate_vector_size(dim, array_index.is_some()), coordinate, array_index, tex_1d_hack, )?; // Separate the coordinate from the value to write and write the expression // of the value to write. write!(self.out, ", ")?; self.write_expr(value, ctx)?; // End the call to `imageStore` and the statement. writeln!(self.out, ");")?; Ok(()) } /// Helper method to write the `ImageAtomic` statement fn write_image_atomic( &mut self, ctx: &back::FunctionCtx, image: Handle, coordinate: Handle, array_index: Option>, fun: crate::AtomicFunction, value: Handle, ) -> Result<(), Error> { use crate::ImageDimension as IDim; // NOTE: openGL requires that `imageAtomic`s have no effects when the texel is invalid // so we don't need to generate bounds checks (OpenGL 4.2 Core §3.9.20) // This will only panic if the module is invalid let dim = match *ctx.resolve_type(image, &self.module.types) { TypeInner::Image { dim, .. } => dim, _ => unreachable!(), }; // Begin our call to `imageAtomic` let fun_str = fun.to_glsl(); write!(self.out, "imageAtomic{fun_str}(")?; self.write_expr(image, ctx)?; // Separate the image argument from the coordinates write!(self.out, ", ")?; // openGL es doesn't have 1D images so we need workaround it let tex_1d_hack = dim == IDim::D1 && self.options.version.is_es(); // Write the coordinate vector self.write_texture_coord( ctx, // Get the size of the coordinate vector self.get_coordinate_vector_size(dim, false), coordinate, array_index, tex_1d_hack, )?; // Separate the coordinate from the value to write and write the expression // of the value to write. write!(self.out, ", ")?; self.write_expr(value, ctx)?; // End the call to `imageAtomic` and the statement. writeln!(self.out, ");")?; Ok(()) } /// Helper method for writing an `ImageLoad` expression. #[allow(clippy::too_many_arguments)] fn write_image_load( &mut self, handle: Handle, ctx: &back::FunctionCtx, image: Handle, coordinate: Handle, array_index: Option>, sample: Option>, level: Option>, ) -> Result<(), Error> { use crate::ImageDimension as IDim; // `ImageLoad` is a bit complicated. // There are two functions one for sampled // images another for storage images, the former uses `texelFetch` and the // latter uses `imageLoad`. // // Furthermore we have `level` which is always `Some` for sampled images // and `None` for storage images, so we end up with two functions: // - `texelFetch(image, coordinate, level)` for sampled images // - `imageLoad(image, coordinate)` for storage images // // Finally we also have to consider bounds checking, for storage images // this is easy since openGL requires that invalid texels always return // 0, for sampled images we need to either verify that all arguments are // in bounds (`ReadZeroSkipWrite`) or make them a valid texel (`Restrict`). // This will only panic if the module is invalid let (dim, class) = match *ctx.resolve_type(image, &self.module.types) { TypeInner::Image { dim, arrayed: _, class, } => (dim, class), _ => unreachable!(), }; // Get the name of the function to be used for the load operation // and the policy to be used with it. let (fun_name, policy) = match class { // Sampled images inherit the policy from the user passed policies crate::ImageClass::Sampled { .. } => ("texelFetch", self.policies.image_load), crate::ImageClass::Storage { .. } => { // OpenGL ES 3.1 mentions in Chapter "8.22 Texture Image Loads and Stores" that: // "Invalid image loads will return a vector where the value of R, G, and B components // is 0 and the value of the A component is undefined." // // OpenGL 4.2 Core mentions in Chapter "3.9.20 Texture Image Loads and Stores" that: // "Invalid image loads will return zero." // // So, we only inject bounds checks for ES let policy = if self.options.version.is_es() { self.policies.image_load } else { proc::BoundsCheckPolicy::Unchecked }; ("imageLoad", policy) } // TODO: Is there even a function for this? crate::ImageClass::Depth { multi: _ } => { return Err(Error::Custom( "WGSL `textureLoad` from depth textures is not supported in GLSL".to_string(), )) } crate::ImageClass::External => unimplemented!(), }; // openGL es doesn't have 1D images so we need workaround it let tex_1d_hack = dim == IDim::D1 && self.options.version.is_es(); // Get the size of the coordinate vector let vector_size = self.get_coordinate_vector_size(dim, array_index.is_some()); if let proc::BoundsCheckPolicy::ReadZeroSkipWrite = policy { // To write the bounds checks for `ReadZeroSkipWrite` we will use a // ternary operator since we are in the middle of an expression and // need to return a value. // // NOTE: glsl does short circuit when evaluating logical // expressions so we can be sure that after we test a // condition it will be true for the next ones // Write parentheses around the ternary operator to prevent problems with // expressions emitted before or after it having more precedence write!(self.out, "(",)?; // The lod check needs to precede the size check since we need // to use the lod to get the size of the image at that level. if let Some(level_expr) = level { self.write_expr(level_expr, ctx)?; write!(self.out, " < textureQueryLevels(",)?; self.write_expr(image, ctx)?; // Chain the next check write!(self.out, ") && ")?; } // Check that the sample arguments doesn't exceed the number of samples if let Some(sample_expr) = sample { self.write_expr(sample_expr, ctx)?; write!(self.out, " < textureSamples(",)?; self.write_expr(image, ctx)?; // Chain the next check write!(self.out, ") && ")?; } // We now need to write the size checks for the coordinates and array index // first we write the comparison function in case the image is 1D non arrayed // (and no 1D to 2D hack was needed) we are comparing scalars so the less than // operator will suffice, but otherwise we'll be comparing two vectors so we'll // need to use the `lessThan` function but it returns a vector of booleans (one // for each comparison) so we need to fold it all in one scalar boolean, since // we want all comparisons to pass we use the `all` function which will only // return `true` if all the elements of the boolean vector are also `true`. // // So we'll end with one of the following forms // - `coord < textureSize(image, lod)` for 1D images // - `all(lessThan(coord, textureSize(image, lod)))` for normal images // - `all(lessThan(ivec(coord, array_index), textureSize(image, lod)))` // for arrayed images // - `all(lessThan(coord, textureSize(image)))` for multi sampled images if vector_size != 1 { write!(self.out, "all(lessThan(")?; } // Write the coordinate vector self.write_texture_coord(ctx, vector_size, coordinate, array_index, tex_1d_hack)?; if vector_size != 1 { // If we used the `lessThan` function we need to separate the // coordinates from the image size. write!(self.out, ", ")?; } else { // If we didn't use it (ie. 1D images) we perform the comparison // using the less than operator. write!(self.out, " < ")?; } // Call `textureSize` to get our image size write!(self.out, "textureSize(")?; self.write_expr(image, ctx)?; // `textureSize` uses the lod as a second argument for mipmapped images if let Some(level_expr) = level { // Separate the image from the lod write!(self.out, ", ")?; self.write_expr(level_expr, ctx)?; } // Close the `textureSize` call write!(self.out, ")")?; if vector_size != 1 { // Close the `all` and `lessThan` calls write!(self.out, "))")?; } // Finally end the condition part of the ternary operator write!(self.out, " ? ")?; } // Begin the call to the function used to load the texel write!(self.out, "{fun_name}(")?; self.write_expr(image, ctx)?; write!(self.out, ", ")?; // If we are using `Restrict` bounds checking we need to pass valid texel // coordinates, to do so we use the `clamp` function to get a value between // 0 and the image size - 1 (indexing begins at 0) if let proc::BoundsCheckPolicy::Restrict = policy { write!(self.out, "clamp(")?; } // Write the coordinate vector self.write_texture_coord(ctx, vector_size, coordinate, array_index, tex_1d_hack)?; // If we are using `Restrict` bounds checking we need to write the rest of the // clamp we initiated before writing the coordinates. if let proc::BoundsCheckPolicy::Restrict = policy { // Write the min value 0 if vector_size == 1 { write!(self.out, ", 0")?; } else { write!(self.out, ", ivec{vector_size}(0)")?; } // Start the `textureSize` call to use as the max value. write!(self.out, ", textureSize(")?; self.write_expr(image, ctx)?; // If the image is mipmapped we need to add the lod argument to the // `textureSize` call, but this needs to be the clamped lod, this should // have been generated earlier and put in a local. if class.is_mipmapped() { write!(self.out, ", {}{}", Baked(handle), CLAMPED_LOD_SUFFIX)?; } // Close the `textureSize` call write!(self.out, ")")?; // Subtract 1 from the `textureSize` call since the coordinates are zero based. if vector_size == 1 { write!(self.out, " - 1")?; } else { write!(self.out, " - ivec{vector_size}(1)")?; } // Close the `clamp` call write!(self.out, ")")?; // Add the clamped lod (if present) as the second argument to the // image load function. if level.is_some() { write!(self.out, ", {}{}", Baked(handle), CLAMPED_LOD_SUFFIX)?; } // If a sample argument is needed we need to clamp it between 0 and // the number of samples the image has. if let Some(sample_expr) = sample { write!(self.out, ", clamp(")?; self.write_expr(sample_expr, ctx)?; // Set the min value to 0 and start the call to `textureSamples` write!(self.out, ", 0, textureSamples(")?; self.write_expr(image, ctx)?; // Close the `textureSamples` call, subtract 1 from it since the sample // argument is zero based, and close the `clamp` call writeln!(self.out, ") - 1)")?; } } else if let Some(sample_or_level) = sample.or(level) { // GLSL only support SInt on this field while WGSL support also UInt let cast_to_int = matches!( *ctx.resolve_type(sample_or_level, &self.module.types), TypeInner::Scalar(crate::Scalar { kind: crate::ScalarKind::Uint, .. }) ); // If no bounds checking is need just add the sample or level argument // after the coordinates write!(self.out, ", ")?; if cast_to_int { write!(self.out, "int(")?; } self.write_expr(sample_or_level, ctx)?; if cast_to_int { write!(self.out, ")")?; } } // Close the image load function. write!(self.out, ")")?; // If we were using the `ReadZeroSkipWrite` policy we need to end the first branch // (which is taken if the condition is `true`) with a colon (`:`) and write the // second branch which is just a 0 value. if let proc::BoundsCheckPolicy::ReadZeroSkipWrite = policy { // Get the kind of the output value. let kind = match class { // Only sampled images can reach here since storage images // don't need bounds checks and depth images aren't implemented crate::ImageClass::Sampled { kind, .. } => kind, _ => unreachable!(), }; // End the first branch write!(self.out, " : ")?; // Write the 0 value write!( self.out, "{}vec4(", glsl_scalar(crate::Scalar { kind, width: 4 })?.prefix, )?; self.write_zero_init_scalar(kind)?; // Close the zero value constructor write!(self.out, ")")?; // Close the parentheses surrounding our ternary write!(self.out, ")")?; } Ok(()) } fn write_named_expr( &mut self, handle: Handle, name: String, // The expression which is being named. // Generally, this is the same as handle, except in WorkGroupUniformLoad named: Handle, ctx: &back::FunctionCtx, ) -> BackendResult { match ctx.info[named].ty { proc::TypeResolution::Handle(ty_handle) => match self.module.types[ty_handle].inner { TypeInner::Struct { .. } => { let ty_name = &self.names[&NameKey::Type(ty_handle)]; write!(self.out, "{ty_name}")?; } _ => { self.write_type(ty_handle)?; } }, proc::TypeResolution::Value(ref inner) => { self.write_value_type(inner)?; } } let resolved = ctx.resolve_type(named, &self.module.types); write!(self.out, " {name}")?; if let TypeInner::Array { base, size, .. } = *resolved { self.write_array_size(base, size)?; } write!(self.out, " = ")?; self.write_expr(handle, ctx)?; writeln!(self.out, ";")?; self.named_expressions.insert(named, name); Ok(()) } /// Helper function that write string with default zero initialization for supported types fn write_zero_init_value(&mut self, ty: Handle) -> BackendResult { let inner = &self.module.types[ty].inner; match *inner { TypeInner::Scalar(scalar) | TypeInner::Atomic(scalar) => { self.write_zero_init_scalar(scalar.kind)?; } TypeInner::Vector { scalar, .. } => { self.write_value_type(inner)?; write!(self.out, "(")?; self.write_zero_init_scalar(scalar.kind)?; write!(self.out, ")")?; } TypeInner::Matrix { .. } => { self.write_value_type(inner)?; write!(self.out, "(")?; self.write_zero_init_scalar(crate::ScalarKind::Float)?; write!(self.out, ")")?; } TypeInner::Array { base, size, .. } => { let count = match size.resolve(self.module.to_ctx())? { proc::IndexableLength::Known(count) => count, proc::IndexableLength::Dynamic => return Ok(()), }; self.write_type(base)?; self.write_array_size(base, size)?; write!(self.out, "(")?; for _ in 1..count { self.write_zero_init_value(base)?; write!(self.out, ", ")?; } // write last parameter without comma and space self.write_zero_init_value(base)?; write!(self.out, ")")?; } TypeInner::Struct { ref members, .. } => { let name = &self.names[&NameKey::Type(ty)]; write!(self.out, "{name}(")?; for (index, member) in members.iter().enumerate() { if index != 0 { write!(self.out, ", ")?; } self.write_zero_init_value(member.ty)?; } write!(self.out, ")")?; } _ => unreachable!(), } Ok(()) } /// Helper function that write string with zero initialization for scalar fn write_zero_init_scalar(&mut self, kind: crate::ScalarKind) -> BackendResult { match kind { crate::ScalarKind::Bool => write!(self.out, "false")?, crate::ScalarKind::Uint => write!(self.out, "0u")?, crate::ScalarKind::Float => write!(self.out, "0.0")?, crate::ScalarKind::Sint => write!(self.out, "0")?, crate::ScalarKind::AbstractInt | crate::ScalarKind::AbstractFloat => { return Err(Error::Custom( "Abstract types should not appear in IR presented to backends".to_string(), )) } } Ok(()) } /// Issue a control barrier. fn write_control_barrier( &mut self, flags: crate::Barrier, level: back::Level, ) -> BackendResult { self.write_memory_barrier(flags, level)?; writeln!(self.out, "{level}barrier();")?; Ok(()) } /// Issue a memory barrier. fn write_memory_barrier(&mut self, flags: crate::Barrier, level: back::Level) -> BackendResult { if flags.contains(crate::Barrier::STORAGE) { writeln!(self.out, "{level}memoryBarrierBuffer();")?; } if flags.contains(crate::Barrier::WORK_GROUP) { writeln!(self.out, "{level}memoryBarrierShared();")?; } if flags.contains(crate::Barrier::SUB_GROUP) { writeln!(self.out, "{level}subgroupMemoryBarrier();")?; } if flags.contains(crate::Barrier::TEXTURE) { writeln!(self.out, "{level}memoryBarrierImage();")?; } Ok(()) } /// Helper function that return the glsl storage access string of [`StorageAccess`](crate::StorageAccess) /// /// glsl allows adding both `readonly` and `writeonly` but this means that /// they can only be used to query information about the resource which isn't what /// we want here so when storage access is both `LOAD` and `STORE` add no modifiers fn write_storage_access(&mut self, storage_access: crate::StorageAccess) -> BackendResult { if storage_access.contains(crate::StorageAccess::ATOMIC) { return Ok(()); } if !storage_access.contains(crate::StorageAccess::STORE) { write!(self.out, "readonly ")?; } if !storage_access.contains(crate::StorageAccess::LOAD) { write!(self.out, "writeonly ")?; } Ok(()) } /// Helper method used to produce the reflection info that's returned to the user fn collect_reflection_info(&mut self) -> Result { let info = self.info.get_entry_point(self.entry_point_idx as usize); let mut texture_mapping = crate::FastHashMap::default(); let mut uniforms = crate::FastHashMap::default(); for sampling in info.sampling_set.iter() { let tex_name = self.reflection_names_globals[&sampling.image].clone(); match texture_mapping.entry(tex_name) { hash_map::Entry::Vacant(v) => { v.insert(TextureMapping { texture: sampling.image, sampler: Some(sampling.sampler), }); } hash_map::Entry::Occupied(e) => { if e.get().sampler != Some(sampling.sampler) { log::error!("Conflicting samplers for {}", e.key()); return Err(Error::ImageMultipleSamplers); } } } } let mut immediates_info = None; for (handle, var) in self.module.global_variables.iter() { if info[handle].is_empty() { continue; } match self.module.types[var.ty].inner { TypeInner::Image { .. } => { let tex_name = self.reflection_names_globals[&handle].clone(); match texture_mapping.entry(tex_name) { hash_map::Entry::Vacant(v) => { v.insert(TextureMapping { texture: handle, sampler: None, }); } hash_map::Entry::Occupied(_) => { // already used with a sampler, do nothing } } } _ => match var.space { crate::AddressSpace::Uniform | crate::AddressSpace::Storage { .. } => { let name = self.reflection_names_globals[&handle].clone(); uniforms.insert(handle, name); } crate::AddressSpace::Immediate => { let name = self.reflection_names_globals[&handle].clone(); immediates_info = Some((name, var.ty)); } _ => (), }, } } let mut immediates_segments = Vec::new(); let mut immediates_items = vec![]; if let Some((name, ty)) = immediates_info { // We don't have a layouter available to us, so we need to create one. // // This is potentially a bit wasteful, but the set of types in the program // shouldn't be too large. let mut layouter = proc::Layouter::default(); layouter.update(self.module.to_ctx()).unwrap(); // We start with the name of the binding itself. immediates_segments.push(name); // We then recursively collect all the uniform fields of the immediate data. self.collect_immediates_items( ty, &mut immediates_segments, &layouter, &mut 0, &mut immediates_items, ); } Ok(ReflectionInfo { texture_mapping, uniforms, varying: mem::take(&mut self.varying), immediates_items, clip_distance_count: self.clip_distance_count, }) } fn collect_immediates_items( &mut self, ty: Handle, segments: &mut Vec, layouter: &proc::Layouter, offset: &mut u32, items: &mut Vec, ) { // At this point in the recursion, `segments` contains the path // needed to access `ty` from the root. let layout = &layouter[ty]; *offset = layout.alignment.round_up(*offset); match self.module.types[ty].inner { // All these types map directly to GL uniforms. TypeInner::Scalar { .. } | TypeInner::Vector { .. } | TypeInner::Matrix { .. } => { // Build the full name, by combining all current segments. let name: String = segments.iter().map(String::as_str).collect(); items.push(ImmediateItem { access_path: name, offset: *offset, ty, }); *offset += layout.size; } // Arrays are recursed into. TypeInner::Array { base, size, .. } => { let crate::ArraySize::Constant(count) = size else { unreachable!("Cannot have dynamic arrays in immediates"); }; for i in 0..count.get() { // Add the array accessor and recurse. segments.push(format!("[{i}]")); self.collect_immediates_items(base, segments, layouter, offset, items); segments.pop(); } // Ensure the stride is kept by rounding up to the alignment. *offset = layout.alignment.round_up(*offset) } TypeInner::Struct { ref members, .. } => { for (index, member) in members.iter().enumerate() { // Add struct accessor and recurse. segments.push(format!( ".{}", self.names[&NameKey::StructMember(ty, index as u32)] )); self.collect_immediates_items(member.ty, segments, layouter, offset, items); segments.pop(); } // Ensure ending padding is kept by rounding up to the alignment. *offset = layout.alignment.round_up(*offset) } _ => unreachable!(), } } } naga-29.0.3/src/back/hlsl/conv.rs000064400000000000000000000226401046102023000145750ustar 00000000000000use crate::common; use alloc::{borrow::Cow, format, string::String}; use super::Error; use crate::proc::Alignment; impl crate::ScalarKind { pub(super) fn to_hlsl_cast(self) -> &'static str { match self { Self::Float => "asfloat", Self::Sint => "asint", Self::Uint => "asuint", Self::Bool | Self::AbstractInt | Self::AbstractFloat => unreachable!(), } } } impl crate::Scalar { /// Helper function that returns scalar related strings /// /// pub(super) const fn to_hlsl_str(self) -> Result<&'static str, Error> { match self.kind { crate::ScalarKind::Sint => match self.width { 4 => Ok("int"), 8 => Ok("int64_t"), _ => Err(Error::UnsupportedScalar(self)), }, crate::ScalarKind::Uint => match self.width { 4 => Ok("uint"), 8 => Ok("uint64_t"), _ => Err(Error::UnsupportedScalar(self)), }, crate::ScalarKind::Float => match self.width { 2 => Ok("half"), 4 => Ok("float"), 8 => Ok("double"), _ => Err(Error::UnsupportedScalar(self)), }, crate::ScalarKind::Bool => Ok("bool"), crate::ScalarKind::AbstractInt | crate::ScalarKind::AbstractFloat => { Err(Error::UnsupportedScalar(self)) } } } } impl crate::TypeInner { pub(super) const fn is_matrix(&self) -> bool { match *self { Self::Matrix { .. } => true, _ => false, } } pub(super) fn size_hlsl(&self, gctx: crate::proc::GlobalCtx) -> Result { match *self { Self::Matrix { columns, rows, scalar, } => { let stride = Alignment::from(rows) * scalar.width as u32; let last_row_size = rows as u32 * scalar.width as u32; Ok(((columns as u32 - 1) * stride) + last_row_size) } Self::Array { base, size, stride } => { let count = match size.resolve(gctx)? { crate::proc::IndexableLength::Known(size) => size, // A dynamically-sized array has to have at least one element crate::proc::IndexableLength::Dynamic => 1, }; let last_el_size = gctx.types[base].inner.size_hlsl(gctx)?; Ok(((count - 1) * stride) + last_el_size) } _ => Ok(self.size(gctx)), } } /// Used to generate the name of the wrapped type constructor pub(super) fn hlsl_type_id<'a>( base: crate::Handle, gctx: crate::proc::GlobalCtx, names: &'a crate::FastHashMap, ) -> Result, Error> { Ok(match gctx.types[base].inner { crate::TypeInner::Scalar(scalar) => Cow::Borrowed(scalar.to_hlsl_str()?), crate::TypeInner::Vector { size, scalar } => Cow::Owned(format!( "{}{}", scalar.to_hlsl_str()?, common::vector_size_str(size) )), crate::TypeInner::Matrix { columns, rows, scalar, } => Cow::Owned(format!( "{}{}x{}", scalar.to_hlsl_str()?, common::vector_size_str(columns), common::vector_size_str(rows), )), crate::TypeInner::Array { base, size: crate::ArraySize::Constant(size), .. } => Cow::Owned(format!( "array{size}_{}_", Self::hlsl_type_id(base, gctx, names)? )), crate::TypeInner::Struct { .. } => { Cow::Borrowed(&names[&crate::proc::NameKey::Type(base)]) } _ => unreachable!(), }) } } impl crate::StorageFormat { pub(super) const fn to_hlsl_str(self) -> &'static str { match self { Self::R16Float | Self::R32Float => "float", Self::R8Unorm | Self::R16Unorm => "unorm float", Self::R8Snorm | Self::R16Snorm => "snorm float", Self::R8Uint | Self::R16Uint | Self::R32Uint => "uint", Self::R8Sint | Self::R16Sint | Self::R32Sint => "int", Self::R64Uint => "uint64_t", Self::Rg16Float | Self::Rg32Float => "float4", Self::Rg8Unorm | Self::Rg16Unorm => "unorm float4", Self::Rg8Snorm | Self::Rg16Snorm => "snorm float4", Self::Rg8Sint | Self::Rg16Sint | Self::Rg32Uint => "int4", Self::Rg8Uint | Self::Rg16Uint | Self::Rg32Sint => "uint4", Self::Rg11b10Ufloat => "float4", Self::Rgba16Float | Self::Rgba32Float => "float4", Self::Rgba8Unorm | Self::Bgra8Unorm | Self::Rgba16Unorm | Self::Rgb10a2Unorm => { "unorm float4" } Self::Rgba8Snorm | Self::Rgba16Snorm => "snorm float4", Self::Rgba8Uint | Self::Rgba16Uint | Self::Rgba32Uint | Self::Rgb10a2Uint => "uint4", Self::Rgba8Sint | Self::Rgba16Sint | Self::Rgba32Sint => "int4", } } } impl crate::BuiltIn { pub(super) fn to_hlsl_str(self) -> Result<&'static str, Error> { Ok(match self { Self::Position { .. } => "SV_Position", // vertex Self::ClipDistance => "SV_ClipDistance", Self::CullDistance => "SV_CullDistance", Self::InstanceIndex => "SV_InstanceID", Self::VertexIndex => "SV_VertexID", // fragment Self::FragDepth => "SV_Depth", Self::FrontFacing => "SV_IsFrontFace", Self::PrimitiveIndex => "SV_PrimitiveID", Self::Barycentric { .. } => "SV_Barycentrics", Self::SampleIndex => "SV_SampleIndex", Self::SampleMask => "SV_Coverage", // compute Self::GlobalInvocationId => "SV_DispatchThreadID", Self::LocalInvocationId => "SV_GroupThreadID", Self::LocalInvocationIndex => "SV_GroupIndex", Self::WorkGroupId => "SV_GroupID", // The specific semantic we use here doesn't matter, because references // to this field will get replaced with references to `SPECIAL_CBUF_VAR` // in `Writer::write_expr`. Self::NumWorkGroups => "SV_GroupID", Self::ViewIndex => "SV_ViewID", // These builtins map to functions Self::SubgroupSize | Self::SubgroupInvocationId | Self::NumSubgroups | Self::SubgroupId => unreachable!(), Self::BaseInstance | Self::BaseVertex | Self::WorkGroupSize => { return Err(Error::Unimplemented(format!("builtin {self:?}"))) } Self::PointSize | Self::PointCoord | Self::DrawIndex => { return Err(Error::Custom(format!("Unsupported builtin {self:?}"))) } Self::CullPrimitive => "SV_CullPrimitive", Self::PointIndex | Self::LineIndices | Self::TriangleIndices => unimplemented!(), Self::MeshTaskSize | Self::VertexCount | Self::PrimitiveCount | Self::Vertices | Self::Primitives => unreachable!(), Self::RayInvocationId | Self::NumRayInvocations | Self::InstanceCustomData | Self::GeometryIndex | Self::WorldRayOrigin | Self::WorldRayDirection | Self::ObjectRayOrigin | Self::ObjectRayDirection | Self::RayTmin | Self::RayTCurrentMax | Self::ObjectToWorld | Self::WorldToObject | Self::HitKind => unreachable!(), }) } } impl crate::Interpolation { /// Return the string corresponding to the HLSL interpolation qualifier. pub(super) const fn to_hlsl_str(self) -> Option<&'static str> { match self { // Would be "linear", but it's the default interpolation in SM4 and up // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-struct#interpolation-modifiers-introduced-in-shader-model-4 Self::Perspective => None, Self::Linear => Some("noperspective"), Self::Flat => Some("nointerpolation"), Self::PerVertex => unreachable!(), } } } impl crate::Sampling { /// Return the HLSL auxiliary qualifier for the given sampling value. pub(super) const fn to_hlsl_str(self) -> Option<&'static str> { match self { Self::Center | Self::First | Self::Either => None, Self::Centroid => Some("centroid"), Self::Sample => Some("sample"), } } } impl crate::AtomicFunction { /// Return the HLSL suffix for the `InterlockedXxx` method. pub(super) const fn to_hlsl_suffix(self) -> &'static str { match self { Self::Add | Self::Subtract => "Add", Self::And => "And", Self::InclusiveOr => "Or", Self::ExclusiveOr => "Xor", Self::Min => "Min", Self::Max => "Max", Self::Exchange { compare: None } => "Exchange", Self::Exchange { .. } => "CompareExchange", } } } naga-29.0.3/src/back/hlsl/help.rs000064400000000000000000002761741046102023000145750ustar 00000000000000/*! Helpers for the hlsl backend Important note about `Expression::ImageQuery`/`Expression::ArrayLength` and hlsl backend: Due to implementation of `GetDimensions` function in hlsl () backend can't work with it as an expression. Instead, it generates a unique wrapped function per `Expression::ImageQuery`, based on texture info and query function. See `WrappedImageQuery` struct that represents a unique function and will be generated before writing all statements and expressions. This allowed to works with `Expression::ImageQuery` as expression and write wrapped function. For example: ```wgsl let dim_1d = textureDimensions(image_1d); ``` ```hlsl int NagaDimensions1D(Texture1D) { uint4 ret; image_1d.GetDimensions(ret.x); return ret.x; } int dim_1d = NagaDimensions1D(image_1d); ``` */ use alloc::format; use core::fmt::Write; use super::{ super::FunctionCtx, writer::{ ABS_FUNCTION, DIV_FUNCTION, EXTRACT_BITS_FUNCTION, F2I32_FUNCTION, F2I64_FUNCTION, F2U32_FUNCTION, F2U64_FUNCTION, IMAGE_LOAD_EXTERNAL_FUNCTION, IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION, INSERT_BITS_FUNCTION, MOD_FUNCTION, NEG_FUNCTION, }, BackendResult, WrappedType, }; use crate::{arena::Handle, proc::NameKey, ScalarKind}; #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] pub(super) struct WrappedArrayLength { pub(super) writable: bool, } #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] pub(super) struct WrappedImageLoad { pub(super) class: crate::ImageClass, } #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] pub(super) struct WrappedImageSample { pub(super) class: crate::ImageClass, pub(super) clamp_to_edge: bool, } #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] pub(super) struct WrappedImageQuery { pub(super) dim: crate::ImageDimension, pub(super) arrayed: bool, pub(super) class: crate::ImageClass, pub(super) query: ImageQuery, } #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] pub(super) struct WrappedConstructor { pub(super) ty: Handle, } #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] pub(super) struct WrappedStructMatrixAccess { pub(super) ty: Handle, pub(super) index: u32, } #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] pub(super) struct WrappedMatCx2 { pub(super) columns: crate::VectorSize, } #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] pub(super) struct WrappedMath { pub(super) fun: crate::MathFunction, pub(super) scalar: crate::Scalar, pub(super) components: Option, } #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] pub(super) struct WrappedZeroValue { pub(super) ty: Handle, } #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] pub(super) struct WrappedUnaryOp { pub(super) op: crate::UnaryOperator, // This can only represent scalar or vector types. If we ever need to wrap // unary ops with other types, we'll need a better representation. pub(super) ty: (Option, crate::Scalar), } #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] pub(super) struct WrappedBinaryOp { pub(super) op: crate::BinaryOperator, // This can only represent scalar or vector types. If we ever need to wrap // binary ops with other types, we'll need a better representation. pub(super) left_ty: (Option, crate::Scalar), pub(super) right_ty: (Option, crate::Scalar), } #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] pub(super) struct WrappedCast { // This can only represent scalar or vector types. If we ever need to wrap // casts with other types, we'll need a better representation. pub(super) vector_size: Option, pub(super) src_scalar: crate::Scalar, pub(super) dst_scalar: crate::Scalar, } /// HLSL backend requires its own `ImageQuery` enum. /// /// It is used inside `WrappedImageQuery` and should be unique per ImageQuery function. /// IR version can't be unique per function, because it's store mipmap level as an expression. /// /// For example: /// ```wgsl /// let dim_cube_array_lod = textureDimensions(image_cube_array, 1); /// let dim_cube_array_lod2 = textureDimensions(image_cube_array, 1); /// ``` /// /// ```ir /// ImageQuery { /// image: [1], /// query: Size { /// level: Some( /// [1], /// ), /// }, /// }, /// ImageQuery { /// image: [1], /// query: Size { /// level: Some( /// [2], /// ), /// }, /// }, /// ``` /// /// HLSL should generate only 1 function for this case. #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] pub(super) enum ImageQuery { Size, SizeLevel, NumLevels, NumLayers, NumSamples, } impl From for ImageQuery { fn from(q: crate::ImageQuery) -> Self { use crate::ImageQuery as Iq; match q { Iq::Size { level: Some(_) } => ImageQuery::SizeLevel, Iq::Size { level: None } => ImageQuery::Size, Iq::NumLevels => ImageQuery::NumLevels, Iq::NumLayers => ImageQuery::NumLayers, Iq::NumSamples => ImageQuery::NumSamples, } } } pub(super) const IMAGE_STORAGE_LOAD_SCALAR_WRAPPER: &str = "LoadedStorageValueFrom"; impl super::Writer<'_, W> { pub(super) fn write_image_type( &mut self, dim: crate::ImageDimension, arrayed: bool, class: crate::ImageClass, ) -> BackendResult { let access_str = match class { crate::ImageClass::Storage { .. } => "RW", _ => "", }; let dim_str = dim.to_hlsl_str(); let arrayed_str = if arrayed { "Array" } else { "" }; write!(self.out, "{access_str}Texture{dim_str}{arrayed_str}")?; match class { crate::ImageClass::Depth { multi } => { let multi_str = if multi { "MS" } else { "" }; write!(self.out, "{multi_str}")? } crate::ImageClass::Sampled { kind, multi } => { let multi_str = if multi { "MS" } else { "" }; let scalar_kind_str = crate::Scalar { kind, width: 4 }.to_hlsl_str()?; write!(self.out, "{multi_str}<{scalar_kind_str}4>")? } crate::ImageClass::Storage { format, .. } => { let storage_format_str = format.to_hlsl_str(); write!(self.out, "<{storage_format_str}>")? } crate::ImageClass::External => { unreachable!( "external images should be handled by `write_global_external_texture`" ); } } Ok(()) } pub(super) fn write_wrapped_array_length_function_name( &mut self, query: WrappedArrayLength, ) -> BackendResult { let access_str = if query.writable { "RW" } else { "" }; write!(self.out, "NagaBufferLength{access_str}",)?; Ok(()) } /// Helper function that write wrapped function for `Expression::ArrayLength` /// /// pub(super) fn write_wrapped_array_length_function( &mut self, wal: WrappedArrayLength, ) -> BackendResult { use crate::back::INDENT; const ARGUMENT_VARIABLE_NAME: &str = "buffer"; const RETURN_VARIABLE_NAME: &str = "ret"; // Write function return type and name write!(self.out, "uint ")?; self.write_wrapped_array_length_function_name(wal)?; // Write function parameters write!(self.out, "(")?; let access_str = if wal.writable { "RW" } else { "" }; writeln!( self.out, "{access_str}ByteAddressBuffer {ARGUMENT_VARIABLE_NAME})" )?; // Write function body writeln!(self.out, "{{")?; // Write `GetDimensions` function. writeln!(self.out, "{INDENT}uint {RETURN_VARIABLE_NAME};")?; writeln!( self.out, "{INDENT}{ARGUMENT_VARIABLE_NAME}.GetDimensions({RETURN_VARIABLE_NAME});" )?; // Write return value writeln!(self.out, "{INDENT}return {RETURN_VARIABLE_NAME};")?; // End of function body writeln!(self.out, "}}")?; // Write extra new line writeln!(self.out)?; Ok(()) } /// Helper function used by [`Self::write_wrapped_image_load_function`] and /// [`Self::write_wrapped_image_sample_function`] to write the shared YUV /// to RGB conversion code for external textures. Expects the preceding /// code to declare the Y component as a `float` variable of name `y`, the /// UV components as a `float2` variable of name `uv`, and the external /// texture params as a variable of name `params`. The emitted code will /// return the result. fn write_convert_yuv_to_rgb_and_return( &mut self, level: crate::back::Level, y: &str, uv: &str, params: &str, ) -> BackendResult { let l1 = level; let l2 = l1.next(); // Convert from YUV to non-linear RGB in the source color space. We // declare our matrices as row_major in HLSL, therefore we must reverse // the order of this multiplication writeln!( self.out, "{l1}float3 srcGammaRgb = mul(float4({y}, {uv}, 1.0), {params}.yuv_conversion_matrix).rgb;" )?; // Apply the inverse of the source transfer function to convert to // linear RGB in the source color space. writeln!( self.out, "{l1}float3 srcLinearRgb = srcGammaRgb < {params}.src_tf.k * {params}.src_tf.b ?" )?; writeln!(self.out, "{l2}srcGammaRgb / {params}.src_tf.k :")?; writeln!(self.out, "{l2}pow((srcGammaRgb + {params}.src_tf.a - 1.0) / {params}.src_tf.a, {params}.src_tf.g);")?; // Multiply by the gamut conversion matrix to convert to linear RGB in // the destination color space. We declare our matrices as row_major in // HLSL, therefore we must reverse the order of this multiplication. writeln!( self.out, "{l1}float3 dstLinearRgb = mul(srcLinearRgb, {params}.gamut_conversion_matrix);" )?; // Finally, apply the dest transfer function to convert to non-linear // RGB in the destination color space, and return the result. writeln!( self.out, "{l1}float3 dstGammaRgb = dstLinearRgb < {params}.dst_tf.b ?" )?; writeln!(self.out, "{l2}{params}.dst_tf.k * dstLinearRgb :")?; writeln!(self.out, "{l2}{params}.dst_tf.a * pow(dstLinearRgb, 1.0 / {params}.dst_tf.g) - ({params}.dst_tf.a - 1);")?; writeln!(self.out, "{l1}return float4(dstGammaRgb, 1.0);")?; Ok(()) } pub(super) fn write_wrapped_image_load_function( &mut self, module: &crate::Module, load: WrappedImageLoad, ) -> BackendResult { match load { WrappedImageLoad { class: crate::ImageClass::External, } => { let l1 = crate::back::Level(1); let l2 = l1.next(); let l3 = l2.next(); let params_ty_name = &self.names [&NameKey::Type(module.special_types.external_texture_params.unwrap())]; writeln!(self.out, "float4 {IMAGE_LOAD_EXTERNAL_FUNCTION}(")?; writeln!(self.out, "{l1}Texture2D plane0,")?; writeln!(self.out, "{l1}Texture2D plane1,")?; writeln!(self.out, "{l1}Texture2D plane2,")?; writeln!(self.out, "{l1}{params_ty_name} params,")?; writeln!(self.out, "{l1}uint2 coords)")?; writeln!(self.out, "{{")?; writeln!(self.out, "{l1}uint2 plane0_size;")?; writeln!( self.out, "{l1}plane0.GetDimensions(plane0_size.x, plane0_size.y);" )?; // Clamp coords to provided size of external texture to prevent OOB read. // If params.size is zero then clamp to the actual size of the texture. writeln!( self.out, "{l1}uint2 cropped_size = any(params.size) ? params.size : plane0_size;" )?; writeln!(self.out, "{l1}coords = min(coords, cropped_size - 1);")?; // Apply load transformation. We declare our matrices as row_major in // HLSL, therefore we must reverse the order of this multiplication writeln!(self.out, "{l1}float3x2 load_transform = float3x2(")?; writeln!(self.out, "{l2}params.load_transform_0,")?; writeln!(self.out, "{l2}params.load_transform_1,")?; writeln!(self.out, "{l2}params.load_transform_2")?; writeln!(self.out, "{l1});")?; writeln!(self.out, "{l1}uint2 plane0_coords = uint2(round(mul(float3(coords, 1.0), load_transform)));")?; writeln!(self.out, "{l1}if (params.num_planes == 1u) {{")?; // For single plane, simply read from plane0 writeln!( self.out, "{l2}return plane0.Load(uint3(plane0_coords, 0u));" )?; writeln!(self.out, "{l1}}} else {{")?; // Chroma planes may be subsampled so we must scale the coords accordingly. writeln!(self.out, "{l2}uint2 plane1_size;")?; writeln!( self.out, "{l2}plane1.GetDimensions(plane1_size.x, plane1_size.y);" )?; writeln!(self.out, "{l2}uint2 plane1_coords = uint2(floor(float2(plane0_coords) * float2(plane1_size) / float2(plane0_size)));")?; // For multi-plane, read the Y value from plane 0 writeln!( self.out, "{l2}float y = plane0.Load(uint3(plane0_coords, 0u)).x;" )?; writeln!(self.out, "{l2}float2 uv;")?; writeln!(self.out, "{l2}if (params.num_planes == 2u) {{")?; // Read UV from interleaved plane 1 writeln!( self.out, "{l3}uv = plane1.Load(uint3(plane1_coords, 0u)).xy;" )?; writeln!(self.out, "{l2}}} else {{")?; // Read U and V from planes 1 and 2 respectively writeln!(self.out, "{l3}uint2 plane2_size;")?; writeln!( self.out, "{l3}plane2.GetDimensions(plane2_size.x, plane2_size.y);" )?; writeln!(self.out, "{l3}uint2 plane2_coords = uint2(floor(float2(plane0_coords) * float2(plane2_size) / float2(plane0_size)));")?; writeln!(self.out, "{l3}uv = float2(plane1.Load(uint3(plane1_coords, 0u)).x, plane2.Load(uint3(plane2_coords, 0u)).x);")?; writeln!(self.out, "{l2}}}")?; self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "params")?; writeln!(self.out, "{l1}}}")?; writeln!(self.out, "}}")?; writeln!(self.out)?; } _ => {} } Ok(()) } pub(super) fn write_wrapped_image_sample_function( &mut self, module: &crate::Module, sample: WrappedImageSample, ) -> BackendResult { match sample { WrappedImageSample { class: crate::ImageClass::External, clamp_to_edge: true, } => { let l1 = crate::back::Level(1); let l2 = l1.next(); let l3 = l2.next(); let params_ty_name = &self.names [&NameKey::Type(module.special_types.external_texture_params.unwrap())]; writeln!( self.out, "float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}(" )?; writeln!(self.out, "{l1}Texture2D plane0,")?; writeln!(self.out, "{l1}Texture2D plane1,")?; writeln!(self.out, "{l1}Texture2D plane2,")?; writeln!(self.out, "{l1}{params_ty_name} params,")?; writeln!(self.out, "{l1}SamplerState samp,")?; writeln!(self.out, "{l1}float2 coords)")?; writeln!(self.out, "{{")?; writeln!(self.out, "{l1}float2 plane0_size;")?; writeln!( self.out, "{l1}plane0.GetDimensions(plane0_size.x, plane0_size.y);" )?; writeln!(self.out, "{l1}float3x2 sample_transform = float3x2(")?; writeln!(self.out, "{l2}params.sample_transform_0,")?; writeln!(self.out, "{l2}params.sample_transform_1,")?; writeln!(self.out, "{l2}params.sample_transform_2")?; writeln!(self.out, "{l1});")?; // Apply sample transformation. We declare our matrices as row_major in // HLSL, therefore we must reverse the order of this multiplication writeln!( self.out, "{l1}coords = mul(float3(coords, 1.0), sample_transform);" )?; // Calculate the sample bounds. The purported size of the texture // (params.size) is irrelevant here as we are dealing with normalized // coordinates. Usually we would clamp to (0,0)..(1,1). However, we must // apply the sample transformation to that, also bearing in mind that it // may contain a flip on either axis. We calculate and adjust for the // half-texel separately for each plane as it depends on the actual // texture size which may vary between planes. writeln!( self.out, "{l1}float2 bounds_min = mul(float3(0.0, 0.0, 1.0), sample_transform);" )?; writeln!( self.out, "{l1}float2 bounds_max = mul(float3(1.0, 1.0, 1.0), sample_transform);" )?; writeln!(self.out, "{l1}float4 bounds = float4(min(bounds_min, bounds_max), max(bounds_min, bounds_max));")?; writeln!( self.out, "{l1}float2 plane0_half_texel = float2(0.5, 0.5) / plane0_size;" )?; writeln!( self.out, "{l1}float2 plane0_coords = clamp(coords, bounds.xy + plane0_half_texel, bounds.zw - plane0_half_texel);" )?; writeln!(self.out, "{l1}if (params.num_planes == 1u) {{")?; // For single plane, simply sample from plane0 writeln!( self.out, "{l2}return plane0.SampleLevel(samp, plane0_coords, 0.0f);" )?; writeln!(self.out, "{l1}}} else {{")?; writeln!(self.out, "{l2}float2 plane1_size;")?; writeln!( self.out, "{l2}plane1.GetDimensions(plane1_size.x, plane1_size.y);" )?; writeln!( self.out, "{l2}float2 plane1_half_texel = float2(0.5, 0.5) / plane1_size;" )?; writeln!( self.out, "{l2}float2 plane1_coords = clamp(coords, bounds.xy + plane1_half_texel, bounds.zw - plane1_half_texel);" )?; // For multi-plane, sample the Y value from plane 0 writeln!( self.out, "{l2}float y = plane0.SampleLevel(samp, plane0_coords, 0.0f).x;" )?; writeln!(self.out, "{l2}float2 uv;")?; writeln!(self.out, "{l2}if (params.num_planes == 2u) {{")?; // Sample UV from interleaved plane 1 writeln!( self.out, "{l3}uv = plane1.SampleLevel(samp, plane1_coords, 0.0f).xy;" )?; writeln!(self.out, "{l2}}} else {{")?; // Sample U and V from planes 1 and 2 respectively writeln!(self.out, "{l3}float2 plane2_size;")?; writeln!( self.out, "{l3}plane2.GetDimensions(plane2_size.x, plane2_size.y);" )?; writeln!( self.out, "{l3}float2 plane2_half_texel = float2(0.5, 0.5) / plane2_size;" )?; writeln!(self.out, "{l3}float2 plane2_coords = clamp(coords, bounds.xy + plane2_half_texel, bounds.zw - plane2_half_texel);")?; writeln!(self.out, "{l3}uv = float2(plane1.SampleLevel(samp, plane1_coords, 0.0f).x, plane2.SampleLevel(samp, plane2_coords, 0.0f).x);")?; writeln!(self.out, "{l2}}}")?; self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "params")?; writeln!(self.out, "{l1}}}")?; writeln!(self.out, "}}")?; writeln!(self.out)?; } WrappedImageSample { class: crate::ImageClass::Sampled { kind: ScalarKind::Float, multi: false, }, clamp_to_edge: true, } => { writeln!(self.out, "float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}(Texture2D tex, SamplerState samp, float2 coords) {{")?; let l1 = crate::back::Level(1); writeln!(self.out, "{l1}float2 size;")?; writeln!(self.out, "{l1}tex.GetDimensions(size.x, size.y);")?; writeln!(self.out, "{l1}float2 half_texel = float2(0.5, 0.5) / size;")?; writeln!( self.out, "{l1}return tex.SampleLevel(samp, clamp(coords, half_texel, 1.0 - half_texel), 0.0);" )?; writeln!(self.out, "}}")?; writeln!(self.out)?; } _ => {} } Ok(()) } pub(super) fn write_wrapped_image_query_function_name( &mut self, query: WrappedImageQuery, ) -> BackendResult { let dim_str = query.dim.to_hlsl_str(); let class_str = match query.class { crate::ImageClass::Sampled { multi: true, .. } => "MS", crate::ImageClass::Depth { multi: true } => "DepthMS", crate::ImageClass::Depth { multi: false } => "Depth", crate::ImageClass::Sampled { multi: false, .. } => "", crate::ImageClass::Storage { .. } => "RW", crate::ImageClass::External => "External", }; let arrayed_str = if query.arrayed { "Array" } else { "" }; let query_str = match query.query { ImageQuery::Size => "Dimensions", ImageQuery::SizeLevel => "MipDimensions", ImageQuery::NumLevels => "NumLevels", ImageQuery::NumLayers => "NumLayers", ImageQuery::NumSamples => "NumSamples", }; write!(self.out, "Naga{class_str}{query_str}{dim_str}{arrayed_str}")?; Ok(()) } /// Helper function that write wrapped function for `Expression::ImageQuery` /// /// pub(super) fn write_wrapped_image_query_function( &mut self, module: &crate::Module, wiq: WrappedImageQuery, expr_handle: Handle, func_ctx: &FunctionCtx, ) -> BackendResult { use crate::{ back::{COMPONENTS, INDENT}, ImageDimension as IDim, }; match wiq.class { crate::ImageClass::External => { if wiq.query != ImageQuery::Size { return Err(super::Error::Custom( "External images only support `Size` queries".into(), )); } write!(self.out, "uint2 ")?; self.write_wrapped_image_query_function_name(wiq)?; let params_name = &self.names [&NameKey::Type(module.special_types.external_texture_params.unwrap())]; // Only plane0 and params are used by this implementation, but it's easier to // always take all of them as arguments so that we can unconditionally expand an // external texture expression each of its parts. writeln!(self.out, "(Texture2D plane0, Texture2D plane1, Texture2D plane2, {params_name} params) {{")?; let l1 = crate::back::Level(1); let l2 = l1.next(); writeln!(self.out, "{l1}if (any(params.size)) {{")?; writeln!(self.out, "{l2}return params.size;")?; writeln!(self.out, "{l1}}} else {{")?; // params.size == (0, 0) indicates to query and return plane 0's actual size writeln!(self.out, "{l2}uint2 ret;")?; writeln!(self.out, "{l2}plane0.GetDimensions(ret.x, ret.y);")?; writeln!(self.out, "{l2}return ret;")?; writeln!(self.out, "{l1}}}")?; writeln!(self.out, "}}")?; writeln!(self.out)?; } _ => { const ARGUMENT_VARIABLE_NAME: &str = "tex"; const RETURN_VARIABLE_NAME: &str = "ret"; const MIP_LEVEL_PARAM: &str = "mip_level"; // Write function return type and name let ret_ty = func_ctx.resolve_type(expr_handle, &module.types); self.write_value_type(module, ret_ty)?; write!(self.out, " ")?; self.write_wrapped_image_query_function_name(wiq)?; // Write function parameters write!(self.out, "(")?; // Texture always first parameter self.write_image_type(wiq.dim, wiq.arrayed, wiq.class)?; write!(self.out, " {ARGUMENT_VARIABLE_NAME}")?; // Mipmap is a second parameter if exists if let ImageQuery::SizeLevel = wiq.query { write!(self.out, ", uint {MIP_LEVEL_PARAM}")?; } writeln!(self.out, ")")?; // Write function body writeln!(self.out, "{{")?; let array_coords = usize::from(wiq.arrayed); // extra parameter is the mip level count or the sample count let extra_coords = match wiq.class { crate::ImageClass::Storage { .. } => 0, crate::ImageClass::Sampled { .. } | crate::ImageClass::Depth { .. } => 1, crate::ImageClass::External => unreachable!(), }; // GetDimensions Overloaded Methods // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-getdimensions#overloaded-methods let (ret_swizzle, number_of_params) = match wiq.query { ImageQuery::Size | ImageQuery::SizeLevel => { let ret = match wiq.dim { IDim::D1 => "x", IDim::D2 => "xy", IDim::D3 => "xyz", IDim::Cube => "xy", }; (ret, ret.len() + array_coords + extra_coords) } ImageQuery::NumLevels | ImageQuery::NumSamples | ImageQuery::NumLayers => { if wiq.arrayed || wiq.dim == IDim::D3 { ("w", 4) } else { ("z", 3) } } }; // Write `GetDimensions` function. writeln!(self.out, "{INDENT}uint4 {RETURN_VARIABLE_NAME};")?; write!(self.out, "{INDENT}{ARGUMENT_VARIABLE_NAME}.GetDimensions(")?; match wiq.query { ImageQuery::SizeLevel => { write!(self.out, "{MIP_LEVEL_PARAM}, ")?; } _ => match wiq.class { crate::ImageClass::Sampled { multi: true, .. } | crate::ImageClass::Depth { multi: true } | crate::ImageClass::Storage { .. } => {} _ => { // Write zero mipmap level for supported types write!(self.out, "0, ")?; } }, } for component in COMPONENTS[..number_of_params - 1].iter() { write!(self.out, "{RETURN_VARIABLE_NAME}.{component}, ")?; } // write last parameter without comma and space for last parameter write!( self.out, "{}.{}", RETURN_VARIABLE_NAME, COMPONENTS[number_of_params - 1] )?; writeln!(self.out, ");")?; // Write return value writeln!( self.out, "{INDENT}return {RETURN_VARIABLE_NAME}.{ret_swizzle};" )?; // End of function body writeln!(self.out, "}}")?; // Write extra new line writeln!(self.out)?; } } Ok(()) } pub(super) fn write_wrapped_constructor_function_name( &mut self, module: &crate::Module, constructor: WrappedConstructor, ) -> BackendResult { let name = crate::TypeInner::hlsl_type_id(constructor.ty, module.to_ctx(), &self.names)?; write!(self.out, "Construct{name}")?; Ok(()) } /// Helper function that write wrapped function for `Expression::Compose` for structures. fn write_wrapped_constructor_function( &mut self, module: &crate::Module, constructor: WrappedConstructor, ) -> BackendResult { use crate::back::INDENT; const ARGUMENT_VARIABLE_NAME: &str = "arg"; const RETURN_VARIABLE_NAME: &str = "ret"; // Write function return type and name if let crate::TypeInner::Array { base, size, .. } = module.types[constructor.ty].inner { write!(self.out, "typedef ")?; self.write_type(module, constructor.ty)?; write!(self.out, " ret_")?; self.write_wrapped_constructor_function_name(module, constructor)?; self.write_array_size(module, base, size)?; writeln!(self.out, ";")?; write!(self.out, "ret_")?; self.write_wrapped_constructor_function_name(module, constructor)?; } else { self.write_type(module, constructor.ty)?; } write!(self.out, " ")?; self.write_wrapped_constructor_function_name(module, constructor)?; // Write function parameters write!(self.out, "(")?; let mut write_arg = |i, ty| -> BackendResult { if i != 0 { write!(self.out, ", ")?; } self.write_type(module, ty)?; write!(self.out, " {ARGUMENT_VARIABLE_NAME}{i}")?; if let crate::TypeInner::Array { base, size, .. } = module.types[ty].inner { self.write_array_size(module, base, size)?; } Ok(()) }; match module.types[constructor.ty].inner { crate::TypeInner::Struct { ref members, .. } => { for (i, member) in members.iter().enumerate() { write_arg(i, member.ty)?; } } crate::TypeInner::Array { base, size: crate::ArraySize::Constant(size), .. } => { for i in 0..size.get() as usize { write_arg(i, base)?; } } _ => unreachable!(), }; write!(self.out, ")")?; // Write function body writeln!(self.out, " {{")?; match module.types[constructor.ty].inner { crate::TypeInner::Struct { ref members, .. } => { let struct_name = &self.names[&NameKey::Type(constructor.ty)]; writeln!( self.out, "{INDENT}{struct_name} {RETURN_VARIABLE_NAME} = ({struct_name})0;" )?; for (i, member) in members.iter().enumerate() { let field_name = &self.names[&NameKey::StructMember(constructor.ty, i as u32)]; match module.types[member.ty].inner { crate::TypeInner::Matrix { columns, rows: crate::VectorSize::Bi, .. } if member.binding.is_none() => { for j in 0..columns as u8 { writeln!( self.out, "{INDENT}{RETURN_VARIABLE_NAME}.{field_name}_{j} = {ARGUMENT_VARIABLE_NAME}{i}[{j}];" )?; } } ref other => { // We cast arrays of native HLSL `floatCx2`s to arrays of `matCx2`s // (where the inner matrix is represented by a struct with C `float2` members). // See the module-level block comment in mod.rs for details. if let Some(super::writer::MatrixType { columns, rows: crate::VectorSize::Bi, width: 4, }) = super::writer::get_inner_matrix_data(module, member.ty) { write!( self.out, "{}{}.{} = (__mat{}x2", INDENT, RETURN_VARIABLE_NAME, field_name, columns as u8 )?; if let crate::TypeInner::Array { base, size, .. } = *other { self.write_array_size(module, base, size)?; } writeln!(self.out, "){ARGUMENT_VARIABLE_NAME}{i};",)?; } else { writeln!( self.out, "{INDENT}{RETURN_VARIABLE_NAME}.{field_name} = {ARGUMENT_VARIABLE_NAME}{i};", )?; } } } } } crate::TypeInner::Array { base, size: crate::ArraySize::Constant(size), .. } => { write!(self.out, "{INDENT}")?; self.write_type(module, base)?; write!(self.out, " {RETURN_VARIABLE_NAME}")?; self.write_array_size(module, base, crate::ArraySize::Constant(size))?; write!(self.out, " = {{ ")?; for i in 0..size.get() { if i != 0 { write!(self.out, ", ")?; } write!(self.out, "{ARGUMENT_VARIABLE_NAME}{i}")?; } writeln!(self.out, " }};",)?; } _ => unreachable!(), } // Write return value writeln!(self.out, "{INDENT}return {RETURN_VARIABLE_NAME};")?; // End of function body writeln!(self.out, "}}")?; // Write extra new line writeln!(self.out)?; Ok(()) } /// Writes the conversion from a single length storage texture load to a vec4 with the loaded /// scalar in its `x` component, 1 in its `a` component and 0 everywhere else. fn write_loaded_scalar_to_storage_loaded_value( &mut self, scalar_type: crate::Scalar, ) -> BackendResult { const ARGUMENT_VARIABLE_NAME: &str = "arg"; const RETURN_VARIABLE_NAME: &str = "ret"; let zero; let one; match scalar_type.kind { ScalarKind::Sint => { assert_eq!( scalar_type.width, 4, "Scalar {scalar_type:?} is not a result from any storage format" ); zero = "0"; one = "1"; } ScalarKind::Uint => match scalar_type.width { 4 => { zero = "0u"; one = "1u"; } 8 => { zero = "0uL"; one = "1uL" } _ => unreachable!("Scalar {scalar_type:?} is not a result from any storage format"), }, ScalarKind::Float => { assert_eq!( scalar_type.width, 4, "Scalar {scalar_type:?} is not a result from any storage format" ); zero = "0.0"; one = "1.0"; } _ => unreachable!("Scalar {scalar_type:?} is not a result from any storage format"), } let ty = scalar_type.to_hlsl_str()?; writeln!( self.out, "{ty}4 {IMAGE_STORAGE_LOAD_SCALAR_WRAPPER}{ty}({ty} {ARGUMENT_VARIABLE_NAME}) {{\ {ty}4 {RETURN_VARIABLE_NAME} = {ty}4({ARGUMENT_VARIABLE_NAME}, {zero}, {zero}, {one});\ return {RETURN_VARIABLE_NAME};\ }}" )?; Ok(()) } pub(super) fn write_wrapped_struct_matrix_get_function_name( &mut self, access: WrappedStructMatrixAccess, ) -> BackendResult { let name = &self.names[&NameKey::Type(access.ty)]; let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)]; write!(self.out, "GetMat{field_name}On{name}")?; Ok(()) } /// Writes a function used to get a matCx2 from within a structure. pub(super) fn write_wrapped_struct_matrix_get_function( &mut self, module: &crate::Module, access: WrappedStructMatrixAccess, ) -> BackendResult { use crate::back::INDENT; const STRUCT_ARGUMENT_VARIABLE_NAME: &str = "obj"; // Write function return type and name let member = match module.types[access.ty].inner { crate::TypeInner::Struct { ref members, .. } => &members[access.index as usize], _ => unreachable!(), }; let ret_ty = &module.types[member.ty].inner; self.write_value_type(module, ret_ty)?; write!(self.out, " ")?; self.write_wrapped_struct_matrix_get_function_name(access)?; // Write function parameters write!(self.out, "(")?; let struct_name = &self.names[&NameKey::Type(access.ty)]; write!(self.out, "{struct_name} {STRUCT_ARGUMENT_VARIABLE_NAME}")?; // Write function body writeln!(self.out, ") {{")?; // Write return value write!(self.out, "{INDENT}return ")?; self.write_value_type(module, ret_ty)?; write!(self.out, "(")?; let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)]; match module.types[member.ty].inner { crate::TypeInner::Matrix { columns, .. } => { for i in 0..columns as u8 { if i != 0 { write!(self.out, ", ")?; } write!(self.out, "{STRUCT_ARGUMENT_VARIABLE_NAME}.{field_name}_{i}")?; } } _ => unreachable!(), } writeln!(self.out, ");")?; // End of function body writeln!(self.out, "}}")?; // Write extra new line writeln!(self.out)?; Ok(()) } pub(super) fn write_wrapped_struct_matrix_set_function_name( &mut self, access: WrappedStructMatrixAccess, ) -> BackendResult { let name = &self.names[&NameKey::Type(access.ty)]; let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)]; write!(self.out, "SetMat{field_name}On{name}")?; Ok(()) } /// Writes a function used to set a matCx2 from within a structure. pub(super) fn write_wrapped_struct_matrix_set_function( &mut self, module: &crate::Module, access: WrappedStructMatrixAccess, ) -> BackendResult { use crate::back::INDENT; const STRUCT_ARGUMENT_VARIABLE_NAME: &str = "obj"; const MATRIX_ARGUMENT_VARIABLE_NAME: &str = "mat"; // Write function return type and name write!(self.out, "void ")?; self.write_wrapped_struct_matrix_set_function_name(access)?; // Write function parameters write!(self.out, "(")?; let struct_name = &self.names[&NameKey::Type(access.ty)]; write!(self.out, "{struct_name} {STRUCT_ARGUMENT_VARIABLE_NAME}, ")?; let member = match module.types[access.ty].inner { crate::TypeInner::Struct { ref members, .. } => &members[access.index as usize], _ => unreachable!(), }; self.write_type(module, member.ty)?; write!(self.out, " {MATRIX_ARGUMENT_VARIABLE_NAME}")?; // Write function body writeln!(self.out, ") {{")?; let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)]; match module.types[member.ty].inner { crate::TypeInner::Matrix { columns, .. } => { for i in 0..columns as u8 { writeln!( self.out, "{INDENT}{STRUCT_ARGUMENT_VARIABLE_NAME}.{field_name}_{i} = {MATRIX_ARGUMENT_VARIABLE_NAME}[{i}];" )?; } } _ => unreachable!(), } // End of function body writeln!(self.out, "}}")?; // Write extra new line writeln!(self.out)?; Ok(()) } pub(super) fn write_wrapped_struct_matrix_set_vec_function_name( &mut self, access: WrappedStructMatrixAccess, ) -> BackendResult { let name = &self.names[&NameKey::Type(access.ty)]; let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)]; write!(self.out, "SetMatVec{field_name}On{name}")?; Ok(()) } /// Writes a function used to set a vec2 on a matCx2 from within a structure. pub(super) fn write_wrapped_struct_matrix_set_vec_function( &mut self, module: &crate::Module, access: WrappedStructMatrixAccess, ) -> BackendResult { use crate::back::INDENT; const STRUCT_ARGUMENT_VARIABLE_NAME: &str = "obj"; const VECTOR_ARGUMENT_VARIABLE_NAME: &str = "vec"; const MATRIX_INDEX_ARGUMENT_VARIABLE_NAME: &str = "mat_idx"; // Write function return type and name write!(self.out, "void ")?; self.write_wrapped_struct_matrix_set_vec_function_name(access)?; // Write function parameters write!(self.out, "(")?; let struct_name = &self.names[&NameKey::Type(access.ty)]; write!(self.out, "{struct_name} {STRUCT_ARGUMENT_VARIABLE_NAME}, ")?; let member = match module.types[access.ty].inner { crate::TypeInner::Struct { ref members, .. } => &members[access.index as usize], _ => unreachable!(), }; let vec_ty = match module.types[member.ty].inner { crate::TypeInner::Matrix { rows, scalar, .. } => { crate::TypeInner::Vector { size: rows, scalar } } _ => unreachable!(), }; self.write_value_type(module, &vec_ty)?; write!( self.out, " {VECTOR_ARGUMENT_VARIABLE_NAME}, uint {MATRIX_INDEX_ARGUMENT_VARIABLE_NAME}" )?; // Write function body writeln!(self.out, ") {{")?; writeln!( self.out, "{INDENT}switch({MATRIX_INDEX_ARGUMENT_VARIABLE_NAME}) {{" )?; let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)]; match module.types[member.ty].inner { crate::TypeInner::Matrix { columns, .. } => { for i in 0..columns as u8 { writeln!( self.out, "{INDENT}case {i}: {{ {STRUCT_ARGUMENT_VARIABLE_NAME}.{field_name}_{i} = {VECTOR_ARGUMENT_VARIABLE_NAME}; break; }}" )?; } } _ => unreachable!(), } writeln!(self.out, "{INDENT}}}")?; // End of function body writeln!(self.out, "}}")?; // Write extra new line writeln!(self.out)?; Ok(()) } pub(super) fn write_wrapped_struct_matrix_set_scalar_function_name( &mut self, access: WrappedStructMatrixAccess, ) -> BackendResult { let name = &self.names[&NameKey::Type(access.ty)]; let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)]; write!(self.out, "SetMatScalar{field_name}On{name}")?; Ok(()) } /// Writes a function used to set a float on a matCx2 from within a structure. pub(super) fn write_wrapped_struct_matrix_set_scalar_function( &mut self, module: &crate::Module, access: WrappedStructMatrixAccess, ) -> BackendResult { use crate::back::INDENT; const STRUCT_ARGUMENT_VARIABLE_NAME: &str = "obj"; const SCALAR_ARGUMENT_VARIABLE_NAME: &str = "scalar"; const MATRIX_INDEX_ARGUMENT_VARIABLE_NAME: &str = "mat_idx"; const VECTOR_INDEX_ARGUMENT_VARIABLE_NAME: &str = "vec_idx"; // Write function return type and name write!(self.out, "void ")?; self.write_wrapped_struct_matrix_set_scalar_function_name(access)?; // Write function parameters write!(self.out, "(")?; let struct_name = &self.names[&NameKey::Type(access.ty)]; write!(self.out, "{struct_name} {STRUCT_ARGUMENT_VARIABLE_NAME}, ")?; let member = match module.types[access.ty].inner { crate::TypeInner::Struct { ref members, .. } => &members[access.index as usize], _ => unreachable!(), }; let scalar_ty = match module.types[member.ty].inner { crate::TypeInner::Matrix { scalar, .. } => crate::TypeInner::Scalar(scalar), _ => unreachable!(), }; self.write_value_type(module, &scalar_ty)?; write!( self.out, " {SCALAR_ARGUMENT_VARIABLE_NAME}, uint {MATRIX_INDEX_ARGUMENT_VARIABLE_NAME}, uint {VECTOR_INDEX_ARGUMENT_VARIABLE_NAME}" )?; // Write function body writeln!(self.out, ") {{")?; writeln!( self.out, "{INDENT}switch({MATRIX_INDEX_ARGUMENT_VARIABLE_NAME}) {{" )?; let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)]; match module.types[member.ty].inner { crate::TypeInner::Matrix { columns, .. } => { for i in 0..columns as u8 { writeln!( self.out, "{INDENT}case {i}: {{ {STRUCT_ARGUMENT_VARIABLE_NAME}.{field_name}_{i}[{VECTOR_INDEX_ARGUMENT_VARIABLE_NAME}] = {SCALAR_ARGUMENT_VARIABLE_NAME}; break; }}" )?; } } _ => unreachable!(), } writeln!(self.out, "{INDENT}}}")?; // End of function body writeln!(self.out, "}}")?; // Write extra new line writeln!(self.out)?; Ok(()) } /// Write functions to create special types. pub(super) fn write_special_functions(&mut self, module: &crate::Module) -> BackendResult { for (type_key, struct_ty) in module.special_types.predeclared_types.iter() { match type_key { &crate::PredeclaredType::ModfResult { size, scalar } | &crate::PredeclaredType::FrexpResult { size, scalar } => { let arg_type_name_owner; let arg_type_name = if let Some(size) = size { arg_type_name_owner = format!( "{}{}", if scalar.width == 8 { "double" } else { "float" }, size as u8 ); &arg_type_name_owner } else if scalar.width == 8 { "double" } else { "float" }; let (defined_func_name, called_func_name, second_field_name, sign_multiplier) = if matches!(type_key, &crate::PredeclaredType::ModfResult { .. }) { (super::writer::MODF_FUNCTION, "modf", "whole", "") } else { ( super::writer::FREXP_FUNCTION, "frexp", "exp_", "sign(arg) * ", ) }; let struct_name = &self.names[&NameKey::Type(*struct_ty)]; writeln!( self.out, "{struct_name} {defined_func_name}({arg_type_name} arg) {{ {arg_type_name} other; {struct_name} result; result.fract = {sign_multiplier}{called_func_name}(arg, other); result.{second_field_name} = other; return result; }}" )?; writeln!(self.out)?; } &crate::PredeclaredType::AtomicCompareExchangeWeakResult { .. } => {} } } if module.special_types.ray_desc.is_some() { self.write_ray_desc_from_ray_desc_constructor_function(module)?; } Ok(()) } /// Helper function that writes wrapped functions for expressions in a function pub(super) fn write_wrapped_expression_functions( &mut self, module: &crate::Module, expressions: &crate::Arena, context: Option<&FunctionCtx>, ) -> BackendResult { for (handle, _) in expressions.iter() { match expressions[handle] { crate::Expression::Compose { ty, .. } => { match module.types[ty].inner { crate::TypeInner::Struct { .. } | crate::TypeInner::Array { .. } => { let constructor = WrappedConstructor { ty }; if self.wrapped.insert(WrappedType::Constructor(constructor)) { self.write_wrapped_constructor_function(module, constructor)?; } } _ => {} }; } crate::Expression::ImageLoad { image, .. } => { // This can only happen in a function as this is not a valid const expression match *context.as_ref().unwrap().resolve_type(image, &module.types) { crate::TypeInner::Image { class: crate::ImageClass::Storage { format, .. }, .. } => { if format.single_component() { let scalar: crate::Scalar = format.into(); if self.wrapped.insert(WrappedType::ImageLoadScalar(scalar)) { self.write_loaded_scalar_to_storage_loaded_value(scalar)?; } } } _ => {} } } crate::Expression::RayQueryGetIntersection { committed, .. } => { if committed { if !self.written_committed_intersection { self.write_committed_intersection_function(module)?; self.written_committed_intersection = true; } } else if !self.written_candidate_intersection { self.write_candidate_intersection_function(module)?; self.written_candidate_intersection = true; } } _ => {} } } Ok(()) } // TODO: we could merge this with iteration in write_wrapped_expression_functions... // /// Helper function that writes zero value wrapped functions pub(super) fn write_wrapped_zero_value_functions( &mut self, module: &crate::Module, expressions: &crate::Arena, ) -> BackendResult { for (handle, _) in expressions.iter() { if let crate::Expression::ZeroValue(ty) = expressions[handle] { let zero_value = WrappedZeroValue { ty }; if self.wrapped.insert(WrappedType::ZeroValue(zero_value)) { self.write_wrapped_zero_value_function(module, zero_value)?; } } } Ok(()) } pub(super) fn write_wrapped_math_functions( &mut self, module: &crate::Module, func_ctx: &FunctionCtx, ) -> BackendResult { for (_, expression) in func_ctx.expressions.iter() { if let crate::Expression::Math { fun, arg, arg1: _arg1, arg2: _arg2, arg3: _arg3, } = *expression { let arg_ty = func_ctx.resolve_type(arg, &module.types); match fun { crate::MathFunction::ExtractBits => { // The behavior of our extractBits polyfill is undefined if offset + count > bit_width. We need // to first sanitize the offset and count first. If we don't do this, we will get out-of-spec // values if the extracted range is not within the bit width. // // This encodes the exact formula specified by the wgsl spec: // https://gpuweb.github.io/gpuweb/wgsl/#extractBits-unsigned-builtin // // w = sizeof(x) * 8 // o = min(offset, w) // c = min(count, w - o) // // bitfieldExtract(x, o, c) let scalar = arg_ty.scalar().unwrap(); let components = arg_ty.components(); let wrapped = WrappedMath { fun, scalar, components, }; if !self.wrapped.insert(WrappedType::Math(wrapped)) { continue; } // Write return type self.write_value_type(module, arg_ty)?; let scalar_width: u8 = scalar.width * 8; // Write function name and parameters writeln!(self.out, " {EXTRACT_BITS_FUNCTION}(")?; write!(self.out, " ")?; self.write_value_type(module, arg_ty)?; writeln!(self.out, " e,")?; writeln!(self.out, " uint offset,")?; writeln!(self.out, " uint count")?; writeln!(self.out, ") {{")?; // Write function body writeln!(self.out, " uint w = {scalar_width};")?; writeln!(self.out, " uint o = min(offset, w);")?; writeln!(self.out, " uint c = min(count, w - o);")?; writeln!( self.out, " return (c == 0 ? 0 : (e << (w - c - o)) >> (w - c));" )?; // End of function body writeln!(self.out, "}}")?; } crate::MathFunction::InsertBits => { // The behavior of our insertBits polyfill has the same constraints as the extractBits polyfill. let scalar = arg_ty.scalar().unwrap(); let components = arg_ty.components(); let wrapped = WrappedMath { fun, scalar, components, }; if !self.wrapped.insert(WrappedType::Math(wrapped)) { continue; } // Write return type self.write_value_type(module, arg_ty)?; let scalar_width: u8 = scalar.width * 8; let scalar_max: u64 = match scalar.width { 1 => 0xFF, 2 => 0xFFFF, 4 => 0xFFFFFFFF, 8 => 0xFFFFFFFFFFFFFFFF, _ => unreachable!(), }; // Write function name and parameters writeln!(self.out, " {INSERT_BITS_FUNCTION}(")?; write!(self.out, " ")?; self.write_value_type(module, arg_ty)?; writeln!(self.out, " e,")?; write!(self.out, " ")?; self.write_value_type(module, arg_ty)?; writeln!(self.out, " newbits,")?; writeln!(self.out, " uint offset,")?; writeln!(self.out, " uint count")?; writeln!(self.out, ") {{")?; // Write function body writeln!(self.out, " uint w = {scalar_width}u;")?; writeln!(self.out, " uint o = min(offset, w);")?; writeln!(self.out, " uint c = min(count, w - o);")?; // The `u` suffix on the literals is _extremely_ important. Otherwise it will use // i32 shifting instead of the intended u32 shifting. writeln!( self.out, " uint mask = (({scalar_max}u >> ({scalar_width}u - c)) << o);" )?; writeln!( self.out, " return (c == 0 ? e : ((e & ~mask) | ((newbits << o) & mask)));" )?; // End of function body writeln!(self.out, "}}")?; } // Taking the absolute value of the minimum value of a two's // complement signed integer type causes overflow, which is // undefined behaviour in HLSL. To avoid this, when the value is // negative we bitcast the value to unsigned and negate it, then // bitcast back to signed. // This adheres to the WGSL spec in that the absolute of the type's // minimum value should equal to the minimum value. // // TODO(#7109): asint()/asuint() only support 32-bit integers, so we // must find another solution for different bit-widths. crate::MathFunction::Abs if matches!(arg_ty.scalar(), Some(crate::Scalar::I32)) => { let scalar = arg_ty.scalar().unwrap(); let components = arg_ty.components(); let wrapped = WrappedMath { fun, scalar, components, }; if !self.wrapped.insert(WrappedType::Math(wrapped)) { continue; } self.write_value_type(module, arg_ty)?; write!(self.out, " {ABS_FUNCTION}(")?; self.write_value_type(module, arg_ty)?; writeln!(self.out, " val) {{")?; let level = crate::back::Level(1); writeln!( self.out, "{level}return val >= 0 ? val : asint(-asuint(val));" )?; writeln!(self.out, "}}")?; writeln!(self.out)?; } _ => {} } } } Ok(()) } pub(super) fn write_wrapped_unary_ops( &mut self, module: &crate::Module, func_ctx: &FunctionCtx, ) -> BackendResult { for (_, expression) in func_ctx.expressions.iter() { if let crate::Expression::Unary { op, expr } = *expression { let expr_ty = func_ctx.resolve_type(expr, &module.types); let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else { continue; }; let wrapped = WrappedUnaryOp { op, ty: (vector_size, scalar), }; // Negating the minimum value of a two's complement signed integer type // causes overflow, which is undefined behaviour in HLSL. To avoid this // we bitcast the value to unsigned and negate it, then bitcast back to // signed. This adheres to the WGSL spec in that the negative of the // type's minimum value should equal to the minimum value. // // TODO(#7109): asint()/asuint() only support 32-bit integers, so we must // find another solution for different bit-widths. match (op, scalar) { (crate::UnaryOperator::Negate, crate::Scalar::I32) => { if !self.wrapped.insert(WrappedType::UnaryOp(wrapped)) { continue; } self.write_value_type(module, expr_ty)?; write!(self.out, " {NEG_FUNCTION}(")?; self.write_value_type(module, expr_ty)?; writeln!(self.out, " val) {{")?; let level = crate::back::Level(1); writeln!(self.out, "{level}return asint(-asuint(val));",)?; writeln!(self.out, "}}")?; writeln!(self.out)?; } _ => {} } } } Ok(()) } pub(super) fn write_wrapped_binary_ops( &mut self, module: &crate::Module, func_ctx: &FunctionCtx, ) -> BackendResult { for (expr_handle, expression) in func_ctx.expressions.iter() { if let crate::Expression::Binary { op, left, right } = *expression { let expr_ty = func_ctx.resolve_type(expr_handle, &module.types); let left_ty = func_ctx.resolve_type(left, &module.types); let right_ty = func_ctx.resolve_type(right, &module.types); match (op, expr_ty.scalar()) { // Signed integer division of the type's minimum representable value // divided by -1, or signed or unsigned division by zero, is // undefined behaviour in HLSL. We override the divisor to 1 in these // cases. // This adheres to the WGSL spec in that: // * TYPE_MIN / -1 == TYPE_MIN // * x / 0 == x ( crate::BinaryOperator::Divide, Some( scalar @ crate::Scalar { kind: ScalarKind::Sint | ScalarKind::Uint, .. }, ), ) => { let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else { continue; }; let Some(right_wrapped_ty) = right_ty.vector_size_and_scalar() else { continue; }; let wrapped = WrappedBinaryOp { op, left_ty: left_wrapped_ty, right_ty: right_wrapped_ty, }; if !self.wrapped.insert(WrappedType::BinaryOp(wrapped)) { continue; } self.write_value_type(module, expr_ty)?; write!(self.out, " {DIV_FUNCTION}(")?; self.write_value_type(module, left_ty)?; write!(self.out, " lhs, ")?; self.write_value_type(module, right_ty)?; writeln!(self.out, " rhs) {{")?; let level = crate::back::Level(1); match scalar.kind { ScalarKind::Sint => { let min_val = match scalar.width { 4 => crate::Literal::I32(i32::MIN), 8 => crate::Literal::I64(i64::MIN), _ => { return Err(super::Error::UnsupportedScalar(scalar)); } }; write!(self.out, "{level}return lhs / (((lhs == ")?; self.write_literal(min_val)?; writeln!(self.out, " & rhs == -1) | (rhs == 0)) ? 1 : rhs);")? } ScalarKind::Uint => { writeln!(self.out, "{level}return lhs / (rhs == 0u ? 1u : rhs);")? } _ => unreachable!(), } writeln!(self.out, "}}")?; writeln!(self.out)?; } // The modulus operator is only defined for integers in HLSL when // either both sides are positive or both sides are negative. To // avoid this undefined behaviour we use the following equation: // // dividend - (dividend / divisor) * divisor // // overriding the divisor to 1 if either it is 0, or it is -1 // and the dividend is the minimum representable value. // // This adheres to the WGSL spec in that: // * min_value % -1 == 0 // * x % 0 == 0 ( crate::BinaryOperator::Modulo, Some( scalar @ crate::Scalar { kind: ScalarKind::Sint | ScalarKind::Uint | ScalarKind::Float, .. }, ), ) => { let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else { continue; }; let Some(right_wrapped_ty) = right_ty.vector_size_and_scalar() else { continue; }; let wrapped = WrappedBinaryOp { op, left_ty: left_wrapped_ty, right_ty: right_wrapped_ty, }; if !self.wrapped.insert(WrappedType::BinaryOp(wrapped)) { continue; } self.write_value_type(module, expr_ty)?; write!(self.out, " {MOD_FUNCTION}(")?; self.write_value_type(module, left_ty)?; write!(self.out, " lhs, ")?; self.write_value_type(module, right_ty)?; writeln!(self.out, " rhs) {{")?; let level = crate::back::Level(1); match scalar.kind { ScalarKind::Sint => { let min_val = match scalar.width { 4 => crate::Literal::I32(i32::MIN), 8 => crate::Literal::I64(i64::MIN), _ => { return Err(super::Error::UnsupportedScalar(scalar)); } }; write!(self.out, "{level}")?; self.write_value_type(module, right_ty)?; write!(self.out, " divisor = ((lhs == ")?; self.write_literal(min_val)?; writeln!(self.out, " & rhs == -1) | (rhs == 0)) ? 1 : rhs;")?; writeln!( self.out, "{level}return lhs - (lhs / divisor) * divisor;" )? } ScalarKind::Uint => { writeln!(self.out, "{level}return lhs % (rhs == 0u ? 1u : rhs);")? } // HLSL's fmod has the same definition as WGSL's % operator but due // to its implementation in DXC it is not as accurate as the WGSL spec // requires it to be. See: // - https://shader-playground.timjones.io/0c8572816dbb6fc4435cc5d016a978a7 // - https://github.com/llvm/llvm-project/blob/50f9b8acafdca48e87e6b8e393c1f116a2d193ee/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h#L78-L81 ScalarKind::Float => { writeln!(self.out, "{level}return lhs - rhs * trunc(lhs / rhs);")? } _ => unreachable!(), } writeln!(self.out, "}}")?; writeln!(self.out)?; } _ => {} } } } Ok(()) } fn write_wrapped_cast_functions( &mut self, module: &crate::Module, func_ctx: &FunctionCtx, ) -> BackendResult { for (_, expression) in func_ctx.expressions.iter() { if let crate::Expression::As { expr, kind, convert: Some(width), } = *expression { // Avoid undefined behaviour when casting from a float to integer // when the value is out of range for the target type. Additionally // ensure we clamp to the correct value as per the WGSL spec. // // https://www.w3.org/TR/WGSL/#floating-point-conversion: // * If X is exactly representable in the target type T, then the // result is that value. // * Otherwise, the result is the value in T closest to // truncate(X) and also exactly representable in the original // floating point type. let src_ty = func_ctx.resolve_type(expr, &module.types); let Some((vector_size, src_scalar)) = src_ty.vector_size_and_scalar() else { continue; }; let dst_scalar = crate::Scalar { kind, width }; if src_scalar.kind != ScalarKind::Float || (dst_scalar.kind != ScalarKind::Sint && dst_scalar.kind != ScalarKind::Uint) { continue; } let wrapped = WrappedCast { src_scalar, vector_size, dst_scalar, }; if !self.wrapped.insert(WrappedType::Cast(wrapped)) { continue; } let (src_ty, dst_ty) = match vector_size { None => ( crate::TypeInner::Scalar(src_scalar), crate::TypeInner::Scalar(dst_scalar), ), Some(vector_size) => ( crate::TypeInner::Vector { scalar: src_scalar, size: vector_size, }, crate::TypeInner::Vector { scalar: dst_scalar, size: vector_size, }, ), }; let (min, max) = crate::proc::min_max_float_representable_by(src_scalar, dst_scalar); let cast_str = format!( "{}{}", dst_scalar.to_hlsl_str()?, vector_size .map(crate::common::vector_size_str) .unwrap_or(""), ); let fun_name = match dst_scalar { crate::Scalar::I32 => F2I32_FUNCTION, crate::Scalar::U32 => F2U32_FUNCTION, crate::Scalar::I64 => F2I64_FUNCTION, crate::Scalar::U64 => F2U64_FUNCTION, _ => unreachable!(), }; self.write_value_type(module, &dst_ty)?; write!(self.out, " {fun_name}(")?; self.write_value_type(module, &src_ty)?; writeln!(self.out, " value) {{")?; let level = crate::back::Level(1); write!(self.out, "{level}return {cast_str}(clamp(value, ")?; self.write_literal(min)?; write!(self.out, ", ")?; self.write_literal(max)?; writeln!(self.out, "));",)?; writeln!(self.out, "}}")?; writeln!(self.out)?; } } Ok(()) } /// Helper function that writes various wrapped functions pub(super) fn write_wrapped_functions( &mut self, module: &crate::Module, func_ctx: &FunctionCtx, ) -> BackendResult { self.write_wrapped_math_functions(module, func_ctx)?; self.write_wrapped_unary_ops(module, func_ctx)?; self.write_wrapped_binary_ops(module, func_ctx)?; self.write_wrapped_expression_functions(module, func_ctx.expressions, Some(func_ctx))?; self.write_wrapped_zero_value_functions(module, func_ctx.expressions)?; self.write_wrapped_cast_functions(module, func_ctx)?; for (handle, _) in func_ctx.expressions.iter() { match func_ctx.expressions[handle] { crate::Expression::ArrayLength(expr) => { let global_expr = match func_ctx.expressions[expr] { crate::Expression::GlobalVariable(_) => expr, crate::Expression::AccessIndex { base, index: _ } => base, ref other => unreachable!("Array length of {:?}", other), }; let global_var = match func_ctx.expressions[global_expr] { crate::Expression::GlobalVariable(var_handle) => { &module.global_variables[var_handle] } ref other => { return Err(super::Error::Unimplemented(format!( "Array length of base {other:?}" ))) } }; let storage_access = match global_var.space { crate::AddressSpace::Storage { access } => access, _ => crate::StorageAccess::default(), }; let wal = WrappedArrayLength { writable: storage_access.contains(crate::StorageAccess::STORE), }; if self.wrapped.insert(WrappedType::ArrayLength(wal)) { self.write_wrapped_array_length_function(wal)?; } } crate::Expression::ImageLoad { image, .. } => { let class = match *func_ctx.resolve_type(image, &module.types) { crate::TypeInner::Image { class, .. } => class, _ => unreachable!(), }; let wrapped = WrappedImageLoad { class }; if self.wrapped.insert(WrappedType::ImageLoad(wrapped)) { self.write_wrapped_image_load_function(module, wrapped)?; } } crate::Expression::ImageSample { image, clamp_to_edge, .. } => { let class = match *func_ctx.resolve_type(image, &module.types) { crate::TypeInner::Image { class, .. } => class, _ => unreachable!(), }; let wrapped = WrappedImageSample { class, clamp_to_edge, }; if self.wrapped.insert(WrappedType::ImageSample(wrapped)) { self.write_wrapped_image_sample_function(module, wrapped)?; } } crate::Expression::ImageQuery { image, query } => { let wiq = match *func_ctx.resolve_type(image, &module.types) { crate::TypeInner::Image { dim, arrayed, class, } => WrappedImageQuery { dim, arrayed, class, query: query.into(), }, _ => unreachable!("we only query images"), }; if self.wrapped.insert(WrappedType::ImageQuery(wiq)) { self.write_wrapped_image_query_function(module, wiq, handle, func_ctx)?; } } // Write `WrappedConstructor` for structs that are loaded from `AddressSpace::Storage` // since they will later be used by the fn `write_storage_load` crate::Expression::Load { pointer } => { let pointer_space = func_ctx .resolve_type(pointer, &module.types) .pointer_space(); if let Some(crate::AddressSpace::Storage { .. }) = pointer_space { if let Some(ty) = func_ctx.info[handle].ty.handle() { write_wrapped_constructor(self, ty, module)?; } } fn write_wrapped_constructor( writer: &mut super::Writer<'_, W>, ty: Handle, module: &crate::Module, ) -> BackendResult { match module.types[ty].inner { crate::TypeInner::Struct { ref members, .. } => { for member in members { write_wrapped_constructor(writer, member.ty, module)?; } let constructor = WrappedConstructor { ty }; if writer.wrapped.insert(WrappedType::Constructor(constructor)) { writer .write_wrapped_constructor_function(module, constructor)?; } } crate::TypeInner::Array { base, .. } => { write_wrapped_constructor(writer, base, module)?; let constructor = WrappedConstructor { ty }; if writer.wrapped.insert(WrappedType::Constructor(constructor)) { writer .write_wrapped_constructor_function(module, constructor)?; } } _ => {} }; Ok(()) } } // We treat matrices of the form `matCx2` as a sequence of C `vec2`s // (see top level module docs for details). // // The functions injected here are required to get the matrix accesses working. crate::Expression::AccessIndex { base, index } => { let base_ty_res = &func_ctx.info[base].ty; let mut resolved = base_ty_res.inner_with(&module.types); let base_ty_handle = match *resolved { crate::TypeInner::Pointer { base, .. } => { resolved = &module.types[base].inner; Some(base) } _ => base_ty_res.handle(), }; if let crate::TypeInner::Struct { ref members, .. } = *resolved { let member = &members[index as usize]; match module.types[member.ty].inner { crate::TypeInner::Matrix { rows: crate::VectorSize::Bi, .. } if member.binding.is_none() => { let ty = base_ty_handle.unwrap(); let access = WrappedStructMatrixAccess { ty, index }; if self.wrapped.insert(WrappedType::StructMatrixAccess(access)) { self.write_wrapped_struct_matrix_get_function(module, access)?; self.write_wrapped_struct_matrix_set_function(module, access)?; self.write_wrapped_struct_matrix_set_vec_function( module, access, )?; self.write_wrapped_struct_matrix_set_scalar_function( module, access, )?; } } _ => {} } } } _ => {} }; } Ok(()) } /// Writes out the sampler heap declarations if they haven't been written yet. pub(super) fn write_sampler_heaps(&mut self) -> BackendResult { if self.wrapped.sampler_heaps { return Ok(()); } writeln!( self.out, "SamplerState {}[2048]: register(s{}, space{});", super::writer::SAMPLER_HEAP_VAR, self.options.sampler_heap_target.standard_samplers.register, self.options.sampler_heap_target.standard_samplers.space )?; writeln!( self.out, "SamplerComparisonState {}[2048]: register(s{}, space{});", super::writer::COMPARISON_SAMPLER_HEAP_VAR, self.options .sampler_heap_target .comparison_samplers .register, self.options.sampler_heap_target.comparison_samplers.space )?; self.wrapped.sampler_heaps = true; Ok(()) } /// Writes out the sampler index buffer declaration if it hasn't been written yet. pub(super) fn write_wrapped_sampler_buffer( &mut self, key: super::SamplerIndexBufferKey, ) -> BackendResult { // The astute will notice that we do a double hash lookup, but we do this to avoid // holding a mutable reference to `self` while trying to call `write_sampler_heaps`. // // We only pay this double lookup cost when we actually need to write out the sampler // buffer, which should be not be common. if self.wrapped.sampler_index_buffers.contains_key(&key) { return Ok(()); }; self.write_sampler_heaps()?; // Because the group number can be arbitrary, we use the namer to generate a unique name // instead of adding it to the reserved name list. let sampler_array_name = self .namer .call(&format!("nagaGroup{}SamplerIndexArray", key.group)); let bind_target = match self.options.sampler_buffer_binding_map.get(&key) { Some(&bind_target) => bind_target, None if self.options.fake_missing_bindings => super::BindTarget { space: u8::MAX, register: key.group, binding_array_size: None, dynamic_storage_buffer_offsets_index: None, restrict_indexing: false, }, None => { unreachable!("Sampler buffer of group {key:?} not bound to a register"); } }; writeln!( self.out, "StructuredBuffer {sampler_array_name} : register(t{}, space{});", bind_target.register, bind_target.space )?; self.wrapped .sampler_index_buffers .insert(key, sampler_array_name); Ok(()) } pub(super) fn write_texture_coordinates( &mut self, kind: &str, coordinate: Handle, array_index: Option>, mip_level: Option>, module: &crate::Module, func_ctx: &FunctionCtx, ) -> BackendResult { // HLSL expects the array index to be merged with the coordinate let extra = array_index.is_some() as usize + (mip_level.is_some()) as usize; if extra == 0 { self.write_expr(module, coordinate, func_ctx)?; } else { let num_coords = match *func_ctx.resolve_type(coordinate, &module.types) { crate::TypeInner::Scalar { .. } => 1, crate::TypeInner::Vector { size, .. } => size as usize, _ => unreachable!(), }; write!(self.out, "{}{}(", kind, num_coords + extra)?; self.write_expr(module, coordinate, func_ctx)?; if let Some(expr) = array_index { write!(self.out, ", ")?; self.write_expr(module, expr, func_ctx)?; } if let Some(expr) = mip_level { // Explicit cast if needed let cast_to_int = matches!( *func_ctx.resolve_type(expr, &module.types), crate::TypeInner::Scalar(crate::Scalar { kind: ScalarKind::Uint, .. }) ); write!(self.out, ", ")?; if cast_to_int { write!(self.out, "int(")?; } self.write_expr(module, expr, func_ctx)?; if cast_to_int { write!(self.out, ")")?; } } write!(self.out, ")")?; } Ok(()) } pub(super) fn write_mat_cx2_typedef_and_functions( &mut self, WrappedMatCx2 { columns }: WrappedMatCx2, ) -> BackendResult { use crate::back::INDENT; // typedef write!(self.out, "typedef struct {{ ")?; for i in 0..columns as u8 { write!(self.out, "float2 _{i}; ")?; } writeln!(self.out, "}} __mat{}x2;", columns as u8)?; // __get_col_of_mat writeln!( self.out, "float2 __get_col_of_mat{}x2(__mat{}x2 mat, uint idx) {{", columns as u8, columns as u8 )?; writeln!(self.out, "{INDENT}switch(idx) {{")?; for i in 0..columns as u8 { writeln!(self.out, "{INDENT}case {i}: {{ return mat._{i}; }}")?; } writeln!(self.out, "{INDENT}default: {{ return (float2)0; }}")?; writeln!(self.out, "{INDENT}}}")?; writeln!(self.out, "}}")?; // __set_col_of_mat writeln!( self.out, "void __set_col_of_mat{}x2(__mat{}x2 mat, uint idx, float2 value) {{", columns as u8, columns as u8 )?; writeln!(self.out, "{INDENT}switch(idx) {{")?; for i in 0..columns as u8 { writeln!(self.out, "{INDENT}case {i}: {{ mat._{i} = value; break; }}")?; } writeln!(self.out, "{INDENT}}}")?; writeln!(self.out, "}}")?; // __set_el_of_mat writeln!( self.out, "void __set_el_of_mat{}x2(__mat{}x2 mat, uint idx, uint vec_idx, float value) {{", columns as u8, columns as u8 )?; writeln!(self.out, "{INDENT}switch(idx) {{")?; for i in 0..columns as u8 { writeln!( self.out, "{INDENT}case {i}: {{ mat._{i}[vec_idx] = value; break; }}" )?; } writeln!(self.out, "{INDENT}}}")?; writeln!(self.out, "}}")?; writeln!(self.out)?; Ok(()) } pub(super) fn write_all_mat_cx2_typedefs_and_functions( &mut self, module: &crate::Module, ) -> BackendResult { for (handle, _) in module.global_variables.iter() { let global = &module.global_variables[handle]; if global.space == crate::AddressSpace::Uniform { if let Some(super::writer::MatrixType { columns, rows: crate::VectorSize::Bi, width: 4, }) = super::writer::get_inner_matrix_data(module, global.ty) { let entry = WrappedMatCx2 { columns }; if self.wrapped.insert(WrappedType::MatCx2(entry)) { self.write_mat_cx2_typedef_and_functions(entry)?; } } } } for (_, ty) in module.types.iter() { if let crate::TypeInner::Struct { ref members, .. } = ty.inner { for member in members.iter() { if let crate::TypeInner::Array { .. } = module.types[member.ty].inner { if let Some(super::writer::MatrixType { columns, rows: crate::VectorSize::Bi, width: 4, }) = super::writer::get_inner_matrix_data(module, member.ty) { let entry = WrappedMatCx2 { columns }; if self.wrapped.insert(WrappedType::MatCx2(entry)) { self.write_mat_cx2_typedef_and_functions(entry)?; } } } } } } Ok(()) } pub(super) fn write_wrapped_zero_value_function_name( &mut self, module: &crate::Module, zero_value: WrappedZeroValue, ) -> BackendResult { let name = crate::TypeInner::hlsl_type_id(zero_value.ty, module.to_ctx(), &self.names)?; write!(self.out, "ZeroValue{name}")?; Ok(()) } /// Helper function that write wrapped function for `Expression::ZeroValue` /// /// This is necessary since we might have a member access after the zero value expression, e.g. /// `.y` (in practice this can come up when consuming SPIRV that's been produced by glslc). /// /// So we can't just write `(float4)0` since `(float4)0.y` won't parse correctly. /// /// Parenthesizing the expression like `((float4)0).y` would work... except DXC can't handle /// cases like: /// /// ```text /// tests\out\hlsl\access.hlsl:183:41: error: cannot compile this l-value expression yet /// t_1.am = (__mat4x2[2])((float4x2[2])0); /// ^ /// ``` fn write_wrapped_zero_value_function( &mut self, module: &crate::Module, zero_value: WrappedZeroValue, ) -> BackendResult { use crate::back::INDENT; // Write function return type and name if let crate::TypeInner::Array { base, size, .. } = module.types[zero_value.ty].inner { write!(self.out, "typedef ")?; self.write_type(module, zero_value.ty)?; write!(self.out, " ret_")?; self.write_wrapped_zero_value_function_name(module, zero_value)?; self.write_array_size(module, base, size)?; writeln!(self.out, ";")?; write!(self.out, "ret_")?; self.write_wrapped_zero_value_function_name(module, zero_value)?; } else { self.write_type(module, zero_value.ty)?; } write!(self.out, " ")?; self.write_wrapped_zero_value_function_name(module, zero_value)?; // Write function parameters (none) and start function body writeln!(self.out, "() {{")?; // Write `ZeroValue` function. write!(self.out, "{INDENT}return ")?; self.write_default_init(module, zero_value.ty)?; writeln!(self.out, ";")?; // End of function body writeln!(self.out, "}}")?; // Write extra new line writeln!(self.out)?; Ok(()) } } impl crate::StorageFormat { /// Returns `true` if there is just one component, otherwise `false` pub(super) const fn single_component(&self) -> bool { match *self { crate::StorageFormat::R16Float | crate::StorageFormat::R32Float | crate::StorageFormat::R8Unorm | crate::StorageFormat::R16Unorm | crate::StorageFormat::R8Snorm | crate::StorageFormat::R16Snorm | crate::StorageFormat::R8Uint | crate::StorageFormat::R16Uint | crate::StorageFormat::R32Uint | crate::StorageFormat::R8Sint | crate::StorageFormat::R16Sint | crate::StorageFormat::R32Sint | crate::StorageFormat::R64Uint => true, _ => false, } } } naga-29.0.3/src/back/hlsl/keywords.rs000064400000000000000000000507151046102023000155030ustar 00000000000000use crate::proc::{CaseInsensitiveKeywordSet, KeywordSet}; use crate::racy_lock::RacyLock; // When compiling with FXC without strict mode, these keywords are actually case insensitive. // If you compile with strict mode and specify a different casing like "Pass" instead in an identifier, FXC will give this error: // "error X3086: alternate cases for 'pass' are deprecated in strict mode" // This behavior is not documented anywhere, but as far as I can tell this is the full list. pub const RESERVED_CASE_INSENSITIVE: &[&str] = &[ "asm", "decl", "pass", "technique", "Texture1D", "Texture2D", "Texture3D", "TextureCube", ]; pub const RESERVED: &[&str] = &[ // FXC keywords, from https://github.com/MicrosoftDocs/win32/blob/c885cb0c63b0e9be80c6a0e6512473ac6f4e771e/desktop-src/direct3dhlsl/dx-graphics-hlsl-appendix-keywords.md?plain=1#L99-L118 "AppendStructuredBuffer", "asm", "asm_fragment", "BlendState", "bool", "break", "Buffer", "ByteAddressBuffer", "case", "cbuffer", "centroid", "class", "column_major", "compile", "compile_fragment", "CompileShader", "const", "continue", "ComputeShader", "ConsumeStructuredBuffer", "default", "DepthStencilState", "DepthStencilView", "discard", "do", "double", "DomainShader", "dword", "else", "export", "extern", "false", "float", "for", "fxgroup", "GeometryShader", "groupshared", "half", "Hullshader", "if", "in", "inline", "inout", "InputPatch", "int", "interface", "line", "lineadj", "linear", "LineStream", "matrix", "min16float", "min10float", "min16int", "min12int", "min16uint", "namespace", "nointerpolation", "noperspective", "NULL", "out", "OutputPatch", "packoffset", "pass", "pixelfragment", "PixelShader", "point", "PointStream", "precise", "RasterizerState", "RenderTargetView", "return", "register", "row_major", "RWBuffer", "RWByteAddressBuffer", "RWStructuredBuffer", "RWTexture1D", "RWTexture1DArray", "RWTexture2D", "RWTexture2DArray", "RWTexture3D", "sample", "sampler", "SamplerState", "SamplerComparisonState", "shared", "snorm", "stateblock", "stateblock_state", "static", "string", "struct", "switch", "StructuredBuffer", "tbuffer", "technique", "technique10", "technique11", "texture", "Texture1D", "Texture1DArray", "Texture2D", "Texture2DArray", "Texture2DMS", "Texture2DMSArray", "Texture3D", "TextureCube", "TextureCubeArray", "true", "typedef", "triangle", "triangleadj", "TriangleStream", "uint", "uniform", "unorm", "unsigned", "vector", "vertexfragment", "VertexShader", "void", "volatile", "while", // FXC reserved keywords, from https://github.com/MicrosoftDocs/win32/blob/c885cb0c63b0e9be80c6a0e6512473ac6f4e771e/desktop-src/direct3dhlsl/dx-graphics-hlsl-appendix-reserved-words.md?plain=1#L19-L38 "auto", "case", "catch", "char", "class", "const_cast", "default", "delete", "dynamic_cast", "enum", "explicit", "friend", "goto", "long", "mutable", "new", "operator", "private", "protected", "public", "reinterpret_cast", "short", "signed", "sizeof", "static_cast", "template", "this", "throw", "try", "typename", "union", "unsigned", "using", "virtual", // FXC intrinsics, from https://github.com/MicrosoftDocs/win32/blob/1682b99e203708f6f5eda972d966e30f3c1588de/desktop-src/direct3dhlsl/dx-graphics-hlsl-intrinsic-functions.md?plain=1#L26-L165 "abort", "abs", "acos", "all", "AllMemoryBarrier", "AllMemoryBarrierWithGroupSync", "any", "asdouble", "asfloat", "asin", "asint", "asuint", "atan", "atan2", "ceil", "CheckAccessFullyMapped", "clamp", "clip", "cos", "cosh", "countbits", "cross", "D3DCOLORtoUBYTE4", "ddx", "ddx_coarse", "ddx_fine", "ddy", "ddy_coarse", "ddy_fine", "degrees", "determinant", "DeviceMemoryBarrier", "DeviceMemoryBarrierWithGroupSync", "distance", "dot", "dst", "errorf", "EvaluateAttributeCentroid", "EvaluateAttributeAtSample", "EvaluateAttributeSnapped", "exp", "exp2", "f16tof32", "f32tof16", "faceforward", "firstbithigh", "firstbitlow", "floor", "fma", "fmod", "frac", "frexp", "fwidth", "GetRenderTargetSampleCount", "GetRenderTargetSamplePosition", "GroupMemoryBarrier", "GroupMemoryBarrierWithGroupSync", "InterlockedAdd", "InterlockedAnd", "InterlockedCompareExchange", "InterlockedCompareStore", "InterlockedExchange", "InterlockedMax", "InterlockedMin", "InterlockedOr", "InterlockedXor", "isfinite", "isinf", "isnan", "ldexp", "length", "lerp", "lit", "log", "log10", "log2", "mad", "max", "min", "modf", "msad4", "mul", "noise", "normalize", "pow", "printf", "Process2DQuadTessFactorsAvg", "Process2DQuadTessFactorsMax", "Process2DQuadTessFactorsMin", "ProcessIsolineTessFactors", "ProcessQuadTessFactorsAvg", "ProcessQuadTessFactorsMax", "ProcessQuadTessFactorsMin", "ProcessTriTessFactorsAvg", "ProcessTriTessFactorsMax", "ProcessTriTessFactorsMin", "radians", "rcp", "reflect", "refract", "reversebits", "round", "rsqrt", "saturate", "sign", "sin", "sincos", "sinh", "smoothstep", "sqrt", "step", "tan", "tanh", "tex1D", "tex1Dbias", "tex1Dgrad", "tex1Dlod", "tex1Dproj", "tex2D", "tex2Dbias", "tex2Dgrad", "tex2Dlod", "tex2Dproj", "tex3D", "tex3Dbias", "tex3Dgrad", "tex3Dlod", "tex3Dproj", "texCUBE", "texCUBEbias", "texCUBEgrad", "texCUBElod", "texCUBEproj", "transpose", "trunc", // DXC (reserved) keywords, from https://github.com/microsoft/DirectXShaderCompiler/blob/d5d478470d3020a438d3cb810b8d3fe0992e6709/tools/clang/include/clang/Basic/TokenKinds.def#L222-L648 // with the KEYALL, KEYCXX, BOOLSUPPORT, WCHARSUPPORT, KEYHLSL options enabled (see https://github.com/microsoft/DirectXShaderCompiler/blob/d5d478470d3020a438d3cb810b8d3fe0992e6709/tools/clang/lib/Frontend/CompilerInvocation.cpp#L1199) "auto", "break", "case", "char", "const", "continue", "default", "do", "double", "else", "enum", "extern", "float", "for", "goto", "if", "inline", "int", "long", "register", "return", "short", "signed", "sizeof", "static", "struct", "switch", "typedef", "union", "unsigned", "void", "volatile", "while", "_Alignas", "_Alignof", "_Atomic", "_Complex", "_Generic", "_Imaginary", "_Noreturn", "_Static_assert", "_Thread_local", "__func__", "__objc_yes", "__objc_no", "asm", "bool", "catch", "class", "const_cast", "delete", "dynamic_cast", "explicit", "export", "false", "friend", "mutable", "namespace", "new", "operator", "private", "protected", "public", "reinterpret_cast", "static_cast", "template", "this", "throw", "true", "try", "typename", "typeid", "using", "virtual", "wchar_t", "_Decimal32", "_Decimal64", "_Decimal128", "__null", "__alignof", "__attribute", "__builtin_choose_expr", "__builtin_offsetof", "__builtin_va_arg", "__extension__", "__imag", "__int128", "__label__", "__real", "__thread", "__FUNCTION__", "__PRETTY_FUNCTION__", "__is_nothrow_assignable", "__is_constructible", "__is_nothrow_constructible", "__has_nothrow_assign", "__has_nothrow_move_assign", "__has_nothrow_copy", "__has_nothrow_constructor", "__has_trivial_assign", "__has_trivial_move_assign", "__has_trivial_copy", "__has_trivial_constructor", "__has_trivial_move_constructor", "__has_trivial_destructor", "__has_virtual_destructor", "__is_abstract", "__is_base_of", "__is_class", "__is_convertible_to", "__is_empty", "__is_enum", "__is_final", "__is_literal", "__is_literal_type", "__is_pod", "__is_polymorphic", "__is_trivial", "__is_union", "__is_trivially_constructible", "__is_trivially_copyable", "__is_trivially_assignable", "__underlying_type", "__is_lvalue_expr", "__is_rvalue_expr", "__is_arithmetic", "__is_floating_point", "__is_integral", "__is_complete_type", "__is_void", "__is_array", "__is_function", "__is_reference", "__is_lvalue_reference", "__is_rvalue_reference", "__is_fundamental", "__is_object", "__is_scalar", "__is_compound", "__is_pointer", "__is_member_object_pointer", "__is_member_function_pointer", "__is_member_pointer", "__is_const", "__is_volatile", "__is_standard_layout", "__is_signed", "__is_unsigned", "__is_same", "__is_convertible", "__array_rank", "__array_extent", "__private_extern__", "__module_private__", "__declspec", "__cdecl", "__stdcall", "__fastcall", "__thiscall", "__vectorcall", "cbuffer", "tbuffer", "packoffset", "linear", "centroid", "nointerpolation", "noperspective", "sample", "column_major", "row_major", "in", "out", "inout", "uniform", "precise", "center", "shared", "groupshared", "discard", "snorm", "unorm", "point", "line", "lineadj", "triangle", "triangleadj", "globallycoherent", "interface", "sampler_state", "technique", "indices", "vertices", "primitives", "payload", "Technique", "technique10", "technique11", "__builtin_omp_required_simd_align", "__pascal", "__fp16", "__alignof__", "__asm", "__asm__", "__attribute__", "__complex", "__complex__", "__const", "__const__", "__decltype", "__imag__", "__inline", "__inline__", "__nullptr", "__real__", "__restrict", "__restrict__", "__signed", "__signed__", "__typeof", "__typeof__", "__volatile", "__volatile__", "_Nonnull", "_Nullable", "_Null_unspecified", "__builtin_convertvector", "__char16_t", "__char32_t", // DXC intrinsics, from https://github.com/microsoft/DirectXShaderCompiler/blob/18c9e114f9c314f93e68fbc72ce207d4ed2e65ae/utils/hct/gen_intrin_main.txt#L86-L376 "D3DCOLORtoUBYTE4", "GetRenderTargetSampleCount", "GetRenderTargetSamplePosition", "abort", "abs", "acos", "all", "AllMemoryBarrier", "AllMemoryBarrierWithGroupSync", "any", "asdouble", "asfloat", "asfloat16", "asint16", "asin", "asint", "asuint", "asuint16", "atan", "atan2", "ceil", "clamp", "clip", "cos", "cosh", "countbits", "cross", "ddx", "ddx_coarse", "ddx_fine", "ddy", "ddy_coarse", "ddy_fine", "degrees", "determinant", "DeviceMemoryBarrier", "DeviceMemoryBarrierWithGroupSync", "distance", "dot", "dst", "EvaluateAttributeAtSample", "EvaluateAttributeCentroid", "EvaluateAttributeSnapped", "GetAttributeAtVertex", "exp", "exp2", "f16tof32", "f32tof16", "faceforward", "firstbithigh", "firstbitlow", "floor", "fma", "fmod", "frac", "frexp", "fwidth", "GroupMemoryBarrier", "GroupMemoryBarrierWithGroupSync", "InterlockedAdd", "InterlockedMin", "InterlockedMax", "InterlockedAnd", "InterlockedOr", "InterlockedXor", "InterlockedCompareStore", "InterlockedExchange", "InterlockedCompareExchange", "InterlockedCompareStoreFloatBitwise", "InterlockedCompareExchangeFloatBitwise", "isfinite", "isinf", "isnan", "ldexp", "length", "lerp", "lit", "log", "log10", "log2", "mad", "max", "min", "modf", "msad4", "mul", "normalize", "pow", "printf", "Process2DQuadTessFactorsAvg", "Process2DQuadTessFactorsMax", "Process2DQuadTessFactorsMin", "ProcessIsolineTessFactors", "ProcessQuadTessFactorsAvg", "ProcessQuadTessFactorsMax", "ProcessQuadTessFactorsMin", "ProcessTriTessFactorsAvg", "ProcessTriTessFactorsMax", "ProcessTriTessFactorsMin", "radians", "rcp", "reflect", "refract", "reversebits", "round", "rsqrt", "saturate", "sign", "sin", "sincos", "sinh", "smoothstep", "source_mark", "sqrt", "step", "tan", "tanh", "tex1D", "tex1Dbias", "tex1Dgrad", "tex1Dlod", "tex1Dproj", "tex2D", "tex2Dbias", "tex2Dgrad", "tex2Dlod", "tex2Dproj", "tex3D", "tex3Dbias", "tex3Dgrad", "tex3Dlod", "tex3Dproj", "texCUBE", "texCUBEbias", "texCUBEgrad", "texCUBElod", "texCUBEproj", "transpose", "trunc", "CheckAccessFullyMapped", "AddUint64", "NonUniformResourceIndex", "WaveIsFirstLane", "WaveGetLaneIndex", "WaveGetLaneCount", "WaveActiveAnyTrue", "WaveActiveAllTrue", "WaveActiveAllEqual", "WaveActiveBallot", "WaveReadLaneAt", "WaveReadLaneFirst", "WaveActiveCountBits", "WaveActiveSum", "WaveActiveProduct", "WaveActiveBitAnd", "WaveActiveBitOr", "WaveActiveBitXor", "WaveActiveMin", "WaveActiveMax", "WavePrefixCountBits", "WavePrefixSum", "WavePrefixProduct", "WaveMatch", "WaveMultiPrefixBitAnd", "WaveMultiPrefixBitOr", "WaveMultiPrefixBitXor", "WaveMultiPrefixCountBits", "WaveMultiPrefixProduct", "WaveMultiPrefixSum", "QuadReadLaneAt", "QuadReadAcrossX", "QuadReadAcrossY", "QuadReadAcrossDiagonal", "QuadAny", "QuadAll", "TraceRay", "ReportHit", "CallShader", "IgnoreHit", "AcceptHitAndEndSearch", "DispatchRaysIndex", "DispatchRaysDimensions", "WorldRayOrigin", "WorldRayDirection", "ObjectRayOrigin", "ObjectRayDirection", "RayTMin", "RayTCurrent", "PrimitiveIndex", "InstanceID", "InstanceIndex", "GeometryIndex", "HitKind", "RayFlags", "ObjectToWorld", "WorldToObject", "ObjectToWorld3x4", "WorldToObject3x4", "ObjectToWorld4x3", "WorldToObject4x3", "dot4add_u8packed", "dot4add_i8packed", "dot2add", "unpack_s8s16", "unpack_u8u16", "unpack_s8s32", "unpack_u8u32", "pack_s8", "pack_u8", "pack_clamp_s8", "pack_clamp_u8", "SetMeshOutputCounts", "DispatchMesh", "IsHelperLane", "AllocateRayQuery", "CreateResourceFromHeap", "and", "or", "select", // DXC resource and other types, from https://github.com/microsoft/DirectXShaderCompiler/blob/18c9e114f9c314f93e68fbc72ce207d4ed2e65ae/tools/clang/lib/AST/HlslTypes.cpp#L441-#L572 "InputPatch", "OutputPatch", "PointStream", "LineStream", "TriangleStream", "Texture1D", "RWTexture1D", "Texture2D", "RWTexture2D", "Texture2DMS", "RWTexture2DMS", "Texture3D", "RWTexture3D", "TextureCube", "RWTextureCube", "Texture1DArray", "RWTexture1DArray", "Texture2DArray", "RWTexture2DArray", "Texture2DMSArray", "RWTexture2DMSArray", "TextureCubeArray", "RWTextureCubeArray", "FeedbackTexture2D", "FeedbackTexture2DArray", "RasterizerOrderedTexture1D", "RasterizerOrderedTexture2D", "RasterizerOrderedTexture3D", "RasterizerOrderedTexture1DArray", "RasterizerOrderedTexture2DArray", "RasterizerOrderedBuffer", "RasterizerOrderedByteAddressBuffer", "RasterizerOrderedStructuredBuffer", "ByteAddressBuffer", "RWByteAddressBuffer", "StructuredBuffer", "RWStructuredBuffer", "AppendStructuredBuffer", "ConsumeStructuredBuffer", "Buffer", "RWBuffer", "SamplerState", "SamplerComparisonState", "ConstantBuffer", "TextureBuffer", "RaytracingAccelerationStructure", // DXC templated types, from https://github.com/microsoft/DirectXShaderCompiler/blob/18c9e114f9c314f93e68fbc72ce207d4ed2e65ae/tools/clang/lib/AST/ASTContextHLSL.cpp // look for `BuiltinTypeDeclBuilder` "matrix", "vector", "TextureBuffer", "ConstantBuffer", "RayQuery", "RayDesc", // Naga utilities super::writer::MODF_FUNCTION, super::writer::FREXP_FUNCTION, super::writer::EXTRACT_BITS_FUNCTION, super::writer::INSERT_BITS_FUNCTION, super::writer::SAMPLER_HEAP_VAR, super::writer::COMPARISON_SAMPLER_HEAP_VAR, super::writer::SAMPLE_EXTERNAL_TEXTURE_FUNCTION, super::writer::ABS_FUNCTION, super::writer::DIV_FUNCTION, super::writer::MOD_FUNCTION, super::writer::NEG_FUNCTION, super::writer::F2I32_FUNCTION, super::writer::F2U32_FUNCTION, super::writer::F2I64_FUNCTION, super::writer::F2U64_FUNCTION, super::writer::IMAGE_LOAD_EXTERNAL_FUNCTION, super::writer::IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION, ]; // DXC scalar types, from https://github.com/microsoft/DirectXShaderCompiler/blob/18c9e114f9c314f93e68fbc72ce207d4ed2e65ae/tools/clang/lib/AST/ASTContextHLSL.cpp#L48-L254 // + vector and matrix shorthands pub const TYPES: &[&str] = &{ const L: usize = 23 * (1 + 4 + 4 * 4); let mut res = [""; L]; let mut c = 0; /// For each scalar type, it will additionally generate vector and matrix shorthands macro_rules! generate { ([$($roots:literal),*], $x:tt) => { $( generate!(@inner push $roots); generate!(@inner $roots, $x); )* }; (@inner $root:literal, [$($x:literal),*]) => { generate!(@inner vector $root, $($x)*); generate!(@inner matrix $root, $($x)*); }; (@inner vector $root:literal, $($x:literal)*) => { $( generate!(@inner push concat!($root, $x)); )* }; (@inner matrix $root:literal, $($x:literal)*) => { // Duplicate the list generate!(@inner matrix $root, $($x)*; $($x)*); }; // The head/tail recursion: pick the first element of the first list and recursively do it for the tail. (@inner matrix $root:literal, $head:literal $($tail:literal)*; $($x:literal)*) => { $( generate!(@inner push concat!($root, $head, "x", $x)); )* generate!(@inner matrix $root, $($tail)*; $($x)*); }; // The end of iteration: we exhausted the list (@inner matrix $root:literal, ; $($x:literal)*) => {}; (@inner push $v:expr) => { res[c] = $v; c += 1; }; } generate!( [ "bool", "int", "uint", "dword", "half", "float", "double", "min10float", "min16float", "min12int", "min16int", "min16uint", "int16_t", "int32_t", "int64_t", "uint16_t", "uint32_t", "uint64_t", "float16_t", "float32_t", "float64_t", "int8_t4_packed", "uint8_t4_packed" ], ["1", "2", "3", "4"] ); debug_assert!(c == L); res }; /// The above set of reserved keywords, turned into a cached HashSet. This saves /// significant time during [`Namer::reset`](crate::proc::Namer::reset). /// /// See for benchmarks. pub static RESERVED_SET: RacyLock = RacyLock::new(|| KeywordSet::from_iter(RESERVED.iter().chain(TYPES))); pub static RESERVED_CASE_INSENSITIVE_SET: RacyLock = RacyLock::new(|| CaseInsensitiveKeywordSet::from_iter(RESERVED_CASE_INSENSITIVE)); pub const RESERVED_PREFIXES: &[&str] = &[ "__dynamic_buffer_offsets", super::help::IMAGE_STORAGE_LOAD_SCALAR_WRAPPER, super::writer::RAY_QUERY_TRACKER_VARIABLE_PREFIX, super::writer::INTERNAL_PREFIX, ]; naga-29.0.3/src/back/hlsl/mod.rs000064400000000000000000000727751046102023000144250ustar 00000000000000/*! Backend for [HLSL][hlsl] (High-Level Shading Language). # Supported shader model versions: - 5.0 - 5.1 - 6.0 # Layout of values in `uniform` buffers WGSL's ["Internal Layout of Values"][ilov] rules specify how each WGSL type should be stored in `uniform` and `storage` buffers. The HLSL we generate must access values in that form, even when it is not what HLSL would use normally. Matching the WGSL memory layout is a concern only for `uniform` variables. WGSL `storage` buffers are translated as HLSL `ByteAddressBuffers`, for which we generate `Load` and `Store` method calls with explicit byte offsets. WGSL pipeline inputs must be scalars or vectors; they cannot be matrices, which is where the interesting problems arise. However, when an affected type appears in a struct definition, the transformations described here are applied without consideration of where the struct is used. Access to storage buffers is implemented in `storage.rs`. Access to uniform buffers is implemented where applicable in `writer.rs`. ## Row- and column-major ordering for matrices WGSL specifies that matrices in uniform buffers are stored in column-major order. This matches HLSL's default, so one might expect things to be straightforward. Unfortunately, WGSL and HLSL disagree on what indexing a matrix means: in WGSL, `m[i]` retrieves the `i`'th *column* of `m`, whereas in HLSL it retrieves the `i`'th *row*. We want to avoid translating `m[i]` into some complicated reassembly of a vector from individually fetched components, so this is a problem. However, with a bit of trickery, it is possible to use HLSL's `m[i]` as the translation of WGSL's `m[i]`: - We declare all matrices in uniform buffers in HLSL with the `row_major` qualifier, and transpose the row and column counts: a WGSL `mat3x4`, say, becomes an HLSL `row_major float3x4`. (Note that WGSL and HLSL type names put the row and column in reverse order.) Since the HLSL type is the transpose of how WebGPU directs the user to store the data, HLSL will load all matrices transposed. - Since matrices are transposed, an HLSL indexing expression retrieves the "columns" of the intended WGSL value, as desired. - For vector-matrix multiplication, since `mul(transpose(m), v)` is equivalent to `mul(v, m)` (note the reversal of the arguments), and `mul(v, transpose(m))` is equivalent to `mul(m, v)`, we can translate WGSL `m * v` and `v * m` to HLSL by simply reversing the arguments to `mul`. ## Padding in two-row matrices An HLSL `row_major floatKx2` matrix has padding between its rows that the WGSL `matKx2` matrix it represents does not. HLSL stores all matrix rows [aligned on 16-byte boundaries][16bb], whereas WGSL says that the columns of a `matKx2` need only be [aligned as required for `vec2`][ilov], which is [eight-byte alignment][8bb]. To compensate for this, any time a `matKx2` appears in a WGSL `uniform` value or as part of a struct/array, we actually emit `K` separate `float2` members, and assemble/disassemble the matrix from its columns (in WGSL; rows in HLSL) upon load and store. For example, the following WGSL struct type: ```ignore struct Baz { m: mat3x2, } ``` is rendered as the HLSL struct type: ```ignore struct Baz { float2 m_0; float2 m_1; float2 m_2; }; ``` The `wrapped_struct_matrix` functions in `help.rs` generate HLSL helper functions to access such members, converting between the stored form and the HLSL matrix types appropriately. For example, for reading the member `m` of the `Baz` struct above, we emit: ```ignore float3x2 GetMatmOnBaz(Baz obj) { return float3x2(obj.m_0, obj.m_1, obj.m_2); } ``` We also emit an analogous `Set` function, as well as functions for accessing individual columns by dynamic index. ## Sampler Handling Due to limitations in how sampler heaps work in D3D12, we need to access samplers through a layer of indirection. Instead of directly binding samplers, we bind the entire sampler heap as both a standard and a comparison sampler heap. We then use a sampler index buffer for each bind group. This buffer is accessed in the shader to get the actual sampler index within the heap. See the wgpu_hal dx12 backend documentation for more information. # External textures Support for [`crate::ImageClass::External`] textures is implemented by lowering each external texture global variable to 3 `Texture2D`s, and a `cbuffer` of type `NagaExternalTextureParams`. This provides up to 3 planes of texture data (for example single planar RGBA, or separate Y, Cb, and Cr planes), and the parameters buffer containing information describing how to handle these correctly. The bind target to use for each of these globals is specified via [`Options::external_texture_binding_map`]. External textures are supported by WGSL's `textureDimensions()`, `textureLoad()`, and `textureSampleBaseClampToEdge()` built-in functions. These are implemented using helper functions. See the following functions for how these are generated: * `Writer::write_wrapped_image_query_function` * `Writer::write_wrapped_image_load_function` * `Writer::write_wrapped_image_sample_function` Ideally the set of global variables could be wrapped in a single struct that could conveniently be passed around. But, alas, HLSL does not allow structs to have `Texture2D` members. Fortunately, however, external textures can only be used as arguments to either built-in or user-defined functions. We therefore expand any external texture function argument to four consecutive arguments (3 textures and the params struct) when declaring user-defined functions, and ensure our built-in function implementations take the same arguments. Then, whenever we need to emit an external texture in `Writer::write_expr`, which fortunately can only ever be for a global variable or function argument, we simply emit the variable name of each of the three textures and the parameters struct in a comma-separated list. This won't win any awards for elegance, but it works for our purposes. [hlsl]: https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl [ilov]: https://gpuweb.github.io/gpuweb/wgsl/#internal-value-layout [16bb]: https://github.com/microsoft/DirectXShaderCompiler/wiki/Buffer-Packing#constant-buffer-packing [8bb]: https://gpuweb.github.io/gpuweb/wgsl/#alignment-and-size */ mod conv; mod help; mod keywords; mod ray; mod storage; mod writer; use alloc::{string::String, vec::Vec}; use core::fmt::Error as FmtError; use thiserror::Error; use crate::{back, ir, proc}; /// Direct3D 12 binding information for a global variable. /// /// This type provides the HLSL-specific information Naga needs to declare and /// access an HLSL global variable that cannot be derived from the `Module` /// itself. /// /// An HLSL global variable declaration includes details that the Direct3D API /// will use to refer to it. For example: /// /// RWByteAddressBuffer s_sasm : register(u0, space2); /// /// This defines a global `s_sasm` that a Direct3D root signature would refer to /// as register `0` in register space `2` in a `UAV` descriptor range. Naga can /// infer the register's descriptor range type from the variable's address class /// (writable [`Storage`] variables are implemented by Direct3D Unordered Access /// Views, the `u` register type), but the register number and register space /// must be supplied by the user. /// /// The [`back::hlsl::Options`] structure provides `BindTarget`s for various /// situations in which Naga may need to generate an HLSL global variable, like /// [`binding_map`] for Naga global variables, or [`immediates_target`] for /// a module's sole [`Immediate`] variable. See those fields' documentation /// for details. /// /// [`Storage`]: crate::ir::AddressSpace::Storage /// [`back::hlsl::Options`]: Options /// [`binding_map`]: Options::binding_map /// [`immediates_target`]: Options::immediates_target /// [`Immediate`]: crate::ir::AddressSpace::Immediate #[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] pub struct BindTarget { pub space: u8, /// For regular bindings this is the register number. /// /// For sampler bindings, this is the index to use into the bind group's sampler index buffer. pub register: u32, /// If the binding is an unsized binding array, this overrides the size. pub binding_array_size: Option, /// This is the index in the buffer at [`Options::dynamic_storage_buffer_offsets_targets`]. pub dynamic_storage_buffer_offsets_index: Option, /// This is a hint that we need to restrict indexing of vectors, matrices and arrays. /// /// If [`Options::restrict_indexing`] is also `true`, we will restrict indexing. #[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(default))] pub restrict_indexing: bool, } #[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] /// BindTarget for dynamic storage buffer offsets pub struct OffsetsBindTarget { pub space: u8, pub register: u32, pub size: u32, } #[cfg(feature = "deserialize")] #[derive(serde::Deserialize)] struct BindingMapSerialization { resource_binding: crate::ResourceBinding, bind_target: BindTarget, } #[cfg(feature = "deserialize")] fn deserialize_binding_map<'de, D>(deserializer: D) -> Result where D: serde::Deserializer<'de>, { use serde::Deserialize; let vec = Vec::::deserialize(deserializer)?; let mut map = BindingMap::default(); for item in vec { map.insert(item.resource_binding, item.bind_target); } Ok(map) } // Using `BTreeMap` instead of `HashMap` so that we can hash itself. pub type BindingMap = alloc::collections::BTreeMap; /// A HLSL shader model version. #[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] pub enum ShaderModel { V5_0, V5_1, V6_0, V6_1, V6_2, V6_3, V6_4, V6_5, V6_6, V6_7, V6_8, V6_9, } impl ShaderModel { pub const fn to_str(self) -> &'static str { match self { Self::V5_0 => "5_0", Self::V5_1 => "5_1", Self::V6_0 => "6_0", Self::V6_1 => "6_1", Self::V6_2 => "6_2", Self::V6_3 => "6_3", Self::V6_4 => "6_4", Self::V6_5 => "6_5", Self::V6_6 => "6_6", Self::V6_7 => "6_7", Self::V6_8 => "6_8", Self::V6_9 => "6_9", } } } impl crate::ShaderStage { pub const fn to_hlsl_str(self) -> &'static str { match self { Self::Vertex => "vs", Self::Fragment => "ps", Self::Compute => "cs", Self::Task => "as", Self::Mesh => "ms", Self::RayGeneration | Self::AnyHit | Self::ClosestHit | Self::Miss => "lib", } } } impl crate::ImageDimension { const fn to_hlsl_str(self) -> &'static str { match self { Self::D1 => "1D", Self::D2 => "2D", Self::D3 => "3D", Self::Cube => "Cube", } } } #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] pub struct SamplerIndexBufferKey { pub group: u32, } #[derive(Clone, Debug, Hash, PartialEq, Eq)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] #[cfg_attr(feature = "deserialize", serde(default))] pub struct SamplerHeapBindTargets { pub standard_samplers: BindTarget, pub comparison_samplers: BindTarget, } impl Default for SamplerHeapBindTargets { fn default() -> Self { Self { standard_samplers: BindTarget { space: 0, register: 0, binding_array_size: None, dynamic_storage_buffer_offsets_index: None, restrict_indexing: false, }, comparison_samplers: BindTarget { space: 1, register: 0, binding_array_size: None, dynamic_storage_buffer_offsets_index: None, restrict_indexing: false, }, } } } #[cfg(feature = "deserialize")] #[derive(serde::Deserialize)] struct SamplerIndexBufferBindingSerialization { group: u32, bind_target: BindTarget, } #[cfg(feature = "deserialize")] fn deserialize_sampler_index_buffer_bindings<'de, D>( deserializer: D, ) -> Result where D: serde::Deserializer<'de>, { use serde::Deserialize; let vec = Vec::::deserialize(deserializer)?; let mut map = SamplerIndexBufferBindingMap::default(); for item in vec { map.insert( SamplerIndexBufferKey { group: item.group }, item.bind_target, ); } Ok(map) } // We use a BTreeMap here so that we can hash it. pub type SamplerIndexBufferBindingMap = alloc::collections::BTreeMap; #[cfg(feature = "deserialize")] #[derive(serde::Deserialize)] struct DynamicStorageBufferOffsetTargetSerialization { index: u32, bind_target: OffsetsBindTarget, } #[cfg(feature = "deserialize")] fn deserialize_storage_buffer_offsets<'de, D>( deserializer: D, ) -> Result where D: serde::Deserializer<'de>, { use serde::Deserialize; let vec = Vec::::deserialize(deserializer)?; let mut map = DynamicStorageBufferOffsetsTargets::default(); for item in vec { map.insert(item.index, item.bind_target); } Ok(map) } pub type DynamicStorageBufferOffsetsTargets = alloc::collections::BTreeMap; /// HLSL binding information for a Naga [`External`] image global variable. /// /// See the module documentation's section on [External textures][mod] for details. /// /// [`External`]: crate::ir::ImageClass::External /// [mod]: #external-textures #[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] pub struct ExternalTextureBindTarget { /// HLSL binding information for the individual plane textures. /// /// Each of these should refer to an HLSL `Texture2D` holding one /// plane of data for the external texture. The exact meaning of each plane /// varies at runtime depending on where the external texture's data /// originated. pub planes: [BindTarget; 3], /// HLSL binding information for a buffer holding the sampling parameters. /// /// This should refer to a cbuffer of type `NagaExternalTextureParams`, that /// the code Naga generates for `textureSampleBaseClampToEdge` consults to /// decide how to combine the data in [`planes`] to get the result required /// by the spec. /// /// [`planes`]: Self::planes pub params: BindTarget, } #[cfg(feature = "deserialize")] #[derive(serde::Deserialize)] struct ExternalTextureBindingMapSerialization { resource_binding: crate::ResourceBinding, bind_target: ExternalTextureBindTarget, } #[cfg(feature = "deserialize")] fn deserialize_external_texture_binding_map<'de, D>( deserializer: D, ) -> Result where D: serde::Deserializer<'de>, { use serde::Deserialize; let vec = Vec::::deserialize(deserializer)?; let mut map = ExternalTextureBindingMap::default(); for item in vec { map.insert(item.resource_binding, item.bind_target); } Ok(map) } pub type ExternalTextureBindingMap = alloc::collections::BTreeMap; /// Shorthand result used internally by the backend type BackendResult = Result<(), Error>; #[derive(Clone, Debug, PartialEq, thiserror::Error)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] pub enum EntryPointError { #[error("mapping of {0:?} is missing")] MissingBinding(crate::ResourceBinding), } /// Configuration used in the [`Writer`]. #[derive(Clone, Debug, Hash, PartialEq, Eq)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] #[cfg_attr(feature = "deserialize", serde(default))] pub struct Options { /// The hlsl shader model to be used pub shader_model: ShaderModel, /// HLSL binding information for each Naga global variable. /// /// This maps Naga [`GlobalVariable`]'s [`ResourceBinding`]s to a /// [`BindTarget`] specifying its register number and space, along with /// other details necessary to generate a full HLSL declaration for it, /// or to access its value. /// /// This must provide a [`BindTarget`] for every [`GlobalVariable`] in the /// [`Module`] that has a [`binding`]. /// /// [`GlobalVariable`]: crate::ir::GlobalVariable /// [`ResourceBinding`]: crate::ir::ResourceBinding /// [`Module`]: crate::ir::Module /// [`binding`]: crate::ir::GlobalVariable::binding #[cfg_attr( feature = "deserialize", serde(deserialize_with = "deserialize_binding_map") )] pub binding_map: BindingMap, /// Don't panic on missing bindings, instead generate any HLSL. pub fake_missing_bindings: bool, /// Add special constants to `SV_VertexIndex` and `SV_InstanceIndex`, /// to make them work like in Vulkan/Metal, with help of the host. pub special_constants_binding: Option, /// HLSL binding information for the [`Immediate`] global, if present. /// /// If a module contains a global in the [`Immediate`] address space, the /// `dx12` backend stores its value directly in the root signature as a /// series of [`D3D12_ROOT_PARAMETER_TYPE_32BIT_CONSTANTS`], whose binding /// information is given here. /// /// [`Immediate`]: crate::ir::AddressSpace::Immediate /// [`D3D12_ROOT_PARAMETER_TYPE_32BIT_CONSTANTS`]: https://learn.microsoft.com/en-us/windows/win32/api/d3d12/ne-d3d12-d3d12_root_parameter_type pub immediates_target: Option, /// HLSL binding information for the sampler heap and comparison sampler heap. pub sampler_heap_target: SamplerHeapBindTargets, /// Mapping of each bind group's sampler index buffer to a bind target. #[cfg_attr( feature = "deserialize", serde(deserialize_with = "deserialize_sampler_index_buffer_bindings") )] pub sampler_buffer_binding_map: SamplerIndexBufferBindingMap, /// Bind target for dynamic storage buffer offsets #[cfg_attr( feature = "deserialize", serde(deserialize_with = "deserialize_storage_buffer_offsets") )] pub dynamic_storage_buffer_offsets_targets: DynamicStorageBufferOffsetsTargets, #[cfg_attr( feature = "deserialize", serde(deserialize_with = "deserialize_external_texture_binding_map") )] /// HLSL binding information for [`External`] image global variables. /// /// See [`ExternalTextureBindTarget`] for details. /// /// [`External`]: crate::ir::ImageClass::External pub external_texture_binding_map: ExternalTextureBindingMap, /// Should workgroup variables be zero initialized (by polyfilling)? pub zero_initialize_workgroup_memory: bool, /// Should we restrict indexing of vectors, matrices and arrays? pub restrict_indexing: bool, /// If set, loops will have code injected into them, forcing the compiler /// to think the number of iterations is bounded. pub force_loop_bounding: bool, /// if set, ray queries will get a variable to track their state to prevent /// misuse. pub ray_query_initialization_tracking: bool, } impl Default for Options { fn default() -> Self { Options { shader_model: ShaderModel::V5_1, binding_map: BindingMap::default(), fake_missing_bindings: true, special_constants_binding: None, sampler_heap_target: SamplerHeapBindTargets::default(), sampler_buffer_binding_map: alloc::collections::BTreeMap::default(), immediates_target: None, dynamic_storage_buffer_offsets_targets: alloc::collections::BTreeMap::new(), external_texture_binding_map: ExternalTextureBindingMap::default(), zero_initialize_workgroup_memory: true, restrict_indexing: true, force_loop_bounding: true, ray_query_initialization_tracking: true, } } } impl Options { fn resolve_resource_binding( &self, res_binding: &crate::ResourceBinding, ) -> Result { match self.binding_map.get(res_binding) { Some(target) => Ok(*target), None if self.fake_missing_bindings => Ok(BindTarget { space: res_binding.group as u8, register: res_binding.binding, binding_array_size: None, dynamic_storage_buffer_offsets_index: None, restrict_indexing: false, }), None => Err(EntryPointError::MissingBinding(*res_binding)), } } fn resolve_external_texture_resource_binding( &self, res_binding: &crate::ResourceBinding, ) -> Result { match self.external_texture_binding_map.get(res_binding) { Some(target) => Ok(*target), None if self.fake_missing_bindings => { let fake = BindTarget { space: res_binding.group as u8, register: res_binding.binding, binding_array_size: None, dynamic_storage_buffer_offsets_index: None, restrict_indexing: false, }; Ok(ExternalTextureBindTarget { planes: [fake, fake, fake], params: fake, }) } None => Err(EntryPointError::MissingBinding(*res_binding)), } } } /// Reflection info for entry point names. #[derive(Default)] pub struct ReflectionInfo { /// Mapping of the entry point names. /// /// Each item in the array corresponds to an entry point index. The real entry point name may be different if one of the /// reserved words are used. /// /// Note: Some entry points may fail translation because of missing bindings. pub entry_point_names: Vec>, } /// A subset of options that are meant to be changed per pipeline. #[derive(Debug, Default, Clone)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] #[cfg_attr(feature = "deserialize", serde(default))] pub struct PipelineOptions { /// The entry point to write. /// /// Entry points are identified by a shader stage specification, /// and a name. /// /// If `None`, all entry points will be written. If `Some` and the entry /// point is not found, an error will be thrown while writing. pub entry_point: Option<(ir::ShaderStage, String)>, } #[derive(Error, Debug)] pub enum Error { #[error(transparent)] IoError(#[from] FmtError), #[error("A scalar with an unsupported width was requested: {0:?}")] UnsupportedScalar(crate::Scalar), #[error("{0}")] Unimplemented(String), // TODO: Error used only during development #[error("{0}")] Custom(String), #[error("overrides should not be present at this stage")] Override, #[error(transparent)] ResolveArraySizeError(#[from] proc::ResolveArraySizeError), #[error("entry point with stage {0:?} and name '{1}' not found")] EntryPointNotFound(ir::ShaderStage, String), #[error("requires shader model {1:?} for reason: {0}")] ShaderModelTooLow(String, ShaderModel), } #[derive(PartialEq, Eq, Hash)] enum WrappedType { ZeroValue(help::WrappedZeroValue), ArrayLength(help::WrappedArrayLength), ImageSample(help::WrappedImageSample), ImageQuery(help::WrappedImageQuery), ImageLoad(help::WrappedImageLoad), ImageLoadScalar(crate::Scalar), Constructor(help::WrappedConstructor), StructMatrixAccess(help::WrappedStructMatrixAccess), MatCx2(help::WrappedMatCx2), Math(help::WrappedMath), UnaryOp(help::WrappedUnaryOp), BinaryOp(help::WrappedBinaryOp), Cast(help::WrappedCast), } #[derive(Default)] struct Wrapped { types: crate::FastHashSet, /// If true, the sampler heaps have been written out. sampler_heaps: bool, // Mapping from SamplerIndexBufferKey to the name the namer returned. sampler_index_buffers: crate::FastHashMap, } impl Wrapped { fn insert(&mut self, r#type: WrappedType) -> bool { self.types.insert(r#type) } fn clear(&mut self) { self.types.clear(); } } /// A fragment entry point to be considered when generating HLSL for the output interface of vertex /// entry points. /// /// This is provided as an optional parameter to [`Writer::write`]. /// /// If this is provided, vertex outputs will be removed if they are not inputs of this fragment /// entry point. This is necessary for generating correct HLSL when some of the vertex shader /// outputs are not consumed by the fragment shader. pub struct FragmentEntryPoint<'a> { module: &'a crate::Module, func: &'a crate::Function, } impl<'a> FragmentEntryPoint<'a> { /// Returns `None` if the entry point with the provided name can't be found or isn't a fragment /// entry point. pub fn new(module: &'a crate::Module, ep_name: &'a str) -> Option { module .entry_points .iter() .find(|ep| ep.name == ep_name) .filter(|ep| ep.stage == crate::ShaderStage::Fragment) .map(|ep| Self { module, func: &ep.function, }) } } pub struct Writer<'a, W> { out: W, names: crate::FastHashMap, namer: proc::Namer, /// HLSL backend options options: &'a Options, /// Per-stage backend options pipeline_options: &'a PipelineOptions, /// Information about entry point arguments and result types. entry_point_io: crate::FastHashMap, /// Set of expressions that have associated temporary variables named_expressions: crate::NamedExpressions, wrapped: Wrapped, written_committed_intersection: bool, written_candidate_intersection: bool, continue_ctx: back::continue_forward::ContinueCtx, /// A reference to some part of a global variable, lowered to a series of /// byte offset calculations. /// /// See the [`storage`] module for background on why we need this. /// /// Each [`SubAccess`] in the vector is a lowering of some [`Access`] or /// [`AccessIndex`] expression to the level of byte strides and offsets. See /// [`SubAccess`] for details. /// /// This field is a member of [`Writer`] solely to allow re-use of /// the `Vec`'s dynamic allocation. The value is no longer needed /// once HLSL for the access has been generated. /// /// [`Storage`]: crate::AddressSpace::Storage /// [`SubAccess`]: storage::SubAccess /// [`Access`]: crate::Expression::Access /// [`AccessIndex`]: crate::Expression::AccessIndex temp_access_chain: Vec, need_bake_expressions: back::NeedBakeExpressions, } pub fn supported_capabilities() -> crate::valid::Capabilities { use crate::valid::Capabilities as Caps; Caps::IMMEDIATES | Caps::FLOAT64 // Unsupported by wgpu but supported by naga | Caps::PRIMITIVE_INDEX | Caps::TEXTURE_AND_SAMPLER_BINDING_ARRAY // No BUFFER_BINDING_ARRAY | Caps::STORAGE_TEXTURE_BINDING_ARRAY // No STORAGE_BUFFER_BINDING_ARRAY | Caps::ACCELERATION_STRUCTURE_BINDING_ARRAY // No CLIP_DISTANCE // No CULL_DISTANCE | Caps::STORAGE_TEXTURE_16BIT_NORM_FORMATS | Caps::MULTIVIEW // No EARLY_DEPTH_TEST | Caps::MULTISAMPLED_SHADING | Caps::RAY_QUERY | Caps::DUAL_SOURCE_BLENDING | Caps::CUBE_ARRAY_TEXTURES | Caps::SHADER_INT64 | Caps::SUBGROUP // No SUBGROUP_BARRIER // No SUBGROUP_VERTEX_STAGE | Caps::SHADER_INT64_ATOMIC_MIN_MAX | Caps::SHADER_INT64_ATOMIC_ALL_OPS // No SHADER_FLOAT32_ATOMIC | Caps::TEXTURE_ATOMIC | Caps::TEXTURE_INT64_ATOMIC // No RAY_HIT_VERTEX_POSITION | Caps::SHADER_FLOAT16 | Caps::TEXTURE_EXTERNAL | Caps::SHADER_FLOAT16_IN_FLOAT32 | Caps::SHADER_BARYCENTRICS // No MESH_SHADER // No MESH_SHADER_POINT_TOPOLOGY | Caps::TEXTURE_AND_SAMPLER_BINDING_ARRAY_NON_UNIFORM_INDEXING // No BUFFER_BINDING_ARRAY_NON_UNIFORM_INDEXING | Caps::STORAGE_TEXTURE_BINDING_ARRAY_NON_UNIFORM_INDEXING | Caps::STORAGE_BUFFER_BINDING_ARRAY_NON_UNIFORM_INDEXING // No COOPERATIVE_MATRIX // No PER_VERTEX // No RAY_TRACING_PIPELINE // No DRAW_INDEX // No MEMORY_DECORATION_VOLATILE | Caps::MEMORY_DECORATION_COHERENT } naga-29.0.3/src/back/hlsl/ray.rs000064400000000000000000000544571046102023000144360ustar 00000000000000use alloc::{ format, string::{String, ToString}, vec, vec::Vec, }; use core::fmt::Write; use crate::{ back::{hlsl::BackendResult, Baked, Level}, Handle, }; use crate::{RayQueryIntersection, TypeInner}; impl super::Writer<'_, W> { // https://sakibsaikia.github.io/graphics/2022/01/04/Nan-Checks-In-HLSL.html suggests that isnan may not work, unsure if this has changed. fn write_not_finite(&mut self, expr: &str) -> BackendResult { self.write_contains_flags(&format!("asuint({expr})"), 0x7f800000) } fn write_nan(&mut self, expr: &str) -> BackendResult { write!(self.out, "(")?; self.write_not_finite(expr)?; write!(self.out, " && ((asuint({expr}) & 0x7fffff) != 0))")?; Ok(()) } fn write_contains_flags(&mut self, expr: &str, flags: u32) -> BackendResult { write!(self.out, "(({expr} & {flags}) == {flags})")?; Ok(()) } // constructs hlsl RayDesc from wgsl RayDesc pub(super) fn write_ray_desc_from_ray_desc_constructor_function( &mut self, module: &crate::Module, ) -> BackendResult { write!(self.out, "RayDesc RayDescFromRayDesc_(")?; self.write_type(module, module.special_types.ray_desc.unwrap())?; writeln!(self.out, " arg0) {{")?; writeln!(self.out, " RayDesc ret = (RayDesc)0;")?; writeln!(self.out, " ret.Origin = arg0.origin;")?; writeln!(self.out, " ret.TMin = arg0.tmin;")?; writeln!(self.out, " ret.Direction = arg0.dir;")?; writeln!(self.out, " ret.TMax = arg0.tmax;")?; writeln!(self.out, " return ret;")?; writeln!(self.out, "}}")?; writeln!(self.out)?; Ok(()) } pub(super) fn write_committed_intersection_function( &mut self, module: &crate::Module, ) -> BackendResult { self.write_type(module, module.special_types.ray_intersection.unwrap())?; write!(self.out, " GetCommittedIntersection(")?; self.write_value_type( module, &TypeInner::RayQuery { vertex_return: false, }, )?; write!(self.out, " rq, ")?; self.write_value_type(module, &TypeInner::Scalar(crate::Scalar::U32))?; writeln!(self.out, " rq_tracker) {{")?; write!(self.out, " ")?; self.write_type(module, module.special_types.ray_intersection.unwrap())?; write!(self.out, " ret = (")?; self.write_type(module, module.special_types.ray_intersection.unwrap())?; writeln!(self.out, ")0;")?; let mut extra_level = Level(0); if self.options.ray_query_initialization_tracking { // *Technically*, `CommittedStatus` is valid as long as the ray query is initialized, but the metal backend // doesn't support this function unless it has finished traversal, so to encourage portable behaviour we // disallow it here too. write!(self.out, " if (")?; self.write_contains_flags( "rq_tracker", crate::back::RayQueryPoint::FINISHED_TRAVERSAL.bits(), )?; writeln!(self.out, ") {{")?; extra_level = extra_level.next(); } writeln!( self.out, " {extra_level}ret.kind = rq.CommittedStatus();" )?; writeln!( self.out, " {extra_level}if( rq.CommittedStatus() == COMMITTED_NOTHING) {{}} else {{" )?; writeln!(self.out, " {extra_level}ret.t = rq.CommittedRayT();")?; writeln!( self.out, " {extra_level}ret.instance_custom_data = rq.CommittedInstanceID();" )?; writeln!( self.out, " {extra_level}ret.instance_index = rq.CommittedInstanceIndex();" )?; writeln!( self.out, " {extra_level}ret.sbt_record_offset = rq.CommittedInstanceContributionToHitGroupIndex();" )?; writeln!( self.out, " {extra_level}ret.geometry_index = rq.CommittedGeometryIndex();" )?; writeln!( self.out, " {extra_level}ret.primitive_index = rq.CommittedPrimitiveIndex();" )?; writeln!( self.out, " {extra_level}if( rq.CommittedStatus() == COMMITTED_TRIANGLE_HIT ) {{" )?; writeln!( self.out, " {extra_level}ret.barycentrics = rq.CommittedTriangleBarycentrics();" )?; writeln!( self.out, " {extra_level}ret.front_face = rq.CommittedTriangleFrontFace();" )?; writeln!(self.out, " {extra_level}}}")?; writeln!( self.out, " {extra_level}ret.object_to_world = rq.CommittedObjectToWorld4x3();" )?; writeln!( self.out, " {extra_level}ret.world_to_object = rq.CommittedWorldToObject4x3();" )?; writeln!(self.out, " {extra_level}}}")?; if self.options.ray_query_initialization_tracking { writeln!(self.out, " }}")?; } writeln!(self.out, " return ret;")?; writeln!(self.out, "}}")?; writeln!(self.out)?; Ok(()) } pub(super) fn write_candidate_intersection_function( &mut self, module: &crate::Module, ) -> BackendResult { self.write_type(module, module.special_types.ray_intersection.unwrap())?; write!(self.out, " GetCandidateIntersection(")?; self.write_value_type( module, &TypeInner::RayQuery { vertex_return: false, }, )?; write!(self.out, " rq, ")?; self.write_value_type(module, &TypeInner::Scalar(crate::Scalar::U32))?; writeln!(self.out, " rq_tracker) {{")?; write!(self.out, " ")?; self.write_type(module, module.special_types.ray_intersection.unwrap())?; write!(self.out, " ret = (")?; self.write_type(module, module.special_types.ray_intersection.unwrap())?; writeln!(self.out, ")0;")?; let mut extra_level = Level(0); if self.options.ray_query_initialization_tracking { write!(self.out, " if (")?; self.write_contains_flags("rq_tracker", crate::back::RayQueryPoint::PROCEED.bits())?; write!(self.out, " && !")?; self.write_contains_flags( "rq_tracker", crate::back::RayQueryPoint::FINISHED_TRAVERSAL.bits(), )?; writeln!(self.out, ") {{")?; extra_level = extra_level.next(); } writeln!( self.out, " {extra_level}CANDIDATE_TYPE kind = rq.CandidateType();" )?; writeln!( self.out, " {extra_level}if (kind == CANDIDATE_NON_OPAQUE_TRIANGLE) {{" )?; writeln!( self.out, " {extra_level}ret.kind = {};", RayQueryIntersection::Triangle as u32 )?; writeln!( self.out, " {extra_level}ret.t = rq.CandidateTriangleRayT();" )?; writeln!( self.out, " {extra_level}ret.barycentrics = rq.CandidateTriangleBarycentrics();" )?; writeln!( self.out, " {extra_level}ret.front_face = rq.CandidateTriangleFrontFace();" )?; writeln!(self.out, " {extra_level}}} else {{")?; writeln!( self.out, " {extra_level}ret.kind = {};", RayQueryIntersection::Aabb as u32 )?; writeln!(self.out, " {extra_level}}}")?; writeln!( self.out, " {extra_level}ret.instance_custom_data = rq.CandidateInstanceID();" )?; writeln!( self.out, " {extra_level}ret.instance_index = rq.CandidateInstanceIndex();" )?; writeln!( self.out, " {extra_level}ret.sbt_record_offset = rq.CandidateInstanceContributionToHitGroupIndex();" )?; writeln!( self.out, " {extra_level}ret.geometry_index = rq.CandidateGeometryIndex();" )?; writeln!( self.out, " {extra_level}ret.primitive_index = rq.CandidatePrimitiveIndex();" )?; writeln!( self.out, " {extra_level}ret.object_to_world = rq.CandidateObjectToWorld4x3();" )?; writeln!( self.out, " {extra_level}ret.world_to_object = rq.CandidateWorldToObject4x3();" )?; if self.options.ray_query_initialization_tracking { writeln!(self.out, " }}")?; } writeln!(self.out, " return ret;")?; writeln!(self.out, "}}")?; writeln!(self.out)?; Ok(()) } #[expect(clippy::too_many_arguments)] pub(super) fn write_initialize_function( &mut self, module: &crate::Module, mut level: Level, query: Handle, acceleration_structure: Handle, descriptor: Handle, rq_tracker: &str, func_ctx: &crate::back::FunctionCtx<'_>, ) -> BackendResult { let base_level = level; // This prevents variables flowing down a level and causing compile errors. writeln!(self.out, "{level}{{")?; level = level.next(); write!(self.out, "{level}")?; self.write_type( module, module .special_types .ray_desc .expect("should have been generated"), )?; write!(self.out, " naga_desc = ")?; self.write_expr(module, descriptor, func_ctx)?; writeln!(self.out, ";")?; if self.options.ray_query_initialization_tracking { // Validate ray extents https://microsoft.github.io/DirectX-Specs/d3d/Raytracing.html#ray-extents // just for convenience writeln!(self.out, "{level}float naga_tmin = naga_desc.tmin;")?; writeln!(self.out, "{level}float naga_tmax = naga_desc.tmax;")?; writeln!(self.out, "{level}float3 naga_origin = naga_desc.origin;")?; writeln!(self.out, "{level}float3 naga_dir = naga_desc.dir;")?; writeln!(self.out, "{level}uint naga_flags = naga_desc.flags;")?; write!( self.out, "{level}bool naga_tmin_valid = (naga_tmin >= 0.0) && (naga_tmin <= naga_tmax) && !" )?; self.write_nan("naga_tmin")?; writeln!(self.out, ";")?; write!(self.out, "{level}bool naga_tmax_valid = !")?; self.write_nan("naga_tmax")?; writeln!(self.out, ";")?; // Unlike Vulkan it seems that for DX12, it seems only NaN components of the origin and direction are invalid write!(self.out, "{level}bool naga_origin_valid = !any(")?; self.write_nan("naga_origin")?; writeln!(self.out, ");")?; write!(self.out, "{level}bool naga_dir_valid = !any(")?; self.write_nan("naga_dir")?; writeln!(self.out, ");")?; write!(self.out, "{level}bool naga_contains_opaque = ")?; self.write_contains_flags("naga_flags", crate::RayFlag::FORCE_OPAQUE.bits())?; writeln!(self.out, ";")?; write!(self.out, "{level}bool naga_contains_no_opaque = ")?; self.write_contains_flags("naga_flags", crate::RayFlag::FORCE_NO_OPAQUE.bits())?; writeln!(self.out, ";")?; write!(self.out, "{level}bool naga_contains_cull_opaque = ")?; self.write_contains_flags("naga_flags", crate::RayFlag::CULL_OPAQUE.bits())?; writeln!(self.out, ";")?; write!(self.out, "{level}bool naga_contains_cull_no_opaque = ")?; self.write_contains_flags("naga_flags", crate::RayFlag::CULL_NO_OPAQUE.bits())?; writeln!(self.out, ";")?; write!(self.out, "{level}bool naga_contains_cull_front = ")?; self.write_contains_flags("naga_flags", crate::RayFlag::CULL_FRONT_FACING.bits())?; writeln!(self.out, ";")?; write!(self.out, "{level}bool naga_contains_cull_back = ")?; self.write_contains_flags("naga_flags", crate::RayFlag::CULL_BACK_FACING.bits())?; writeln!(self.out, ";")?; write!(self.out, "{level}bool naga_contains_skip_triangles = ")?; self.write_contains_flags("naga_flags", crate::RayFlag::SKIP_TRIANGLES.bits())?; writeln!(self.out, ";")?; write!(self.out, "{level}bool naga_contains_skip_aabbs = ")?; self.write_contains_flags("naga_flags", crate::RayFlag::SKIP_AABBS.bits())?; writeln!(self.out, ";")?; // A textified version of the same in the spirv writer fn less_than_two_true(mut bools: Vec<&str>) -> Result { assert!(bools.len() > 1, "Must have multiple booleans!"); let mut final_expr = String::new(); while let Some(last_bool) = bools.pop() { for &bool in &bools { if !final_expr.is_empty() { final_expr.push_str("||"); } write!(final_expr, " ({last_bool} && {bool}) ")?; } } Ok(final_expr) } writeln!( self.out, "{level}bool naga_contains_skip_triangles_aabbs = {};", less_than_two_true(vec![ "naga_contains_skip_triangles", "naga_contains_skip_aabbs" ])? )?; writeln!( self.out, "{level}bool naga_contains_skip_triangles_cull = {};", less_than_two_true(vec![ "naga_contains_skip_triangles", "naga_contains_cull_back", "naga_contains_cull_front" ])? )?; writeln!( self.out, "{level}bool naga_contains_multiple_opaque = {};", less_than_two_true(vec![ "naga_contains_opaque", "naga_contains_no_opaque", "naga_contains_cull_opaque", "naga_contains_cull_no_opaque" ])? )?; writeln!( self.out, "{level}if (naga_tmin_valid && naga_tmax_valid && naga_origin_valid && naga_dir_valid && !(naga_contains_skip_triangles_aabbs || naga_contains_skip_triangles_cull || naga_contains_multiple_opaque)) {{" )?; level = level.next(); writeln!( self.out, "{level}{rq_tracker} = {rq_tracker} | {};", crate::back::RayQueryPoint::INITIALIZED.bits() )?; } write!(self.out, "{level}")?; self.write_expr(module, query, func_ctx)?; write!(self.out, ".TraceRayInline(")?; self.write_expr(module, acceleration_structure, func_ctx)?; writeln!( self.out, ", naga_desc.flags, naga_desc.cull_mask, RayDescFromRayDesc_(naga_desc));" )?; if self.options.ray_query_initialization_tracking { writeln!(self.out, "{base_level} }}")?; } writeln!(self.out, "{base_level}}}")?; Ok(()) } pub(super) fn write_proceed( &mut self, module: &crate::Module, mut level: Level, query: Handle, result: Handle, rq_tracker: &str, func_ctx: &crate::back::FunctionCtx<'_>, ) -> BackendResult { let base_level = level; write!(self.out, "{level}")?; let name = Baked(result).to_string(); writeln!(self.out, "bool {name} = false;")?; // This prevents variables flowing down a level and causing compile errors. if self.options.ray_query_initialization_tracking { writeln!(self.out, "{level}{{")?; level = level.next(); write!(self.out, "{level}bool naga_has_initialized = ")?; self.write_contains_flags(rq_tracker, crate::back::RayQueryPoint::INITIALIZED.bits())?; writeln!(self.out, ";")?; write!(self.out, "{level}bool naga_has_finished = ")?; self.write_contains_flags( rq_tracker, crate::back::RayQueryPoint::FINISHED_TRAVERSAL.bits(), )?; writeln!(self.out, ";")?; writeln!( self.out, "{level}if (naga_has_initialized && !naga_has_finished) {{" )?; level = level.next(); } write!(self.out, "{level}{name} = ")?; self.write_expr(module, query, func_ctx)?; writeln!(self.out, ".Proceed();")?; if self.options.ray_query_initialization_tracking { writeln!( self.out, "{level}{rq_tracker} = {rq_tracker} | {};", crate::back::RayQueryPoint::PROCEED.bits() )?; writeln!( self.out, "{level}if (!{name}) {{ {rq_tracker} = {rq_tracker} | {}; }}", crate::back::RayQueryPoint::FINISHED_TRAVERSAL.bits() )?; writeln!(self.out, "{base_level}}}}}")?; } self.named_expressions.insert(result, name); Ok(()) } pub(super) fn write_generate_intersection( &mut self, module: &crate::Module, mut level: Level, query: Handle, hit_t: Handle, rq_tracker: &str, func_ctx: &crate::back::FunctionCtx<'_>, ) -> BackendResult { let base_level = level; if self.options.ray_query_initialization_tracking { write!(self.out, "{level}if (")?; self.write_contains_flags(rq_tracker, crate::back::RayQueryPoint::PROCEED.bits())?; write!(self.out, " && !")?; self.write_contains_flags( rq_tracker, crate::back::RayQueryPoint::FINISHED_TRAVERSAL.bits(), )?; writeln!(self.out, ") {{")?; level = level.next(); write!(self.out, "{level}CANDIDATE_TYPE naga_kind = ")?; self.write_expr(module, query, func_ctx)?; writeln!(self.out, ".CandidateType();")?; write!(self.out, "{level}float naga_tmin = ")?; self.write_expr(module, query, func_ctx)?; writeln!(self.out, ".RayTMin();")?; write!(self.out, "{level}float naga_tcurrentmax = ")?; self.write_expr(module, query, func_ctx)?; // This gets initialized to tmax and is updated after each intersection is committed so is valid to call. // Note: there is a bug in DXC's spirv backend that makes this technically UB in spirv, but HLSL backend // is intended for DXIL, so it should be fine (hopefully). writeln!(self.out, ".CommittedRayT();")?; write!( self.out, "{level}if ((naga_kind == CANDIDATE_PROCEDURAL_PRIMITIVE) && (naga_tmin <=" )?; self.write_expr(module, hit_t, func_ctx)?; write!(self.out, ") && (")?; self.write_expr(module, hit_t, func_ctx)?; writeln!(self.out, " <= naga_tcurrentmax)) {{")?; level = level.next(); } write!(self.out, "{level}")?; self.write_expr(module, query, func_ctx)?; write!(self.out, ".CommitProceduralPrimitiveHit(")?; self.write_expr(module, hit_t, func_ctx)?; writeln!(self.out, ");")?; if self.options.ray_query_initialization_tracking { writeln!(self.out, "{base_level}}}}}")?; } Ok(()) } pub(super) fn write_confirm_intersection( &mut self, module: &crate::Module, mut level: Level, query: Handle, rq_tracker: &str, func_ctx: &crate::back::FunctionCtx<'_>, ) -> BackendResult { let base_level = level; if self.options.ray_query_initialization_tracking { write!(self.out, "{level}if (")?; self.write_contains_flags(rq_tracker, crate::back::RayQueryPoint::PROCEED.bits())?; write!(self.out, " && !")?; self.write_contains_flags( rq_tracker, crate::back::RayQueryPoint::FINISHED_TRAVERSAL.bits(), )?; writeln!(self.out, ") {{")?; level = level.next(); write!(self.out, "{level}CANDIDATE_TYPE naga_kind = ")?; self.write_expr(module, query, func_ctx)?; writeln!(self.out, ".CandidateType();")?; writeln!( self.out, "{level}if (naga_kind == CANDIDATE_NON_OPAQUE_TRIANGLE) {{" )?; level = level.next(); } write!(self.out, "{level}")?; self.write_expr(module, query, func_ctx)?; writeln!(self.out, ".CommitNonOpaqueTriangleHit();")?; if self.options.ray_query_initialization_tracking { writeln!(self.out, "{base_level}}}}}")?; } Ok(()) } pub(super) fn write_terminate( &mut self, module: &crate::Module, mut level: Level, query: Handle, rq_tracker: &str, func_ctx: &crate::back::FunctionCtx<'_>, ) -> BackendResult { let base_level = level; if self.options.ray_query_initialization_tracking { write!(self.out, "{level}if (")?; // RayQuery::Abort() can be called any time after RayQuery::TraceRayInline() has been called. // from https://microsoft.github.io/DirectX-Specs/d3d/Raytracing.html#rayquery-abort self.write_contains_flags(rq_tracker, crate::back::RayQueryPoint::INITIALIZED.bits())?; writeln!(self.out, ") {{")?; level = level.next(); } write!(self.out, "{level}")?; self.write_expr(module, query, func_ctx)?; writeln!(self.out, ".Abort();")?; if self.options.ray_query_initialization_tracking { writeln!(self.out, "{base_level}}}")?; } Ok(()) } } naga-29.0.3/src/back/hlsl/storage.rs000064400000000000000000000664001046102023000152760ustar 00000000000000/*! Generating accesses to [`ByteAddressBuffer`] contents. Naga IR globals in the [`Storage`] address space are rendered as [`ByteAddressBuffer`]s or [`RWByteAddressBuffer`]s in HLSL. These buffers don't have HLSL types (structs, arrays, etc.); instead, they are just raw blocks of bytes, with methods to load and store values of specific types at particular byte offsets. This means that Naga must translate chains of [`Access`] and [`AccessIndex`] expressions into HLSL expressions that compute byte offsets into the buffer. To generate code for a [`Storage`] access: - Call [`Writer::fill_access_chain`] on the expression referring to the value. This populates [`Writer::temp_access_chain`] with the appropriate byte offset calculations, as a vector of [`SubAccess`] values. - Call [`Writer::write_storage_address`] to emit an HLSL expression for a given slice of [`SubAccess`] values. Naga IR expressions can operate on composite values of any type, but [`ByteAddressBuffer`] and [`RWByteAddressBuffer`] have only a fixed set of `Load` and `Store` methods, to access one through four consecutive 32-bit values. To synthesize a Naga access, you can initialize [`temp_access_chain`] to refer to the composite, and then temporarily push and pop additional steps on [`Writer::temp_access_chain`] to generate accesses to the individual elements/members. The [`temp_access_chain`] field is a member of [`Writer`] solely to allow re-use of the `Vec`'s dynamic allocation. Its value is no longer needed once HLSL for the access has been generated. Note about DXC and Load/Store functions: DXC's HLSL has a generic [`Load` and `Store`] function for [`ByteAddressBuffer`] and [`RWByteAddressBuffer`]. This is not available in FXC's HLSL, so we use it only for types that are only available in DXC. Notably 64 and 16 bit types. FXC's HLSL has functions Load, Load2, Load3, and Load4 and Store, Store2, Store3, Store4. This loads/stores a vector of length 1, 2, 3, or 4. We use that for 32bit types, bitcasting to the correct type if necessary. [`Storage`]: crate::AddressSpace::Storage [`ByteAddressBuffer`]: https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/sm5-object-byteaddressbuffer [`RWByteAddressBuffer`]: https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/sm5-object-rwbyteaddressbuffer [`Access`]: crate::Expression::Access [`AccessIndex`]: crate::Expression::AccessIndex [`Writer::fill_access_chain`]: super::Writer::fill_access_chain [`Writer::write_storage_address`]: super::Writer::write_storage_address [`Writer::temp_access_chain`]: super::Writer::temp_access_chain [`temp_access_chain`]: super::Writer::temp_access_chain [`Writer`]: super::Writer [`Load` and `Store`]: https://github.com/microsoft/DirectXShaderCompiler/wiki/ByteAddressBuffer-Load-Store-Additions */ use alloc::format; use core::{fmt, mem}; use super::{super::FunctionCtx, BackendResult, Error}; use crate::{ proc::{Alignment, NameKey, TypeResolution}, Handle, }; const STORE_TEMP_NAME: &str = "_value"; /// One step in accessing a [`Storage`] global's component or element. /// /// [`Writer::temp_access_chain`] holds a series of these structures, /// describing how to compute the byte offset of a particular element /// or member of some global variable in the [`Storage`] address /// space. /// /// [`Writer::temp_access_chain`]: super::Writer::temp_access_chain /// [`Storage`]: crate::AddressSpace::Storage #[derive(Debug)] pub(super) enum SubAccess { BufferOffset { group: u32, offset: u32, }, /// Add the given byte offset. This is used for struct members, or /// known components of a vector or matrix. In all those cases, /// the byte offset is a compile-time constant. Offset(u32), /// Scale `value` by `stride`, and add that to the current byte /// offset. This is used to compute the offset of an array element /// whose index is computed at runtime. Index { value: Handle, stride: u32, }, } pub(super) enum StoreValue { Expression(Handle), TempIndex { depth: usize, index: u32, ty: TypeResolution, }, TempAccess { depth: usize, base: Handle, member_index: u32, }, // Access to a single column of a Cx2 matrix within a struct TempColumnAccess { depth: usize, base: Handle, member_index: u32, column: u32, }, } impl super::Writer<'_, W> { pub(super) fn write_storage_address( &mut self, module: &crate::Module, chain: &[SubAccess], func_ctx: &FunctionCtx, ) -> BackendResult { if chain.is_empty() { write!(self.out, "0")?; } for (i, access) in chain.iter().enumerate() { if i != 0 { write!(self.out, "+")?; } match *access { SubAccess::BufferOffset { group, offset } => { write!(self.out, "__dynamic_buffer_offsets{group}._{offset}")?; } SubAccess::Offset(offset) => { write!(self.out, "{offset}")?; } SubAccess::Index { value, stride } => { self.write_expr(module, value, func_ctx)?; write!(self.out, "*{stride}")?; } } } Ok(()) } fn write_storage_load_sequence>( &mut self, module: &crate::Module, var_handle: Handle, sequence: I, func_ctx: &FunctionCtx, ) -> BackendResult { for (i, (ty_resolution, offset)) in sequence.enumerate() { // add the index temporarily self.temp_access_chain.push(SubAccess::Offset(offset)); if i != 0 { write!(self.out, ", ")?; }; self.write_storage_load(module, var_handle, ty_resolution, func_ctx)?; self.temp_access_chain.pop(); } Ok(()) } /// Emit code to access a [`Storage`] global's component. /// /// Emit HLSL to access the component of `var_handle`, a global /// variable in the [`Storage`] address space, whose type is /// `result_ty` and whose location within the global is given by /// [`self.temp_access_chain`]. See the [`storage`] module's /// documentation for background. /// /// [`Storage`]: crate::AddressSpace::Storage /// [`self.temp_access_chain`]: super::Writer::temp_access_chain pub(super) fn write_storage_load( &mut self, module: &crate::Module, var_handle: Handle, result_ty: TypeResolution, func_ctx: &FunctionCtx, ) -> BackendResult { match *result_ty.inner_with(&module.types) { crate::TypeInner::Scalar(scalar) => { // working around the borrow checker in `self.write_expr` let chain = mem::take(&mut self.temp_access_chain); let var_name = &self.names[&NameKey::GlobalVariable(var_handle)]; // See note about DXC and Load/Store in the module's documentation. if scalar.width == 4 { let cast = scalar.kind.to_hlsl_cast(); write!(self.out, "{cast}({var_name}.Load(")?; } else { let ty = scalar.to_hlsl_str()?; write!(self.out, "{var_name}.Load<{ty}>(")?; }; self.write_storage_address(module, &chain, func_ctx)?; write!(self.out, ")")?; if scalar.width == 4 { write!(self.out, ")")?; } self.temp_access_chain = chain; } crate::TypeInner::Vector { size, scalar } => { // working around the borrow checker in `self.write_expr` let chain = mem::take(&mut self.temp_access_chain); let var_name = &self.names[&NameKey::GlobalVariable(var_handle)]; let size = size as u8; // See note about DXC and Load/Store in the module's documentation. if scalar.width == 4 { let cast = scalar.kind.to_hlsl_cast(); write!(self.out, "{cast}({var_name}.Load{size}(")?; } else { let ty = scalar.to_hlsl_str()?; write!(self.out, "{var_name}.Load<{ty}{size}>(")?; }; self.write_storage_address(module, &chain, func_ctx)?; write!(self.out, ")")?; if scalar.width == 4 { write!(self.out, ")")?; } self.temp_access_chain = chain; } crate::TypeInner::Matrix { columns, rows, scalar, } => { write!( self.out, "{}{}x{}(", scalar.to_hlsl_str()?, columns as u8, rows as u8, )?; // Note: Matrices containing vec3s, due to padding, act like they contain vec4s. let row_stride = Alignment::from(rows) * scalar.width as u32; let iter = (0..columns as u32).map(|i| { let ty_inner = crate::TypeInner::Vector { size: rows, scalar }; (TypeResolution::Value(ty_inner), i * row_stride) }); self.write_storage_load_sequence(module, var_handle, iter, func_ctx)?; write!(self.out, ")")?; } crate::TypeInner::Array { base, size: crate::ArraySize::Constant(size), stride, } => { let constructor = super::help::WrappedConstructor { ty: result_ty.handle().unwrap(), }; self.write_wrapped_constructor_function_name(module, constructor)?; write!(self.out, "(")?; let iter = (0..size.get()).map(|i| (TypeResolution::Handle(base), stride * i)); self.write_storage_load_sequence(module, var_handle, iter, func_ctx)?; write!(self.out, ")")?; } crate::TypeInner::Struct { ref members, .. } => { let constructor = super::help::WrappedConstructor { ty: result_ty.handle().unwrap(), }; self.write_wrapped_constructor_function_name(module, constructor)?; write!(self.out, "(")?; let iter = members .iter() .map(|m| (TypeResolution::Handle(m.ty), m.offset)); self.write_storage_load_sequence(module, var_handle, iter, func_ctx)?; write!(self.out, ")")?; } _ => unreachable!(), } Ok(()) } fn write_store_value( &mut self, module: &crate::Module, value: &StoreValue, func_ctx: &FunctionCtx, ) -> BackendResult { match *value { StoreValue::Expression(expr) => self.write_expr(module, expr, func_ctx)?, StoreValue::TempIndex { depth, index, ty: _, } => write!(self.out, "{STORE_TEMP_NAME}{depth}[{index}]")?, StoreValue::TempAccess { depth, base, member_index, } => { let name = &self.names[&NameKey::StructMember(base, member_index)]; write!(self.out, "{STORE_TEMP_NAME}{depth}.{name}")? } StoreValue::TempColumnAccess { depth, base, member_index, column, } => { let name = &self.names[&NameKey::StructMember(base, member_index)]; write!(self.out, "{STORE_TEMP_NAME}{depth}.{name}_{column}")? } } Ok(()) } /// Helper function to write down the Store operation on a `ByteAddressBuffer`. pub(super) fn write_storage_store( &mut self, module: &crate::Module, var_handle: Handle, value: StoreValue, func_ctx: &FunctionCtx, level: crate::back::Level, within_struct: Option>, ) -> BackendResult { let temp_resolution; let ty_resolution = match value { StoreValue::Expression(expr) => &func_ctx.info[expr].ty, StoreValue::TempIndex { depth: _, index: _, ref ty, } => ty, StoreValue::TempAccess { depth: _, base, member_index, } => { let ty_handle = match module.types[base].inner { crate::TypeInner::Struct { ref members, .. } => { members[member_index as usize].ty } _ => unreachable!(), }; temp_resolution = TypeResolution::Handle(ty_handle); &temp_resolution } StoreValue::TempColumnAccess { .. } => { unreachable!("attempting write_storage_store for TempColumnAccess"); } }; match *ty_resolution.inner_with(&module.types) { crate::TypeInner::Scalar(scalar) => { // working around the borrow checker in `self.write_expr` let chain = mem::take(&mut self.temp_access_chain); let var_name = &self.names[&NameKey::GlobalVariable(var_handle)]; // See note about DXC and Load/Store in the module's documentation. if scalar.width == 4 { write!(self.out, "{level}{var_name}.Store(")?; self.write_storage_address(module, &chain, func_ctx)?; write!(self.out, ", asuint(")?; self.write_store_value(module, &value, func_ctx)?; writeln!(self.out, "));")?; } else { write!(self.out, "{level}{var_name}.Store(")?; self.write_storage_address(module, &chain, func_ctx)?; write!(self.out, ", ")?; self.write_store_value(module, &value, func_ctx)?; writeln!(self.out, ");")?; } self.temp_access_chain = chain; } crate::TypeInner::Vector { size, scalar } => { // working around the borrow checker in `self.write_expr` let chain = mem::take(&mut self.temp_access_chain); let var_name = &self.names[&NameKey::GlobalVariable(var_handle)]; // See note about DXC and Load/Store in the module's documentation. if scalar.width == 4 { write!(self.out, "{}{}.Store{}(", level, var_name, size as u8)?; self.write_storage_address(module, &chain, func_ctx)?; write!(self.out, ", asuint(")?; self.write_store_value(module, &value, func_ctx)?; writeln!(self.out, "));")?; } else { write!(self.out, "{level}{var_name}.Store(")?; self.write_storage_address(module, &chain, func_ctx)?; write!(self.out, ", ")?; self.write_store_value(module, &value, func_ctx)?; writeln!(self.out, ");")?; } self.temp_access_chain = chain; } crate::TypeInner::Matrix { columns, rows, scalar, } => { // Note: Matrices containing vec3s, due to padding, act like they contain vec4s. let row_stride = Alignment::from(rows) * scalar.width as u32; writeln!(self.out, "{level}{{")?; match within_struct { Some(containing_struct) if rows == crate::VectorSize::Bi => { // If we are within a struct, then the struct was already assigned to // a temporary, we don't need to make another. let mut chain = mem::take(&mut self.temp_access_chain); for i in 0..columns as u32 { chain.push(SubAccess::Offset(i * row_stride)); // working around the borrow checker in `self.write_expr` let var_name = &self.names[&NameKey::GlobalVariable(var_handle)]; let StoreValue::TempAccess { member_index, .. } = value else { unreachable!( "write_storage_store within_struct but not TempAccess" ); }; let column_value = StoreValue::TempColumnAccess { depth: level.0, // note not incrementing, b/c no temp base: containing_struct, member_index, column: i, }; // See note about DXC and Load/Store in the module's documentation. if scalar.width == 4 { write!( self.out, "{}{}.Store{}(", level.next(), var_name, rows as u8 )?; self.write_storage_address(module, &chain, func_ctx)?; write!(self.out, ", asuint(")?; self.write_store_value(module, &column_value, func_ctx)?; writeln!(self.out, "));")?; } else { write!(self.out, "{}{var_name}.Store(", level.next())?; self.write_storage_address(module, &chain, func_ctx)?; write!(self.out, ", ")?; self.write_store_value(module, &column_value, func_ctx)?; writeln!(self.out, ");")?; } chain.pop(); } self.temp_access_chain = chain; } _ => { // first, assign the value to a temporary let depth = level.0 + 1; write!( self.out, "{}{}{}x{} {}{} = ", level.next(), scalar.to_hlsl_str()?, columns as u8, rows as u8, STORE_TEMP_NAME, depth, )?; self.write_store_value(module, &value, func_ctx)?; writeln!(self.out, ";")?; // then iterate the stores for i in 0..columns as u32 { self.temp_access_chain .push(SubAccess::Offset(i * row_stride)); let ty_inner = crate::TypeInner::Vector { size: rows, scalar }; let sv = StoreValue::TempIndex { depth, index: i, ty: TypeResolution::Value(ty_inner), }; self.write_storage_store( module, var_handle, sv, func_ctx, level.next(), None, )?; self.temp_access_chain.pop(); } } } // done writeln!(self.out, "{level}}}")?; } crate::TypeInner::Array { base, size: crate::ArraySize::Constant(size), stride, } => { // first, assign the value to a temporary writeln!(self.out, "{level}{{")?; write!(self.out, "{}", level.next())?; self.write_type(module, base)?; let depth = level.next().0; write!(self.out, " {STORE_TEMP_NAME}{depth}")?; self.write_array_size(module, base, crate::ArraySize::Constant(size))?; write!(self.out, " = ")?; self.write_store_value(module, &value, func_ctx)?; writeln!(self.out, ";")?; // then iterate the stores for i in 0..size.get() { self.temp_access_chain.push(SubAccess::Offset(i * stride)); let sv = StoreValue::TempIndex { depth, index: i, ty: TypeResolution::Handle(base), }; self.write_storage_store(module, var_handle, sv, func_ctx, level.next(), None)?; self.temp_access_chain.pop(); } // done writeln!(self.out, "{level}}}")?; } crate::TypeInner::Struct { ref members, .. } => { // first, assign the value to a temporary writeln!(self.out, "{level}{{")?; let depth = level.next().0; let struct_ty = ty_resolution.handle().unwrap(); let struct_name = &self.names[&NameKey::Type(struct_ty)]; write!( self.out, "{}{} {}{} = ", level.next(), struct_name, STORE_TEMP_NAME, depth )?; self.write_store_value(module, &value, func_ctx)?; writeln!(self.out, ";")?; // then iterate the stores for (i, member) in members.iter().enumerate() { self.temp_access_chain .push(SubAccess::Offset(member.offset)); let sv = StoreValue::TempAccess { depth, base: struct_ty, member_index: i as u32, }; self.write_storage_store( module, var_handle, sv, func_ctx, level.next(), Some(struct_ty), )?; self.temp_access_chain.pop(); } // done writeln!(self.out, "{level}}}")?; } _ => unreachable!(), } Ok(()) } /// Set [`temp_access_chain`] to compute the byte offset of `cur_expr`. /// /// The `cur_expr` expression must be a reference to a global /// variable in the [`Storage`] address space, or a chain of /// [`Access`] and [`AccessIndex`] expressions referring to some /// component of such a global. /// /// [`temp_access_chain`]: super::Writer::temp_access_chain /// [`Storage`]: crate::AddressSpace::Storage /// [`Access`]: crate::Expression::Access /// [`AccessIndex`]: crate::Expression::AccessIndex pub(super) fn fill_access_chain( &mut self, module: &crate::Module, mut cur_expr: Handle, func_ctx: &FunctionCtx, ) -> Result, Error> { enum AccessIndex { Expression(Handle), Constant(u32), } enum Parent<'a> { Array { stride: u32 }, Struct(&'a [crate::StructMember]), } self.temp_access_chain.clear(); loop { let (next_expr, access_index) = match func_ctx.expressions[cur_expr] { crate::Expression::GlobalVariable(handle) => { if let Some(ref binding) = module.global_variables[handle].binding { // this was already resolved earlier when we started evaluating an entry point. let bt = self.options.resolve_resource_binding(binding).unwrap(); if let Some(dynamic_storage_buffer_offsets_index) = bt.dynamic_storage_buffer_offsets_index { self.temp_access_chain.push(SubAccess::BufferOffset { group: binding.group, offset: dynamic_storage_buffer_offsets_index, }); } } return Ok(handle); } crate::Expression::Access { base, index } => (base, AccessIndex::Expression(index)), crate::Expression::AccessIndex { base, index } => { (base, AccessIndex::Constant(index)) } ref other => { return Err(Error::Unimplemented(format!("Pointer access of {other:?}"))) } }; let parent = match *func_ctx.resolve_type(next_expr, &module.types) { crate::TypeInner::Pointer { base, .. } => match module.types[base].inner { crate::TypeInner::Struct { ref members, .. } => Parent::Struct(members), crate::TypeInner::Array { stride, .. } => Parent::Array { stride }, crate::TypeInner::Vector { scalar, .. } => Parent::Array { stride: scalar.width as u32, }, crate::TypeInner::Matrix { rows, scalar, .. } => Parent::Array { // The stride between matrices is the count of rows as this is how // long each column is. stride: Alignment::from(rows) * scalar.width as u32, }, _ => unreachable!(), }, crate::TypeInner::ValuePointer { scalar, .. } => Parent::Array { stride: scalar.width as u32, }, _ => unreachable!(), }; let sub = match (parent, access_index) { (Parent::Array { stride }, AccessIndex::Expression(value)) => { SubAccess::Index { value, stride } } (Parent::Array { stride }, AccessIndex::Constant(index)) => { SubAccess::Offset(stride * index) } (Parent::Struct(members), AccessIndex::Constant(index)) => { SubAccess::Offset(members[index as usize].offset) } (Parent::Struct(_), AccessIndex::Expression(_)) => unreachable!(), }; self.temp_access_chain.push(sub); cur_expr = next_expr; } } } naga-29.0.3/src/back/hlsl/writer.rs000064400000000000000000006254451046102023000151600ustar 00000000000000use alloc::{ format, string::{String, ToString}, vec::Vec, }; use core::{fmt, mem}; use super::{ help, help::{ WrappedArrayLength, WrappedConstructor, WrappedImageQuery, WrappedStructMatrixAccess, WrappedZeroValue, }, storage::StoreValue, BackendResult, Error, FragmentEntryPoint, Options, PipelineOptions, ShaderModel, }; use crate::{ back::{self, get_entry_points, Baked}, common, proc::{self, index, ExternalTextureNameKey, NameKey}, valid, Handle, Module, RayQueryFunction, Scalar, ScalarKind, ShaderStage, TypeInner, }; const LOCATION_SEMANTIC: &str = "LOC"; const SPECIAL_CBUF_TYPE: &str = "NagaConstants"; const SPECIAL_CBUF_VAR: &str = "_NagaConstants"; const SPECIAL_FIRST_VERTEX: &str = "first_vertex"; const SPECIAL_FIRST_INSTANCE: &str = "first_instance"; const SPECIAL_OTHER: &str = "other"; pub(crate) const MODF_FUNCTION: &str = "naga_modf"; pub(crate) const FREXP_FUNCTION: &str = "naga_frexp"; pub(crate) const EXTRACT_BITS_FUNCTION: &str = "naga_extractBits"; pub(crate) const INSERT_BITS_FUNCTION: &str = "naga_insertBits"; pub(crate) const SAMPLER_HEAP_VAR: &str = "nagaSamplerHeap"; pub(crate) const COMPARISON_SAMPLER_HEAP_VAR: &str = "nagaComparisonSamplerHeap"; pub(crate) const SAMPLE_EXTERNAL_TEXTURE_FUNCTION: &str = "nagaSampleExternalTexture"; pub(crate) const ABS_FUNCTION: &str = "naga_abs"; pub(crate) const DIV_FUNCTION: &str = "naga_div"; pub(crate) const MOD_FUNCTION: &str = "naga_mod"; pub(crate) const NEG_FUNCTION: &str = "naga_neg"; pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32"; pub(crate) const F2U32_FUNCTION: &str = "naga_f2u32"; pub(crate) const F2I64_FUNCTION: &str = "naga_f2i64"; pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64"; pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str = "nagaTextureSampleBaseClampToEdge"; pub(crate) const IMAGE_LOAD_EXTERNAL_FUNCTION: &str = "nagaTextureLoadExternal"; pub(crate) const RAY_QUERY_TRACKER_VARIABLE_PREFIX: &str = "naga_query_init_tracker_for_"; /// Prefix for variables in a naga statement pub(crate) const INTERNAL_PREFIX: &str = "naga_"; enum Index { Expression(Handle), Static(u32), } struct EpStructMember { name: String, ty: Handle, // technically, this should always be `Some` // (we `debug_assert!` this in `write_interface_struct`) binding: Option, index: u32, } /// Structure contains information required for generating /// wrapped structure of all entry points arguments struct EntryPointBinding { /// Name of the fake EP argument that contains the struct /// with all the flattened input data. arg_name: String, /// Generated structure name ty_name: String, /// Members of generated structure members: Vec, local_invocation_index_name: Option, } pub(super) struct EntryPointInterface { /// If `Some`, the input of an entry point is gathered in a special /// struct with members sorted by binding. /// The `EntryPointBinding::members` array is sorted by index, /// so that we can walk it in `write_ep_arguments_initialization`. input: Option, /// If `Some`, the output of an entry point is flattened. /// The `EntryPointBinding::members` array is sorted by binding, /// So that we can walk it in `Statement::Return` handler. output: Option, } #[derive(Clone, Eq, PartialEq, PartialOrd, Ord)] enum InterfaceKey { Location(u32), BuiltIn(crate::BuiltIn), Other, } impl InterfaceKey { const fn new(binding: Option<&crate::Binding>) -> Self { match binding { Some(&crate::Binding::Location { location, .. }) => Self::Location(location), Some(&crate::Binding::BuiltIn(built_in)) => Self::BuiltIn(built_in), None => Self::Other, } } } #[derive(Copy, Clone, PartialEq)] enum Io { Input, Output, } const fn is_subgroup_builtin_binding(binding: &Option) -> bool { let &Some(crate::Binding::BuiltIn(builtin)) = binding else { return false; }; matches!( builtin, crate::BuiltIn::SubgroupSize | crate::BuiltIn::SubgroupInvocationId | crate::BuiltIn::NumSubgroups | crate::BuiltIn::SubgroupId ) } /// Information for how to generate a `binding_array` access. struct BindingArraySamplerInfo { /// Variable name of the sampler heap sampler_heap_name: &'static str, /// Variable name of the sampler index buffer sampler_index_buffer_name: String, /// Variable name of the base index _into_ the sampler index buffer binding_array_base_index_name: String, } impl<'a, W: fmt::Write> super::Writer<'a, W> { pub fn new(out: W, options: &'a Options, pipeline_options: &'a PipelineOptions) -> Self { Self { out, names: crate::FastHashMap::default(), namer: proc::Namer::default(), options, pipeline_options, entry_point_io: crate::FastHashMap::default(), named_expressions: crate::NamedExpressions::default(), wrapped: super::Wrapped::default(), written_committed_intersection: false, written_candidate_intersection: false, continue_ctx: back::continue_forward::ContinueCtx::default(), temp_access_chain: Vec::new(), need_bake_expressions: Default::default(), } } fn reset(&mut self, module: &Module) { self.names.clear(); self.namer.reset( module, &super::keywords::RESERVED_SET, proc::KeywordSet::empty(), &super::keywords::RESERVED_CASE_INSENSITIVE_SET, super::keywords::RESERVED_PREFIXES, &mut self.names, ); self.entry_point_io.clear(); self.named_expressions.clear(); self.wrapped.clear(); self.written_committed_intersection = false; self.written_candidate_intersection = false; self.continue_ctx.clear(); self.need_bake_expressions.clear(); } /// Generates statements to be inserted immediately before and at the very /// start of the body of each loop, to defeat infinite loop reasoning. /// The 0th item of the returned tuple should be inserted immediately prior /// to the loop and the 1st item should be inserted at the very start of /// the loop body. /// /// See [`back::msl::Writer::gen_force_bounded_loop_statements`] for details. fn gen_force_bounded_loop_statements( &mut self, level: back::Level, ) -> Option<(String, String)> { if !self.options.force_loop_bounding { return None; } let loop_bound_name = self.namer.call("loop_bound"); let max = u32::MAX; // Count down from u32::MAX rather than up from 0 to avoid hang on // certain Intel drivers. See . let decl = format!("{level}uint2 {loop_bound_name} = uint2({max}u, {max}u);"); let level = level.next(); let break_and_inc = format!( "{level}if (all({loop_bound_name} == uint2(0u, 0u))) {{ break; }} {level}{loop_bound_name} -= uint2({loop_bound_name}.y == 0u, 1u);" ); Some((decl, break_and_inc)) } /// Helper method used to find which expressions of a given function require baking /// /// # Notes /// Clears `need_bake_expressions` set before adding to it fn update_expressions_to_bake( &mut self, module: &Module, func: &crate::Function, info: &valid::FunctionInfo, ) { use crate::Expression; self.need_bake_expressions.clear(); for (exp_handle, expr) in func.expressions.iter() { let expr_info = &info[exp_handle]; let min_ref_count = func.expressions[exp_handle].bake_ref_count(); if min_ref_count <= expr_info.ref_count { self.need_bake_expressions.insert(exp_handle); } if let Expression::Math { fun, arg, arg1, .. } = *expr { match fun { crate::MathFunction::Asinh | crate::MathFunction::Acosh | crate::MathFunction::Atanh | crate::MathFunction::Unpack2x16float | crate::MathFunction::Unpack2x16snorm | crate::MathFunction::Unpack2x16unorm | crate::MathFunction::Unpack4x8snorm | crate::MathFunction::Unpack4x8unorm | crate::MathFunction::Unpack4xI8 | crate::MathFunction::Unpack4xU8 | crate::MathFunction::Pack2x16float | crate::MathFunction::Pack2x16snorm | crate::MathFunction::Pack2x16unorm | crate::MathFunction::Pack4x8snorm | crate::MathFunction::Pack4x8unorm | crate::MathFunction::Pack4xI8 | crate::MathFunction::Pack4xU8 | crate::MathFunction::Pack4xI8Clamp | crate::MathFunction::Pack4xU8Clamp => { self.need_bake_expressions.insert(arg); } crate::MathFunction::CountLeadingZeros => { let inner = info[exp_handle].ty.inner_with(&module.types); if let Some(ScalarKind::Sint) = inner.scalar_kind() { self.need_bake_expressions.insert(arg); } } crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => { self.need_bake_expressions.insert(arg); self.need_bake_expressions.insert(arg1.unwrap()); } _ => {} } } if let Expression::Derivative { axis, ctrl, expr } = *expr { use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl}; if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) { self.need_bake_expressions.insert(expr); } } if let Expression::GlobalVariable(_) = *expr { let inner = info[exp_handle].ty.inner_with(&module.types); if let TypeInner::Sampler { .. } = *inner { self.need_bake_expressions.insert(exp_handle); } } } for statement in func.body.iter() { match *statement { crate::Statement::SubgroupCollectiveOperation { op: _, collective_op: crate::CollectiveOperation::InclusiveScan, argument, result: _, } => { self.need_bake_expressions.insert(argument); } crate::Statement::Atomic { fun: crate::AtomicFunction::Exchange { compare: Some(cmp) }, .. } => { self.need_bake_expressions.insert(cmp); } _ => {} } } } pub fn write( &mut self, module: &Module, module_info: &valid::ModuleInfo, fragment_entry_point: Option<&FragmentEntryPoint<'_>>, ) -> Result { self.reset(module); // Write special constants, if needed if let Some(ref bt) = self.options.special_constants_binding { writeln!(self.out, "struct {SPECIAL_CBUF_TYPE} {{")?; writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_VERTEX)?; writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_INSTANCE)?; writeln!(self.out, "{}uint {};", back::INDENT, SPECIAL_OTHER)?; writeln!(self.out, "}};")?; write!( self.out, "ConstantBuffer<{}> {}: register(b{}", SPECIAL_CBUF_TYPE, SPECIAL_CBUF_VAR, bt.register )?; if bt.space != 0 { write!(self.out, ", space{}", bt.space)?; } writeln!(self.out, ");")?; // Extra newline for readability writeln!(self.out)?; } for (group, bt) in self.options.dynamic_storage_buffer_offsets_targets.iter() { writeln!(self.out, "struct __dynamic_buffer_offsetsTy{group} {{")?; for i in 0..bt.size { writeln!(self.out, "{}uint _{};", back::INDENT, i)?; } writeln!(self.out, "}};")?; writeln!( self.out, "ConstantBuffer<__dynamic_buffer_offsetsTy{}> __dynamic_buffer_offsets{}: register(b{}, space{});", group, group, bt.register, bt.space )?; // Extra newline for readability writeln!(self.out)?; } // Save all entry point output types let ep_results = module .entry_points .iter() .map(|ep| (ep.stage, ep.function.result.clone())) .collect::)>>(); self.write_all_mat_cx2_typedefs_and_functions(module)?; // Write all structs for (handle, ty) in module.types.iter() { if let TypeInner::Struct { ref members, span } = ty.inner { if module.types[members.last().unwrap().ty] .inner .is_dynamically_sized(&module.types) { // unsized arrays can only be in storage buffers, // for which we use `ByteAddressBuffer` anyway. continue; } let ep_result = ep_results.iter().find(|e| { if let Some(ref result) = e.1 { result.ty == handle } else { false } }); self.write_struct( module, handle, members, span, ep_result.map(|r| (r.0, Io::Output)), )?; writeln!(self.out)?; } } self.write_special_functions(module)?; self.write_wrapped_expression_functions(module, &module.global_expressions, None)?; self.write_wrapped_zero_value_functions(module, &module.global_expressions)?; // Write all named constants let mut constants = module .constants .iter() .filter(|&(_, c)| c.name.is_some()) .peekable(); while let Some((handle, _)) = constants.next() { self.write_global_constant(module, handle)?; // Add extra newline for readability on last iteration if constants.peek().is_none() { writeln!(self.out)?; } } // Write all globals for (global, _) in module.global_variables.iter() { self.write_global(module, global)?; } if !module.global_variables.is_empty() { // Add extra newline for readability writeln!(self.out)?; } let ep_range = get_entry_points(module, self.pipeline_options.entry_point.as_ref()) .map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?; // Write all entry points wrapped structs for index in ep_range.clone() { let ep = &module.entry_points[index]; let ep_name = self.names[&NameKey::EntryPoint(index as u16)].clone(); let ep_io = self.write_ep_interface( module, &ep.function, ep.stage, &ep_name, fragment_entry_point, )?; self.entry_point_io.insert(index, ep_io); } // Write all regular functions for (handle, function) in module.functions.iter() { let info = &module_info[handle]; // Check if all of the globals are accessible if !self.options.fake_missing_bindings { if let Some((var_handle, _)) = module .global_variables .iter() .find(|&(var_handle, var)| match var.binding { Some(ref binding) if !info[var_handle].is_empty() => { self.options.resolve_resource_binding(binding).is_err() && self .options .resolve_external_texture_resource_binding(binding) .is_err() } _ => false, }) { log::debug!( "Skipping function {:?} (name {:?}) because global {:?} is inaccessible", handle, function.name, var_handle ); continue; } } let ctx = back::FunctionCtx { ty: back::FunctionType::Function(handle), info, expressions: &function.expressions, named_expressions: &function.named_expressions, }; let name = self.names[&NameKey::Function(handle)].clone(); self.write_wrapped_functions(module, &ctx)?; self.write_function(module, name.as_str(), function, &ctx, info)?; writeln!(self.out)?; } let mut translated_ep_names = Vec::with_capacity(ep_range.len()); // Write all entry points for index in ep_range { let ep = &module.entry_points[index]; let info = module_info.get_entry_point(index); if !self.options.fake_missing_bindings { let mut ep_error = None; for (var_handle, var) in module.global_variables.iter() { match var.binding { Some(ref binding) if !info[var_handle].is_empty() => { if let Err(err) = self.options.resolve_resource_binding(binding) { if self .options .resolve_external_texture_resource_binding(binding) .is_err() { ep_error = Some(err); break; } } } _ => {} } } if let Some(err) = ep_error { translated_ep_names.push(Err(err)); continue; } } let ctx = back::FunctionCtx { ty: back::FunctionType::EntryPoint(index as u16), info, expressions: &ep.function.expressions, named_expressions: &ep.function.named_expressions, }; self.write_wrapped_functions(module, &ctx)?; if ep.stage.compute_like() { // HLSL is calling workgroup size "num threads" let num_threads = ep.workgroup_size; writeln!( self.out, "[numthreads({}, {}, {})]", num_threads[0], num_threads[1], num_threads[2] )?; } let name = self.names[&NameKey::EntryPoint(index as u16)].clone(); self.write_function(module, &name, &ep.function, &ctx, info)?; if index < module.entry_points.len() - 1 { writeln!(self.out)?; } translated_ep_names.push(Ok(name)); } Ok(super::ReflectionInfo { entry_point_names: translated_ep_names, }) } fn write_modifier(&mut self, binding: &crate::Binding) -> BackendResult { match *binding { crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }) => { write!(self.out, "precise ")?; } crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { perspective: false }) => { write!(self.out, "noperspective ")?; } crate::Binding::Location { interpolation, sampling, .. } => { if let Some(interpolation) = interpolation { if let Some(string) = interpolation.to_hlsl_str() { write!(self.out, "{string} ")? } } if let Some(sampling) = sampling { if let Some(string) = sampling.to_hlsl_str() { write!(self.out, "{string} ")? } } } crate::Binding::BuiltIn(_) => {} } Ok(()) } //TODO: we could force fragment outputs to always go through `entry_point_io.output` path // if they are struct, so that the `stage` argument here could be omitted. fn write_semantic( &mut self, binding: &Option, stage: Option<(ShaderStage, Io)>, ) -> BackendResult { match *binding { Some(crate::Binding::BuiltIn(builtin)) if !is_subgroup_builtin_binding(binding) => { if builtin == crate::BuiltIn::ViewIndex && self.options.shader_model < ShaderModel::V6_1 { return Err(Error::ShaderModelTooLow( "used @builtin(view_index) or SV_ViewID".to_string(), ShaderModel::V6_1, )); } let builtin_str = builtin.to_hlsl_str()?; write!(self.out, " : {builtin_str}")?; } Some(crate::Binding::Location { blend_src: Some(1), .. }) => { write!(self.out, " : SV_Target1")?; } Some(crate::Binding::Location { location, .. }) => { if stage == Some((ShaderStage::Fragment, Io::Output)) { write!(self.out, " : SV_Target{location}")?; } else { write!(self.out, " : {LOCATION_SEMANTIC}{location}")?; } } _ => {} } Ok(()) } fn write_interface_struct( &mut self, module: &Module, shader_stage: (ShaderStage, Io), struct_name: String, mut members: Vec, ) -> Result { // Sort the members so that first come the user-defined varyings // in ascending locations, and then built-ins. This allows VS and FS // interfaces to match with regards to order. members.sort_by_key(|m| InterfaceKey::new(m.binding.as_ref())); write!(self.out, "struct {struct_name}")?; writeln!(self.out, " {{")?; let mut local_invocation_index_name = None; let mut subgroup_id_used = false; for m in members.iter() { // Sanity check that each IO member is a built-in or is assigned a // location. Also see note about nesting in `write_ep_input_struct`. debug_assert!(m.binding.is_some()); match m.binding { Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) => { subgroup_id_used = true; } Some(crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationIndex)) => { local_invocation_index_name = Some(m.name.clone()); } _ => (), } if is_subgroup_builtin_binding(&m.binding) { continue; } write!(self.out, "{}", back::INDENT)?; if let Some(ref binding) = m.binding { self.write_modifier(binding)?; } self.write_type(module, m.ty)?; write!(self.out, " {}", &m.name)?; self.write_semantic(&m.binding, Some(shader_stage))?; writeln!(self.out, ";")?; } if subgroup_id_used && local_invocation_index_name.is_none() { let name = self.namer.call("local_invocation_index"); writeln!(self.out, "{}uint {name} : SV_GroupIndex;", back::INDENT)?; local_invocation_index_name = Some(name); } writeln!(self.out, "}};")?; writeln!(self.out)?; // See ordering notes on EntryPointInterface fields match shader_stage.1 { Io::Input => { // bring back the original order members.sort_by_key(|m| m.index); } Io::Output => { // keep it sorted by binding } } Ok(EntryPointBinding { arg_name: self.namer.call(struct_name.to_lowercase().as_str()), ty_name: struct_name, members, local_invocation_index_name, }) } /// Flatten all entry point arguments into a single struct. /// This is needed since we need to re-order them: first placing user locations, /// then built-ins. fn write_ep_input_struct( &mut self, module: &Module, func: &crate::Function, stage: ShaderStage, entry_point_name: &str, ) -> Result { let struct_name = format!("{stage:?}Input_{entry_point_name}"); let mut fake_members = Vec::new(); for arg in func.arguments.iter() { // NOTE: We don't need to handle nesting structs. All members must // be either built-ins or assigned a location. I.E. `binding` is // `Some`. This is checked in `VaryingContext::validate`. See: // https://gpuweb.github.io/gpuweb/wgsl/#input-output-locations match module.types[arg.ty].inner { TypeInner::Struct { ref members, .. } => { for member in members.iter() { let name = self.namer.call_or(&member.name, "member"); let index = fake_members.len() as u32; fake_members.push(EpStructMember { name, ty: member.ty, binding: member.binding.clone(), index, }); } } _ => { let member_name = self.namer.call_or(&arg.name, "member"); let index = fake_members.len() as u32; fake_members.push(EpStructMember { name: member_name, ty: arg.ty, binding: arg.binding.clone(), index, }); } } } self.write_interface_struct(module, (stage, Io::Input), struct_name, fake_members) } /// Flatten all entry point results into a single struct. /// This is needed since we need to re-order them: first placing user locations, /// then built-ins. fn write_ep_output_struct( &mut self, module: &Module, result: &crate::FunctionResult, stage: ShaderStage, entry_point_name: &str, frag_ep: Option<&FragmentEntryPoint<'_>>, ) -> Result { let struct_name = format!("{stage:?}Output_{entry_point_name}"); let empty = []; let members = match module.types[result.ty].inner { TypeInner::Struct { ref members, .. } => members, ref other => { log::error!("Unexpected {other:?} output type without a binding"); &empty[..] } }; // Gather list of fragment input locations. We use this below to remove user-defined // varyings from VS outputs that aren't in the FS inputs. This makes the VS interface match // as long as the FS inputs are a subset of the VS outputs. This is only applied if the // writer is supplied with information about the fragment entry point. let fs_input_locs = if let (Some(frag_ep), ShaderStage::Vertex) = (frag_ep, stage) { let mut fs_input_locs = Vec::new(); for arg in frag_ep.func.arguments.iter() { let mut push_if_location = |binding: &Option| match *binding { Some(crate::Binding::Location { location, .. }) => fs_input_locs.push(location), Some(crate::Binding::BuiltIn(_)) | None => {} }; // NOTE: We don't need to handle struct nesting. See note in // `write_ep_input_struct`. match frag_ep.module.types[arg.ty].inner { TypeInner::Struct { ref members, .. } => { for member in members.iter() { push_if_location(&member.binding); } } _ => push_if_location(&arg.binding), } } fs_input_locs.sort(); Some(fs_input_locs) } else { None }; let mut fake_members = Vec::new(); for (index, member) in members.iter().enumerate() { if let Some(ref fs_input_locs) = fs_input_locs { match member.binding { Some(crate::Binding::Location { location, .. }) => { if fs_input_locs.binary_search(&location).is_err() { continue; } } Some(crate::Binding::BuiltIn(_)) | None => {} } } let member_name = self.namer.call_or(&member.name, "member"); fake_members.push(EpStructMember { name: member_name, ty: member.ty, binding: member.binding.clone(), index: index as u32, }); } self.write_interface_struct(module, (stage, Io::Output), struct_name, fake_members) } /// Writes special interface structures for an entry point. The special structures have /// all the fields flattened into them and sorted by binding. They are needed to emulate /// subgroup built-ins and to make the interfaces between VS outputs and FS inputs match. fn write_ep_interface( &mut self, module: &Module, func: &crate::Function, stage: ShaderStage, ep_name: &str, frag_ep: Option<&FragmentEntryPoint<'_>>, ) -> Result { Ok(EntryPointInterface { input: if !func.arguments.is_empty() && (stage == ShaderStage::Fragment || func .arguments .iter() .any(|arg| is_subgroup_builtin_binding(&arg.binding))) { Some(self.write_ep_input_struct(module, func, stage, ep_name)?) } else { None }, output: match func.result { Some(ref fr) if fr.binding.is_none() && stage == ShaderStage::Vertex => { Some(self.write_ep_output_struct(module, fr, stage, ep_name, frag_ep)?) } _ => None, }, }) } fn write_ep_argument_initialization( &mut self, ep: &crate::EntryPoint, ep_input: &EntryPointBinding, fake_member: &EpStructMember, ) -> BackendResult { match fake_member.binding { Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupSize)) => { write!(self.out, "WaveGetLaneCount()")? } Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupInvocationId)) => { write!(self.out, "WaveGetLaneIndex()")? } Some(crate::Binding::BuiltIn(crate::BuiltIn::NumSubgroups)) => write!( self.out, "({}u + WaveGetLaneCount() - 1u) / WaveGetLaneCount()", ep.workgroup_size[0] * ep.workgroup_size[1] * ep.workgroup_size[2] )?, Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) => { write!( self.out, "{}.{} / WaveGetLaneCount()", ep_input.arg_name, // When writing SubgroupId, we always guarantee that local_invocation_index_name is written ep_input.local_invocation_index_name.as_ref().unwrap() )?; } _ => { write!(self.out, "{}.{}", ep_input.arg_name, fake_member.name)?; } } Ok(()) } /// Write an entry point preface that initializes the arguments as specified in IR. fn write_ep_arguments_initialization( &mut self, module: &Module, func: &crate::Function, ep_index: u16, ) -> BackendResult { let ep = &module.entry_points[ep_index as usize]; let ep_input = match self .entry_point_io .get_mut(&(ep_index as usize)) .unwrap() .input .take() { Some(ep_input) => ep_input, None => return Ok(()), }; let mut fake_iter = ep_input.members.iter(); for (arg_index, arg) in func.arguments.iter().enumerate() { write!(self.out, "{}", back::INDENT)?; self.write_type(module, arg.ty)?; let arg_name = &self.names[&NameKey::EntryPointArgument(ep_index, arg_index as u32)]; write!(self.out, " {arg_name}")?; match module.types[arg.ty].inner { TypeInner::Array { base, size, .. } => { self.write_array_size(module, base, size)?; write!(self.out, " = ")?; self.write_ep_argument_initialization( ep, &ep_input, fake_iter.next().unwrap(), )?; writeln!(self.out, ";")?; } TypeInner::Struct { ref members, .. } => { write!(self.out, " = {{ ")?; for index in 0..members.len() { if index != 0 { write!(self.out, ", ")?; } self.write_ep_argument_initialization( ep, &ep_input, fake_iter.next().unwrap(), )?; } writeln!(self.out, " }};")?; } _ => { write!(self.out, " = ")?; self.write_ep_argument_initialization( ep, &ep_input, fake_iter.next().unwrap(), )?; writeln!(self.out, ";")?; } } } assert!(fake_iter.next().is_none()); Ok(()) } /// Helper method used to write global variables /// # Notes /// Always adds a newline fn write_global( &mut self, module: &Module, handle: Handle, ) -> BackendResult { let global = &module.global_variables[handle]; let inner = &module.types[global.ty].inner; let handle_ty = match *inner { TypeInner::BindingArray { ref base, .. } => &module.types[*base].inner, _ => inner, }; // External textures are handled entirely differently, so defer entirely to that method. // We do so prior to calling resolve_resource_binding() below, as we even need to resolve // their bindings separately. let is_external_texture = matches!( *handle_ty, TypeInner::Image { class: crate::ImageClass::External, .. } ); if is_external_texture { return self.write_global_external_texture(module, handle, global); } if let Some(ref binding) = global.binding { if let Err(err) = self.options.resolve_resource_binding(binding) { log::debug!( "Skipping global {:?} (name {:?}) for being inaccessible: {}", handle, global.name, err, ); return Ok(()); } } // Samplers are handled entirely differently, so defer entirely to that method. let is_sampler = matches!(*handle_ty, TypeInner::Sampler { .. }); if is_sampler { return self.write_global_sampler(module, handle, global); } // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-variable-register let register_ty = match global.space { crate::AddressSpace::Function => unreachable!("Function address space"), crate::AddressSpace::Private => { write!(self.out, "static ")?; self.write_type(module, global.ty)?; "" } crate::AddressSpace::WorkGroup => { write!(self.out, "groupshared ")?; self.write_type(module, global.ty)?; "" } crate::AddressSpace::TaskPayload => unimplemented!(), crate::AddressSpace::Uniform => { // constant buffer declarations are expected to be inlined, e.g. // `cbuffer foo: register(b0) { field1: type1; }` write!(self.out, "cbuffer")?; "b" } crate::AddressSpace::Storage { access } => { if global .memory_decorations .contains(crate::MemoryDecorations::COHERENT) { write!(self.out, "globallycoherent ")?; } let (prefix, register) = if access.contains(crate::StorageAccess::STORE) { ("RW", "u") } else { ("", "t") }; write!(self.out, "{prefix}ByteAddressBuffer")?; register } crate::AddressSpace::Handle => { let register = match *handle_ty { // all storage textures are UAV, unconditionally TypeInner::Image { class: crate::ImageClass::Storage { .. }, .. } => "u", _ => "t", }; self.write_type(module, global.ty)?; register } crate::AddressSpace::Immediate => { // The type of the immediates will be wrapped in `ConstantBuffer` write!(self.out, "ConstantBuffer<")?; "b" } crate::AddressSpace::RayPayload | crate::AddressSpace::IncomingRayPayload => { unimplemented!() } }; // If the global is a immediate data write the type now because it will be a // generic argument to `ConstantBuffer` if global.space == crate::AddressSpace::Immediate { self.write_global_type(module, global.ty)?; // need to write the array size if the type was emitted with `write_type` if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner { self.write_array_size(module, base, size)?; } // Close the angled brackets for the generic argument write!(self.out, ">")?; } let name = &self.names[&NameKey::GlobalVariable(handle)]; write!(self.out, " {name}")?; // Immediates need to be assigned a binding explicitly by the consumer // since naga has no way to know the binding from the shader alone if global.space == crate::AddressSpace::Immediate { match module.types[global.ty].inner { TypeInner::Struct { .. } => {} _ => { return Err(Error::Unimplemented(format!( "push-constant '{name}' has non-struct type; tracked by: https://github.com/gfx-rs/wgpu/issues/5683" ))); } } let target = self .options .immediates_target .as_ref() .expect("No bind target was defined for the immediates block"); write!(self.out, ": register(b{}", target.register)?; if target.space != 0 { write!(self.out, ", space{}", target.space)?; } write!(self.out, ")")?; } if let Some(ref binding) = global.binding { // this was already resolved earlier when we started evaluating an entry point. let bt = self.options.resolve_resource_binding(binding).unwrap(); // need to write the binding array size if the type was emitted with `write_type` if let TypeInner::BindingArray { base, size, .. } = module.types[global.ty].inner { if let Some(overridden_size) = bt.binding_array_size { write!(self.out, "[{overridden_size}]")?; } else { self.write_array_size(module, base, size)?; } } write!(self.out, " : register({}{}", register_ty, bt.register)?; if bt.space != 0 { write!(self.out, ", space{}", bt.space)?; } write!(self.out, ")")?; } else { // need to write the array size if the type was emitted with `write_type` if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner { self.write_array_size(module, base, size)?; } if global.space == crate::AddressSpace::Private { write!(self.out, " = ")?; if let Some(init) = global.init { self.write_const_expression(module, init, &module.global_expressions)?; } else { self.write_default_init(module, global.ty)?; } } } if global.space == crate::AddressSpace::Uniform { write!(self.out, " {{ ")?; self.write_global_type(module, global.ty)?; write!( self.out, " {}", &self.names[&NameKey::GlobalVariable(handle)] )?; // need to write the array size if the type was emitted with `write_type` if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner { self.write_array_size(module, base, size)?; } writeln!(self.out, "; }}")?; } else { writeln!(self.out, ";")?; } Ok(()) } fn write_global_sampler( &mut self, module: &Module, handle: Handle, global: &crate::GlobalVariable, ) -> BackendResult { let binding = *global.binding.as_ref().unwrap(); let key = super::SamplerIndexBufferKey { group: binding.group, }; self.write_wrapped_sampler_buffer(key)?; // This was already validated, so we can confidently unwrap it. let bt = self.options.resolve_resource_binding(&binding).unwrap(); match module.types[global.ty].inner { TypeInner::Sampler { comparison } => { // If we are generating a static access, we create a variable for the sampler. // // This prevents the DXIL from containing multiple lookups for the sampler, which // the backend compiler will then have to eliminate. AMD does seem to be able to // eliminate these, but better safe than sorry. write!(self.out, "static const ")?; self.write_type(module, global.ty)?; let heap_var = if comparison { COMPARISON_SAMPLER_HEAP_VAR } else { SAMPLER_HEAP_VAR }; let index_buffer_name = &self.wrapped.sampler_index_buffers[&key]; let name = &self.names[&NameKey::GlobalVariable(handle)]; writeln!( self.out, " {name} = {heap_var}[{index_buffer_name}[{register}]];", register = bt.register )?; } TypeInner::BindingArray { .. } => { // If we are generating a binding array, we cannot directly access the sampler as the index // into the sampler index buffer is unknown at compile time. Instead we generate a constant // that represents the "base" index into the sampler index buffer. This constant is added // to the user provided index to get the final index into the sampler index buffer. let name = &self.names[&NameKey::GlobalVariable(handle)]; writeln!( self.out, "static const uint {name} = {register};", register = bt.register )?; } _ => unreachable!(), }; Ok(()) } /// Write the declarations for an external texture global variable. /// These are emitted as multiple global variables: Three `Texture2D`s /// (one for each plane) and a parameters cbuffer. fn write_global_external_texture( &mut self, module: &Module, handle: Handle, global: &crate::GlobalVariable, ) -> BackendResult { let res_binding = global .binding .as_ref() .expect("External texture global variables must have a resource binding"); let ext_tex_bindings = match self .options .resolve_external_texture_resource_binding(res_binding) { Ok(bindings) => bindings, Err(err) => { log::debug!( "Skipping global {:?} (name {:?}) for being inaccessible: {}", handle, global.name, err, ); return Ok(()); } }; let mut write_plane = |bt: &super::BindTarget, name| -> BackendResult { write!( self.out, "Texture2D {}: register(t{}", name, bt.register )?; if bt.space != 0 { write!(self.out, ", space{}", bt.space)?; } writeln!(self.out, ");")?; Ok(()) }; for (i, bt) in ext_tex_bindings.planes.iter().enumerate() { let plane_name = &self.names [&NameKey::ExternalTextureGlobalVariable(handle, ExternalTextureNameKey::Plane(i))]; write_plane(bt, plane_name)?; } let params_name = &self.names [&NameKey::ExternalTextureGlobalVariable(handle, ExternalTextureNameKey::Params)]; let params_ty_name = &self.names[&NameKey::Type(module.special_types.external_texture_params.unwrap())]; write!( self.out, "cbuffer {}: register(b{}", params_name, ext_tex_bindings.params.register )?; if ext_tex_bindings.params.space != 0 { write!(self.out, ", space{}", ext_tex_bindings.params.space)?; } writeln!(self.out, ") {{ {params_ty_name} {params_name}; }};")?; Ok(()) } /// Helper method used to write global constants /// /// # Notes /// Ends in a newline fn write_global_constant( &mut self, module: &Module, handle: Handle, ) -> BackendResult { write!(self.out, "static const ")?; let constant = &module.constants[handle]; self.write_type(module, constant.ty)?; let name = &self.names[&NameKey::Constant(handle)]; write!(self.out, " {name}")?; // Write size for array type if let TypeInner::Array { base, size, .. } = module.types[constant.ty].inner { self.write_array_size(module, base, size)?; } write!(self.out, " = ")?; self.write_const_expression(module, constant.init, &module.global_expressions)?; writeln!(self.out, ";")?; Ok(()) } pub(super) fn write_array_size( &mut self, module: &Module, base: Handle, size: crate::ArraySize, ) -> BackendResult { write!(self.out, "[")?; match size.resolve(module.to_ctx())? { proc::IndexableLength::Known(size) => { write!(self.out, "{size}")?; } proc::IndexableLength::Dynamic => unreachable!(), } write!(self.out, "]")?; if let TypeInner::Array { base: next_base, size: next_size, .. } = module.types[base].inner { self.write_array_size(module, next_base, next_size)?; } Ok(()) } /// Helper method used to write structs /// /// # Notes /// Ends in a newline fn write_struct( &mut self, module: &Module, handle: Handle, members: &[crate::StructMember], span: u32, shader_stage: Option<(ShaderStage, Io)>, ) -> BackendResult { // Write struct name let struct_name = &self.names[&NameKey::Type(handle)]; writeln!(self.out, "struct {struct_name} {{")?; let mut last_offset = 0; for (index, member) in members.iter().enumerate() { if member.binding.is_none() && member.offset > last_offset { // using int as padding should work as long as the backend // doesn't support a type that's less than 4 bytes in size // (Error::UnsupportedScalar catches this) let padding = (member.offset - last_offset) / 4; for i in 0..padding { writeln!(self.out, "{}int _pad{}_{};", back::INDENT, index, i)?; } } let ty_inner = &module.types[member.ty].inner; last_offset = member.offset + ty_inner.size_hlsl(module.to_ctx())?; // The indentation is only for readability write!(self.out, "{}", back::INDENT)?; match module.types[member.ty].inner { TypeInner::Array { base, size, .. } => { // HLSL arrays are written as `type name[size]` self.write_global_type(module, member.ty)?; // Write `name` write!( self.out, " {}", &self.names[&NameKey::StructMember(handle, index as u32)] )?; // Write [size] self.write_array_size(module, base, size)?; } // We treat matrices of the form `matCx2` as a sequence of C `vec2`s. // See the module-level block comment in mod.rs for details. TypeInner::Matrix { rows, columns, scalar, } if member.binding.is_none() && rows == crate::VectorSize::Bi => { let vec_ty = TypeInner::Vector { size: rows, scalar }; let field_name_key = NameKey::StructMember(handle, index as u32); for i in 0..columns as u8 { if i != 0 { write!(self.out, "; ")?; } self.write_value_type(module, &vec_ty)?; write!(self.out, " {}_{}", &self.names[&field_name_key], i)?; } } _ => { // Write modifier before type if let Some(ref binding) = member.binding { self.write_modifier(binding)?; } // Even though Naga IR matrices are column-major, we must describe // matrices passed from the CPU as being in row-major order. // See the module-level block comment in mod.rs for details. if let TypeInner::Matrix { .. } = module.types[member.ty].inner { write!(self.out, "row_major ")?; } // Write the member type and name self.write_type(module, member.ty)?; write!( self.out, " {}", &self.names[&NameKey::StructMember(handle, index as u32)] )?; } } self.write_semantic(&member.binding, shader_stage)?; writeln!(self.out, ";")?; } // add padding at the end since sizes of types don't get rounded up to their alignment in HLSL if members.last().unwrap().binding.is_none() && span > last_offset { let padding = (span - last_offset) / 4; for i in 0..padding { writeln!(self.out, "{}int _end_pad_{};", back::INDENT, i)?; } } writeln!(self.out, "}};")?; Ok(()) } /// Helper method used to write global/structs non image/sampler types /// /// # Notes /// Adds no trailing or leading whitespace pub(super) fn write_global_type( &mut self, module: &Module, ty: Handle, ) -> BackendResult { let matrix_data = get_inner_matrix_data(module, ty); // We treat matrices of the form `matCx2` as a sequence of C `vec2`s. // See the module-level block comment in mod.rs for details. if let Some(MatrixType { columns, rows: crate::VectorSize::Bi, width: 4, }) = matrix_data { write!(self.out, "__mat{}x2", columns as u8)?; } else { // Even though Naga IR matrices are column-major, we must describe // matrices passed from the CPU as being in row-major order. // See the module-level block comment in mod.rs for details. if matrix_data.is_some() { write!(self.out, "row_major ")?; } self.write_type(module, ty)?; } Ok(()) } /// Helper method used to write non image/sampler types /// /// # Notes /// Adds no trailing or leading whitespace pub(super) fn write_type(&mut self, module: &Module, ty: Handle) -> BackendResult { let inner = &module.types[ty].inner; match *inner { TypeInner::Struct { .. } => write!(self.out, "{}", self.names[&NameKey::Type(ty)])?, // hlsl array has the size separated from the base type TypeInner::Array { base, .. } | TypeInner::BindingArray { base, .. } => { self.write_type(module, base)? } ref other => self.write_value_type(module, other)?, } Ok(()) } /// Helper method used to write value types /// /// # Notes /// Adds no trailing or leading whitespace pub(super) fn write_value_type(&mut self, module: &Module, inner: &TypeInner) -> BackendResult { match *inner { TypeInner::Scalar(scalar) | TypeInner::Atomic(scalar) => { write!(self.out, "{}", scalar.to_hlsl_str()?)?; } TypeInner::Vector { size, scalar } => { write!( self.out, "{}{}", scalar.to_hlsl_str()?, common::vector_size_str(size) )?; } TypeInner::Matrix { columns, rows, scalar, } => { // The IR supports only float matrix // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-matrix // Because of the implicit transpose all matrices have in HLSL, we need to transpose the size as well. write!( self.out, "{}{}x{}", scalar.to_hlsl_str()?, common::vector_size_str(columns), common::vector_size_str(rows), )?; } TypeInner::Image { dim, arrayed, class, } => { self.write_image_type(dim, arrayed, class)?; } TypeInner::Sampler { comparison } => { let sampler = if comparison { "SamplerComparisonState" } else { "SamplerState" }; write!(self.out, "{sampler}")?; } // HLSL arrays are written as `type name[size]` // Current code is written arrays only as `[size]` // Base `type` and `name` should be written outside TypeInner::Array { base, size, .. } | TypeInner::BindingArray { base, size } => { self.write_array_size(module, base, size)?; } TypeInner::AccelerationStructure { .. } => { write!(self.out, "RaytracingAccelerationStructure")?; } TypeInner::RayQuery { .. } => { // these are constant flags, there are dynamic flags also but constant flags are not supported by naga write!(self.out, "RayQuery")?; } _ => return Err(Error::Unimplemented(format!("write_value_type {inner:?}"))), } Ok(()) } /// Helper method used to write functions /// # Notes /// Ends in a newline fn write_function( &mut self, module: &Module, name: &str, func: &crate::Function, func_ctx: &back::FunctionCtx<'_>, info: &valid::FunctionInfo, ) -> BackendResult { // Function Declaration Syntax - https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-function-syntax self.update_expressions_to_bake(module, func, info); if let Some(ref result) = func.result { // Write typedef if return type is an array let array_return_type = match module.types[result.ty].inner { TypeInner::Array { base, size, .. } => { let array_return_type = self.namer.call(&format!("ret_{name}")); write!(self.out, "typedef ")?; self.write_type(module, result.ty)?; write!(self.out, " {array_return_type}")?; self.write_array_size(module, base, size)?; writeln!(self.out, ";")?; Some(array_return_type) } _ => None, }; // Write modifier if let Some( ref binding @ crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }), ) = result.binding { self.write_modifier(binding)?; } // Write return type match func_ctx.ty { back::FunctionType::Function(_) => { if let Some(array_return_type) = array_return_type { write!(self.out, "{array_return_type}")?; } else { self.write_type(module, result.ty)?; } } back::FunctionType::EntryPoint(index) => { if let Some(ref ep_output) = self.entry_point_io.get(&(index as usize)).unwrap().output { write!(self.out, "{}", ep_output.ty_name)?; } else { self.write_type(module, result.ty)?; } } } } else { write!(self.out, "void")?; } // Write function name write!(self.out, " {name}(")?; let need_workgroup_variables_initialization = self.need_workgroup_variables_initialization(func_ctx, module); let needs_local_invocation_id_name = need_workgroup_variables_initialization; let mut local_invocation_id_name = None; // Write function arguments for non entry point functions match func_ctx.ty { back::FunctionType::Function(handle) => { for (index, arg) in func.arguments.iter().enumerate() { if index != 0 { write!(self.out, ", ")?; } self.write_function_argument(module, handle, arg, index)?; } } back::FunctionType::EntryPoint(ep_index) => { if let Some(ref ep_input) = self.entry_point_io.get(&(ep_index as usize)).unwrap().input { write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name)?; } else { let stage = module.entry_points[ep_index as usize].stage; for (index, arg) in func.arguments.iter().enumerate() { if index != 0 { write!(self.out, ", ")?; } self.write_type(module, arg.ty)?; let argument_name = &self.names[&NameKey::EntryPointArgument(ep_index, index as u32)]; if arg.binding == Some(crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationId)) { local_invocation_id_name = Some(argument_name.clone()); } write!(self.out, " {argument_name}")?; if let TypeInner::Array { base, size, .. } = module.types[arg.ty].inner { self.write_array_size(module, base, size)?; } self.write_semantic(&arg.binding, Some((stage, Io::Input)))?; } } if needs_local_invocation_id_name && local_invocation_id_name.is_none() { if self .entry_point_io .get(&(ep_index as usize)) .unwrap() .input .is_some() || !func.arguments.is_empty() { write!(self.out, ", ")?; } let var_name = self.namer.call("local_invocation_id"); write!(self.out, "uint3 {var_name} : SV_GroupThreadID")?; local_invocation_id_name = Some(var_name); } } } // Ends of arguments write!(self.out, ")")?; // Write semantic if it present if let back::FunctionType::EntryPoint(index) = func_ctx.ty { let stage = module.entry_points[index as usize].stage; if let Some(crate::FunctionResult { ref binding, .. }) = func.result { self.write_semantic(binding, Some((stage, Io::Output)))?; } } // Function body start writeln!(self.out)?; writeln!(self.out, "{{")?; if need_workgroup_variables_initialization { self.write_workgroup_variables_initialization( func_ctx, module, // need_workgroup_variables_initialization forces this to be written // if the user doesn't specify it (so this must be Some()) local_invocation_id_name.unwrap(), )?; } if let back::FunctionType::EntryPoint(index) = func_ctx.ty { self.write_ep_arguments_initialization(module, func, index)?; } // Write function local variables for (handle, local) in func.local_variables.iter() { // Write indentation (only for readability) write!(self.out, "{}", back::INDENT)?; // Write the local name // The leading space is important self.write_type(module, local.ty)?; write!(self.out, " {}", self.names[&func_ctx.name_key(handle)])?; // Write size for array type if let TypeInner::Array { base, size, .. } = module.types[local.ty].inner { self.write_array_size(module, base, size)?; } let is_ray_query = match module.types[local.ty].inner { // from https://microsoft.github.io/DirectX-Specs/d3d/Raytracing.html#tracerayinline-example-1 it seems that ray queries shouldn't be zeroed TypeInner::RayQuery { .. } => true, _ => { write!(self.out, " = ")?; // Write the local initializer if needed if let Some(init) = local.init { self.write_expr(module, init, func_ctx)?; } else { // Zero initialize local variables self.write_default_init(module, local.ty)?; } false } }; // Finish the local with `;` and add a newline (only for readability) writeln!(self.out, ";")?; // If it's a ray query, we also want a tracker variable if is_ray_query { write!(self.out, "{}", back::INDENT)?; self.write_value_type(module, &TypeInner::Scalar(Scalar::U32))?; writeln!( self.out, " {RAY_QUERY_TRACKER_VARIABLE_PREFIX}{} = 0;", self.names[&func_ctx.name_key(handle)] )?; } } if !func.local_variables.is_empty() { writeln!(self.out)?; } // Write the function body (statement list) for sta in func.body.iter() { // The indentation should always be 1 when writing the function body self.write_stmt(module, sta, func_ctx, back::Level(1))?; } writeln!(self.out, "}}")?; self.named_expressions.clear(); Ok(()) } fn write_function_argument( &mut self, module: &Module, handle: Handle, arg: &crate::FunctionArgument, index: usize, ) -> BackendResult { // External texture arguments must be expanded into separate // arguments for each plane and the params buffer. if let TypeInner::Image { class: crate::ImageClass::External, .. } = module.types[arg.ty].inner { return self.write_function_external_texture_argument(module, handle, index); } // Write argument type let arg_ty = match module.types[arg.ty].inner { // pointers in function arguments are expected and resolve to `inout` TypeInner::Pointer { base, .. } => { //TODO: can we narrow this down to just `in` when possible? write!(self.out, "inout ")?; base } _ => arg.ty, }; self.write_type(module, arg_ty)?; let argument_name = &self.names[&NameKey::FunctionArgument(handle, index as u32)]; // Write argument name. Space is important. write!(self.out, " {argument_name}")?; if let TypeInner::Array { base, size, .. } = module.types[arg_ty].inner { self.write_array_size(module, base, size)?; } Ok(()) } fn write_function_external_texture_argument( &mut self, module: &Module, handle: Handle, index: usize, ) -> BackendResult { let plane_names = [0, 1, 2].map(|i| { &self.names[&NameKey::ExternalTextureFunctionArgument( handle, index as u32, ExternalTextureNameKey::Plane(i), )] }); let params_name = &self.names[&NameKey::ExternalTextureFunctionArgument( handle, index as u32, ExternalTextureNameKey::Params, )]; let params_ty_name = &self.names[&NameKey::Type(module.special_types.external_texture_params.unwrap())]; write!( self.out, "Texture2D {}, Texture2D {}, Texture2D {}, {params_ty_name} {params_name}", plane_names[0], plane_names[1], plane_names[2], )?; Ok(()) } fn need_workgroup_variables_initialization( &mut self, func_ctx: &back::FunctionCtx, module: &Module, ) -> bool { self.options.zero_initialize_workgroup_memory && func_ctx.ty.is_compute_like_entry_point(module) && module.global_variables.iter().any(|(handle, var)| { !func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup }) } fn write_workgroup_variables_initialization( &mut self, func_ctx: &back::FunctionCtx, module: &Module, local_invocation_id_name: String, ) -> BackendResult { let level = back::Level(1); writeln!( self.out, "{level}if (all({local_invocation_id_name} == uint3(0u, 0u, 0u))) {{" )?; let vars = module.global_variables.iter().filter(|&(handle, var)| { !func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup }); for (handle, var) in vars { let name = &self.names[&NameKey::GlobalVariable(handle)]; write!(self.out, "{}{} = ", level.next(), name)?; self.write_default_init(module, var.ty)?; writeln!(self.out, ";")?; } writeln!(self.out, "{level}}}")?; self.write_control_barrier(crate::Barrier::WORK_GROUP, level) } /// Helper method used to write switches fn write_switch( &mut self, module: &Module, func_ctx: &back::FunctionCtx<'_>, level: back::Level, selector: Handle, cases: &[crate::SwitchCase], ) -> BackendResult { // Write all cases let indent_level_1 = level.next(); let indent_level_2 = indent_level_1.next(); // See docs of `back::continue_forward` module. if let Some(variable) = self.continue_ctx.enter_switch(&mut self.namer) { writeln!(self.out, "{level}bool {variable} = false;",)?; }; // Check if there is only one body, by seeing if all except the last case are fall through // with empty bodies. FXC doesn't handle these switches correctly, so // we generate a `do {} while(false);` loop instead. There must be a default case, so there // is no need to check if one of the cases would have matched. let one_body = cases .iter() .rev() .skip(1) .all(|case| case.fall_through && case.body.is_empty()); if one_body { // Start the do-while writeln!(self.out, "{level}do {{")?; // Note: Expressions have no side-effects so we don't need to emit selector expression. // Body if let Some(case) = cases.last() { for sta in case.body.iter() { self.write_stmt(module, sta, func_ctx, indent_level_1)?; } } // End do-while writeln!(self.out, "{level}}} while(false);")?; } else { // Start the switch write!(self.out, "{level}")?; write!(self.out, "switch(")?; self.write_expr(module, selector, func_ctx)?; writeln!(self.out, ") {{")?; for (i, case) in cases.iter().enumerate() { match case.value { crate::SwitchValue::I32(value) => { write!(self.out, "{indent_level_1}case {value}:")? } crate::SwitchValue::U32(value) => { write!(self.out, "{indent_level_1}case {value}u:")? } crate::SwitchValue::Default => write!(self.out, "{indent_level_1}default:")?, } // The new block is not only stylistic, it plays a role here: // We might end up having to write the same case body // multiple times due to FXC not supporting fallthrough. // Therefore, some `Expression`s written by `Statement::Emit` // will end up having the same name (`_expr`). // So we need to put each case in its own scope. let write_block_braces = !(case.fall_through && case.body.is_empty()); if write_block_braces { writeln!(self.out, " {{")?; } else { writeln!(self.out)?; } // Although FXC does support a series of case clauses before // a block[^yes], it does not support fallthrough from a // non-empty case block to the next[^no]. If this case has a // non-empty body with a fallthrough, emulate that by // duplicating the bodies of all the cases it would fall // into as extensions of this case's own body. This makes // the HLSL output potentially quadratic in the size of the // Naga IR. // // [^yes]: ```hlsl // case 1: // case 2: do_stuff() // ``` // [^no]: ```hlsl // case 1: do_this(); // case 2: do_that(); // ``` if case.fall_through && !case.body.is_empty() { let curr_len = i + 1; let end_case_idx = curr_len + cases .iter() .skip(curr_len) .position(|case| !case.fall_through) .unwrap(); let indent_level_3 = indent_level_2.next(); for case in &cases[i..=end_case_idx] { writeln!(self.out, "{indent_level_2}{{")?; let prev_len = self.named_expressions.len(); for sta in case.body.iter() { self.write_stmt(module, sta, func_ctx, indent_level_3)?; } // Clear all named expressions that were previously inserted by the statements in the block self.named_expressions.truncate(prev_len); writeln!(self.out, "{indent_level_2}}}")?; } let last_case = &cases[end_case_idx]; if last_case.body.last().is_none_or(|s| !s.is_terminator()) { writeln!(self.out, "{indent_level_2}break;")?; } } else { for sta in case.body.iter() { self.write_stmt(module, sta, func_ctx, indent_level_2)?; } if !case.fall_through && case.body.last().is_none_or(|s| !s.is_terminator()) { writeln!(self.out, "{indent_level_2}break;")?; } } if write_block_braces { writeln!(self.out, "{indent_level_1}}}")?; } } writeln!(self.out, "{level}}}")?; } // Handle any forwarded continue statements. use back::continue_forward::ExitControlFlow; let op = match self.continue_ctx.exit_switch() { ExitControlFlow::None => None, ExitControlFlow::Continue { variable } => Some(("continue", variable)), ExitControlFlow::Break { variable } => Some(("break", variable)), }; if let Some((control_flow, variable)) = op { writeln!(self.out, "{level}if ({variable}) {{")?; writeln!(self.out, "{indent_level_1}{control_flow};")?; writeln!(self.out, "{level}}}")?; } Ok(()) } fn write_index( &mut self, module: &Module, index: Index, func_ctx: &back::FunctionCtx<'_>, ) -> BackendResult { match index { Index::Static(index) => { write!(self.out, "{index}")?; } Index::Expression(index) => { self.write_expr(module, index, func_ctx)?; } } Ok(()) } /// Helper method used to write statements /// /// # Notes /// Always adds a newline fn write_stmt( &mut self, module: &Module, stmt: &crate::Statement, func_ctx: &back::FunctionCtx<'_>, level: back::Level, ) -> BackendResult { use crate::Statement; match *stmt { Statement::Emit(ref range) => { for handle in range.clone() { let ptr_class = func_ctx.resolve_type(handle, &module.types).pointer_space(); let expr_name = if ptr_class.is_some() { // HLSL can't save a pointer-valued expression in a variable, // but we shouldn't ever need to: they should never be named expressions, // and none of the expression types flagged by bake_ref_count can be pointer-valued. None } else if let Some(name) = func_ctx.named_expressions.get(&handle) { // Front end provides names for all variables at the start of writing. // But we write them to step by step. We need to recache them // Otherwise, we could accidentally write variable name instead of full expression. // Also, we use sanitized names! It defense backend from generating variable with name from reserved keywords. Some(self.namer.call(name)) } else if self.need_bake_expressions.contains(&handle) { Some(Baked(handle).to_string()) } else { None }; if let Some(name) = expr_name { write!(self.out, "{level}")?; self.write_named_expr(module, handle, name, handle, func_ctx)?; } } } // TODO: copy-paste from glsl-out Statement::Block(ref block) => { write!(self.out, "{level}")?; writeln!(self.out, "{{")?; for sta in block.iter() { // Increase the indentation to help with readability self.write_stmt(module, sta, func_ctx, level.next())? } writeln!(self.out, "{level}}}")? } // TODO: copy-paste from glsl-out Statement::If { condition, ref accept, ref reject, } => { write!(self.out, "{level}")?; write!(self.out, "if (")?; self.write_expr(module, condition, func_ctx)?; writeln!(self.out, ") {{")?; let l2 = level.next(); for sta in accept { // Increase indentation to help with readability self.write_stmt(module, sta, func_ctx, l2)?; } // If there are no statements in the reject block we skip writing it // This is only for readability if !reject.is_empty() { writeln!(self.out, "{level}}} else {{")?; for sta in reject { // Increase indentation to help with readability self.write_stmt(module, sta, func_ctx, l2)?; } } writeln!(self.out, "{level}}}")? } // TODO: copy-paste from glsl-out Statement::Kill => writeln!(self.out, "{level}discard;")?, Statement::Return { value: None } => { writeln!(self.out, "{level}return;")?; } Statement::Return { value: Some(expr) } => { let base_ty_res = &func_ctx.info[expr].ty; let mut resolved = base_ty_res.inner_with(&module.types); if let TypeInner::Pointer { base, space: _ } = *resolved { resolved = &module.types[base].inner; } if let TypeInner::Struct { .. } = *resolved { // We can safely unwrap here, since we now we working with struct let ty = base_ty_res.handle().unwrap(); let struct_name = &self.names[&NameKey::Type(ty)]; let variable_name = self.namer.call(&struct_name.to_lowercase()); write!(self.out, "{level}const {struct_name} {variable_name} = ",)?; self.write_expr(module, expr, func_ctx)?; writeln!(self.out, ";")?; // for entry point returns, we may need to reshuffle the outputs into a different struct let ep_output = match func_ctx.ty { back::FunctionType::Function(_) => None, back::FunctionType::EntryPoint(index) => self .entry_point_io .get(&(index as usize)) .unwrap() .output .as_ref(), }; let final_name = match ep_output { Some(ep_output) => { let final_name = self.namer.call(&variable_name); write!( self.out, "{}const {} {} = {{ ", level, ep_output.ty_name, final_name, )?; for (index, m) in ep_output.members.iter().enumerate() { if index != 0 { write!(self.out, ", ")?; } let member_name = &self.names[&NameKey::StructMember(ty, m.index)]; write!(self.out, "{variable_name}.{member_name}")?; } writeln!(self.out, " }};")?; final_name } None => variable_name, }; writeln!(self.out, "{level}return {final_name};")?; } else { write!(self.out, "{level}return ")?; self.write_expr(module, expr, func_ctx)?; writeln!(self.out, ";")? } } Statement::Store { pointer, value } => { let ty_inner = func_ctx.resolve_type(pointer, &module.types); if let Some(crate::AddressSpace::Storage { .. }) = ty_inner.pointer_space() { let var_handle = self.fill_access_chain(module, pointer, func_ctx)?; self.write_storage_store( module, var_handle, StoreValue::Expression(value), func_ctx, level, None, )?; } else { // We treat matrices of the form `matCx2` as a sequence of C `vec2`s. // See the module-level block comment in mod.rs for details. // // We handle matrix Stores here directly (including sub accesses for Vectors and Scalars). // Loads are handled by `Expression::AccessIndex` (since sub accesses work fine for Loads). enum MatrixAccess { Direct { base: Handle, index: u32, }, Struct { columns: crate::VectorSize, base: Handle, }, } let get_members = |expr: Handle| { let resolved = func_ctx.resolve_type(expr, &module.types); match *resolved { TypeInner::Pointer { base, .. } => match module.types[base].inner { TypeInner::Struct { ref members, .. } => Some(members), _ => None, }, _ => None, } }; write!(self.out, "{level}")?; let matrix_access_on_lhs = find_matrix_in_access_chain(module, pointer, func_ctx).and_then( |(matrix_expr, vector, scalar)| match ( func_ctx.resolve_type(matrix_expr, &module.types), &func_ctx.expressions[matrix_expr], ) { ( &TypeInner::Pointer { base: ty, .. }, &crate::Expression::AccessIndex { base, index }, ) if matches!( module.types[ty].inner, TypeInner::Matrix { rows: crate::VectorSize::Bi, .. } ) && get_members(base) .map(|members| members[index as usize].binding.is_none()) == Some(true) => { Some((MatrixAccess::Direct { base, index }, vector, scalar)) } _ => { if let Some(MatrixType { columns, rows: crate::VectorSize::Bi, width: 4, }) = get_inner_matrix_of_struct_array_member( module, matrix_expr, func_ctx, true, ) { Some(( MatrixAccess::Struct { columns, base: matrix_expr, }, vector, scalar, )) } else { None } } }, ); match matrix_access_on_lhs { Some((MatrixAccess::Direct { index, base }, vector, scalar)) => { let base_ty_res = &func_ctx.info[base].ty; let resolved = base_ty_res.inner_with(&module.types); let ty = match *resolved { TypeInner::Pointer { base, .. } => base, _ => base_ty_res.handle().unwrap(), }; if let Some(Index::Static(vec_index)) = vector { self.write_expr(module, base, func_ctx)?; write!( self.out, ".{}_{}", &self.names[&NameKey::StructMember(ty, index)], vec_index )?; if let Some(scalar_index) = scalar { write!(self.out, "[")?; self.write_index(module, scalar_index, func_ctx)?; write!(self.out, "]")?; } write!(self.out, " = ")?; self.write_expr(module, value, func_ctx)?; writeln!(self.out, ";")?; } else { let access = WrappedStructMatrixAccess { ty, index }; match (&vector, &scalar) { (&Some(_), &Some(_)) => { self.write_wrapped_struct_matrix_set_scalar_function_name( access, )?; } (&Some(_), &None) => { self.write_wrapped_struct_matrix_set_vec_function_name( access, )?; } (&None, _) => { self.write_wrapped_struct_matrix_set_function_name(access)?; } } write!(self.out, "(")?; self.write_expr(module, base, func_ctx)?; write!(self.out, ", ")?; self.write_expr(module, value, func_ctx)?; if let Some(Index::Expression(vec_index)) = vector { write!(self.out, ", ")?; self.write_expr(module, vec_index, func_ctx)?; if let Some(scalar_index) = scalar { write!(self.out, ", ")?; self.write_index(module, scalar_index, func_ctx)?; } } writeln!(self.out, ");")?; } } Some(( MatrixAccess::Struct { columns, base }, Some(Index::Expression(vec_index)), scalar, )) => { // We handle `Store`s to __matCx2 column vectors and scalar elements via // the previously injected functions __set_col_of_matCx2 / __set_el_of_matCx2. if scalar.is_some() { write!(self.out, "__set_el_of_mat{}x2", columns as u8)?; } else { write!(self.out, "__set_col_of_mat{}x2", columns as u8)?; } write!(self.out, "(")?; self.write_expr(module, base, func_ctx)?; write!(self.out, ", ")?; self.write_expr(module, vec_index, func_ctx)?; if let Some(scalar_index) = scalar { write!(self.out, ", ")?; self.write_index(module, scalar_index, func_ctx)?; } write!(self.out, ", ")?; self.write_expr(module, value, func_ctx)?; writeln!(self.out, ");")?; } Some((MatrixAccess::Struct { .. }, Some(Index::Static(_)), _)) | Some((MatrixAccess::Struct { .. }, None, _)) | None => { self.write_expr(module, pointer, func_ctx)?; write!(self.out, " = ")?; // We cast the RHS of this store in cases where the LHS // is a struct member with type: // - matCx2 or // - a (possibly nested) array of matCx2's if let Some(MatrixType { columns, rows: crate::VectorSize::Bi, width: 4, }) = get_inner_matrix_of_struct_array_member( module, pointer, func_ctx, false, ) { let mut resolved = func_ctx.resolve_type(pointer, &module.types); if let TypeInner::Pointer { base, .. } = *resolved { resolved = &module.types[base].inner; } write!(self.out, "(__mat{}x2", columns as u8)?; if let TypeInner::Array { base, size, .. } = *resolved { self.write_array_size(module, base, size)?; } write!(self.out, ")")?; } self.write_expr(module, value, func_ctx)?; writeln!(self.out, ";")? } } } } Statement::Loop { ref body, ref continuing, break_if, } => { let force_loop_bound_statements = self.gen_force_bounded_loop_statements(level); let gate_name = (!continuing.is_empty() || break_if.is_some()) .then(|| self.namer.call("loop_init")); if let Some((ref decl, _)) = force_loop_bound_statements { writeln!(self.out, "{decl}")?; } if let Some(ref gate_name) = gate_name { writeln!(self.out, "{level}bool {gate_name} = true;")?; } self.continue_ctx.enter_loop(); writeln!(self.out, "{level}while(true) {{")?; if let Some((_, ref break_and_inc)) = force_loop_bound_statements { writeln!(self.out, "{break_and_inc}")?; } let l2 = level.next(); if let Some(gate_name) = gate_name { writeln!(self.out, "{l2}if (!{gate_name}) {{")?; let l3 = l2.next(); for sta in continuing.iter() { self.write_stmt(module, sta, func_ctx, l3)?; } if let Some(condition) = break_if { write!(self.out, "{l3}if (")?; self.write_expr(module, condition, func_ctx)?; writeln!(self.out, ") {{")?; writeln!(self.out, "{}break;", l3.next())?; writeln!(self.out, "{l3}}}")?; } writeln!(self.out, "{l2}}}")?; writeln!(self.out, "{l2}{gate_name} = false;")?; } for sta in body.iter() { self.write_stmt(module, sta, func_ctx, l2)?; } writeln!(self.out, "{level}}}")?; self.continue_ctx.exit_loop(); } Statement::Break => writeln!(self.out, "{level}break;")?, Statement::Continue => { if let Some(variable) = self.continue_ctx.continue_encountered() { writeln!(self.out, "{level}{variable} = true;")?; writeln!(self.out, "{level}break;")? } else { writeln!(self.out, "{level}continue;")? } } Statement::ControlBarrier(barrier) => { self.write_control_barrier(barrier, level)?; } Statement::MemoryBarrier(barrier) => { self.write_memory_barrier(barrier, level)?; } Statement::ImageStore { image, coordinate, array_index, value, } => { write!(self.out, "{level}")?; self.write_expr(module, image, func_ctx)?; write!(self.out, "[")?; if let Some(index) = array_index { // Array index accepted only for texture_storage_2d_array, so we can safety use int3(coordinate, array_index) here write!(self.out, "int3(")?; self.write_expr(module, coordinate, func_ctx)?; write!(self.out, ", ")?; self.write_expr(module, index, func_ctx)?; write!(self.out, ")")?; } else { self.write_expr(module, coordinate, func_ctx)?; } write!(self.out, "]")?; write!(self.out, " = ")?; self.write_expr(module, value, func_ctx)?; writeln!(self.out, ";")?; } Statement::Call { function, ref arguments, result, } => { write!(self.out, "{level}")?; if let Some(expr) = result { write!(self.out, "const ")?; let name = Baked(expr).to_string(); let expr_ty = &func_ctx.info[expr].ty; let ty_inner = match *expr_ty { proc::TypeResolution::Handle(handle) => { self.write_type(module, handle)?; &module.types[handle].inner } proc::TypeResolution::Value(ref value) => { self.write_value_type(module, value)?; value } }; write!(self.out, " {name}")?; if let TypeInner::Array { base, size, .. } = *ty_inner { self.write_array_size(module, base, size)?; } write!(self.out, " = ")?; self.named_expressions.insert(expr, name); } let func_name = &self.names[&NameKey::Function(function)]; write!(self.out, "{func_name}(")?; for (index, argument) in arguments.iter().enumerate() { if index != 0 { write!(self.out, ", ")?; } self.write_expr(module, *argument, func_ctx)?; } writeln!(self.out, ");")? } Statement::Atomic { pointer, ref fun, value, result, } => { write!(self.out, "{level}")?; let res_var_info = if let Some(res_handle) = result { let name = Baked(res_handle).to_string(); match func_ctx.info[res_handle].ty { proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?, proc::TypeResolution::Value(ref value) => { self.write_value_type(module, value)? } }; write!(self.out, " {name}; ")?; self.named_expressions.insert(res_handle, name.clone()); Some((res_handle, name)) } else { None }; let pointer_space = func_ctx .resolve_type(pointer, &module.types) .pointer_space() .unwrap(); let fun_str = fun.to_hlsl_suffix(); let compare_expr = match *fun { crate::AtomicFunction::Exchange { compare: Some(cmp) } => Some(cmp), _ => None, }; match pointer_space { crate::AddressSpace::WorkGroup => { write!(self.out, "Interlocked{fun_str}(")?; self.write_expr(module, pointer, func_ctx)?; self.emit_hlsl_atomic_tail( module, func_ctx, fun, compare_expr, value, &res_var_info, )?; } crate::AddressSpace::Storage { .. } => { let var_handle = self.fill_access_chain(module, pointer, func_ctx)?; let var_name = &self.names[&NameKey::GlobalVariable(var_handle)]; let width = match func_ctx.resolve_type(value, &module.types) { &TypeInner::Scalar(Scalar { width: 8, .. }) => "64", _ => "", }; write!(self.out, "{var_name}.Interlocked{fun_str}{width}(")?; let chain = mem::take(&mut self.temp_access_chain); self.write_storage_address(module, &chain, func_ctx)?; self.temp_access_chain = chain; self.emit_hlsl_atomic_tail( module, func_ctx, fun, compare_expr, value, &res_var_info, )?; } ref other => { return Err(Error::Custom(format!( "invalid address space {other:?} for atomic statement" ))) } } if let Some(cmp) = compare_expr { if let Some(&(_res_handle, ref res_name)) = res_var_info.as_ref() { write!( self.out, "{level}{res_name}.exchanged = ({res_name}.old_value == " )?; self.write_expr(module, cmp, func_ctx)?; writeln!(self.out, ");")?; } } } Statement::ImageAtomic { image, coordinate, array_index, fun, value, } => { write!(self.out, "{level}")?; let fun_str = fun.to_hlsl_suffix(); write!(self.out, "Interlocked{fun_str}(")?; self.write_expr(module, image, func_ctx)?; write!(self.out, "[")?; self.write_texture_coordinates( "int", coordinate, array_index, None, module, func_ctx, )?; write!(self.out, "],")?; self.write_expr(module, value, func_ctx)?; writeln!(self.out, ");")?; } Statement::WorkGroupUniformLoad { pointer, result } => { self.write_control_barrier(crate::Barrier::WORK_GROUP, level)?; write!(self.out, "{level}")?; let name = Baked(result).to_string(); self.write_named_expr(module, pointer, name, result, func_ctx)?; self.write_control_barrier(crate::Barrier::WORK_GROUP, level)?; } Statement::Switch { selector, ref cases, } => { self.write_switch(module, func_ctx, level, selector, cases)?; } Statement::RayQuery { query, ref fun } => { // There are three possibilities for a ptr to be: // 1. A variable // 2. A function argument // 3. part of a struct // // 2 and 3 are not possible, a ray query (in naga IR) // is not allowed to be passed into a function, and // all languages disallow it in a struct (you get fun results if // you try it :) ). // // Therefore, the ray query expression must be a variable. let crate::Expression::LocalVariable(query_var) = func_ctx.expressions[query] else { unreachable!() }; let tracker_expr_name = format!( "{RAY_QUERY_TRACKER_VARIABLE_PREFIX}{}", self.names[&func_ctx.name_key(query_var)] ); match *fun { RayQueryFunction::Initialize { acceleration_structure, descriptor, } => { self.write_initialize_function( module, level, query, acceleration_structure, descriptor, &tracker_expr_name, func_ctx, )?; } RayQueryFunction::Proceed { result } => { self.write_proceed( module, level, query, result, &tracker_expr_name, func_ctx, )?; } RayQueryFunction::GenerateIntersection { hit_t } => { self.write_generate_intersection( module, level, query, hit_t, &tracker_expr_name, func_ctx, )?; } RayQueryFunction::ConfirmIntersection => { self.write_confirm_intersection( module, level, query, &tracker_expr_name, func_ctx, )?; } RayQueryFunction::Terminate => { self.write_terminate(module, level, query, &tracker_expr_name, func_ctx)?; } } } Statement::SubgroupBallot { result, predicate } => { write!(self.out, "{level}")?; let name = Baked(result).to_string(); write!(self.out, "const uint4 {name} = ")?; self.named_expressions.insert(result, name); write!(self.out, "WaveActiveBallot(")?; match predicate { Some(predicate) => self.write_expr(module, predicate, func_ctx)?, None => write!(self.out, "true")?, } writeln!(self.out, ");")?; } Statement::SubgroupCollectiveOperation { op, collective_op, argument, result, } => { write!(self.out, "{level}")?; write!(self.out, "const ")?; let name = Baked(result).to_string(); match func_ctx.info[result].ty { proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?, proc::TypeResolution::Value(ref value) => { self.write_value_type(module, value)? } }; write!(self.out, " {name} = ")?; self.named_expressions.insert(result, name); match (collective_op, op) { (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => { write!(self.out, "WaveActiveAllTrue(")? } (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => { write!(self.out, "WaveActiveAnyTrue(")? } (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => { write!(self.out, "WaveActiveSum(")? } (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => { write!(self.out, "WaveActiveProduct(")? } (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => { write!(self.out, "WaveActiveMax(")? } (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => { write!(self.out, "WaveActiveMin(")? } (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => { write!(self.out, "WaveActiveBitAnd(")? } (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => { write!(self.out, "WaveActiveBitOr(")? } (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => { write!(self.out, "WaveActiveBitXor(")? } (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => { write!(self.out, "WavePrefixSum(")? } (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => { write!(self.out, "WavePrefixProduct(")? } (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => { self.write_expr(module, argument, func_ctx)?; write!(self.out, " + WavePrefixSum(")?; } (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => { self.write_expr(module, argument, func_ctx)?; write!(self.out, " * WavePrefixProduct(")?; } _ => unimplemented!(), } self.write_expr(module, argument, func_ctx)?; writeln!(self.out, ");")?; } Statement::SubgroupGather { mode, argument, result, } => { write!(self.out, "{level}")?; write!(self.out, "const ")?; let name = Baked(result).to_string(); match func_ctx.info[result].ty { proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?, proc::TypeResolution::Value(ref value) => { self.write_value_type(module, value)? } }; write!(self.out, " {name} = ")?; self.named_expressions.insert(result, name); match mode { crate::GatherMode::BroadcastFirst => { write!(self.out, "WaveReadLaneFirst(")?; self.write_expr(module, argument, func_ctx)?; } crate::GatherMode::QuadBroadcast(index) => { write!(self.out, "QuadReadLaneAt(")?; self.write_expr(module, argument, func_ctx)?; write!(self.out, ", ")?; self.write_expr(module, index, func_ctx)?; } crate::GatherMode::QuadSwap(direction) => { match direction { crate::Direction::X => { write!(self.out, "QuadReadAcrossX(")?; } crate::Direction::Y => { write!(self.out, "QuadReadAcrossY(")?; } crate::Direction::Diagonal => { write!(self.out, "QuadReadAcrossDiagonal(")?; } } self.write_expr(module, argument, func_ctx)?; } _ => { write!(self.out, "WaveReadLaneAt(")?; self.write_expr(module, argument, func_ctx)?; write!(self.out, ", ")?; match mode { crate::GatherMode::BroadcastFirst => unreachable!(), crate::GatherMode::Broadcast(index) | crate::GatherMode::Shuffle(index) => { self.write_expr(module, index, func_ctx)?; } crate::GatherMode::ShuffleDown(index) => { write!(self.out, "WaveGetLaneIndex() + ")?; self.write_expr(module, index, func_ctx)?; } crate::GatherMode::ShuffleUp(index) => { write!(self.out, "WaveGetLaneIndex() - ")?; self.write_expr(module, index, func_ctx)?; } crate::GatherMode::ShuffleXor(index) => { write!(self.out, "WaveGetLaneIndex() ^ ")?; self.write_expr(module, index, func_ctx)?; } crate::GatherMode::QuadBroadcast(_) => unreachable!(), crate::GatherMode::QuadSwap(_) => unreachable!(), } } } writeln!(self.out, ");")?; } Statement::CooperativeStore { .. } => unimplemented!(), Statement::RayPipelineFunction(_) => unreachable!(), } Ok(()) } fn write_const_expression( &mut self, module: &Module, expr: Handle, arena: &crate::Arena, ) -> BackendResult { self.write_possibly_const_expression(module, expr, arena, |writer, expr| { writer.write_const_expression(module, expr, arena) }) } pub(super) fn write_literal(&mut self, literal: crate::Literal) -> BackendResult { match literal { crate::Literal::F64(value) => write!(self.out, "{value:?}L")?, crate::Literal::F32(value) => write!(self.out, "{value:?}")?, crate::Literal::F16(value) => write!(self.out, "{value:?}h")?, crate::Literal::U32(value) => write!(self.out, "{value}u")?, // `-2147483648` is parsed by some compilers as unary negation of // positive 2147483648, which is too large for an int, causing // issues for some compilers. Neither DXC nor FXC appear to have // this problem, but this is not specified and could change. We // therefore use `-2147483647 - 1` as a precaution. crate::Literal::I32(value) if value == i32::MIN => { write!(self.out, "int({} - 1)", value + 1)? } // HLSL has no suffix for explicit i32 literals, but not using any suffix // makes the type ambiguous which prevents overload resolution from // working. So we explicitly use the int() constructor syntax. crate::Literal::I32(value) => write!(self.out, "int({value})")?, crate::Literal::U64(value) => write!(self.out, "{value}uL")?, // I64 version of the minimum I32 value issue described above. crate::Literal::I64(value) if value == i64::MIN => { write!(self.out, "({}L - 1L)", value + 1)?; } crate::Literal::I64(value) => write!(self.out, "{value}L")?, crate::Literal::Bool(value) => write!(self.out, "{value}")?, crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => { return Err(Error::Custom( "Abstract types should not appear in IR presented to backends".into(), )); } } Ok(()) } fn write_possibly_const_expression( &mut self, module: &Module, expr: Handle, expressions: &crate::Arena, write_expression: E, ) -> BackendResult where E: Fn(&mut Self, Handle) -> BackendResult, { use crate::Expression; match expressions[expr] { Expression::Literal(literal) => { self.write_literal(literal)?; } Expression::Constant(handle) => { let constant = &module.constants[handle]; if constant.name.is_some() { write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?; } else { self.write_const_expression(module, constant.init, &module.global_expressions)?; } } Expression::ZeroValue(ty) => { self.write_wrapped_zero_value_function_name(module, WrappedZeroValue { ty })?; write!(self.out, "()")?; } Expression::Compose { ty, ref components } => { match module.types[ty].inner { TypeInner::Struct { .. } | TypeInner::Array { .. } => { self.write_wrapped_constructor_function_name( module, WrappedConstructor { ty }, )?; } _ => { self.write_type(module, ty)?; } }; write!(self.out, "(")?; for (index, component) in components.iter().enumerate() { if index != 0 { write!(self.out, ", ")?; } write_expression(self, *component)?; } write!(self.out, ")")?; } Expression::Splat { size, value } => { // hlsl is not supported one value constructor // if we write, for example, int4(0), dxc returns error: // error: too few elements in vector initialization (expected 4 elements, have 1) let number_of_components = match size { crate::VectorSize::Bi => "xx", crate::VectorSize::Tri => "xxx", crate::VectorSize::Quad => "xxxx", }; write!(self.out, "(")?; write_expression(self, value)?; write!(self.out, ").{number_of_components}")? } _ => { return Err(Error::Override); } } Ok(()) } /// Helper method to write expressions /// /// # Notes /// Doesn't add any newlines or leading/trailing spaces pub(super) fn write_expr( &mut self, module: &Module, expr: Handle, func_ctx: &back::FunctionCtx<'_>, ) -> BackendResult { use crate::Expression; // Handle the special semantics of vertex_index/instance_index let ff_input = if self.options.special_constants_binding.is_some() { func_ctx.is_fixed_function_input(expr, module) } else { None }; let closing_bracket = match ff_input { Some(crate::BuiltIn::VertexIndex) => { write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX} + ")?; ")" } Some(crate::BuiltIn::InstanceIndex) => { write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE} + ",)?; ")" } Some(crate::BuiltIn::NumWorkGroups) => { // Note: despite their names (`FIRST_VERTEX` and `FIRST_INSTANCE`), // in compute shaders the special constants contain the number // of workgroups, which we are using here. write!( self.out, "uint3({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX}, {SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE}, {SPECIAL_CBUF_VAR}.{SPECIAL_OTHER})", )?; return Ok(()); } _ => "", }; if let Some(name) = self.named_expressions.get(&expr) { write!(self.out, "{name}{closing_bracket}")?; return Ok(()); } let expression = &func_ctx.expressions[expr]; match *expression { Expression::Literal(_) | Expression::Constant(_) | Expression::ZeroValue(_) | Expression::Compose { .. } | Expression::Splat { .. } => { self.write_possibly_const_expression( module, expr, func_ctx.expressions, |writer, expr| writer.write_expr(module, expr, func_ctx), )?; } Expression::Override(_) => return Err(Error::Override), // Avoid undefined behaviour for addition, subtraction, and // multiplication of signed integers by casting operands to // unsigned, performing the operation, then casting the result back // to signed. // TODO(#7109): This relies on the asint()/asuint() functions which only work // for 32-bit types, so we must find another solution for different bit widths. Expression::Binary { op: op @ crate::BinaryOperator::Add | op @ crate::BinaryOperator::Subtract | op @ crate::BinaryOperator::Multiply, left, right, } if matches!( func_ctx.resolve_type(expr, &module.types).scalar(), Some(Scalar::I32) ) => { write!(self.out, "asint(asuint(",)?; self.write_expr(module, left, func_ctx)?; write!(self.out, ") {} asuint(", back::binary_operation_str(op))?; self.write_expr(module, right, func_ctx)?; write!(self.out, "))")?; } // All of the multiplication can be expressed as `mul`, // except vector * vector, which needs to use the "*" operator. Expression::Binary { op: crate::BinaryOperator::Multiply, left, right, } if func_ctx.resolve_type(left, &module.types).is_matrix() || func_ctx.resolve_type(right, &module.types).is_matrix() => { // We intentionally flip the order of multiplication as our matrices are implicitly transposed. write!(self.out, "mul(")?; self.write_expr(module, right, func_ctx)?; write!(self.out, ", ")?; self.write_expr(module, left, func_ctx)?; write!(self.out, ")")?; } // WGSL says that floating-point division by zero should return // infinity. Microsoft's Direct3D 11 functional specification // (https://microsoft.github.io/DirectX-Specs/d3d/archive/D3D11_3_FunctionalSpec.htm) // says: // // Divide by 0 produces +/- INF, except 0/0 which results in NaN. // // which is what we want. The DXIL specification for the FDiv // instruction corroborates this: // // https://github.com/microsoft/DirectXShaderCompiler/blob/main/docs/DXIL.rst#fdiv Expression::Binary { op: crate::BinaryOperator::Divide, left, right, } if matches!( func_ctx.resolve_type(expr, &module.types).scalar_kind(), Some(ScalarKind::Sint | ScalarKind::Uint) ) => { write!(self.out, "{DIV_FUNCTION}(")?; self.write_expr(module, left, func_ctx)?; write!(self.out, ", ")?; self.write_expr(module, right, func_ctx)?; write!(self.out, ")")?; } Expression::Binary { op: crate::BinaryOperator::Modulo, left, right, } if matches!( func_ctx.resolve_type(expr, &module.types).scalar_kind(), Some(ScalarKind::Sint | ScalarKind::Uint | ScalarKind::Float) ) => { write!(self.out, "{MOD_FUNCTION}(")?; self.write_expr(module, left, func_ctx)?; write!(self.out, ", ")?; self.write_expr(module, right, func_ctx)?; write!(self.out, ")")?; } Expression::Binary { op, left, right } => { write!(self.out, "(")?; self.write_expr(module, left, func_ctx)?; write!(self.out, " {} ", back::binary_operation_str(op))?; self.write_expr(module, right, func_ctx)?; write!(self.out, ")")?; } Expression::Access { base, index } => { if let Some(crate::AddressSpace::Storage { .. }) = func_ctx.resolve_type(expr, &module.types).pointer_space() { // do nothing, the chain is written on `Load`/`Store` } else { // We use the function __get_col_of_matCx2 here in cases // where `base`s type resolves to a matCx2 and is part of a // struct member with type of (possibly nested) array of matCx2's. // // Note that this only works for `Load`s and we handle // `Store`s differently in `Statement::Store`. if let Some(MatrixType { columns, rows: crate::VectorSize::Bi, width: 4, }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true) .or_else(|| get_global_uniform_matrix(module, base, func_ctx)) { write!(self.out, "__get_col_of_mat{}x2(", columns as u8)?; self.write_expr(module, base, func_ctx)?; write!(self.out, ", ")?; self.write_expr(module, index, func_ctx)?; write!(self.out, ")")?; return Ok(()); } let resolved = func_ctx.resolve_type(base, &module.types); let (indexing_binding_array, non_uniform_qualifier) = match *resolved { TypeInner::BindingArray { .. } => { let uniformity = &func_ctx.info[index].uniformity; (true, uniformity.non_uniform_result.is_some()) } _ => (false, false), }; self.write_expr(module, base, func_ctx)?; let array_sampler_info = self.sampler_binding_array_info_from_expression( module, func_ctx, base, resolved, ); if let Some(ref info) = array_sampler_info { write!(self.out, "{}[", info.sampler_heap_name)?; } else { write!(self.out, "[")?; } let needs_bound_check = self.options.restrict_indexing && !indexing_binding_array && match resolved.pointer_space() { Some( crate::AddressSpace::Function | crate::AddressSpace::Private | crate::AddressSpace::WorkGroup | crate::AddressSpace::Immediate | crate::AddressSpace::TaskPayload | crate::AddressSpace::RayPayload | crate::AddressSpace::IncomingRayPayload, ) | None => true, Some(crate::AddressSpace::Uniform) => { // check if BindTarget.restrict_indexing is set, this is used for dynamic buffers let var_handle = self.fill_access_chain(module, base, func_ctx)?; let bind_target = self .options .resolve_resource_binding( module.global_variables[var_handle] .binding .as_ref() .unwrap(), ) .unwrap(); bind_target.restrict_indexing } Some( crate::AddressSpace::Handle | crate::AddressSpace::Storage { .. }, ) => unreachable!(), }; // Decide whether this index needs to be clamped to fall within range. let restriction_needed = if needs_bound_check { index::access_needs_check( base, index::GuardedIndex::Expression(index), module, func_ctx.expressions, func_ctx.info, ) } else { None }; if let Some(limit) = restriction_needed { write!(self.out, "min(uint(")?; self.write_expr(module, index, func_ctx)?; write!(self.out, "), ")?; match limit { index::IndexableLength::Known(limit) => { write!(self.out, "{}u", limit - 1)?; } index::IndexableLength::Dynamic => unreachable!(), } write!(self.out, ")")?; } else { if non_uniform_qualifier { write!(self.out, "NonUniformResourceIndex(")?; } if let Some(ref info) = array_sampler_info { write!( self.out, "{}[{} + ", info.sampler_index_buffer_name, info.binding_array_base_index_name, )?; } self.write_expr(module, index, func_ctx)?; if array_sampler_info.is_some() { write!(self.out, "]")?; } if non_uniform_qualifier { write!(self.out, ")")?; } } write!(self.out, "]")?; } } Expression::AccessIndex { base, index } => { if let Some(crate::AddressSpace::Storage { .. }) = func_ctx.resolve_type(expr, &module.types).pointer_space() { // do nothing, the chain is written on `Load`/`Store` } else { // See if we need to write the matrix column access in a // special way since the type of `base` is our special // __matCx2 struct. if let Some(MatrixType { rows: crate::VectorSize::Bi, width: 4, .. }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true) .or_else(|| get_global_uniform_matrix(module, base, func_ctx)) { self.write_expr(module, base, func_ctx)?; write!(self.out, "._{index}")?; return Ok(()); } let base_ty_res = &func_ctx.info[base].ty; let mut resolved = base_ty_res.inner_with(&module.types); let base_ty_handle = match *resolved { TypeInner::Pointer { base, .. } => { resolved = &module.types[base].inner; Some(base) } _ => base_ty_res.handle(), }; // We treat matrices of the form `matCx2` as a sequence of C `vec2`s. // See the module-level block comment in mod.rs for details. // // We handle matrix reconstruction here for Loads. // Stores are handled directly by `Statement::Store`. if let TypeInner::Struct { ref members, .. } = *resolved { let member = &members[index as usize]; match module.types[member.ty].inner { TypeInner::Matrix { rows: crate::VectorSize::Bi, .. } if member.binding.is_none() => { let ty = base_ty_handle.unwrap(); self.write_wrapped_struct_matrix_get_function_name( WrappedStructMatrixAccess { ty, index }, )?; write!(self.out, "(")?; self.write_expr(module, base, func_ctx)?; write!(self.out, ")")?; return Ok(()); } _ => {} } } let array_sampler_info = self.sampler_binding_array_info_from_expression( module, func_ctx, base, resolved, ); if let Some(ref info) = array_sampler_info { write!( self.out, "{}[{}", info.sampler_heap_name, info.sampler_index_buffer_name )?; } self.write_expr(module, base, func_ctx)?; match *resolved { // We specifically lift the ValuePointer to this case. While `[0]` is valid // HLSL for any vector behind a value pointer, FXC completely miscompiles // it and generates completely nonsensical DXBC. // // See https://github.com/gfx-rs/naga/issues/2095 for more details. TypeInner::Vector { .. } | TypeInner::ValuePointer { .. } => { // Write vector access as a swizzle write!(self.out, ".{}", back::COMPONENTS[index as usize])? } TypeInner::Matrix { .. } | TypeInner::Array { .. } | TypeInner::BindingArray { .. } => { if let Some(ref info) = array_sampler_info { write!( self.out, "[{} + {index}]", info.binding_array_base_index_name )?; } else { write!(self.out, "[{index}]")?; } } TypeInner::Struct { .. } => { // This will never panic in case the type is a `Struct`, this is not true // for other types so we can only check while inside this match arm let ty = base_ty_handle.unwrap(); write!( self.out, ".{}", &self.names[&NameKey::StructMember(ty, index)] )? } ref other => return Err(Error::Custom(format!("Cannot index {other:?}"))), } if array_sampler_info.is_some() { write!(self.out, "]")?; } } } Expression::FunctionArgument(pos) => { let ty = func_ctx.resolve_type(expr, &module.types); // We know that any external texture function argument has been expanded into // separate consecutive arguments for each plane and the parameters buffer. And we // also know that external textures can only ever be used as an argument to another // function. Therefore we can simply emit each of the expanded arguments in a // consecutive comma-separated list. if let TypeInner::Image { class: crate::ImageClass::External, .. } = *ty { let plane_names = [0, 1, 2].map(|i| { &self.names[&func_ctx .external_texture_argument_key(pos, ExternalTextureNameKey::Plane(i))] }); let params_name = &self.names[&func_ctx .external_texture_argument_key(pos, ExternalTextureNameKey::Params)]; write!( self.out, "{}, {}, {}, {}", plane_names[0], plane_names[1], plane_names[2], params_name )?; } else { let key = func_ctx.argument_key(pos); let name = &self.names[&key]; write!(self.out, "{name}")?; } } Expression::ImageSample { coordinate, image, sampler, clamp_to_edge: true, gather: None, array_index: None, offset: None, level: crate::SampleLevel::Zero, depth_ref: None, } => { write!(self.out, "{IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}(")?; self.write_expr(module, image, func_ctx)?; write!(self.out, ", ")?; self.write_expr(module, sampler, func_ctx)?; write!(self.out, ", ")?; self.write_expr(module, coordinate, func_ctx)?; write!(self.out, ")")?; } Expression::ImageSample { image, sampler, gather, coordinate, array_index, offset, level, depth_ref, clamp_to_edge, } => { if clamp_to_edge { return Err(Error::Custom( "ImageSample::clamp_to_edge should have been validated out".to_string(), )); } use crate::SampleLevel as Sl; const COMPONENTS: [&str; 4] = ["", "Green", "Blue", "Alpha"]; let (base_str, component_str) = match gather { Some(component) => ("Gather", COMPONENTS[component as usize]), None => ("Sample", ""), }; let cmp_str = match depth_ref { Some(_) => "Cmp", None => "", }; let level_str = match level { Sl::Zero if gather.is_none() => "LevelZero", Sl::Auto | Sl::Zero => "", Sl::Exact(_) => "Level", Sl::Bias(_) => "Bias", Sl::Gradient { .. } => "Grad", }; self.write_expr(module, image, func_ctx)?; write!(self.out, ".{base_str}{cmp_str}{component_str}{level_str}(")?; self.write_expr(module, sampler, func_ctx)?; write!(self.out, ", ")?; self.write_texture_coordinates( "float", coordinate, array_index, None, module, func_ctx, )?; if let Some(depth_ref) = depth_ref { write!(self.out, ", ")?; self.write_expr(module, depth_ref, func_ctx)?; } match level { Sl::Auto | Sl::Zero => {} Sl::Exact(expr) => { write!(self.out, ", ")?; self.write_expr(module, expr, func_ctx)?; } Sl::Bias(expr) => { write!(self.out, ", ")?; self.write_expr(module, expr, func_ctx)?; } Sl::Gradient { x, y } => { write!(self.out, ", ")?; self.write_expr(module, x, func_ctx)?; write!(self.out, ", ")?; self.write_expr(module, y, func_ctx)?; } } if let Some(offset) = offset { write!(self.out, ", ")?; write!(self.out, "int2(")?; // work around https://github.com/microsoft/DirectXShaderCompiler/issues/5082#issuecomment-1540147807 self.write_const_expression(module, offset, func_ctx.expressions)?; write!(self.out, ")")?; } write!(self.out, ")")?; } Expression::ImageQuery { image, query } => { // use wrapped image query function if let TypeInner::Image { dim, arrayed, class, } = *func_ctx.resolve_type(image, &module.types) { let wrapped_image_query = WrappedImageQuery { dim, arrayed, class, query: query.into(), }; self.write_wrapped_image_query_function_name(wrapped_image_query)?; write!(self.out, "(")?; // Image always first param self.write_expr(module, image, func_ctx)?; if let crate::ImageQuery::Size { level: Some(level) } = query { write!(self.out, ", ")?; self.write_expr(module, level, func_ctx)?; } write!(self.out, ")")?; } } Expression::ImageLoad { image, coordinate, array_index, sample, level, } => self.write_image_load( &module, expr, func_ctx, image, coordinate, array_index, sample, level, )?, Expression::GlobalVariable(handle) => { let global_variable = &module.global_variables[handle]; let ty = &module.types[global_variable.ty].inner; // In the case of binding arrays of samplers, we need to not write anything // as the we are in the wrong position to fully write the expression. // // The entire writing is done by AccessIndex. let is_binding_array_of_samplers = match *ty { TypeInner::BindingArray { base, .. } => { let base_ty = &module.types[base].inner; matches!(*base_ty, TypeInner::Sampler { .. }) } _ => false, }; let is_storage_space = matches!(global_variable.space, crate::AddressSpace::Storage { .. }); // Our external texture global variable has been expanded into multiple // global variables, one for each plane and the parameters buffer. // External textures can only ever be used as arguments to a function // call, and we know that an external texture argument to any function // will have been expanded to separate consecutive arguments for each // plane and the parameters buffer. Therefore we can simply emit each of // the expanded global variables in a consecutive comma-separated list. if let TypeInner::Image { class: crate::ImageClass::External, .. } = *ty { let plane_names = [0, 1, 2].map(|i| { &self.names[&NameKey::ExternalTextureGlobalVariable( handle, ExternalTextureNameKey::Plane(i), )] }); let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable( handle, ExternalTextureNameKey::Params, )]; write!( self.out, "{}, {}, {}, {}", plane_names[0], plane_names[1], plane_names[2], params_name )?; } else if !is_binding_array_of_samplers && !is_storage_space { let name = &self.names[&NameKey::GlobalVariable(handle)]; write!(self.out, "{name}")?; } } Expression::LocalVariable(handle) => { write!(self.out, "{}", self.names[&func_ctx.name_key(handle)])? } Expression::Load { pointer } => { match func_ctx .resolve_type(pointer, &module.types) .pointer_space() { Some(crate::AddressSpace::Storage { .. }) => { let var_handle = self.fill_access_chain(module, pointer, func_ctx)?; let result_ty = func_ctx.info[expr].ty.clone(); self.write_storage_load(module, var_handle, result_ty, func_ctx)?; } _ => { let mut close_paren = false; // We cast the value loaded to a native HLSL floatCx2 // in cases where it is of type: // - __matCx2 or // - a (possibly nested) array of __matCx2's if let Some(MatrixType { rows: crate::VectorSize::Bi, width: 4, .. }) = get_inner_matrix_of_struct_array_member( module, pointer, func_ctx, false, ) .or_else(|| get_inner_matrix_of_global_uniform(module, pointer, func_ctx)) { let mut resolved = func_ctx.resolve_type(pointer, &module.types); let ptr_tr = resolved.pointer_base_type(); if let Some(ptr_ty) = ptr_tr.as_ref().map(|tr| tr.inner_with(&module.types)) { resolved = ptr_ty; } write!(self.out, "((")?; if let TypeInner::Array { base, size, .. } = *resolved { self.write_type(module, base)?; self.write_array_size(module, base, size)?; } else { self.write_value_type(module, resolved)?; } write!(self.out, ")")?; close_paren = true; } self.write_expr(module, pointer, func_ctx)?; if close_paren { write!(self.out, ")")?; } } } } Expression::Unary { op, expr } => { // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-operators#unary-operators let op_str = match op { crate::UnaryOperator::Negate => { match func_ctx.resolve_type(expr, &module.types).scalar() { Some(Scalar::I32) => NEG_FUNCTION, _ => "-", } } crate::UnaryOperator::LogicalNot => "!", crate::UnaryOperator::BitwiseNot => "~", }; write!(self.out, "{op_str}(")?; self.write_expr(module, expr, func_ctx)?; write!(self.out, ")")?; } Expression::As { expr, kind, convert, } => { let inner = func_ctx.resolve_type(expr, &module.types); if inner.scalar_kind() == Some(ScalarKind::Float) && (kind == ScalarKind::Sint || kind == ScalarKind::Uint) && convert.is_some() { // Use helper functions for float to int casts in order to // avoid undefined behaviour when value is out of range for // the target type. let fun_name = match (kind, convert) { (ScalarKind::Sint, Some(4)) => F2I32_FUNCTION, (ScalarKind::Uint, Some(4)) => F2U32_FUNCTION, (ScalarKind::Sint, Some(8)) => F2I64_FUNCTION, (ScalarKind::Uint, Some(8)) => F2U64_FUNCTION, _ => unreachable!(), }; write!(self.out, "{fun_name}(")?; self.write_expr(module, expr, func_ctx)?; write!(self.out, ")")?; } else { let close_paren = match convert { Some(dst_width) => { let scalar = Scalar { kind, width: dst_width, }; match *inner { TypeInner::Vector { size, .. } => { write!( self.out, "{}{}(", scalar.to_hlsl_str()?, common::vector_size_str(size) )?; } TypeInner::Scalar(_) => { write!(self.out, "{}(", scalar.to_hlsl_str()?,)?; } TypeInner::Matrix { columns, rows, .. } => { write!( self.out, "{}{}x{}(", scalar.to_hlsl_str()?, common::vector_size_str(columns), common::vector_size_str(rows) )?; } _ => { return Err(Error::Unimplemented(format!( "write_expr expression::as {inner:?}" ))); } }; true } None => { if inner.scalar_width() == Some(8) { false } else { write!(self.out, "{}(", kind.to_hlsl_cast(),)?; true } } }; self.write_expr(module, expr, func_ctx)?; if close_paren { write!(self.out, ")")?; } } } Expression::Math { fun, arg, arg1, arg2, arg3, } => { use crate::MathFunction as Mf; enum Function { Asincosh { is_sin: bool }, Atanh, Pack2x16float, Pack2x16snorm, Pack2x16unorm, Pack4x8snorm, Pack4x8unorm, Pack4xI8, Pack4xU8, Pack4xI8Clamp, Pack4xU8Clamp, Unpack2x16float, Unpack2x16snorm, Unpack2x16unorm, Unpack4x8snorm, Unpack4x8unorm, Unpack4xI8, Unpack4xU8, Dot4I8Packed, Dot4U8Packed, QuantizeToF16, Regular(&'static str), MissingIntOverload(&'static str), MissingIntReturnType(&'static str), CountTrailingZeros, CountLeadingZeros, } let fun = match fun { // comparison Mf::Abs => match func_ctx.resolve_type(arg, &module.types).scalar() { Some(Scalar::I32) => Function::Regular(ABS_FUNCTION), _ => Function::Regular("abs"), }, Mf::Min => Function::Regular("min"), Mf::Max => Function::Regular("max"), Mf::Clamp => Function::Regular("clamp"), Mf::Saturate => Function::Regular("saturate"), // trigonometry Mf::Cos => Function::Regular("cos"), Mf::Cosh => Function::Regular("cosh"), Mf::Sin => Function::Regular("sin"), Mf::Sinh => Function::Regular("sinh"), Mf::Tan => Function::Regular("tan"), Mf::Tanh => Function::Regular("tanh"), Mf::Acos => Function::Regular("acos"), Mf::Asin => Function::Regular("asin"), Mf::Atan => Function::Regular("atan"), Mf::Atan2 => Function::Regular("atan2"), Mf::Asinh => Function::Asincosh { is_sin: true }, Mf::Acosh => Function::Asincosh { is_sin: false }, Mf::Atanh => Function::Atanh, Mf::Radians => Function::Regular("radians"), Mf::Degrees => Function::Regular("degrees"), // decomposition Mf::Ceil => Function::Regular("ceil"), Mf::Floor => Function::Regular("floor"), Mf::Round => Function::Regular("round"), Mf::Fract => Function::Regular("frac"), Mf::Trunc => Function::Regular("trunc"), Mf::Modf => Function::Regular(MODF_FUNCTION), Mf::Frexp => Function::Regular(FREXP_FUNCTION), Mf::Ldexp => Function::Regular("ldexp"), // exponent Mf::Exp => Function::Regular("exp"), Mf::Exp2 => Function::Regular("exp2"), Mf::Log => Function::Regular("log"), Mf::Log2 => Function::Regular("log2"), Mf::Pow => Function::Regular("pow"), // geometry Mf::Dot => Function::Regular("dot"), Mf::Dot4I8Packed => Function::Dot4I8Packed, Mf::Dot4U8Packed => Function::Dot4U8Packed, //Mf::Outer => , Mf::Cross => Function::Regular("cross"), Mf::Distance => Function::Regular("distance"), Mf::Length => Function::Regular("length"), Mf::Normalize => Function::Regular("normalize"), Mf::FaceForward => Function::Regular("faceforward"), Mf::Reflect => Function::Regular("reflect"), Mf::Refract => Function::Regular("refract"), // computational Mf::Sign => Function::Regular("sign"), Mf::Fma => Function::Regular("mad"), Mf::Mix => Function::Regular("lerp"), Mf::Step => Function::Regular("step"), Mf::SmoothStep => Function::Regular("smoothstep"), Mf::Sqrt => Function::Regular("sqrt"), Mf::InverseSqrt => Function::Regular("rsqrt"), //Mf::Inverse =>, Mf::Transpose => Function::Regular("transpose"), Mf::Determinant => Function::Regular("determinant"), Mf::QuantizeToF16 => Function::QuantizeToF16, // bits Mf::CountTrailingZeros => Function::CountTrailingZeros, Mf::CountLeadingZeros => Function::CountLeadingZeros, Mf::CountOneBits => Function::MissingIntOverload("countbits"), Mf::ReverseBits => Function::MissingIntOverload("reversebits"), Mf::FirstTrailingBit => Function::MissingIntReturnType("firstbitlow"), Mf::FirstLeadingBit => Function::MissingIntReturnType("firstbithigh"), Mf::ExtractBits => Function::Regular(EXTRACT_BITS_FUNCTION), Mf::InsertBits => Function::Regular(INSERT_BITS_FUNCTION), // Data Packing Mf::Pack2x16float => Function::Pack2x16float, Mf::Pack2x16snorm => Function::Pack2x16snorm, Mf::Pack2x16unorm => Function::Pack2x16unorm, Mf::Pack4x8snorm => Function::Pack4x8snorm, Mf::Pack4x8unorm => Function::Pack4x8unorm, Mf::Pack4xI8 => Function::Pack4xI8, Mf::Pack4xU8 => Function::Pack4xU8, Mf::Pack4xI8Clamp => Function::Pack4xI8Clamp, Mf::Pack4xU8Clamp => Function::Pack4xU8Clamp, // Data Unpacking Mf::Unpack2x16float => Function::Unpack2x16float, Mf::Unpack2x16snorm => Function::Unpack2x16snorm, Mf::Unpack2x16unorm => Function::Unpack2x16unorm, Mf::Unpack4x8snorm => Function::Unpack4x8snorm, Mf::Unpack4x8unorm => Function::Unpack4x8unorm, Mf::Unpack4xI8 => Function::Unpack4xI8, Mf::Unpack4xU8 => Function::Unpack4xU8, _ => return Err(Error::Unimplemented(format!("write_expr_math {fun:?}"))), }; match fun { Function::Asincosh { is_sin } => { write!(self.out, "log(")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, " + sqrt(")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, " * ")?; self.write_expr(module, arg, func_ctx)?; match is_sin { true => write!(self.out, " + 1.0))")?, false => write!(self.out, " - 1.0))")?, } } Function::Atanh => { write!(self.out, "0.5 * log((1.0 + ")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, ") / (1.0 - ")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, "))")?; } Function::Pack2x16float => { write!(self.out, "(f32tof16(")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, "[0]) | f32tof16(")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, "[1]) << 16)")?; } Function::Pack2x16snorm => { let scale = 32767; write!(self.out, "uint((int(round(clamp(")?; self.write_expr(module, arg, func_ctx)?; write!( self.out, "[0], -1.0, 1.0) * {scale}.0)) & 0xFFFF) | ((int(round(clamp(" )?; self.write_expr(module, arg, func_ctx)?; write!(self.out, "[1], -1.0, 1.0) * {scale}.0)) & 0xFFFF) << 16))",)?; } Function::Pack2x16unorm => { let scale = 65535; write!(self.out, "(uint(round(clamp(")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, "[1], 0.0, 1.0) * {scale}.0)) << 16)")?; } Function::Pack4x8snorm => { let scale = 127; write!(self.out, "uint((int(round(clamp(")?; self.write_expr(module, arg, func_ctx)?; write!( self.out, "[0], -1.0, 1.0) * {scale}.0)) & 0xFF) | ((int(round(clamp(" )?; self.write_expr(module, arg, func_ctx)?; write!( self.out, "[1], -1.0, 1.0) * {scale}.0)) & 0xFF) << 8) | ((int(round(clamp(" )?; self.write_expr(module, arg, func_ctx)?; write!( self.out, "[2], -1.0, 1.0) * {scale}.0)) & 0xFF) << 16) | ((int(round(clamp(" )?; self.write_expr(module, arg, func_ctx)?; write!(self.out, "[3], -1.0, 1.0) * {scale}.0)) & 0xFF) << 24))",)?; } Function::Pack4x8unorm => { let scale = 255; write!(self.out, "(uint(round(clamp(")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?; self.write_expr(module, arg, func_ctx)?; write!( self.out, "[1], 0.0, 1.0) * {scale}.0)) << 8 | uint(round(clamp(" )?; self.write_expr(module, arg, func_ctx)?; write!( self.out, "[2], 0.0, 1.0) * {scale}.0)) << 16 | uint(round(clamp(" )?; self.write_expr(module, arg, func_ctx)?; write!(self.out, "[3], 0.0, 1.0) * {scale}.0)) << 24)")?; } fun @ (Function::Pack4xI8 | Function::Pack4xU8 | Function::Pack4xI8Clamp | Function::Pack4xU8Clamp) => { let was_signed = matches!(fun, Function::Pack4xI8 | Function::Pack4xI8Clamp); let clamp_bounds = match fun { Function::Pack4xI8Clamp => Some(("-128", "127")), Function::Pack4xU8Clamp => Some(("0", "255")), _ => None, }; if was_signed { write!(self.out, "uint(")?; } let write_arg = |this: &mut Self| -> BackendResult { if let Some((min, max)) = clamp_bounds { write!(this.out, "clamp(")?; this.write_expr(module, arg, func_ctx)?; write!(this.out, ", {min}, {max})")?; } else { this.write_expr(module, arg, func_ctx)?; } Ok(()) }; write!(self.out, "(")?; write_arg(self)?; write!(self.out, "[0] & 0xFF) | ((")?; write_arg(self)?; write!(self.out, "[1] & 0xFF) << 8) | ((")?; write_arg(self)?; write!(self.out, "[2] & 0xFF) << 16) | ((")?; write_arg(self)?; write!(self.out, "[3] & 0xFF) << 24)")?; if was_signed { write!(self.out, ")")?; } } Function::Unpack2x16float => { write!(self.out, "float2(f16tof32(")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, "), f16tof32((")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, ") >> 16))")?; } Function::Unpack2x16snorm => { let scale = 32767; write!(self.out, "(float2(int2(")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, " << 16, ")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, ") >> 16) / {scale}.0)")?; } Function::Unpack2x16unorm => { let scale = 65535; write!(self.out, "(float2(")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, " & 0xFFFF, ")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, " >> 16) / {scale}.0)")?; } Function::Unpack4x8snorm => { let scale = 127; write!(self.out, "(float4(int4(")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, " << 24, ")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, " << 16, ")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, " << 8, ")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, ") >> 24) / {scale}.0)")?; } Function::Unpack4x8unorm => { let scale = 255; write!(self.out, "(float4(")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, " & 0xFF, ")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, " >> 8 & 0xFF, ")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, " >> 16 & 0xFF, ")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, " >> 24) / {scale}.0)")?; } fun @ (Function::Unpack4xI8 | Function::Unpack4xU8) => { write!(self.out, "(")?; if matches!(fun, Function::Unpack4xU8) { write!(self.out, "u")?; } write!(self.out, "int4(")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, ", ")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, " >> 8, ")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, " >> 16, ")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, " >> 24) << 24 >> 24)")?; } fun @ (Function::Dot4I8Packed | Function::Dot4U8Packed) => { let arg1 = arg1.unwrap(); if self.options.shader_model >= ShaderModel::V6_4 { // Intrinsics `dot4add_{i, u}8packed` are available in SM 6.4 and later. let function_name = match fun { Function::Dot4I8Packed => "dot4add_i8packed", Function::Dot4U8Packed => "dot4add_u8packed", _ => unreachable!(), }; write!(self.out, "{function_name}(")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, ", ")?; self.write_expr(module, arg1, func_ctx)?; write!(self.out, ", 0)")?; } else { // Fall back to a polyfill as `dot4add_u8packed` is not available. write!(self.out, "dot(")?; if matches!(fun, Function::Dot4U8Packed) { write!(self.out, "u")?; } write!(self.out, "int4(")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, ", ")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, " >> 8, ")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, " >> 16, ")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, " >> 24) << 24 >> 24, ")?; if matches!(fun, Function::Dot4U8Packed) { write!(self.out, "u")?; } write!(self.out, "int4(")?; self.write_expr(module, arg1, func_ctx)?; write!(self.out, ", ")?; self.write_expr(module, arg1, func_ctx)?; write!(self.out, " >> 8, ")?; self.write_expr(module, arg1, func_ctx)?; write!(self.out, " >> 16, ")?; self.write_expr(module, arg1, func_ctx)?; write!(self.out, " >> 24) << 24 >> 24)")?; } } Function::QuantizeToF16 => { write!(self.out, "f16tof32(f32tof16(")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, "))")?; } Function::Regular(fun_name) => { write!(self.out, "{fun_name}(")?; self.write_expr(module, arg, func_ctx)?; if let Some(arg) = arg1 { write!(self.out, ", ")?; self.write_expr(module, arg, func_ctx)?; } if let Some(arg) = arg2 { write!(self.out, ", ")?; self.write_expr(module, arg, func_ctx)?; } if let Some(arg) = arg3 { write!(self.out, ", ")?; self.write_expr(module, arg, func_ctx)?; } write!(self.out, ")")? } // These overloads are only missing on FXC, so this is only needed for 32bit types, // as non-32bit types are DXC only. Function::MissingIntOverload(fun_name) => { let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar(); if let Some(Scalar::I32) = scalar_kind { write!(self.out, "asint({fun_name}(asuint(")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, ")))")?; } else { write!(self.out, "{fun_name}(")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, ")")?; } } // These overloads are only missing on FXC, so this is only needed for 32bit types, // as non-32bit types are DXC only. Function::MissingIntReturnType(fun_name) => { let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar(); if let Some(Scalar::I32) = scalar_kind { write!(self.out, "asint({fun_name}(")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, "))")?; } else { write!(self.out, "{fun_name}(")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, ")")?; } } Function::CountTrailingZeros => { match *func_ctx.resolve_type(arg, &module.types) { TypeInner::Vector { size, scalar } => { let s = match size { crate::VectorSize::Bi => ".xx", crate::VectorSize::Tri => ".xxx", crate::VectorSize::Quad => ".xxxx", }; let scalar_width_bits = scalar.width * 8; if scalar.kind == ScalarKind::Uint || scalar.width != 4 { write!( self.out, "min(({scalar_width_bits}u){s}, firstbitlow(" )?; self.write_expr(module, arg, func_ctx)?; write!(self.out, "))")?; } else { // This is only needed for the FXC path, on 32bit signed integers. write!( self.out, "asint(min(({scalar_width_bits}u){s}, firstbitlow(" )?; self.write_expr(module, arg, func_ctx)?; write!(self.out, ")))")?; } } TypeInner::Scalar(scalar) => { let scalar_width_bits = scalar.width * 8; if scalar.kind == ScalarKind::Uint || scalar.width != 4 { write!(self.out, "min({scalar_width_bits}u, firstbitlow(")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, "))")?; } else { // This is only needed for the FXC path, on 32bit signed integers. write!( self.out, "asint(min({scalar_width_bits}u, firstbitlow(" )?; self.write_expr(module, arg, func_ctx)?; write!(self.out, ")))")?; } } _ => unreachable!(), } return Ok(()); } Function::CountLeadingZeros => { match *func_ctx.resolve_type(arg, &module.types) { TypeInner::Vector { size, scalar } => { let s = match size { crate::VectorSize::Bi => ".xx", crate::VectorSize::Tri => ".xxx", crate::VectorSize::Quad => ".xxxx", }; // scalar width - 1 let constant = scalar.width * 8 - 1; if scalar.kind == ScalarKind::Uint { write!(self.out, "(({constant}u){s} - firstbithigh(")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, "))")?; } else { let conversion_func = match scalar.width { 4 => "asint", _ => "", }; write!(self.out, "(")?; self.write_expr(module, arg, func_ctx)?; write!( self.out, " < (0){s} ? (0){s} : ({constant}){s} - {conversion_func}(firstbithigh(" )?; self.write_expr(module, arg, func_ctx)?; write!(self.out, ")))")?; } } TypeInner::Scalar(scalar) => { // scalar width - 1 let constant = scalar.width * 8 - 1; if let ScalarKind::Uint = scalar.kind { write!(self.out, "({constant}u - firstbithigh(")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, "))")?; } else { let conversion_func = match scalar.width { 4 => "asint", _ => "", }; write!(self.out, "(")?; self.write_expr(module, arg, func_ctx)?; write!( self.out, " < 0 ? 0 : {constant} - {conversion_func}(firstbithigh(" )?; self.write_expr(module, arg, func_ctx)?; write!(self.out, ")))")?; } } _ => unreachable!(), } return Ok(()); } } } Expression::Swizzle { size, vector, pattern, } => { self.write_expr(module, vector, func_ctx)?; write!(self.out, ".")?; for &sc in pattern[..size as usize].iter() { self.out.write_char(back::COMPONENTS[sc as usize])?; } } Expression::ArrayLength(expr) => { let var_handle = match func_ctx.expressions[expr] { Expression::AccessIndex { base, index: _ } => { match func_ctx.expressions[base] { Expression::GlobalVariable(handle) => handle, _ => unreachable!(), } } Expression::GlobalVariable(handle) => handle, _ => unreachable!(), }; let var = &module.global_variables[var_handle]; let (offset, stride) = match module.types[var.ty].inner { TypeInner::Array { stride, .. } => (0, stride), TypeInner::Struct { ref members, .. } => { let last = members.last().unwrap(); let stride = match module.types[last.ty].inner { TypeInner::Array { stride, .. } => stride, _ => unreachable!(), }; (last.offset, stride) } _ => unreachable!(), }; let storage_access = match var.space { crate::AddressSpace::Storage { access } => access, _ => crate::StorageAccess::default(), }; let wrapped_array_length = WrappedArrayLength { writable: storage_access.contains(crate::StorageAccess::STORE), }; write!(self.out, "((")?; self.write_wrapped_array_length_function_name(wrapped_array_length)?; let var_name = &self.names[&NameKey::GlobalVariable(var_handle)]; write!(self.out, "({var_name}) - {offset}) / {stride})")? } Expression::Derivative { axis, ctrl, expr } => { use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl}; if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) { let tail = match ctrl { Ctrl::Coarse => "coarse", Ctrl::Fine => "fine", Ctrl::None => unreachable!(), }; write!(self.out, "abs(ddx_{tail}(")?; self.write_expr(module, expr, func_ctx)?; write!(self.out, ")) + abs(ddy_{tail}(")?; self.write_expr(module, expr, func_ctx)?; write!(self.out, "))")? } else { let fun_str = match (axis, ctrl) { (Axis::X, Ctrl::Coarse) => "ddx_coarse", (Axis::X, Ctrl::Fine) => "ddx_fine", (Axis::X, Ctrl::None) => "ddx", (Axis::Y, Ctrl::Coarse) => "ddy_coarse", (Axis::Y, Ctrl::Fine) => "ddy_fine", (Axis::Y, Ctrl::None) => "ddy", (Axis::Width, Ctrl::Coarse | Ctrl::Fine) => unreachable!(), (Axis::Width, Ctrl::None) => "fwidth", }; write!(self.out, "{fun_str}(")?; self.write_expr(module, expr, func_ctx)?; write!(self.out, ")")? } } Expression::Relational { fun, argument } => { use crate::RelationalFunction as Rf; let fun_str = match fun { Rf::All => "all", Rf::Any => "any", Rf::IsNan => "isnan", Rf::IsInf => "isinf", }; write!(self.out, "{fun_str}(")?; self.write_expr(module, argument, func_ctx)?; write!(self.out, ")")? } Expression::Select { condition, accept, reject, } => { write!(self.out, "(")?; self.write_expr(module, condition, func_ctx)?; write!(self.out, " ? ")?; self.write_expr(module, accept, func_ctx)?; write!(self.out, " : ")?; self.write_expr(module, reject, func_ctx)?; write!(self.out, ")")? } Expression::RayQueryGetIntersection { query, committed } => { // For reasoning, see write_stmt let Expression::LocalVariable(query_var) = func_ctx.expressions[query] else { unreachable!() }; let tracker_expr_name = format!( "{RAY_QUERY_TRACKER_VARIABLE_PREFIX}{}", self.names[&func_ctx.name_key(query_var)] ); if committed { write!(self.out, "GetCommittedIntersection(")?; self.write_expr(module, query, func_ctx)?; write!(self.out, ", {tracker_expr_name})")?; } else { write!(self.out, "GetCandidateIntersection(")?; self.write_expr(module, query, func_ctx)?; write!(self.out, ", {tracker_expr_name})")?; } } // Not supported yet Expression::RayQueryVertexPositions { .. } | Expression::CooperativeLoad { .. } | Expression::CooperativeMultiplyAdd { .. } => { unreachable!() } // Nothing to do here, since call expression already cached Expression::CallResult(_) | Expression::AtomicResult { .. } | Expression::WorkGroupUniformLoadResult { .. } | Expression::RayQueryProceedResult | Expression::SubgroupBallotResult | Expression::SubgroupOperationResult { .. } => {} } if !closing_bracket.is_empty() { write!(self.out, "{closing_bracket}")?; } Ok(()) } #[allow(clippy::too_many_arguments)] fn write_image_load( &mut self, module: &&Module, expr: Handle, func_ctx: &back::FunctionCtx, image: Handle, coordinate: Handle, array_index: Option>, sample: Option>, level: Option>, ) -> Result<(), Error> { let mut wrapping_type = None; match *func_ctx.resolve_type(image, &module.types) { TypeInner::Image { class: crate::ImageClass::External, .. } => { write!(self.out, "{IMAGE_LOAD_EXTERNAL_FUNCTION}(")?; self.write_expr(module, image, func_ctx)?; write!(self.out, ", ")?; self.write_expr(module, coordinate, func_ctx)?; write!(self.out, ")")?; return Ok(()); } TypeInner::Image { class: crate::ImageClass::Storage { format, .. }, .. } => { if format.single_component() { wrapping_type = Some(Scalar::from(format)); } } _ => {} } if let Some(scalar) = wrapping_type { write!( self.out, "{}{}(", help::IMAGE_STORAGE_LOAD_SCALAR_WRAPPER, scalar.to_hlsl_str()? )?; } // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-load self.write_expr(module, image, func_ctx)?; write!(self.out, ".Load(")?; self.write_texture_coordinates("int", coordinate, array_index, level, module, func_ctx)?; if let Some(sample) = sample { write!(self.out, ", ")?; self.write_expr(module, sample, func_ctx)?; } // close bracket for Load function write!(self.out, ")")?; if wrapping_type.is_some() { write!(self.out, ")")?; } // return x component if return type is scalar if let TypeInner::Scalar(_) = *func_ctx.resolve_type(expr, &module.types) { write!(self.out, ".x")?; } Ok(()) } /// Find the [`BindingArraySamplerInfo`] from an expression so that such an access /// can be generated later. fn sampler_binding_array_info_from_expression( &mut self, module: &Module, func_ctx: &back::FunctionCtx<'_>, base: Handle, resolved: &TypeInner, ) -> Option { if let TypeInner::BindingArray { base: base_ty_handle, .. } = *resolved { let base_ty = &module.types[base_ty_handle].inner; if let TypeInner::Sampler { comparison, .. } = *base_ty { let base = &func_ctx.expressions[base]; if let crate::Expression::GlobalVariable(handle) = *base { let variable = &module.global_variables[handle]; let sampler_heap_name = match comparison { true => COMPARISON_SAMPLER_HEAP_VAR, false => SAMPLER_HEAP_VAR, }; return Some(BindingArraySamplerInfo { sampler_heap_name, sampler_index_buffer_name: self .wrapped .sampler_index_buffers .get(&super::SamplerIndexBufferKey { group: variable.binding.unwrap().group, }) .unwrap() .clone(), binding_array_base_index_name: self.names[&NameKey::GlobalVariable(handle)] .clone(), }); } } } None } fn write_named_expr( &mut self, module: &Module, handle: Handle, name: String, // The expression which is being named. // Generally, this is the same as handle, except in WorkGroupUniformLoad named: Handle, ctx: &back::FunctionCtx, ) -> BackendResult { match ctx.info[named].ty { proc::TypeResolution::Handle(ty_handle) => match module.types[ty_handle].inner { TypeInner::Struct { .. } => { let ty_name = &self.names[&NameKey::Type(ty_handle)]; write!(self.out, "{ty_name}")?; } _ => { self.write_type(module, ty_handle)?; } }, proc::TypeResolution::Value(ref inner) => { self.write_value_type(module, inner)?; } } let resolved = ctx.resolve_type(named, &module.types); write!(self.out, " {name}")?; // If rhs is a array type, we should write array size if let TypeInner::Array { base, size, .. } = *resolved { self.write_array_size(module, base, size)?; } write!(self.out, " = ")?; self.write_expr(module, handle, ctx)?; writeln!(self.out, ";")?; self.named_expressions.insert(named, name); Ok(()) } /// Helper function that write default zero initialization pub(super) fn write_default_init( &mut self, module: &Module, ty: Handle, ) -> BackendResult { write!(self.out, "(")?; self.write_type(module, ty)?; if let TypeInner::Array { base, size, .. } = module.types[ty].inner { self.write_array_size(module, base, size)?; } write!(self.out, ")0")?; Ok(()) } fn write_control_barrier( &mut self, barrier: crate::Barrier, level: back::Level, ) -> BackendResult { if barrier.contains(crate::Barrier::STORAGE) { writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?; } if barrier.contains(crate::Barrier::WORK_GROUP) { writeln!(self.out, "{level}GroupMemoryBarrierWithGroupSync();")?; } if barrier.contains(crate::Barrier::SUB_GROUP) { // Does not exist in DirectX } if barrier.contains(crate::Barrier::TEXTURE) { writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?; } Ok(()) } fn write_memory_barrier( &mut self, barrier: crate::Barrier, level: back::Level, ) -> BackendResult { if barrier.contains(crate::Barrier::STORAGE) { writeln!(self.out, "{level}DeviceMemoryBarrier();")?; } if barrier.contains(crate::Barrier::WORK_GROUP) { writeln!(self.out, "{level}GroupMemoryBarrier();")?; } if barrier.contains(crate::Barrier::SUB_GROUP) { // Does not exist in DirectX } if barrier.contains(crate::Barrier::TEXTURE) { writeln!(self.out, "{level}DeviceMemoryBarrier();")?; } Ok(()) } /// Helper to emit the shared tail of an HLSL atomic call (arguments, value, result) fn emit_hlsl_atomic_tail( &mut self, module: &Module, func_ctx: &back::FunctionCtx<'_>, fun: &crate::AtomicFunction, compare_expr: Option>, value: Handle, res_var_info: &Option<(Handle, String)>, ) -> BackendResult { if let Some(cmp) = compare_expr { write!(self.out, ", ")?; self.write_expr(module, cmp, func_ctx)?; } write!(self.out, ", ")?; if let crate::AtomicFunction::Subtract = *fun { // we just wrote `InterlockedAdd`, so negate the argument write!(self.out, "-")?; } self.write_expr(module, value, func_ctx)?; if let Some(&(_res_handle, ref res_name)) = res_var_info.as_ref() { write!(self.out, ", ")?; if compare_expr.is_some() { write!(self.out, "{res_name}.old_value")?; } else { write!(self.out, "{res_name}")?; } } writeln!(self.out, ");")?; Ok(()) } } pub(super) struct MatrixType { pub(super) columns: crate::VectorSize, pub(super) rows: crate::VectorSize, pub(super) width: crate::Bytes, } pub(super) fn get_inner_matrix_data( module: &Module, handle: Handle, ) -> Option { match module.types[handle].inner { TypeInner::Matrix { columns, rows, scalar, } => Some(MatrixType { columns, rows, width: scalar.width, }), TypeInner::Array { base, .. } => get_inner_matrix_data(module, base), _ => None, } } /// If `base` is an access chain of the form `mat`, `mat[col]`, or `mat[col][row]`, /// returns a tuple of the matrix, the column (vector) index (if present), and /// the row (scalar) index (if present). fn find_matrix_in_access_chain( module: &Module, base: Handle, func_ctx: &back::FunctionCtx<'_>, ) -> Option<(Handle, Option, Option)> { let mut current_base = base; let mut vector = None; let mut scalar = None; loop { let resolved_tr = func_ctx .resolve_type(current_base, &module.types) .pointer_base_type(); let resolved = resolved_tr.as_ref()?.inner_with(&module.types); match *resolved { TypeInner::Matrix { .. } => return Some((current_base, vector, scalar)), TypeInner::Scalar(_) | TypeInner::Vector { .. } => {} _ => return None, } let index; (current_base, index) = match func_ctx.expressions[current_base] { crate::Expression::Access { base, index } => (base, Index::Expression(index)), crate::Expression::AccessIndex { base, index } => (base, Index::Static(index)), _ => return None, }; match *resolved { TypeInner::Scalar(_) => scalar = Some(index), TypeInner::Vector { .. } => vector = Some(index), _ => unreachable!(), } } } /// Returns the matrix data if the access chain starting at `base`: /// - starts with an expression with resolved type of [`TypeInner::Matrix`] if `direct = true` /// - contains one or more expressions with resolved type of [`TypeInner::Array`] of [`TypeInner::Matrix`] /// - ends at an expression with resolved type of [`TypeInner::Struct`] pub(super) fn get_inner_matrix_of_struct_array_member( module: &Module, base: Handle, func_ctx: &back::FunctionCtx<'_>, direct: bool, ) -> Option { let mut mat_data = None; let mut array_base = None; let mut current_base = base; loop { let mut resolved = func_ctx.resolve_type(current_base, &module.types); if let TypeInner::Pointer { base, .. } = *resolved { resolved = &module.types[base].inner; }; match *resolved { TypeInner::Matrix { columns, rows, scalar, } => { mat_data = Some(MatrixType { columns, rows, width: scalar.width, }) } TypeInner::Array { base, .. } => { array_base = Some(base); } TypeInner::Struct { .. } => { if let Some(array_base) = array_base { if direct { return mat_data; } else { return get_inner_matrix_data(module, array_base); } } break; } _ => break, } current_base = match func_ctx.expressions[current_base] { crate::Expression::Access { base, .. } => base, crate::Expression::AccessIndex { base, .. } => base, _ => break, }; } None } /// Simpler version of get_inner_matrix_of_global_uniform that only looks at the /// immediate expression, rather than traversing an access chain. fn get_global_uniform_matrix( module: &Module, base: Handle, func_ctx: &back::FunctionCtx<'_>, ) -> Option { let base_tr = func_ctx .resolve_type(base, &module.types) .pointer_base_type(); let base_ty = base_tr.as_ref().map(|tr| tr.inner_with(&module.types)); match (&func_ctx.expressions[base], base_ty) { ( &crate::Expression::GlobalVariable(handle), Some(&TypeInner::Matrix { columns, rows, scalar, }), ) if module.global_variables[handle].space == crate::AddressSpace::Uniform => { Some(MatrixType { columns, rows, width: scalar.width, }) } _ => None, } } /// Returns the matrix data if the access chain starting at `base`: /// - starts with an expression with resolved type of [`TypeInner::Matrix`] /// - contains zero or more expressions with resolved type of [`TypeInner::Array`] of [`TypeInner::Matrix`] /// - ends with an [`Expression::GlobalVariable`](crate::Expression::GlobalVariable) in [`AddressSpace::Uniform`](crate::AddressSpace::Uniform) fn get_inner_matrix_of_global_uniform( module: &Module, base: Handle, func_ctx: &back::FunctionCtx<'_>, ) -> Option { let mut mat_data = None; let mut array_base = None; let mut current_base = base; loop { let mut resolved = func_ctx.resolve_type(current_base, &module.types); if let TypeInner::Pointer { base, .. } = *resolved { resolved = &module.types[base].inner; }; match *resolved { TypeInner::Matrix { columns, rows, scalar, } => { mat_data = Some(MatrixType { columns, rows, width: scalar.width, }) } TypeInner::Array { base, .. } => { array_base = Some(base); } _ => break, } current_base = match func_ctx.expressions[current_base] { crate::Expression::Access { base, .. } => base, crate::Expression::AccessIndex { base, .. } => base, crate::Expression::GlobalVariable(handle) if module.global_variables[handle].space == crate::AddressSpace::Uniform => { return mat_data.or_else(|| { array_base.and_then(|array_base| get_inner_matrix_data(module, array_base)) }) } _ => break, }; } None } naga-29.0.3/src/back/mod.rs000064400000000000000000000335431046102023000134510ustar 00000000000000/*! Backend functions that export shader [`Module`](super::Module)s into binary and text formats. */ #![cfg_attr( not(any(dot_out, glsl_out, hlsl_out, msl_out, spv_out, wgsl_out)), allow( dead_code, reason = "shared helpers can be dead if none of the enabled backends need it" ) )] use alloc::string::String; #[cfg(dot_out)] pub mod dot; #[cfg(glsl_out)] pub mod glsl; #[cfg(hlsl_out)] pub mod hlsl; #[cfg(msl_out)] pub mod msl; #[cfg(spv_out)] pub mod spv; #[cfg(wgsl_out)] pub mod wgsl; #[cfg(any(hlsl_out, msl_out, spv_out, glsl_out))] pub mod pipeline_constants; #[cfg(any(hlsl_out, glsl_out))] mod continue_forward; /// Names of vector components. pub const COMPONENTS: &[char] = &['x', 'y', 'z', 'w']; /// Indent for backends. pub const INDENT: &str = " "; /// Expressions that need baking. pub type NeedBakeExpressions = crate::FastHashSet>; /// A type for displaying expression handles as baking identifiers. /// /// Given an [`Expression`] [`Handle`] `h`, `Baked(h)` implements /// [`core::fmt::Display`], showing the handle's index prefixed by /// `_e`. /// /// [`Expression`]: crate::Expression /// [`Handle`]: crate::Handle #[cfg_attr( not(any(glsl_out, hlsl_out, msl_out, wgsl_out)), allow( dead_code, reason = "shared helpers can be dead if none of the enabled backends need it" ) )] struct Baked(crate::Handle); impl core::fmt::Display for Baked { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { self.0.write_prefixed(f, "_e") } } bitflags::bitflags! { /// How far through a ray query are we #[derive(Clone, Copy, Debug, Eq, PartialEq)] #[cfg_attr( not(any(hlsl_out, spv_out)), allow( dead_code, reason = "shared helpers can be dead if none of the enabled backends need it" ) )] pub(super) struct RayQueryPoint: u32 { /// Ray query has been successfully initialized. const INITIALIZED = 1 << 0; /// Proceed has been called on ray query. const PROCEED = 1 << 1; /// Proceed has returned false (have finished traversal). const FINISHED_TRAVERSAL = 1 << 2; } } /// Specifies the values of pipeline-overridable constants in the shader module. /// /// If an `@id` attribute was specified on the declaration, /// the key must be the pipeline constant ID as a decimal ASCII number; if not, /// the key must be the constant's identifier name. /// /// The value may represent any of WGSL's concrete scalar types. pub type PipelineConstants = hashbrown::HashMap; /// Indentation level. #[derive(Clone, Copy)] pub struct Level(pub usize); impl Level { pub const fn next(&self) -> Self { Level(self.0 + 1) } } impl core::fmt::Display for Level { fn fmt(&self, formatter: &mut core::fmt::Formatter<'_>) -> Result<(), core::fmt::Error> { (0..self.0).try_for_each(|_| formatter.write_str(INDENT)) } } /// Locate the entry point(s) to write. /// /// If `entry_point` is given, and the specified entry point exists, returns a /// length-1 `Range` containing the index of that entry point. If no /// `entry_point` is given, returns the complete range of entry point indices. /// If `entry_point` is given but does not exist, returns an error. #[cfg(any(hlsl_out, msl_out))] fn get_entry_points( module: &crate::ir::Module, entry_point: Option<&(crate::ir::ShaderStage, String)>, ) -> Result, (crate::ir::ShaderStage, String)> { use alloc::borrow::ToOwned; if let Some(&(stage, ref name)) = entry_point { let Some(ep_index) = module .entry_points .iter() .position(|ep| ep.stage == stage && ep.name == *name) else { return Err((stage, name.to_owned())); }; Ok(ep_index..ep_index + 1) } else { Ok(0..module.entry_points.len()) } } /// Whether we're generating an entry point or a regular function. /// /// Backend languages often require different code for a [`Function`] /// depending on whether it represents an [`EntryPoint`] or not. /// Backends can pass common code one of these values to select the /// right behavior. /// /// These values also carry enough information to find the `Function` /// in the [`Module`]: the `Handle` for a regular function, or the /// index into [`Module::entry_points`] for an entry point. /// /// [`Function`]: crate::Function /// [`EntryPoint`]: crate::EntryPoint /// [`Module`]: crate::Module /// [`Module::entry_points`]: crate::Module::entry_points #[derive(Clone, Copy, Debug)] pub enum FunctionType { /// A regular function. Function(crate::Handle), /// An [`EntryPoint`], and its index in [`Module::entry_points`]. /// /// [`EntryPoint`]: crate::EntryPoint /// [`Module::entry_points`]: crate::Module::entry_points EntryPoint(crate::proc::EntryPointIndex), } impl FunctionType { /// Returns true if the function is an entry point for a compute-like shader. pub fn is_compute_like_entry_point(&self, module: &crate::Module) -> bool { match *self { FunctionType::EntryPoint(index) => { module.entry_points[index as usize].stage.compute_like() } FunctionType::Function(_) => false, } } } /// Helper structure that stores data needed when writing the function pub struct FunctionCtx<'a> { /// The current function being written pub ty: FunctionType, /// Analysis about the function pub info: &'a crate::valid::FunctionInfo, /// The expression arena of the current function being written pub expressions: &'a crate::Arena, /// Map of expressions that have associated variable names pub named_expressions: &'a crate::NamedExpressions, } impl FunctionCtx<'_> { /// Helper method that resolves a type of a given expression. pub fn resolve_type<'a>( &'a self, handle: crate::Handle, types: &'a crate::UniqueArena, ) -> &'a crate::TypeInner { self.info[handle].ty.inner_with(types) } /// Helper method that generates a [`NameKey`](crate::proc::NameKey) for a local in the current function pub const fn name_key( &self, local: crate::Handle, ) -> crate::proc::NameKey { match self.ty { FunctionType::Function(handle) => crate::proc::NameKey::FunctionLocal(handle, local), FunctionType::EntryPoint(idx) => crate::proc::NameKey::EntryPointLocal(idx, local), } } /// Helper method that generates a [`NameKey`](crate::proc::NameKey) for a function argument. /// /// # Panics /// - If the function arguments are less or equal to `arg` pub const fn argument_key(&self, arg: u32) -> crate::proc::NameKey { match self.ty { FunctionType::Function(handle) => crate::proc::NameKey::FunctionArgument(handle, arg), FunctionType::EntryPoint(ep_index) => { crate::proc::NameKey::EntryPointArgument(ep_index, arg) } } } /// Helper method that generates a [`NameKey`](crate::proc::NameKey) for an external texture /// function argument. /// /// # Panics /// - If the function arguments are less or equal to `arg` /// - If `self.ty` is not `FunctionType::Function`. pub const fn external_texture_argument_key( &self, arg: u32, external_texture_key: crate::proc::ExternalTextureNameKey, ) -> crate::proc::NameKey { match self.ty { FunctionType::Function(handle) => { crate::proc::NameKey::ExternalTextureFunctionArgument( handle, arg, external_texture_key, ) } // This is a const function, which _sometimes_ gets called, // so this lint is _sometimes_ triggered, depending on feature set. #[expect(clippy::allow_attributes)] #[allow(clippy::panic)] FunctionType::EntryPoint(_) => { panic!("External textures cannot be used as arguments to entry points") } } } /// Returns true if the given expression points to a fixed-function pipeline input. pub fn is_fixed_function_input( &self, mut expression: crate::Handle, module: &crate::Module, ) -> Option { let ep_function = match self.ty { FunctionType::Function(_) => return None, FunctionType::EntryPoint(ep_index) => &module.entry_points[ep_index as usize].function, }; let mut built_in = None; loop { match self.expressions[expression] { crate::Expression::FunctionArgument(arg_index) => { return match ep_function.arguments[arg_index as usize].binding { Some(crate::Binding::BuiltIn(bi)) => Some(bi), _ => built_in, }; } crate::Expression::AccessIndex { base, index } => { match *self.resolve_type(base, &module.types) { crate::TypeInner::Struct { ref members, .. } => { if let Some(crate::Binding::BuiltIn(bi)) = members[index as usize].binding { built_in = Some(bi); } } _ => return None, } expression = base; } _ => return None, } } } } impl crate::Expression { /// Returns the ref count, upon reaching which this expression /// should be considered for baking. /// /// Note: we have to cache any expressions that depend on the control flow, /// or otherwise they may be moved into a non-uniform control flow, accidentally. /// See the [module-level documentation][emit] for details. /// /// [emit]: index.html#expression-evaluation-time pub const fn bake_ref_count(&self) -> usize { match *self { // accesses are never cached, only loads are crate::Expression::Access { .. } | crate::Expression::AccessIndex { .. } => usize::MAX, // sampling may use the control flow, and image ops look better by themselves crate::Expression::ImageSample { .. } | crate::Expression::ImageLoad { .. } => 1, // derivatives use the control flow crate::Expression::Derivative { .. } => 1, // TODO: We need a better fix for named `Load` expressions // More info - https://github.com/gfx-rs/naga/pull/914 // And https://github.com/gfx-rs/naga/issues/910 crate::Expression::Load { .. } => 1, // cache expressions that are referenced multiple times _ => 2, } } } /// Helper function that returns the string corresponding to the [`BinaryOperator`](crate::BinaryOperator) pub const fn binary_operation_str(op: crate::BinaryOperator) -> &'static str { use crate::BinaryOperator as Bo; match op { Bo::Add => "+", Bo::Subtract => "-", Bo::Multiply => "*", Bo::Divide => "/", Bo::Modulo => "%", Bo::Equal => "==", Bo::NotEqual => "!=", Bo::Less => "<", Bo::LessEqual => "<=", Bo::Greater => ">", Bo::GreaterEqual => ">=", Bo::And => "&", Bo::ExclusiveOr => "^", Bo::InclusiveOr => "|", Bo::LogicalAnd => "&&", Bo::LogicalOr => "||", Bo::ShiftLeft => "<<", Bo::ShiftRight => ">>", } } impl crate::TypeInner { /// Returns true if a variable of this type is a handle. pub const fn is_handle(&self) -> bool { match *self { Self::Image { .. } | Self::Sampler { .. } | Self::AccelerationStructure { .. } => true, _ => false, } } } impl crate::Statement { /// Returns true if the statement directly terminates the current block. /// /// Used to decide whether case blocks require a explicit `break`. pub const fn is_terminator(&self) -> bool { match *self { crate::Statement::Break | crate::Statement::Continue | crate::Statement::Return { .. } | crate::Statement::Kill => true, _ => false, } } } bitflags::bitflags! { /// Ray flags, for a [`RayDesc`]'s `flags` field. /// /// Note that these exactly correspond to the SPIR-V "Ray Flags" mask, and /// the SPIR-V backend passes them directly through to the /// [`OpRayQueryInitializeKHR`][op] instruction. (We have to choose something, so /// we might as well make one back end's life easier.) /// /// [`RayDesc`]: crate::Module::generate_ray_desc_type /// [op]: https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpRayQueryInitializeKHR #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)] pub struct RayFlag: u32 { const OPAQUE = 0x01; const NO_OPAQUE = 0x02; const TERMINATE_ON_FIRST_HIT = 0x04; const SKIP_CLOSEST_HIT_SHADER = 0x08; const CULL_BACK_FACING = 0x10; const CULL_FRONT_FACING = 0x20; const CULL_OPAQUE = 0x40; const CULL_NO_OPAQUE = 0x80; const SKIP_TRIANGLES = 0x100; const SKIP_AABBS = 0x200; } } /// The intersection test to use for ray queries. #[repr(u32)] pub enum RayIntersectionType { Triangle = 1, BoundingBox = 4, } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] pub struct TaskDispatchLimits { pub max_mesh_workgroups_per_dim: u32, pub max_mesh_workgroups_total: u32, } naga-29.0.3/src/back/msl/keywords.rs000064400000000000000000000204321046102023000153250ustar 00000000000000use crate::proc::{concrete_int_scalars, vector_size_str, vector_sizes, KeywordSet}; use crate::racy_lock::RacyLock; use alloc::{format, string::String, vec::Vec}; // MSLS - Metal Shading Language Specification: // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf // // C++ - Standard for Programming Language C++ (N4431) // https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2015/n4431.pdf const RESERVED: &[&str] = &[ // Undocumented "assert", // found in https://github.com/gfx-rs/wgpu/issues/5347 // Standard for Programming Language C++ (N4431): 2.5 Alternative tokens "and", "bitor", "or", "xor", "compl", "bitand", "and_eq", "or_eq", "xor_eq", "not", "not_eq", // Standard for Programming Language C++ (N4431): 2.11 Keywords "alignas", "alignof", "asm", "auto", "bool", "break", "case", "catch", "char", "char16_t", "char32_t", "class", "const", "constexpr", "const_cast", "continue", "decltype", "default", "delete", "do", "double", "dynamic_cast", "else", "enum", "explicit", "export", "extern", "false", "float", "for", "friend", "goto", "if", "inline", "int", "long", "mutable", "namespace", "new", "noexcept", "nullptr", "operator", "private", "protected", "public", "register", "reinterpret_cast", "return", "short", "signed", "sizeof", "static", "static_assert", "static_cast", "struct", "switch", "template", "this", "thread_local", "throw", "true", "try", "typedef", "typeid", "typename", "union", "unsigned", "using", "virtual", "void", "volatile", "wchar_t", "while", // Metal Shading Language Specification: 1.4.4 Restrictions "main", // Metal Shading Language Specification: 2.1 Scalar Data Types "int8_t", "uchar", "uint8_t", "int16_t", "ushort", "uint16_t", "int32_t", "uint", "uint32_t", "int64_t", "uint64_t", "half", "bfloat", "size_t", "ptrdiff_t", // Metal Shading Language Specification: 2.2 Vector Data Types "bool2", "bool3", "bool4", "char2", "char3", "char4", "short2", "short3", "short4", "int2", "int3", "int4", "long2", "long3", "long4", "uchar2", "uchar3", "uchar4", "ushort2", "ushort3", "ushort4", "uint2", "uint3", "uint4", "ulong2", "ulong3", "ulong4", "half2", "half3", "half4", "bfloat2", "bfloat3", "bfloat4", "float2", "float3", "float4", "vec", // Metal Shading Language Specification: 2.2.3 Packed Vector Types "packed_bool2", "packed_bool3", "packed_bool4", "packed_char2", "packed_char3", "packed_char4", "packed_short2", "packed_short3", "packed_short4", "packed_int2", "packed_int3", "packed_int4", "packed_uchar2", "packed_uchar3", "packed_uchar4", "packed_ushort2", "packed_ushort3", "packed_ushort4", "packed_uint2", "packed_uint3", "packed_uint4", "packed_half2", "packed_half3", "packed_half4", "packed_bfloat2", "packed_bfloat3", "packed_bfloat4", "packed_float2", "packed_float3", "packed_float4", "packed_long2", "packed_long3", "packed_long4", "packed_vec", // Metal Shading Language Specification: 2.3 Matrix Data Types "half2x2", "half2x3", "half2x4", "half3x2", "half3x3", "half3x4", "half4x2", "half4x3", "half4x4", "float2x2", "float2x3", "float2x4", "float3x2", "float3x3", "float3x4", "float4x2", "float4x3", "float4x4", "matrix", // Metal Shading Language Specification: 2.6 Atomic Data Types "atomic", "atomic_int", "atomic_uint", "atomic_bool", "atomic_ulong", "atomic_float", // Metal Shading Language Specification: 2.20 Type Conversions and Re-interpreting Data "as_type", // Metal Shading Language Specification: 4 Address Spaces "device", "constant", "thread", "threadgroup", "threadgroup_imageblock", "ray_data", "object_data", // Metal Shading Language Specification: 5.1 Functions "vertex", "fragment", "kernel", // Metal Shading Language Specification: 6.1 Namespace and Header Files "metal", // C99 / C++ extension: "restrict", // Metal reserved types in : "llong", "ullong", "quad", "complex", "imaginary", // Constants in : "CHAR_BIT", "SCHAR_MAX", "SCHAR_MIN", "UCHAR_MAX", "CHAR_MAX", "CHAR_MIN", "USHRT_MAX", "SHRT_MAX", "SHRT_MIN", "UINT_MAX", "INT_MAX", "INT_MIN", "ULONG_MAX", "LONG_MAX", "LONG_MIN", "ULLONG_MAX", "LLONG_MAX", "LLONG_MIN", "FLT_DIG", "FLT_MANT_DIG", "FLT_MAX_10_EXP", "FLT_MAX_EXP", "FLT_MIN_10_EXP", "FLT_MIN_EXP", "FLT_RADIX", "FLT_MAX", "FLT_MIN", "FLT_EPSILON", "FLT_DECIMAL_DIG", "FP_ILOGB0", "FP_ILOGB0", "FP_ILOGBNAN", "FP_ILOGBNAN", "MAXFLOAT", "HUGE_VALF", "INFINITY", "NAN", "M_E_F", "M_LOG2E_F", "M_LOG10E_F", "M_LN2_F", "M_LN10_F", "M_PI_F", "M_PI_2_F", "M_PI_4_F", "M_1_PI_F", "M_2_PI_F", "M_2_SQRTPI_F", "M_SQRT2_F", "M_SQRT1_2_F", "HALF_DIG", "HALF_MANT_DIG", "HALF_MAX_10_EXP", "HALF_MAX_EXP", "HALF_MIN_10_EXP", "HALF_MIN_EXP", "HALF_RADIX", "HALF_MAX", "HALF_MIN", "HALF_EPSILON", "HALF_DECIMAL_DIG", "MAXHALF", "HUGE_VALH", "M_E_H", "M_LOG2E_H", "M_LOG10E_H", "M_LN2_H", "M_LN10_H", "M_PI_H", "M_PI_2_H", "M_PI_4_H", "M_1_PI_H", "M_2_PI_H", "M_2_SQRTPI_H", "M_SQRT2_H", "M_SQRT1_2_H", "DBL_DIG", "DBL_MANT_DIG", "DBL_MAX_10_EXP", "DBL_MAX_EXP", "DBL_MIN_10_EXP", "DBL_MIN_EXP", "DBL_RADIX", "DBL_MAX", "DBL_MIN", "DBL_EPSILON", "DBL_DECIMAL_DIG", "MAXDOUBLE", "HUGE_VAL", "M_E", "M_LOG2E", "M_LOG10E", "M_LN2", "M_LN10", "M_PI", "M_PI_2", "M_PI_4", "M_1_PI", "M_2_PI", "M_2_SQRTPI", "M_SQRT2", "M_SQRT1_2", // Naga utilities "DefaultConstructible", // Naga builtin names "__local_invocation_id", super::writer::FREXP_FUNCTION, super::writer::MODF_FUNCTION, super::writer::ABS_FUNCTION, super::writer::DIV_FUNCTION, // DOT_FUNCTION_PREFIX variants are added dynamically below super::writer::MOD_FUNCTION, super::writer::NEG_FUNCTION, super::writer::F2I32_FUNCTION, super::writer::F2U32_FUNCTION, super::writer::F2I64_FUNCTION, super::writer::F2U64_FUNCTION, super::writer::IMAGE_LOAD_EXTERNAL_FUNCTION, super::writer::IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION, super::writer::IMAGE_SIZE_EXTERNAL_FUNCTION, super::writer::ARGUMENT_BUFFER_WRAPPER_STRUCT, super::writer::EXTERNAL_TEXTURE_WRAPPER_STRUCT, super::writer::COOPERATIVE_LOAD_FUNCTION, super::writer::COOPERATIVE_MULTIPLY_ADD_FUNCTION, ]; // The set of concrete integer dot product function variants. // This must match the set of names that could be produced by // `Writer::get_dot_wrapper_function_helper_name`. static DOT_FUNCTION_NAMES: RacyLock> = RacyLock::new(|| { let mut names = Vec::new(); for scalar in concrete_int_scalars().map(crate::Scalar::to_msl_name) { for size_suffix in vector_sizes().map(vector_size_str) { let fun_name = format!( "{}_{}{}", super::writer::DOT_FUNCTION_PREFIX, scalar, size_suffix ); names.push(fun_name); } } names }); /// The above set of reserved keywords, turned into a cached HashSet. This saves /// significant time during [`Namer::reset`](crate::proc::Namer::reset). /// /// See for benchmarks. pub static RESERVED_SET: RacyLock = RacyLock::new(|| { let mut set = KeywordSet::from_iter(RESERVED); set.extend(DOT_FUNCTION_NAMES.iter().map(String::as_str)); set }); naga-29.0.3/src/back/msl/mod.rs000064400000000000000000001057141046102023000142440ustar 00000000000000/*! Backend for [MSL][msl] (Metal Shading Language). This backend does not support the [`SHADER_INT64_ATOMIC_ALL_OPS`][all-atom] capability. ## Binding model Metal's bindings are flat per resource. Since there isn't an obvious mapping from SPIR-V's descriptor sets, we require a separate mapping provided in the options. This mapping may have one or more resource end points for each descriptor set + index pair. ## Entry points Even though MSL and our IR appear to be similar in that the entry points in both can accept arguments and return values, the restrictions are different. MSL allows the varyings to be either in separate arguments, or inside a single `[[stage_in]]` struct. We gather input varyings and form this artificial structure. We also add all the (non-Private) globals into the arguments. At the beginning of the entry point, we assign the local constants and re-compose the arguments as they are declared on IR side, so that the rest of the logic can pretend that MSL doesn't have all the restrictions it has. For the result type, if it's a structure, we re-compose it with a temporary value holding the result. [msl]: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf [all-atom]: crate::valid::Capabilities::SHADER_INT64_ATOMIC_ALL_OPS ## Pointer-typed bounds-checked expressions and OOB locals MSL (unlike HLSL and GLSL) has native support for pointer-typed function arguments. When the [`BoundsCheckPolicy`] is `ReadZeroSkipWrite` and an out-of-bounds index expression is used for such an argument, our strategy is to pass a pointer to a dummy variable. These dummy variables are called "OOB locals". We emit at most one OOB local per function for each type, since all expressions producing a result of that type can share the same OOB local. (Note that the OOB local mechanism is not actually implementing "skip write", nor even "read zero" in some cases of read-after-write, but doing so would require additional effort and the difference is unlikely to matter.) [`BoundsCheckPolicy`]: crate::proc::BoundsCheckPolicy ## External textures Support for [`crate::ImageClass::External`] textures is implemented by lowering each external texture global variable to 3 `texture2d`s, and a constant buffer of type `NagaExternalTextureParams`. This provides up to 3 planes of texture data (for example single planar RGBA, or separate Y, Cb, and Cr planes), and the parameters buffer containing information describing how to handle these correctly. The bind target to use for each of these globals is specified via the [`BindTarget::external_texture`] field of the relevant entries in [`EntryPointResources::resources`]. External textures are supported by WGSL's `textureDimensions()`, `textureLoad()`, and `textureSampleBaseClampToEdge()` built-in functions. These are implemented using helper functions. See the following functions for how these are generated: * `Writer::write_wrapped_image_query` * `Writer::write_wrapped_image_load` * `Writer::write_wrapped_image_sample` The lowered global variables for each external texture global are passed to the entry point as separate arguments (see "Entry points" above). However, they are then wrapped in a struct to allow them to be conveniently passed to user defined and helper functions. See `writer::EXTERNAL_TEXTURE_WRAPPER_STRUCT`. */ use alloc::{ format, string::{String, ToString}, vec::Vec, }; use core::fmt::{Error as FmtError, Write}; use crate::{arena::Handle, ir, proc::index, valid::ModuleInfo}; mod keywords; pub mod sampler; mod writer; pub use writer::Writer; pub type Slot = u8; pub type InlineSamplerIndex = u8; #[derive(Clone, Debug, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] pub enum BindSamplerTarget { Resource(Slot), Inline(InlineSamplerIndex), } /// Binding information for a Naga [`External`] image global variable. /// /// See the module documentation's section on external textures for details. /// /// [`External`]: crate::ir::ImageClass::External #[derive(Clone, Debug, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] pub struct BindExternalTextureTarget { pub planes: [Slot; 3], pub params: Slot, } #[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] #[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(default))] pub struct BindTarget { pub buffer: Option, pub texture: Option, pub sampler: Option, pub external_texture: Option, pub mutable: bool, } #[cfg(feature = "deserialize")] #[derive(serde::Deserialize)] struct BindingMapSerialization { resource_binding: crate::ResourceBinding, bind_target: BindTarget, } #[cfg(feature = "deserialize")] fn deserialize_binding_map<'de, D>(deserializer: D) -> Result where D: serde::Deserializer<'de>, { use serde::Deserialize; let vec = Vec::::deserialize(deserializer)?; let mut map = BindingMap::default(); for item in vec { map.insert(item.resource_binding, item.bind_target); } Ok(map) } // Using `BTreeMap` instead of `HashMap` so that we can hash itself. pub type BindingMap = alloc::collections::BTreeMap; #[derive(Clone, Debug, Default, Hash, Eq, PartialEq)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] #[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(default))] pub struct EntryPointResources { #[cfg_attr( feature = "deserialize", serde(deserialize_with = "deserialize_binding_map") )] pub resources: BindingMap, pub immediates_buffer: Option, /// The slot of a buffer that contains an array of `u32`, /// one for the size of each bound buffer that contains a runtime array, /// in order of [`crate::GlobalVariable`] declarations. pub sizes_buffer: Option, } pub type EntryPointResourceMap = alloc::collections::BTreeMap; enum ResolvedBinding { BuiltIn(crate::BuiltIn), Attribute(u32), Color { location: u32, blend_src: Option, }, User { prefix: &'static str, index: u32, interpolation: Option, }, Resource(BindTarget), } #[derive(Copy, Clone)] enum ResolvedInterpolation { CenterPerspective, CenterNoPerspective, CentroidPerspective, CentroidNoPerspective, SamplePerspective, SampleNoPerspective, Flat, } // Note: some of these should be removed in favor of proper IR validation. #[derive(Debug, thiserror::Error)] pub enum Error { #[error(transparent)] Format(#[from] FmtError), #[error("bind target {0:?} is empty")] UnimplementedBindTarget(BindTarget), #[error("composing of {0:?} is not implemented yet")] UnsupportedCompose(Handle), #[error("operation {0:?} is not implemented yet")] UnsupportedBinaryOp(crate::BinaryOperator), #[error("standard function '{0}' is not implemented yet")] UnsupportedCall(String), #[error("feature '{0}' is not implemented yet")] FeatureNotImplemented(String), #[error("internal naga error: module should not have validated: {0}")] GenericValidation(String), #[error("BuiltIn {0:?} is not supported")] UnsupportedBuiltIn(crate::BuiltIn), #[error("capability {0:?} is not supported")] CapabilityNotSupported(crate::valid::Capabilities), #[error("attribute '{0}' is not supported for target MSL version")] UnsupportedAttribute(String), #[error("function '{0}' is not supported for target MSL version")] UnsupportedFunction(String), #[error("can not use writeable storage buffers in fragment stage prior to MSL 1.2")] UnsupportedWriteableStorageBuffer, #[error("can not use writeable storage textures in {0:?} stage prior to MSL 1.2")] UnsupportedWriteableStorageTexture(ir::ShaderStage), #[error("can not use read-write storage textures prior to MSL 1.2")] UnsupportedRWStorageTexture, #[error("array of '{0}' is not supported for target MSL version")] UnsupportedArrayOf(String), #[error("array of type '{0:?}' is not supported")] UnsupportedArrayOfType(Handle), #[error("ray tracing is not supported prior to MSL 2.4")] UnsupportedRayTracing, #[error("cooperative matrix is not supported prior to MSL 2.3")] UnsupportedCooperativeMatrix, #[error("overrides should not be present at this stage")] Override, #[error("bitcasting to {0:?} is not supported")] UnsupportedBitCast(crate::TypeInner), #[error(transparent)] ResolveArraySizeError(#[from] crate::proc::ResolveArraySizeError), #[error("entry point with stage {0:?} and name '{1}' not found")] EntryPointNotFound(ir::ShaderStage, String), } #[derive(Clone, Debug, PartialEq, thiserror::Error)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] pub enum EntryPointError { #[error("global '{0}' doesn't have a binding")] MissingBinding(String), #[error("mapping of {0:?} is missing")] MissingBindTarget(crate::ResourceBinding), #[error("mapping for immediates is missing")] MissingImmediateData, #[error("mapping for sizes buffer is missing")] MissingSizesBuffer, } /// Points in the MSL code where we might emit a pipeline input or output. /// /// Note that, even though vertex shaders' outputs are always fragment /// shaders' inputs, we still need to distinguish `VertexOutput` and /// `FragmentInput`, since there are certain differences in the way /// [`ResolvedBinding`s] are represented on either side. /// /// [`ResolvedBinding`s]: ResolvedBinding #[derive(Clone, Copy, Debug)] enum LocationMode { /// Input to the vertex shader. VertexInput, /// Output from the vertex shader. VertexOutput, /// Input to the fragment shader. FragmentInput, /// Output from the fragment shader. FragmentOutput, /// Compute shader input or output. Uniform, } #[derive(Clone, Debug, Hash, PartialEq, Eq)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] #[cfg_attr(feature = "deserialize", serde(default))] pub struct Options { /// (Major, Minor) target version of the Metal Shading Language. pub lang_version: (u8, u8), /// Map of entry-point resources, indexed by entry point function name, to slots. pub per_entry_point_map: EntryPointResourceMap, /// Samplers to be inlined into the code. pub inline_samplers: Vec, /// Make it possible to link different stages via SPIRV-Cross. pub spirv_cross_compatibility: bool, /// Don't panic on missing bindings, instead generate invalid MSL. pub fake_missing_bindings: bool, /// Bounds checking policies. pub bounds_check_policies: index::BoundsCheckPolicies, /// Should workgroup variables be zero initialized (by polyfilling)? pub zero_initialize_workgroup_memory: bool, /// If set, loops will have code injected into them, forcing the compiler /// to think the number of iterations is bounded. pub force_loop_bounding: bool, } impl Default for Options { fn default() -> Self { Options { lang_version: (1, 0), per_entry_point_map: EntryPointResourceMap::default(), inline_samplers: Vec::new(), spirv_cross_compatibility: false, fake_missing_bindings: true, bounds_check_policies: index::BoundsCheckPolicies::default(), zero_initialize_workgroup_memory: true, force_loop_bounding: true, } } } /// Corresponds to [WebGPU `GPUVertexFormat`]( /// https://gpuweb.github.io/gpuweb/#enumdef-gpuvertexformat). #[repr(u32)] #[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] pub enum VertexFormat { /// One unsigned byte (u8). `u32` in shaders. Uint8 = 0, /// Two unsigned bytes (u8). `vec2` in shaders. Uint8x2 = 1, /// Four unsigned bytes (u8). `vec4` in shaders. Uint8x4 = 2, /// One signed byte (i8). `i32` in shaders. Sint8 = 3, /// Two signed bytes (i8). `vec2` in shaders. Sint8x2 = 4, /// Four signed bytes (i8). `vec4` in shaders. Sint8x4 = 5, /// One unsigned byte (u8). [0, 255] converted to float [0, 1] `f32` in shaders. Unorm8 = 6, /// Two unsigned bytes (u8). [0, 255] converted to float [0, 1] `vec2` in shaders. Unorm8x2 = 7, /// Four unsigned bytes (u8). [0, 255] converted to float [0, 1] `vec4` in shaders. Unorm8x4 = 8, /// One signed byte (i8). [-127, 127] converted to float [-1, 1] `f32` in shaders. Snorm8 = 9, /// Two signed bytes (i8). [-127, 127] converted to float [-1, 1] `vec2` in shaders. Snorm8x2 = 10, /// Four signed bytes (i8). [-127, 127] converted to float [-1, 1] `vec4` in shaders. Snorm8x4 = 11, /// One unsigned short (u16). `u32` in shaders. Uint16 = 12, /// Two unsigned shorts (u16). `vec2` in shaders. Uint16x2 = 13, /// Four unsigned shorts (u16). `vec4` in shaders. Uint16x4 = 14, /// One signed short (u16). `i32` in shaders. Sint16 = 15, /// Two signed shorts (i16). `vec2` in shaders. Sint16x2 = 16, /// Four signed shorts (i16). `vec4` in shaders. Sint16x4 = 17, /// One unsigned short (u16). [0, 65535] converted to float [0, 1] `f32` in shaders. Unorm16 = 18, /// Two unsigned shorts (u16). [0, 65535] converted to float [0, 1] `vec2` in shaders. Unorm16x2 = 19, /// Four unsigned shorts (u16). [0, 65535] converted to float [0, 1] `vec4` in shaders. Unorm16x4 = 20, /// One signed short (i16). [-32767, 32767] converted to float [-1, 1] `f32` in shaders. Snorm16 = 21, /// Two signed shorts (i16). [-32767, 32767] converted to float [-1, 1] `vec2` in shaders. Snorm16x2 = 22, /// Four signed shorts (i16). [-32767, 32767] converted to float [-1, 1] `vec4` in shaders. Snorm16x4 = 23, /// One half-precision float (no Rust equiv). `f32` in shaders. Float16 = 24, /// Two half-precision floats (no Rust equiv). `vec2` in shaders. Float16x2 = 25, /// Four half-precision floats (no Rust equiv). `vec4` in shaders. Float16x4 = 26, /// One single-precision float (f32). `f32` in shaders. Float32 = 27, /// Two single-precision floats (f32). `vec2` in shaders. Float32x2 = 28, /// Three single-precision floats (f32). `vec3` in shaders. Float32x3 = 29, /// Four single-precision floats (f32). `vec4` in shaders. Float32x4 = 30, /// One unsigned int (u32). `u32` in shaders. Uint32 = 31, /// Two unsigned ints (u32). `vec2` in shaders. Uint32x2 = 32, /// Three unsigned ints (u32). `vec3` in shaders. Uint32x3 = 33, /// Four unsigned ints (u32). `vec4` in shaders. Uint32x4 = 34, /// One signed int (i32). `i32` in shaders. Sint32 = 35, /// Two signed ints (i32). `vec2` in shaders. Sint32x2 = 36, /// Three signed ints (i32). `vec3` in shaders. Sint32x3 = 37, /// Four signed ints (i32). `vec4` in shaders. Sint32x4 = 38, /// Three unsigned 10-bit integers and one 2-bit integer, packed into a 32-bit integer (u32). [0, 1024] converted to float [0, 1] `vec4` in shaders. #[cfg_attr( any(feature = "serialize", feature = "deserialize"), serde(rename = "unorm10-10-10-2") )] Unorm10_10_10_2 = 43, /// Four unsigned 8-bit integers, packed into a 32-bit integer (u32). [0, 255] converted to float [0, 1] `vec4` in shaders. #[cfg_attr( any(feature = "serialize", feature = "deserialize"), serde(rename = "unorm8x4-bgra") )] Unorm8x4Bgra = 44, } /// Defines how to advance the data in vertex buffers. #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] pub enum VertexBufferStepMode { Constant, #[default] ByVertex, ByInstance, } /// A mapping of vertex buffers and their attributes to shader /// locations. #[derive(Debug, Clone, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] pub struct AttributeMapping { /// Shader location associated with this attribute pub shader_location: u32, /// Offset in bytes from start of vertex buffer structure pub offset: u32, /// Format code to help us unpack the attribute into the type /// used by the shader. Codes correspond to a 0-based index of /// . /// The conversion process is described by /// . pub format: VertexFormat, } /// A description of a vertex buffer with all the information we /// need to address the attributes within it. #[derive(Debug, Default, Clone, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] pub struct VertexBufferMapping { /// Shader location associated with this buffer pub id: u32, /// Size of the structure in bytes pub stride: u32, /// Vertex buffer step mode pub step_mode: VertexBufferStepMode, /// Vec of the attributes within the structure pub attributes: Vec, } /// A subset of options that are meant to be changed per pipeline. #[derive(Debug, Default, Clone)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] #[cfg_attr(feature = "deserialize", serde(default))] pub struct PipelineOptions { /// The entry point to write. /// /// Entry points are identified by a shader stage specification, /// and a name. /// /// If `None`, all entry points will be written. If `Some` and the entry /// point is not found, an error will be thrown while writing. pub entry_point: Option<(ir::ShaderStage, String)>, /// Allow `BuiltIn::PointSize` and inject it if doesn't exist. /// /// Metal doesn't like this for non-point primitive topologies and requires it for /// point primitive topologies. /// /// Enable this for vertex shaders with point primitive topologies. pub allow_and_force_point_size: bool, /// If set, when generating the Metal vertex shader, transform it /// to receive the vertex buffers, lengths, and vertex id as args, /// and bounds-check the vertex id and use the index into the /// vertex buffers to access attributes, rather than using Metal's /// [[stage-in]] assembled attribute data. This is true by default, /// but remains configurable for use by tests via deserialization /// of this struct. There is no user-facing way to set this value. pub vertex_pulling_transform: bool, /// vertex_buffer_mappings are used during shader translation to /// support vertex pulling. pub vertex_buffer_mappings: Vec, } impl Options { fn resolve_local_binding( &self, binding: &crate::Binding, mode: LocationMode, ) -> Result { match *binding { crate::Binding::BuiltIn(mut built_in) => { match built_in { crate::BuiltIn::Position { ref mut invariant } => { if *invariant && self.lang_version < (2, 1) { return Err(Error::UnsupportedAttribute("invariant".to_string())); } // The 'invariant' attribute may only appear on vertex // shader outputs, not fragment shader inputs. if !matches!(mode, LocationMode::VertexOutput) { *invariant = false; } } crate::BuiltIn::BaseInstance if self.lang_version < (1, 2) => { return Err(Error::UnsupportedAttribute("base_instance".to_string())); } crate::BuiltIn::InstanceIndex if self.lang_version < (1, 2) => { return Err(Error::UnsupportedAttribute("instance_id".to_string())); } // macOS: Since Metal 2.2 // iOS: Since Metal 2.3 (check depends on https://github.com/gfx-rs/wgpu/issues/4414) crate::BuiltIn::PrimitiveIndex if self.lang_version < (2, 3) => { return Err(Error::UnsupportedAttribute("primitive_id".to_string())); } // macOS: since Metal 2.3 // iOS: Since Metal 2.2 // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf#page=114 crate::BuiltIn::ViewIndex if self.lang_version < (2, 2) => { return Err(Error::UnsupportedAttribute("amplification_id".to_string())); } // macOS: Since Metal 2.2 // iOS: Since Metal 2.3 (check depends on https://github.com/gfx-rs/wgpu/issues/4414) crate::BuiltIn::Barycentric { .. } if self.lang_version < (2, 3) => { return Err(Error::UnsupportedAttribute("barycentric_coord".to_string())); } _ => {} } Ok(ResolvedBinding::BuiltIn(built_in)) } crate::Binding::Location { location, interpolation, sampling, blend_src, per_primitive: _, } => match mode { LocationMode::VertexInput => Ok(ResolvedBinding::Attribute(location)), LocationMode::FragmentOutput => { if blend_src.is_some() && self.lang_version < (1, 2) { return Err(Error::UnsupportedAttribute("blend_src".to_string())); } Ok(ResolvedBinding::Color { location, blend_src, }) } LocationMode::VertexOutput | LocationMode::FragmentInput => { Ok(ResolvedBinding::User { prefix: if self.spirv_cross_compatibility { "locn" } else { "loc" }, index: location, interpolation: { // unwrap: The verifier ensures that vertex shader outputs and fragment // shader inputs always have fully specified interpolation, and that // sampling is `None` only for Flat interpolation. let interpolation = interpolation.unwrap(); let sampling = sampling.unwrap_or(crate::Sampling::Center); Some(ResolvedInterpolation::from_binding(interpolation, sampling)) }, }) } LocationMode::Uniform => Err(Error::GenericValidation(format!( "Unexpected Binding::Location({location}) for the Uniform mode" ))), }, } } fn get_entry_point_resources(&self, ep: &crate::EntryPoint) -> Option<&EntryPointResources> { self.per_entry_point_map.get(&ep.name) } fn get_resource_binding_target( &self, ep: &crate::EntryPoint, res_binding: &crate::ResourceBinding, ) -> Option<&BindTarget> { self.get_entry_point_resources(ep) .and_then(|res| res.resources.get(res_binding)) } fn resolve_resource_binding( &self, ep: &crate::EntryPoint, res_binding: &crate::ResourceBinding, ) -> Result { let target = self.get_resource_binding_target(ep, res_binding); match target { Some(target) => Ok(ResolvedBinding::Resource(target.clone())), None if self.fake_missing_bindings => Ok(ResolvedBinding::User { prefix: "fake", index: 0, interpolation: None, }), None => Err(EntryPointError::MissingBindTarget(*res_binding)), } } fn resolve_immediates( &self, ep: &crate::EntryPoint, ) -> Result { let slot = self .get_entry_point_resources(ep) .and_then(|res| res.immediates_buffer); match slot { Some(slot) => Ok(ResolvedBinding::Resource(BindTarget { buffer: Some(slot), ..Default::default() })), None if self.fake_missing_bindings => Ok(ResolvedBinding::User { prefix: "fake", index: 0, interpolation: None, }), None => Err(EntryPointError::MissingImmediateData), } } fn resolve_sizes_buffer( &self, ep: &crate::EntryPoint, ) -> Result { let slot = self .get_entry_point_resources(ep) .and_then(|res| res.sizes_buffer); match slot { Some(slot) => Ok(ResolvedBinding::Resource(BindTarget { buffer: Some(slot), ..Default::default() })), None if self.fake_missing_bindings => Ok(ResolvedBinding::User { prefix: "fake", index: 0, interpolation: None, }), None => Err(EntryPointError::MissingSizesBuffer), } } } impl ResolvedBinding { fn as_inline_sampler<'a>(&self, options: &'a Options) -> Option<&'a sampler::InlineSampler> { match *self { Self::Resource(BindTarget { sampler: Some(BindSamplerTarget::Inline(index)), .. }) => Some(&options.inline_samplers[index as usize]), _ => None, } } fn try_fmt(&self, out: &mut W) -> Result<(), Error> { write!(out, " [[")?; match *self { Self::BuiltIn(built_in) => { use crate::BuiltIn as Bi; let name = match built_in { Bi::Position { invariant: false } => "position", Bi::Position { invariant: true } => "position, invariant", Bi::ViewIndex => "amplification_id", // vertex Bi::BaseInstance => "base_instance", Bi::BaseVertex => "base_vertex", Bi::ClipDistance => "clip_distance", Bi::InstanceIndex => "instance_id", Bi::PointSize => "point_size", Bi::VertexIndex => "vertex_id", // fragment Bi::FragDepth => "depth(any)", Bi::PointCoord => "point_coord", Bi::FrontFacing => "front_facing", Bi::PrimitiveIndex => "primitive_id", Bi::Barycentric { perspective: true } => "barycentric_coord", Bi::Barycentric { perspective: false } => { "barycentric_coord, center_no_perspective" } Bi::SampleIndex => "sample_id", Bi::SampleMask => "sample_mask", // compute Bi::GlobalInvocationId => "thread_position_in_grid", Bi::LocalInvocationId => "thread_position_in_threadgroup", Bi::LocalInvocationIndex => "thread_index_in_threadgroup", Bi::WorkGroupId => "threadgroup_position_in_grid", Bi::WorkGroupSize => "dispatch_threads_per_threadgroup", Bi::NumWorkGroups => "threadgroups_per_grid", // subgroup Bi::NumSubgroups => "simdgroups_per_threadgroup", Bi::SubgroupId => "simdgroup_index_in_threadgroup", Bi::SubgroupSize => "threads_per_simdgroup", Bi::SubgroupInvocationId => "thread_index_in_simdgroup", Bi::CullDistance | Bi::DrawIndex => { return Err(Error::UnsupportedBuiltIn(built_in)) } Bi::CullPrimitive => "primitive_culled", // TODO: figure out how to make this written as a function call Bi::PointIndex | Bi::LineIndices | Bi::TriangleIndices => unimplemented!(), Bi::MeshTaskSize | Bi::VertexCount | Bi::PrimitiveCount | Bi::Vertices | Bi::Primitives | Bi::RayInvocationId | Bi::NumRayInvocations | Bi::InstanceCustomData | Bi::GeometryIndex | Bi::WorldRayOrigin | Bi::WorldRayDirection | Bi::ObjectRayOrigin | Bi::ObjectRayDirection | Bi::RayTmin | Bi::RayTCurrentMax | Bi::ObjectToWorld | Bi::WorldToObject | Bi::HitKind => unreachable!(), }; write!(out, "{name}")?; } Self::Attribute(index) => write!(out, "attribute({index})")?, Self::Color { location, blend_src, } => { if let Some(blend_src) = blend_src { write!(out, "color({location}) index({blend_src})")? } else { write!(out, "color({location})")? } } Self::User { prefix, index, interpolation, } => { write!(out, "user({prefix}{index})")?; if let Some(interpolation) = interpolation { write!(out, ", ")?; interpolation.try_fmt(out)?; } } Self::Resource(ref target) => { if let Some(id) = target.buffer { write!(out, "buffer({id})")?; } else if let Some(id) = target.texture { write!(out, "texture({id})")?; } else if let Some(BindSamplerTarget::Resource(id)) = target.sampler { write!(out, "sampler({id})")?; } else { return Err(Error::UnimplementedBindTarget(target.clone())); } } } write!(out, "]]")?; Ok(()) } } impl ResolvedInterpolation { const fn from_binding(interpolation: crate::Interpolation, sampling: crate::Sampling) -> Self { use crate::Interpolation as I; use crate::Sampling as S; match (interpolation, sampling) { (I::Perspective, S::Center) => Self::CenterPerspective, (I::Perspective, S::Centroid) => Self::CentroidPerspective, (I::Perspective, S::Sample) => Self::SamplePerspective, (I::Linear, S::Center) => Self::CenterNoPerspective, (I::Linear, S::Centroid) => Self::CentroidNoPerspective, (I::Linear, S::Sample) => Self::SampleNoPerspective, (I::Flat, _) => Self::Flat, _ => unreachable!(), } } fn try_fmt(self, out: &mut W) -> Result<(), Error> { let identifier = match self { Self::CenterPerspective => "center_perspective", Self::CenterNoPerspective => "center_no_perspective", Self::CentroidPerspective => "centroid_perspective", Self::CentroidNoPerspective => "centroid_no_perspective", Self::SamplePerspective => "sample_perspective", Self::SampleNoPerspective => "sample_no_perspective", Self::Flat => "flat", }; out.write_str(identifier)?; Ok(()) } } /// Information about a translated module that is required /// for the use of the result. pub struct TranslationInfo { /// Mapping of the entry point names. Each item in the array /// corresponds to an entry point index. /// ///Note: Some entry points may fail translation because of missing bindings. pub entry_point_names: Vec>, } pub fn write_string( module: &crate::Module, info: &ModuleInfo, options: &Options, pipeline_options: &PipelineOptions, ) -> Result<(String, TranslationInfo), Error> { let mut w = Writer::new(String::new()); let info = w.write(module, info, options, pipeline_options)?; Ok((w.finish(), info)) } pub fn supported_capabilities() -> crate::valid::Capabilities { use crate::valid::Capabilities as Caps; Caps::IMMEDIATES // No FLOAT64 | Caps::PRIMITIVE_INDEX | Caps::TEXTURE_AND_SAMPLER_BINDING_ARRAY // No BUFFER_BINDING_ARRAY | Caps::STORAGE_TEXTURE_BINDING_ARRAY | Caps::STORAGE_BUFFER_BINDING_ARRAY | Caps::CLIP_DISTANCE // CLIP_DISTANCE isn't supported by metal backend? But is supported by MSL writer // No CULL_DISTANCE | Caps::STORAGE_TEXTURE_16BIT_NORM_FORMATS | Caps::MULTIVIEW // No EARLY_DEPTH_TEST | Caps::MULTISAMPLED_SHADING | Caps::RAY_QUERY | Caps::DUAL_SOURCE_BLENDING | Caps::CUBE_ARRAY_TEXTURES | Caps::SHADER_INT64 | Caps::SUBGROUP | Caps::SUBGROUP_BARRIER // No SUBGROUP_VERTEX_STAGE | Caps::SHADER_INT64_ATOMIC_MIN_MAX // No SHADER_INT64_ATOMIC_ALL_OPS | Caps::SHADER_FLOAT32_ATOMIC | Caps::TEXTURE_ATOMIC | Caps::TEXTURE_INT64_ATOMIC // No RAY_HIT_VERTEX_POSITION | Caps::SHADER_FLOAT16 | Caps::TEXTURE_EXTERNAL | Caps::SHADER_FLOAT16_IN_FLOAT32 | Caps::SHADER_BARYCENTRICS // No MESH_SHADER // No MESH_SHADER_POINT_TOPOLOGY | Caps::TEXTURE_AND_SAMPLER_BINDING_ARRAY_NON_UNIFORM_INDEXING // No BUFFER_BINDING_ARRAY_NON_UNIFORM_INDEXING | Caps::STORAGE_TEXTURE_BINDING_ARRAY_NON_UNIFORM_INDEXING | Caps::STORAGE_BUFFER_BINDING_ARRAY_NON_UNIFORM_INDEXING | Caps::COOPERATIVE_MATRIX // No PER_VERTEX // No RAY_TRACING_PIPELINE // No DRAW_INDEX // No MEMORY_DECORATION_VOLATILE | Caps::MEMORY_DECORATION_COHERENT } #[test] fn test_error_size() { assert_eq!(size_of::(), 40); } naga-29.0.3/src/back/msl/sampler.rs000064400000000000000000000076221046102023000151270ustar 00000000000000use core::{num::NonZeroU32, ops::Range}; #[cfg(feature = "deserialize")] use serde::Deserialize; #[cfg(feature = "serialize")] use serde::Serialize; #[derive(Clone, Copy, Default, Debug, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] pub enum Coord { #[default] Normalized, Pixel, } impl Coord { pub const fn as_str(&self) -> &'static str { match *self { Self::Normalized => "normalized", Self::Pixel => "pixel", } } } #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] pub enum Address { Repeat, MirroredRepeat, #[default] ClampToEdge, ClampToZero, ClampToBorder, } impl Address { pub const fn as_str(&self) -> &'static str { match *self { Self::Repeat => "repeat", Self::MirroredRepeat => "mirrored_repeat", Self::ClampToEdge => "clamp_to_edge", Self::ClampToZero => "clamp_to_zero", Self::ClampToBorder => "clamp_to_border", } } } #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] pub enum BorderColor { #[default] TransparentBlack, OpaqueBlack, OpaqueWhite, } impl BorderColor { pub const fn as_str(&self) -> &'static str { match *self { Self::TransparentBlack => "transparent_black", Self::OpaqueBlack => "opaque_black", Self::OpaqueWhite => "opaque_white", } } } #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] pub enum Filter { #[default] Nearest, Linear, } impl Filter { pub const fn as_str(&self) -> &'static str { match *self { Self::Nearest => "nearest", Self::Linear => "linear", } } } #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] pub enum CompareFunc { #[default] Never, Less, LessEqual, Greater, GreaterEqual, Equal, NotEqual, Always, } impl CompareFunc { pub const fn as_str(&self) -> &'static str { match *self { Self::Never => "never", Self::Less => "less", Self::LessEqual => "less_equal", Self::Greater => "greater", Self::GreaterEqual => "greater_equal", Self::Equal => "equal", Self::NotEqual => "not_equal", Self::Always => "always", } } } #[derive(Clone, Debug, Default, PartialEq)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] pub struct InlineSampler { pub coord: Coord, pub address: [Address; 3], pub border_color: BorderColor, pub mag_filter: Filter, pub min_filter: Filter, pub mip_filter: Option, pub lod_clamp: Option>, pub max_anisotropy: Option, pub compare_func: CompareFunc, } impl Eq for InlineSampler {} impl core::hash::Hash for InlineSampler { fn hash(&self, hasher: &mut H) { self.coord.hash(hasher); self.address.hash(hasher); self.border_color.hash(hasher); self.mag_filter.hash(hasher); self.min_filter.hash(hasher); self.mip_filter.hash(hasher); self.lod_clamp .as_ref() .map(|range| (range.start.to_bits(), range.end.to_bits())) .hash(hasher); self.max_anisotropy.hash(hasher); self.compare_func.hash(hasher); } } naga-29.0.3/src/back/msl/writer.rs000064400000000000000000012614471046102023000150100ustar 00000000000000use alloc::{ format, string::{String, ToString}, vec, vec::Vec, }; use core::{ cmp::Ordering, fmt::{Display, Error as FmtError, Formatter, Write}, iter, }; use num_traits::real::Real as _; use half::f16; use super::{sampler as sm, Error, LocationMode, Options, PipelineOptions, TranslationInfo}; use crate::{ arena::{Handle, HandleSet}, back::{self, get_entry_points, Baked}, common, proc::{ self, concrete_int_scalars, index::{self, BoundsCheck}, ExternalTextureNameKey, NameKey, TypeResolution, }, valid, FastHashMap, FastHashSet, }; #[cfg(test)] use core::ptr; /// Shorthand result used internally by the backend type BackendResult = Result<(), Error>; const NAMESPACE: &str = "metal"; // The name of the array member of the Metal struct types we generate to // represent Naga `Array` types. See the comments in `Writer::write_type_defs` // for details. const WRAPPED_ARRAY_FIELD: &str = "inner"; // This is a hack: we need to pass a pointer to an atomic, // but generally the backend isn't putting "&" in front of every pointer. // Some more general handling of pointers is needed to be implemented here. const ATOMIC_REFERENCE: &str = "&"; const RT_NAMESPACE: &str = "metal::raytracing"; const RAY_QUERY_TYPE: &str = "_RayQuery"; const RAY_QUERY_FIELD_INTERSECTOR: &str = "intersector"; const RAY_QUERY_FIELD_INTERSECTION: &str = "intersection"; const RAY_QUERY_MODERN_SUPPORT: bool = false; //TODO const RAY_QUERY_FIELD_READY: &str = "ready"; const RAY_QUERY_FUN_MAP_INTERSECTION: &str = "_map_intersection_type"; pub(crate) const ATOMIC_COMP_EXCH_FUNCTION: &str = "naga_atomic_compare_exchange_weak_explicit"; pub(crate) const MODF_FUNCTION: &str = "naga_modf"; pub(crate) const FREXP_FUNCTION: &str = "naga_frexp"; pub(crate) const ABS_FUNCTION: &str = "naga_abs"; pub(crate) const DIV_FUNCTION: &str = "naga_div"; pub(crate) const DOT_FUNCTION_PREFIX: &str = "naga_dot"; pub(crate) const MOD_FUNCTION: &str = "naga_mod"; pub(crate) const NEG_FUNCTION: &str = "naga_neg"; pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32"; pub(crate) const F2U32_FUNCTION: &str = "naga_f2u32"; pub(crate) const F2I64_FUNCTION: &str = "naga_f2i64"; pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64"; pub(crate) const IMAGE_LOAD_EXTERNAL_FUNCTION: &str = "nagaTextureLoadExternal"; pub(crate) const IMAGE_SIZE_EXTERNAL_FUNCTION: &str = "nagaTextureDimensionsExternal"; pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str = "nagaTextureSampleBaseClampToEdge"; /// For some reason, Metal does not let you have `metal::texture<..>*` as a buffer argument. /// However, if you put that texture inside a struct, everything is totally fine. This /// baffles me to no end. /// /// As such, we wrap all argument buffers in a struct that has a single generic `` field. /// This allows `NagaArgumentBufferWrapper>*` to work. The astute among /// you have noticed that this should be exactly the same to the compiler, and you're correct. pub(crate) const ARGUMENT_BUFFER_WRAPPER_STRUCT: &str = "NagaArgumentBufferWrapper"; /// Name of the struct that is declared to wrap the 3 textures and parameters /// buffer that [`crate::ImageClass::External`] variables are lowered to, /// allowing them to be conveniently passed to user-defined or wrapper /// functions. The struct is declared in [`Writer::write_type_defs`]. pub(crate) const EXTERNAL_TEXTURE_WRAPPER_STRUCT: &str = "NagaExternalTextureWrapper"; pub(crate) const COOPERATIVE_LOAD_FUNCTION: &str = "NagaCooperativeLoad"; pub(crate) const COOPERATIVE_MULTIPLY_ADD_FUNCTION: &str = "NagaCooperativeMultiplyAdd"; /// Write the Metal name for a Naga numeric type: scalar, vector, or matrix. /// /// The `sizes` slice determines whether this function writes a /// scalar, vector, or matrix type: /// /// - An empty slice produces a scalar type. /// - A one-element slice produces a vector type. /// - A two element slice `[ROWS COLUMNS]` produces a matrix of the given size. fn put_numeric_type( out: &mut impl Write, scalar: crate::Scalar, sizes: &[crate::VectorSize], ) -> Result<(), FmtError> { match (scalar, sizes) { (scalar, &[]) => { write!(out, "{}", scalar.to_msl_name()) } (scalar, &[rows]) => { write!( out, "{}::{}{}", NAMESPACE, scalar.to_msl_name(), common::vector_size_str(rows) ) } (scalar, &[rows, columns]) => { write!( out, "{}::{}{}x{}", NAMESPACE, scalar.to_msl_name(), common::vector_size_str(columns), common::vector_size_str(rows) ) } (_, _) => Ok(()), // not meaningful } } const fn scalar_is_int(scalar: crate::Scalar) -> bool { use crate::ScalarKind::*; match scalar.kind { Sint | Uint | AbstractInt | Bool => true, Float | AbstractFloat => false, } } /// Prefix for cached clamped level-of-detail values for `ImageLoad` expressions. const CLAMPED_LOD_LOAD_PREFIX: &str = "clamped_lod_e"; /// Prefix for reinterpreted expressions using `as_type(...)`. const REINTERPRET_PREFIX: &str = "reinterpreted_"; /// Wrapper for identifier names for clamped level-of-detail values /// /// Values of this type implement [`core::fmt::Display`], formatting as /// the name of the variable used to hold the cached clamped /// level-of-detail value for an `ImageLoad` expression. struct ClampedLod(Handle); impl Display for ClampedLod { fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { self.0.write_prefixed(f, CLAMPED_LOD_LOAD_PREFIX) } } /// Wrapper for generating `struct _mslBufferSizes` member names for /// runtime-sized array lengths. /// /// On Metal, `wgpu_hal` passes the element counts for all runtime-sized arrays /// as an argument to the entry point. This argument's type in the MSL is /// `struct _mslBufferSizes`, a Naga-synthesized struct with a `uint` member for /// each global variable containing a runtime-sized array. /// /// If `global` is a [`Handle`] for a [`GlobalVariable`] that contains a /// runtime-sized array, then the value `ArraySize(global)` implements /// [`core::fmt::Display`], formatting as the name of the struct member carrying /// the number of elements in that runtime-sized array. /// /// [`GlobalVariable`]: crate::GlobalVariable struct ArraySizeMember(Handle); impl Display for ArraySizeMember { fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { self.0.write_prefixed(f, "size") } } /// Wrapper for reinterpreted variables using `as_type(orig)`. /// /// Implements [`core::fmt::Display`], formatting as a name derived from /// `target_type` and the variable name of `orig`. #[derive(Clone, Copy)] struct Reinterpreted<'a> { target_type: &'a str, orig: Handle, } impl<'a> Reinterpreted<'a> { const fn new(target_type: &'a str, orig: Handle) -> Self { Self { target_type, orig } } } impl Display for Reinterpreted<'_> { fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { f.write_str(REINTERPRET_PREFIX)?; f.write_str(self.target_type)?; self.orig.write_prefixed(f, "_e") } } struct TypeContext<'a> { handle: Handle, gctx: proc::GlobalCtx<'a>, names: &'a FastHashMap, access: crate::StorageAccess, first_time: bool, } impl TypeContext<'_> { fn scalar(&self) -> Option { let ty = &self.gctx.types[self.handle]; ty.inner.scalar() } fn vector_size(&self) -> Option { let ty = &self.gctx.types[self.handle]; match ty.inner { crate::TypeInner::Vector { size, .. } => Some(size), _ => None, } } } impl Display for TypeContext<'_> { fn fmt(&self, out: &mut Formatter<'_>) -> Result<(), FmtError> { let ty = &self.gctx.types[self.handle]; if ty.needs_alias() && !self.first_time { let name = &self.names[&NameKey::Type(self.handle)]; return write!(out, "{name}"); } match ty.inner { crate::TypeInner::Scalar(scalar) => put_numeric_type(out, scalar, &[]), crate::TypeInner::Atomic(scalar) => { write!(out, "{}::atomic_{}", NAMESPACE, scalar.to_msl_name()) } crate::TypeInner::Vector { size, scalar } => put_numeric_type(out, scalar, &[size]), crate::TypeInner::Matrix { columns, rows, scalar, } => put_numeric_type(out, scalar, &[rows, columns]), // Requires Metal-2.3 crate::TypeInner::CooperativeMatrix { columns, rows, scalar, role: _, } => { write!( out, "{NAMESPACE}::simdgroup_{}{}x{}", scalar.to_msl_name(), columns as u32, rows as u32, ) } crate::TypeInner::Pointer { base, space } => { let sub = Self { handle: base, first_time: false, ..*self }; let space_name = match space.to_msl_name() { Some(name) => name, None => return Ok(()), }; write!(out, "{space_name} {sub}&") } crate::TypeInner::ValuePointer { size, scalar, space, } => { match space.to_msl_name() { Some(name) => write!(out, "{name} ")?, None => return Ok(()), }; match size { Some(rows) => put_numeric_type(out, scalar, &[rows])?, None => put_numeric_type(out, scalar, &[])?, }; write!(out, "&") } crate::TypeInner::Array { base, .. } => { let sub = Self { handle: base, first_time: false, ..*self }; // Array lengths go at the end of the type definition, // so just print the element type here. write!(out, "{sub}") } crate::TypeInner::Struct { .. } => unreachable!(), crate::TypeInner::Image { dim, arrayed, class, } => { let dim_str = match dim { crate::ImageDimension::D1 => "1d", crate::ImageDimension::D2 => "2d", crate::ImageDimension::D3 => "3d", crate::ImageDimension::Cube => "cube", }; let (texture_str, msaa_str, scalar, access) = match class { crate::ImageClass::Sampled { kind, multi } => { let (msaa_str, access) = if multi { ("_ms", "read") } else { ("", "sample") }; let scalar = crate::Scalar { kind, width: 4 }; ("texture", msaa_str, scalar, access) } crate::ImageClass::Depth { multi } => { let (msaa_str, access) = if multi { ("_ms", "read") } else { ("", "sample") }; let scalar = crate::Scalar { kind: crate::ScalarKind::Float, width: 4, }; ("depth", msaa_str, scalar, access) } crate::ImageClass::Storage { format, .. } => { let access = if self .access .contains(crate::StorageAccess::LOAD | crate::StorageAccess::STORE) { "read_write" } else if self.access.contains(crate::StorageAccess::STORE) { "write" } else if self.access.contains(crate::StorageAccess::LOAD) { "read" } else { log::warn!( "Storage access for {:?} (name '{}'): {:?}", self.handle, ty.name.as_deref().unwrap_or_default(), self.access ); unreachable!("module is not valid"); }; ("texture", "", format.into(), access) } crate::ImageClass::External => { return write!(out, "{EXTERNAL_TEXTURE_WRAPPER_STRUCT}"); } }; let base_name = scalar.to_msl_name(); let array_str = if arrayed { "_array" } else { "" }; write!( out, "{NAMESPACE}::{texture_str}{dim_str}{msaa_str}{array_str}<{base_name}, {NAMESPACE}::access::{access}>", ) } crate::TypeInner::Sampler { comparison: _ } => { write!(out, "{NAMESPACE}::sampler") } crate::TypeInner::AccelerationStructure { vertex_return } => { if vertex_return { unimplemented!("metal does not support vertex ray hit return") } write!(out, "{RT_NAMESPACE}::instance_acceleration_structure") } crate::TypeInner::RayQuery { vertex_return } => { if vertex_return { unimplemented!("metal does not support vertex ray hit return") } write!(out, "{RAY_QUERY_TYPE}") } crate::TypeInner::BindingArray { base, .. } => { let base_tyname = Self { handle: base, first_time: false, ..*self }; write!( out, "constant {ARGUMENT_BUFFER_WRAPPER_STRUCT}<{base_tyname}>*" ) } } } } struct TypedGlobalVariable<'a> { module: &'a crate::Module, names: &'a FastHashMap, handle: Handle, usage: valid::GlobalUse, reference: bool, } impl TypedGlobalVariable<'_> { fn try_fmt(&self, out: &mut W) -> BackendResult { let var = &self.module.global_variables[self.handle]; let name = &self.names[&NameKey::GlobalVariable(self.handle)]; let storage_access = match var.space { crate::AddressSpace::Storage { access } => access, _ => match self.module.types[var.ty].inner { crate::TypeInner::Image { class: crate::ImageClass::Storage { access, .. }, .. } => access, crate::TypeInner::BindingArray { base, .. } => { match self.module.types[base].inner { crate::TypeInner::Image { class: crate::ImageClass::Storage { access, .. }, .. } => access, _ => crate::StorageAccess::default(), } } _ => crate::StorageAccess::default(), }, }; let ty_name = TypeContext { handle: var.ty, gctx: self.module.to_ctx(), names: self.names, access: storage_access, first_time: false, }; let (coherent, space, access, reference) = match var.space.to_msl_name() { Some(space) if self.reference => { let coherent = if var .memory_decorations .contains(crate::MemoryDecorations::COHERENT) { "coherent " } else { "" }; let access = if var.space.needs_access_qualifier() && !self.usage.intersects(valid::GlobalUse::WRITE) { "const" } else { "" }; (coherent, space, access, "&") } _ => ("", "", "", ""), }; Ok(write!( out, "{}{}{}{}{}{}{} {}", coherent, space, if space.is_empty() { "" } else { " " }, ty_name, if access.is_empty() { "" } else { " " }, access, reference, name, )?) } } #[derive(Eq, PartialEq, Hash)] enum WrappedFunction { UnaryOp { op: crate::UnaryOperator, ty: (Option, crate::Scalar), }, BinaryOp { op: crate::BinaryOperator, left_ty: (Option, crate::Scalar), right_ty: (Option, crate::Scalar), }, Math { fun: crate::MathFunction, arg_ty: (Option, crate::Scalar), }, Cast { src_scalar: crate::Scalar, vector_size: Option, dst_scalar: crate::Scalar, }, ImageLoad { class: crate::ImageClass, }, ImageSample { class: crate::ImageClass, clamp_to_edge: bool, }, ImageQuerySize { class: crate::ImageClass, }, CooperativeLoad { space_name: &'static str, columns: crate::CooperativeSize, rows: crate::CooperativeSize, scalar: crate::Scalar, }, CooperativeMultiplyAdd { space_name: &'static str, columns: crate::CooperativeSize, rows: crate::CooperativeSize, intermediate: crate::CooperativeSize, scalar: crate::Scalar, }, } pub struct Writer { out: W, names: FastHashMap, named_expressions: crate::NamedExpressions, /// Set of expressions that need to be baked to avoid unnecessary repetition in output need_bake_expressions: back::NeedBakeExpressions, namer: proc::Namer, wrapped_functions: FastHashSet, #[cfg(test)] put_expression_stack_pointers: FastHashSet<*const ()>, #[cfg(test)] put_block_stack_pointers: FastHashSet<*const ()>, /// Set of (struct type, struct field index) denoting which fields require /// padding inserted **before** them (i.e. between fields at index - 1 and index) struct_member_pads: FastHashSet<(Handle, u32)>, } impl crate::Scalar { pub(super) fn to_msl_name(self) -> &'static str { use crate::ScalarKind as Sk; match self { Self { kind: Sk::Float, width: 4, } => "float", Self { kind: Sk::Float, width: 2, } => "half", Self { kind: Sk::Sint, width: 4, } => "int", Self { kind: Sk::Uint, width: 4, } => "uint", Self { kind: Sk::Sint, width: 8, } => "long", Self { kind: Sk::Uint, width: 8, } => "ulong", Self { kind: Sk::Bool, width: _, } => "bool", Self { kind: Sk::AbstractInt | Sk::AbstractFloat, width: _, } => unreachable!("Found Abstract scalar kind"), _ => unreachable!("Unsupported scalar kind: {:?}", self), } } } const fn separate(need_separator: bool) -> &'static str { if need_separator { "," } else { "" } } fn should_pack_struct_member( members: &[crate::StructMember], span: u32, index: usize, module: &crate::Module, ) -> Option { let member = &members[index]; let ty_inner = &module.types[member.ty].inner; let last_offset = member.offset + ty_inner.size(module.to_ctx()); let next_offset = match members.get(index + 1) { Some(next) => next.offset, None => span, }; let is_tight = next_offset == last_offset; match *ty_inner { crate::TypeInner::Vector { size: crate::VectorSize::Tri, scalar: scalar @ crate::Scalar { width: 4 | 2, .. }, } if is_tight => Some(scalar), _ => None, } } fn needs_array_length(ty: Handle, arena: &crate::UniqueArena) -> bool { match arena[ty].inner { crate::TypeInner::Struct { ref members, .. } => { if let Some(member) = members.last() { if let crate::TypeInner::Array { size: crate::ArraySize::Dynamic, .. } = arena[member.ty].inner { return true; } } false } crate::TypeInner::Array { size: crate::ArraySize::Dynamic, .. } => true, _ => false, } } impl crate::AddressSpace { /// Returns true if global variables in this address space are /// passed in function arguments. These arguments need to be /// passed through any functions called from the entry point. const fn needs_pass_through(&self) -> bool { match *self { Self::Uniform | Self::Storage { .. } | Self::Private | Self::WorkGroup | Self::Immediate | Self::Handle | Self::TaskPayload => true, Self::Function => false, Self::RayPayload | Self::IncomingRayPayload => unreachable!(), } } /// Returns true if the address space may need a "const" qualifier. const fn needs_access_qualifier(&self) -> bool { match *self { //Note: we are ignoring the storage access here, and instead // rely on the actual use of a global by functions. This means we // may end up with "const" even if the binding is read-write, // and that should be OK. Self::Storage { .. } => true, Self::TaskPayload | Self::RayPayload | Self::IncomingRayPayload => unimplemented!(), // These should always be read-write. Self::Private | Self::WorkGroup => false, // These translate to `constant` address space, no need for qualifiers. Self::Uniform | Self::Immediate => false, // Not applicable. Self::Handle | Self::Function => false, } } const fn to_msl_name(self) -> Option<&'static str> { match self { Self::Handle => None, Self::Uniform | Self::Immediate => Some("constant"), Self::Storage { .. } => Some("device"), // note for `RayPayload`, this probably needs to be emulated as a // private variable, as metal has essentially an inout input // for where it is passed. Self::Private | Self::Function | Self::RayPayload => Some("thread"), Self::WorkGroup => Some("threadgroup"), Self::TaskPayload => Some("object_data"), Self::IncomingRayPayload => Some("ray_data"), } } } impl crate::Type { // Returns `true` if we need to emit an alias for this type. const fn needs_alias(&self) -> bool { use crate::TypeInner as Ti; match self.inner { // value types are concise enough, we only alias them if they are named Ti::Scalar(_) | Ti::Vector { .. } | Ti::Matrix { .. } | Ti::CooperativeMatrix { .. } | Ti::Atomic(_) | Ti::Pointer { .. } | Ti::ValuePointer { .. } => self.name.is_some(), // composite types are better to be aliased, regardless of the name Ti::Struct { .. } | Ti::Array { .. } => true, // handle types may be different, depending on the global var access, so we always inline them Ti::Image { .. } | Ti::Sampler { .. } | Ti::AccelerationStructure { .. } | Ti::RayQuery { .. } | Ti::BindingArray { .. } => false, } } } #[derive(Clone, Copy)] enum FunctionOrigin { Handle(Handle), EntryPoint(proc::EntryPointIndex), } trait NameKeyExt { fn local(origin: FunctionOrigin, local_handle: Handle) -> NameKey { match origin { FunctionOrigin::Handle(handle) => NameKey::FunctionLocal(handle, local_handle), FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointLocal(idx, local_handle), } } /// Return the name key for a local variable used by ReadZeroSkipWrite bounds-check /// policy when it needs to produce a pointer-typed result for an OOB access. These /// are unique per accessed type, so the second argument is a type handle. See docs /// for [`crate::back::msl`]. fn oob_local_for_type(origin: FunctionOrigin, ty: Handle) -> NameKey { match origin { FunctionOrigin::Handle(handle) => NameKey::FunctionOobLocal(handle, ty), FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointOobLocal(idx, ty), } } } impl NameKeyExt for NameKey {} /// A level of detail argument. /// /// When [`BoundsCheckPolicy::Restrict`] applies to an [`ImageLoad`] access, we /// save the clamped level of detail in a temporary variable whose name is based /// on the handle of the `ImageLoad` expression. But for other policies, we just /// use the expression directly. /// /// [`BoundsCheckPolicy::Restrict`]: index::BoundsCheckPolicy::Restrict /// [`ImageLoad`]: crate::Expression::ImageLoad #[derive(Clone, Copy)] enum LevelOfDetail { Direct(Handle), Restricted(Handle), } /// Values needed to select a particular texel for [`ImageLoad`] and [`ImageStore`]. /// /// When this is used in code paths unconcerned with the `Restrict` bounds check /// policy, the `LevelOfDetail` enum introduces an unneeded match, since `level` /// will always be either `None` or `Some(Direct(_))`. But this turns out not to /// be too awkward. If that changes, we can revisit. /// /// [`ImageLoad`]: crate::Expression::ImageLoad /// [`ImageStore`]: crate::Statement::ImageStore struct TexelAddress { coordinate: Handle, array_index: Option>, sample: Option>, level: Option, } struct ExpressionContext<'a> { function: &'a crate::Function, origin: FunctionOrigin, info: &'a valid::FunctionInfo, module: &'a crate::Module, mod_info: &'a valid::ModuleInfo, pipeline_options: &'a PipelineOptions, lang_version: (u8, u8), policies: index::BoundsCheckPolicies, /// The set of expressions used as indices in `ReadZeroSkipWrite`-policy /// accesses. These may need to be cached in temporary variables. See /// `index::find_checked_indexes` for details. guarded_indices: HandleSet, /// See [`Writer::gen_force_bounded_loop_statements`] for details. force_loop_bounding: bool, } impl<'a> ExpressionContext<'a> { fn resolve_type(&self, handle: Handle) -> &'a crate::TypeInner { self.info[handle].ty.inner_with(&self.module.types) } /// Return true if calls to `image`'s `read` and `write` methods should supply a level of detail. /// /// Only mipmapped images need to specify a level of detail. Since 1D /// textures cannot have mipmaps, MSL requires that the level argument to /// texture1d queries and accesses must be a constexpr 0. It's easiest /// just to omit the level entirely for 1D textures. fn image_needs_lod(&self, image: Handle) -> bool { let image_ty = self.resolve_type(image); if let crate::TypeInner::Image { dim, class, .. } = *image_ty { class.is_mipmapped() && dim != crate::ImageDimension::D1 } else { false } } fn choose_bounds_check_policy( &self, pointer: Handle, ) -> index::BoundsCheckPolicy { self.policies .choose_policy(pointer, &self.module.types, self.info) } /// See docs for [`proc::index::access_needs_check`]. fn access_needs_check( &self, base: Handle, index: index::GuardedIndex, ) -> Option { index::access_needs_check( base, index, self.module, &self.function.expressions, self.info, ) } /// See docs for [`proc::index::bounds_check_iter`]. fn bounds_check_iter( &self, chain: Handle, ) -> impl Iterator + '_ { index::bounds_check_iter(chain, self.module, self.function, self.info) } /// See docs for [`proc::index::oob_local_types`]. fn oob_local_types(&self) -> FastHashSet> { index::oob_local_types(self.module, self.function, self.info, self.policies) } fn get_packed_vec_kind(&self, expr_handle: Handle) -> Option { match self.function.expressions[expr_handle] { crate::Expression::AccessIndex { base, index } => { let ty = match *self.resolve_type(base) { crate::TypeInner::Pointer { base, .. } => &self.module.types[base].inner, ref ty => ty, }; match *ty { crate::TypeInner::Struct { ref members, span, .. } => should_pack_struct_member(members, span, index as usize, self.module), _ => None, } } _ => None, } } } struct StatementContext<'a> { expression: ExpressionContext<'a>, result_struct: Option<&'a str>, } impl Writer { /// Creates a new `Writer` instance. pub fn new(out: W) -> Self { Writer { out, names: FastHashMap::default(), named_expressions: Default::default(), need_bake_expressions: Default::default(), namer: proc::Namer::default(), wrapped_functions: FastHashSet::default(), #[cfg(test)] put_expression_stack_pointers: Default::default(), #[cfg(test)] put_block_stack_pointers: Default::default(), struct_member_pads: FastHashSet::default(), } } /// Finishes writing and returns the output. // See https://github.com/rust-lang/rust-clippy/issues/4979. pub fn finish(self) -> W { self.out } /// Generates statements to be inserted immediately before and at the very /// start of the body of each loop, to defeat MSL infinite loop reasoning. /// The 0th item of the returned tuple should be inserted immediately prior /// to the loop and the 1st item should be inserted at the very start of /// the loop body. /// /// # What is this trying to solve? /// /// In Metal Shading Language, an infinite loop has undefined behavior. /// (This rule is inherited from C++14.) This means that, if the MSL /// compiler determines that a given loop will never exit, it may assume /// that it is never reached. It may thus assume that any conditions /// sufficient to cause the loop to be reached must be false. Like many /// optimizing compilers, MSL uses this kind of analysis to establish limits /// on the range of values variables involved in those conditions might /// hold. /// /// For example, suppose the MSL compiler sees the code: /// /// ```ignore /// if (i >= 10) { /// while (true) { } /// } /// ``` /// /// It will recognize that the `while` loop will never terminate, conclude /// that it must be unreachable, and thus infer that, if this code is /// reached, then `i < 10` at that point. /// /// Now suppose that, at some point where `i` has the same value as above, /// the compiler sees the code: /// /// ```ignore /// if (i < 10) { /// a[i] = 1; /// } /// ``` /// /// Because the compiler is confident that `i < 10`, it will make the /// assignment to `a[i]` unconditional, rewriting this code as, simply: /// /// ```ignore /// a[i] = 1; /// ``` /// /// If that `if` condition was injected by Naga to implement a bounds check, /// the MSL compiler's optimizations could allow out-of-bounds array /// accesses to occur. /// /// Naga cannot feasibly anticipate whether the MSL compiler will determine /// that a loop is infinite, so an attacker could craft a Naga module /// containing an infinite loop protected by conditions that cause the Metal /// compiler to remove bounds checks that Naga injected elsewhere in the /// function. /// /// This rewrite could occur even if the conditional assignment appears /// *before* the `while` loop, as long as `i < 10` by the time the loop is /// reached. This would allow the attacker to save the results of /// unauthorized reads somewhere accessible before entering the infinite /// loop. But even worse, the MSL compiler has been observed to simply /// delete the infinite loop entirely, so that even code dominated by the /// loop becomes reachable. This would make the attack even more flexible, /// since shaders that would appear to never terminate would actually exit /// nicely, after having stolen data from elsewhere in the GPU address /// space. /// /// To avoid UB, Naga must persuade the MSL compiler that no loop Naga /// generates is infinite. One approach would be to add inline assembly to /// each loop that is annotated as potentially branching out of the loop, /// but which in fact generates no instructions. Unfortunately, inline /// assembly is not handled correctly by some Metal device drivers. /// /// A previously used approach was to add the following code to the bottom /// of every loop: /// /// ```ignore /// if (volatile bool unpredictable = false; unpredictable) /// break; /// ``` /// /// Although the `if` condition will always be false in any real execution, /// the `volatile` qualifier prevents the compiler from assuming this. Thus, /// it must assume that the `break` might be reached, and hence that the /// loop is not unbounded. This prevents the range analysis impact described /// above. Unfortunately this prevented the compiler from making important, /// and safe, optimizations such as loop unrolling and was observed to /// significantly hurt performance. /// /// Our current approach declares a counter before every loop and /// increments it every iteration, breaking after 2^64 iterations: /// /// ```ignore /// uint2 loop_bound = uint2(0); /// while (true) { /// if (metal::all(loop_bound == uint2(4294967295))) { break; } /// loop_bound += uint2(loop_bound.y == 4294967295, 1); /// } /// ``` /// /// This convinces the compiler that the loop is finite and therefore may /// execute, whilst at the same time allowing optimizations such as loop /// unrolling. Furthermore the 64-bit counter is large enough it seems /// implausible that it would affect the execution of any shader. /// /// This approach is also used by Chromium WebGPU's Dawn shader compiler: /// fn gen_force_bounded_loop_statements( &mut self, level: back::Level, context: &StatementContext, ) -> Option<(String, String)> { if !context.expression.force_loop_bounding { return None; } let loop_bound_name = self.namer.call("loop_bound"); // Count down from u32::MAX rather than up from 0 to avoid hang on // certain Intel drivers. See . let decl = format!("{level}uint2 {loop_bound_name} = uint2({}u);", u32::MAX); let level = level.next(); let break_and_inc = format!( "{level}if ({NAMESPACE}::all({loop_bound_name} == uint2(0u))) {{ break; }} {level}{loop_bound_name} -= uint2({loop_bound_name}.y == 0u, 1u);" ); Some((decl, break_and_inc)) } fn put_call_parameters( &mut self, parameters: impl Iterator>, context: &ExpressionContext, ) -> BackendResult { self.put_call_parameters_impl(parameters, context, |writer, context, expr| { writer.put_expression(expr, context, true) }) } fn put_call_parameters_impl( &mut self, parameters: impl Iterator>, ctx: &C, put_expression: E, ) -> BackendResult where E: Fn(&mut Self, &C, Handle) -> BackendResult, { write!(self.out, "(")?; for (i, handle) in parameters.enumerate() { if i != 0 { write!(self.out, ", ")?; } put_expression(self, ctx, handle)?; } write!(self.out, ")")?; Ok(()) } /// Writes the local variables of the given function, as well as any extra /// out-of-bounds locals that are needed. /// /// The names of the OOB locals are also added to `self.names` at the same /// time. fn put_locals(&mut self, context: &ExpressionContext) -> BackendResult { let oob_local_types = context.oob_local_types(); for &ty in oob_local_types.iter() { let name_key = NameKey::oob_local_for_type(context.origin, ty); self.names.insert(name_key, self.namer.call("oob")); } for (name_key, ty, init) in context .function .local_variables .iter() .map(|(local_handle, local)| { let name_key = NameKey::local(context.origin, local_handle); (name_key, local.ty, local.init) }) .chain(oob_local_types.iter().map(|&ty| { let name_key = NameKey::oob_local_for_type(context.origin, ty); (name_key, ty, None) })) { let ty_name = TypeContext { handle: ty, gctx: context.module.to_ctx(), names: &self.names, access: crate::StorageAccess::empty(), first_time: false, }; write!( self.out, "{}{} {}", back::INDENT, ty_name, self.names[&name_key] )?; match init { Some(value) => { write!(self.out, " = ")?; self.put_expression(value, context, true)?; } None => { write!(self.out, " = {{}}")?; } }; writeln!(self.out, ";")?; } Ok(()) } fn put_level_of_detail( &mut self, level: LevelOfDetail, context: &ExpressionContext, ) -> BackendResult { match level { LevelOfDetail::Direct(expr) => self.put_expression(expr, context, true)?, LevelOfDetail::Restricted(load) => write!(self.out, "{}", ClampedLod(load))?, } Ok(()) } fn put_image_query( &mut self, image: Handle, query: &str, level: Option, context: &ExpressionContext, ) -> BackendResult { self.put_expression(image, context, false)?; write!(self.out, ".get_{query}(")?; if let Some(level) = level { self.put_level_of_detail(level, context)?; } write!(self.out, ")")?; Ok(()) } fn put_image_size_query( &mut self, image: Handle, level: Option, kind: crate::ScalarKind, context: &ExpressionContext, ) -> BackendResult { if let crate::TypeInner::Image { class: crate::ImageClass::External, .. } = *context.resolve_type(image) { write!(self.out, "{IMAGE_SIZE_EXTERNAL_FUNCTION}(")?; self.put_expression(image, context, true)?; write!(self.out, ")")?; return Ok(()); } //Note: MSL only has separate width/height/depth queries, // so compose the result of them. let dim = match *context.resolve_type(image) { crate::TypeInner::Image { dim, .. } => dim, ref other => unreachable!("Unexpected type {:?}", other), }; let scalar = crate::Scalar { kind, width: 4 }; let coordinate_type = scalar.to_msl_name(); match dim { crate::ImageDimension::D1 => { // Since 1D textures never have mipmaps, MSL requires that the // `level` argument be a constexpr 0. It's simplest for us just // to pass `None` and omit the level entirely. if kind == crate::ScalarKind::Uint { // No need to construct a vector. No cast needed. self.put_image_query(image, "width", None, context)?; } else { // There's no definition for `int` in the `metal` namespace. write!(self.out, "int(")?; self.put_image_query(image, "width", None, context)?; write!(self.out, ")")?; } } crate::ImageDimension::D2 => { write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?; self.put_image_query(image, "width", level, context)?; write!(self.out, ", ")?; self.put_image_query(image, "height", level, context)?; write!(self.out, ")")?; } crate::ImageDimension::D3 => { write!(self.out, "{NAMESPACE}::{coordinate_type}3(")?; self.put_image_query(image, "width", level, context)?; write!(self.out, ", ")?; self.put_image_query(image, "height", level, context)?; write!(self.out, ", ")?; self.put_image_query(image, "depth", level, context)?; write!(self.out, ")")?; } crate::ImageDimension::Cube => { write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?; self.put_image_query(image, "width", level, context)?; write!(self.out, ")")?; } } Ok(()) } fn put_cast_to_uint_scalar_or_vector( &mut self, expr: Handle, context: &ExpressionContext, ) -> BackendResult { // coordinates in IR are int, but Metal expects uint match *context.resolve_type(expr) { crate::TypeInner::Scalar(_) => { put_numeric_type(&mut self.out, crate::Scalar::U32, &[])? } crate::TypeInner::Vector { size, .. } => { put_numeric_type(&mut self.out, crate::Scalar::U32, &[size])? } _ => { return Err(Error::GenericValidation( "Invalid type for image coordinate".into(), )) } }; write!(self.out, "(")?; self.put_expression(expr, context, true)?; write!(self.out, ")")?; Ok(()) } fn put_image_sample_level( &mut self, image: Handle, level: crate::SampleLevel, context: &ExpressionContext, ) -> BackendResult { let has_levels = context.image_needs_lod(image); match level { crate::SampleLevel::Auto => {} crate::SampleLevel::Zero => { //TODO: do we support Zero on `Sampled` image classes? } _ if !has_levels => { log::warn!("1D image can't be sampled with level {level:?}"); } crate::SampleLevel::Exact(h) => { write!(self.out, ", {NAMESPACE}::level(")?; self.put_expression(h, context, true)?; write!(self.out, ")")?; } crate::SampleLevel::Bias(h) => { write!(self.out, ", {NAMESPACE}::bias(")?; self.put_expression(h, context, true)?; write!(self.out, ")")?; } crate::SampleLevel::Gradient { x, y } => { write!(self.out, ", {NAMESPACE}::gradient2d(")?; self.put_expression(x, context, true)?; write!(self.out, ", ")?; self.put_expression(y, context, true)?; write!(self.out, ")")?; } } Ok(()) } fn put_image_coordinate_limits( &mut self, image: Handle, level: Option, context: &ExpressionContext, ) -> BackendResult { self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?; write!(self.out, " - 1")?; Ok(()) } /// General function for writing restricted image indexes. /// /// This is used to produce restricted mip levels, array indices, and sample /// indices for [`ImageLoad`] and [`ImageStore`] accesses under the /// [`Restrict`] bounds check policy. /// /// This function writes an expression of the form: /// /// ```ignore /// /// metal::min(uint(INDEX), IMAGE.LIMIT_METHOD() - 1) /// /// ``` /// /// [`ImageLoad`]: crate::Expression::ImageLoad /// [`ImageStore`]: crate::Statement::ImageStore /// [`Restrict`]: index::BoundsCheckPolicy::Restrict fn put_restricted_scalar_image_index( &mut self, image: Handle, index: Handle, limit_method: &str, context: &ExpressionContext, ) -> BackendResult { write!(self.out, "{NAMESPACE}::min(uint(")?; self.put_expression(index, context, true)?; write!(self.out, "), ")?; self.put_expression(image, context, false)?; write!(self.out, ".{limit_method}() - 1)")?; Ok(()) } fn put_restricted_texel_address( &mut self, image: Handle, address: &TexelAddress, context: &ExpressionContext, ) -> BackendResult { // Write the coordinate. write!(self.out, "{NAMESPACE}::min(")?; self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?; write!(self.out, ", ")?; self.put_image_coordinate_limits(image, address.level, context)?; write!(self.out, ")")?; // Write the array index, if present. if let Some(array_index) = address.array_index { write!(self.out, ", ")?; self.put_restricted_scalar_image_index(image, array_index, "get_array_size", context)?; } // Write the sample index, if present. if let Some(sample) = address.sample { write!(self.out, ", ")?; self.put_restricted_scalar_image_index(image, sample, "get_num_samples", context)?; } // The level of detail should be clamped and cached by // `put_cache_restricted_level`, so we don't need to clamp it here. if let Some(level) = address.level { write!(self.out, ", ")?; self.put_level_of_detail(level, context)?; } Ok(()) } /// Write an expression that is true if the given image access is in bounds. fn put_image_access_bounds_check( &mut self, image: Handle, address: &TexelAddress, context: &ExpressionContext, ) -> BackendResult { let mut conjunction = ""; // First, check the level of detail. Only if that is in bounds can we // use it to find the appropriate bounds for the coordinates. let level = if let Some(level) = address.level { write!(self.out, "uint(")?; self.put_level_of_detail(level, context)?; write!(self.out, ") < ")?; self.put_expression(image, context, true)?; write!(self.out, ".get_num_mip_levels()")?; conjunction = " && "; Some(level) } else { None }; // Check sample index, if present. if let Some(sample) = address.sample { write!(self.out, "uint(")?; self.put_expression(sample, context, true)?; write!(self.out, ") < ")?; self.put_expression(image, context, true)?; write!(self.out, ".get_num_samples()")?; conjunction = " && "; } // Check array index, if present. if let Some(array_index) = address.array_index { write!(self.out, "{conjunction}uint(")?; self.put_expression(array_index, context, true)?; write!(self.out, ") < ")?; self.put_expression(image, context, true)?; write!(self.out, ".get_array_size()")?; conjunction = " && "; } // Finally, check if the coordinates are within bounds. let coord_is_vector = match *context.resolve_type(address.coordinate) { crate::TypeInner::Vector { .. } => true, _ => false, }; write!(self.out, "{conjunction}")?; if coord_is_vector { write!(self.out, "{NAMESPACE}::all(")?; } self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?; write!(self.out, " < ")?; self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?; if coord_is_vector { write!(self.out, ")")?; } Ok(()) } fn put_image_load( &mut self, load: Handle, image: Handle, mut address: TexelAddress, context: &ExpressionContext, ) -> BackendResult { if let crate::TypeInner::Image { class: crate::ImageClass::External, .. } = *context.resolve_type(image) { write!(self.out, "{IMAGE_LOAD_EXTERNAL_FUNCTION}(")?; self.put_expression(image, context, true)?; write!(self.out, ", ")?; self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?; write!(self.out, ")")?; return Ok(()); } match context.policies.image_load { proc::BoundsCheckPolicy::Restrict => { // Use the cached restricted level of detail, if any. Omit the // level altogether for 1D textures. if address.level.is_some() { address.level = if context.image_needs_lod(image) { Some(LevelOfDetail::Restricted(load)) } else { None } } self.put_expression(image, context, false)?; write!(self.out, ".read(")?; self.put_restricted_texel_address(image, &address, context)?; write!(self.out, ")")?; } proc::BoundsCheckPolicy::ReadZeroSkipWrite => { write!(self.out, "(")?; self.put_image_access_bounds_check(image, &address, context)?; write!(self.out, " ? ")?; self.put_unchecked_image_load(image, &address, context)?; write!(self.out, ": DefaultConstructible())")?; } proc::BoundsCheckPolicy::Unchecked => { self.put_unchecked_image_load(image, &address, context)?; } } Ok(()) } fn put_unchecked_image_load( &mut self, image: Handle, address: &TexelAddress, context: &ExpressionContext, ) -> BackendResult { self.put_expression(image, context, false)?; write!(self.out, ".read(")?; // coordinates in IR are int, but Metal expects uint self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?; if let Some(expr) = address.array_index { write!(self.out, ", ")?; self.put_expression(expr, context, true)?; } if let Some(sample) = address.sample { write!(self.out, ", ")?; self.put_expression(sample, context, true)?; } if let Some(level) = address.level { if context.image_needs_lod(image) { write!(self.out, ", ")?; self.put_level_of_detail(level, context)?; } } write!(self.out, ")")?; Ok(()) } fn put_image_atomic( &mut self, level: back::Level, image: Handle, address: &TexelAddress, fun: crate::AtomicFunction, value: Handle, context: &StatementContext, ) -> BackendResult { write!(self.out, "{level}")?; self.put_expression(image, &context.expression, false)?; let op = if context.expression.resolve_type(value).scalar_width() == Some(8) { fun.to_msl_64_bit()? } else { fun.to_msl() }; write!(self.out, ".atomic_{op}(")?; // coordinates in IR are int, but Metal expects uint self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?; write!(self.out, ", ")?; self.put_expression(value, &context.expression, true)?; writeln!(self.out, ");")?; // Workaround for Apple Metal TBDR driver bug: fragment shader atomic // texture writes randomly drop unless followed by a standard texture // write. Insert a dead-code write behind an unprovable condition so // the compiler emits proper memory safety barriers. // See: https://projects.blender.org/blender/blender/commit/aa95220576706122d79c91c7f5c522e6c7416425 let value_ty = context.expression.resolve_type(value); let zero_value = match (value_ty.scalar_kind(), value_ty.scalar_width()) { (Some(crate::ScalarKind::Sint), _) => "int4(0)", (_, Some(8)) => "ulong4(0uL)", _ => "uint4(0u)", }; let coord_ty = context.expression.resolve_type(address.coordinate); let x = if matches!(coord_ty, crate::TypeInner::Scalar(_)) { "" } else { ".x" }; write!(self.out, "{level}if (")?; self.put_expression(address.coordinate, &context.expression, true)?; write!(self.out, "{x} == -99999) {{ ")?; self.put_expression(image, &context.expression, false)?; write!(self.out, ".write({zero_value}, ")?; self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?; if let Some(array_index) = address.array_index { write!(self.out, ", ")?; self.put_expression(array_index, &context.expression, true)?; } writeln!(self.out, "); }}")?; Ok(()) } fn put_image_store( &mut self, level: back::Level, image: Handle, address: &TexelAddress, value: Handle, context: &StatementContext, ) -> BackendResult { write!(self.out, "{level}")?; self.put_expression(image, &context.expression, false)?; write!(self.out, ".write(")?; self.put_expression(value, &context.expression, true)?; write!(self.out, ", ")?; // coordinates in IR are int, but Metal expects uint self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?; if let Some(expr) = address.array_index { write!(self.out, ", ")?; self.put_expression(expr, &context.expression, true)?; } writeln!(self.out, ");")?; Ok(()) } /// Write the maximum valid index of the dynamically sized array at the end of `handle`. /// /// The 'maximum valid index' is simply one less than the array's length. /// /// This emits an expression of the form `a / b`, so the caller must /// parenthesize its output if it will be applying operators of higher /// precedence. /// /// `handle` must be the handle of a global variable whose final member is a /// dynamically sized array. fn put_dynamic_array_max_index( &mut self, handle: Handle, context: &ExpressionContext, ) -> BackendResult { let global = &context.module.global_variables[handle]; let (offset, array_ty) = match context.module.types[global.ty].inner { crate::TypeInner::Struct { ref members, .. } => match members.last() { Some(&crate::StructMember { offset, ty, .. }) => (offset, ty), None => return Err(Error::GenericValidation("Struct has no members".into())), }, crate::TypeInner::Array { size: crate::ArraySize::Dynamic, .. } => (0, global.ty), ref ty => { return Err(Error::GenericValidation(format!( "Expected type with dynamic array, got {ty:?}" ))) } }; let (size, stride) = match context.module.types[array_ty].inner { crate::TypeInner::Array { base, stride, .. } => ( context.module.types[base] .inner .size(context.module.to_ctx()), stride, ), ref ty => { return Err(Error::GenericValidation(format!( "Expected array type, got {ty:?}" ))) } }; // When the stride length is larger than the size, the final element's stride of // bytes would have padding following the value. But the buffer size in // `buffer_sizes.sizeN` may not include this padding - it only needs to be large // enough to hold the actual values' bytes. // // So subtract off the size to get a byte size that falls at the start or within // the final element. Then divide by the stride size, to get one less than the // length, and then add one. This works even if the buffer size does include the // stride padding, since division rounds towards zero (MSL 2.4 §6.1). It will fail // if there are zero elements in the array, but the WebGPU `validating shader binding` // rules, together with draw-time validation when `minBindingSize` is zero, // prevent that. write!( self.out, "(_buffer_sizes.{member} - {offset} - {size}) / {stride}", member = ArraySizeMember(handle), offset = offset, size = size, stride = stride, )?; Ok(()) } /// Emit code for the arithmetic expression of the dot product. /// /// The argument `extractor` is a function that accepts a `Writer`, a vector, and /// an index. It writes out the expression for the vector component at that index. fn put_dot_product( &mut self, arg: T, arg1: T, size: usize, extractor: impl Fn(&mut Self, T, usize) -> BackendResult, ) -> BackendResult { // Write parentheses around the dot product expression to prevent operators // with different precedences from applying earlier. write!(self.out, "(")?; // Cycle through all the components of the vector for index in 0..size { // Write the addition to the previous product // This will print an extra '+' at the beginning but that is fine in msl write!(self.out, " + ")?; extractor(self, arg, index)?; write!(self.out, " * ")?; extractor(self, arg1, index)?; } write!(self.out, ")")?; Ok(()) } /// Emit code for the WGSL functions `pack4x{I, U}8[Clamp]`. fn put_pack4x8( &mut self, arg: Handle, context: &ExpressionContext<'_>, was_signed: bool, clamp_bounds: Option<(&str, &str)>, ) -> Result<(), Error> { let write_arg = |this: &mut Self| -> BackendResult { if let Some((min, max)) = clamp_bounds { // Clamping with scalar bounds works (component-wise) even for packed_[u]char4. write!(this.out, "{NAMESPACE}::clamp(")?; this.put_expression(arg, context, true)?; write!(this.out, ", {min}, {max})")?; } else { this.put_expression(arg, context, true)?; } Ok(()) }; if context.lang_version >= (2, 1) { let packed_type = if was_signed { "packed_char4" } else { "packed_uchar4" }; // Metal uses little endian byte order, which matches what WGSL expects here. write!(self.out, "as_type({packed_type}(")?; write_arg(self)?; write!(self.out, "))")?; } else { // MSL < 2.1 doesn't support `as_type` casting between packed chars and scalars. if was_signed { write!(self.out, "uint(")?; } write!(self.out, "(")?; write_arg(self)?; write!(self.out, "[0] & 0xFF) | ((")?; write_arg(self)?; write!(self.out, "[1] & 0xFF) << 8) | ((")?; write_arg(self)?; write!(self.out, "[2] & 0xFF) << 16) | ((")?; write_arg(self)?; write!(self.out, "[3] & 0xFF) << 24)")?; if was_signed { write!(self.out, ")")?; } } Ok(()) } /// Emit code for the isign expression. /// fn put_isign( &mut self, arg: Handle, context: &ExpressionContext, ) -> BackendResult { write!(self.out, "{NAMESPACE}::select({NAMESPACE}::select(")?; let scalar = context .resolve_type(arg) .scalar() .expect("put_isign should only be called for args which have an integer scalar type") .to_msl_name(); match context.resolve_type(arg) { &crate::TypeInner::Vector { size, .. } => { let size = common::vector_size_str(size); write!(self.out, "{scalar}{size}(-1), {scalar}{size}(1)")?; } _ => { write!(self.out, "{scalar}(-1), {scalar}(1)")?; } } write!(self.out, ", (")?; self.put_expression(arg, context, true)?; write!(self.out, " > 0)), {scalar}(0), (")?; self.put_expression(arg, context, true)?; write!(self.out, " == 0))")?; Ok(()) } fn put_const_expression( &mut self, expr_handle: Handle, module: &crate::Module, mod_info: &valid::ModuleInfo, arena: &crate::Arena, ) -> BackendResult { self.put_possibly_const_expression( expr_handle, arena, module, mod_info, &(module, mod_info), |&(_, mod_info), expr| &mod_info[expr], |writer, &(module, _), expr| writer.put_const_expression(expr, module, mod_info, arena), ) } fn put_literal(&mut self, literal: crate::Literal) -> BackendResult { match literal { crate::Literal::F64(_) => { return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64)) } crate::Literal::F16(value) => { if value.is_infinite() { let sign = if value.is_sign_negative() { "-" } else { "" }; write!(self.out, "{sign}INFINITY")?; } else if value.is_nan() { write!(self.out, "NAN")?; } else { let suffix = if value.fract() == f16::from_f32(0.0) { ".0h" } else { "h" }; write!(self.out, "{value}{suffix}")?; } } crate::Literal::F32(value) => { if value.is_infinite() { let sign = if value.is_sign_negative() { "-" } else { "" }; write!(self.out, "{sign}INFINITY")?; } else if value.is_nan() { write!(self.out, "NAN")?; } else { let suffix = if value.fract() == 0.0 { ".0" } else { "" }; write!(self.out, "{value}{suffix}")?; } } crate::Literal::U32(value) => { write!(self.out, "{value}u")?; } crate::Literal::I32(value) => { // `-2147483648` is parsed as unary negation of positive 2147483648. // 2147483648 is too large for int32_t meaning the expression gets // promoted to a int64_t which is not our intention. Avoid this by instead // using `-2147483647 - 1`. if value == i32::MIN { write!(self.out, "({} - 1)", value + 1)?; } else { write!(self.out, "{value}")?; } } crate::Literal::U64(value) => { write!(self.out, "{value}uL")?; } crate::Literal::I64(value) => { // `-9223372036854775808` is parsed as unary negation of positive // 9223372036854775808. 9223372036854775808 is too large for int64_t // causing Metal to emit a `-Wconstant-conversion` warning, and change the // value to `-9223372036854775808`. Which would then be negated, possibly // causing undefined behaviour. Avoid this by instead using // `-9223372036854775808L - 1L`. if value == i64::MIN { write!(self.out, "({}L - 1L)", value + 1)?; } else { write!(self.out, "{value}L")?; } } crate::Literal::Bool(value) => { write!(self.out, "{value}")?; } crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => { return Err(Error::GenericValidation( "Unsupported abstract literal".into(), )); } } Ok(()) } #[allow(clippy::too_many_arguments)] fn put_possibly_const_expression( &mut self, expr_handle: Handle, expressions: &crate::Arena, module: &crate::Module, mod_info: &valid::ModuleInfo, ctx: &C, get_expr_ty: I, put_expression: E, ) -> BackendResult where I: Fn(&C, Handle) -> &TypeResolution, E: Fn(&mut Self, &C, Handle) -> BackendResult, { match expressions[expr_handle] { crate::Expression::Literal(literal) => { self.put_literal(literal)?; } crate::Expression::Constant(handle) => { let constant = &module.constants[handle]; if constant.name.is_some() { write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?; } else { self.put_const_expression( constant.init, module, mod_info, &module.global_expressions, )?; } } crate::Expression::ZeroValue(ty) => { let ty_name = TypeContext { handle: ty, gctx: module.to_ctx(), names: &self.names, access: crate::StorageAccess::empty(), first_time: false, }; write!(self.out, "{ty_name} {{}}")?; } crate::Expression::Compose { ty, ref components } => { let ty_name = TypeContext { handle: ty, gctx: module.to_ctx(), names: &self.names, access: crate::StorageAccess::empty(), first_time: false, }; write!(self.out, "{ty_name}")?; match module.types[ty].inner { crate::TypeInner::Scalar(_) | crate::TypeInner::Vector { .. } | crate::TypeInner::Matrix { .. } => { self.put_call_parameters_impl( components.iter().copied(), ctx, put_expression, )?; } crate::TypeInner::Array { .. } => { // Naga Arrays are Metal arrays wrapped in structs, so // we need two levels of braces. write!(self.out, " {{{{")?; for (index, &component) in components.iter().enumerate() { if index != 0 { write!(self.out, ", ")?; } put_expression(self, ctx, component)?; } write!(self.out, "}}}}")?; } crate::TypeInner::Struct { .. } => { write!(self.out, " {{")?; for (index, &component) in components.iter().enumerate() { if index != 0 { write!(self.out, ", ")?; } // insert padding initialization, if needed if self.struct_member_pads.contains(&(ty, index as u32)) { write!(self.out, "{{}}, ")?; } put_expression(self, ctx, component)?; } write!(self.out, "}}")?; } _ => return Err(Error::UnsupportedCompose(ty)), } } crate::Expression::Splat { size, value } => { let scalar = match *get_expr_ty(ctx, value).inner_with(&module.types) { crate::TypeInner::Scalar(scalar) => scalar, ref ty => { return Err(Error::GenericValidation(format!( "Expected splat value type must be a scalar, got {ty:?}", ))) } }; put_numeric_type(&mut self.out, scalar, &[size])?; write!(self.out, "(")?; put_expression(self, ctx, value)?; write!(self.out, ")")?; } _ => { return Err(Error::Override); } } Ok(()) } /// Emit code for the expression `expr_handle`. /// /// The `is_scoped` argument is true if the surrounding operators have the /// precedence of the comma operator, or lower. So, for example: /// /// - Pass `true` for `is_scoped` when writing function arguments, an /// expression statement, an initializer expression, or anything already /// wrapped in parenthesis. /// /// - Pass `false` if it is an operand of a `?:` operator, a `[]`, or really /// almost anything else. fn put_expression( &mut self, expr_handle: Handle, context: &ExpressionContext, is_scoped: bool, ) -> BackendResult { // Add to the set in order to track the stack size. #[cfg(test)] self.put_expression_stack_pointers .insert(ptr::from_ref(&expr_handle).cast()); if let Some(name) = self.named_expressions.get(&expr_handle) { write!(self.out, "{name}")?; return Ok(()); } let expression = &context.function.expressions[expr_handle]; match *expression { crate::Expression::Literal(_) | crate::Expression::Constant(_) | crate::Expression::ZeroValue(_) | crate::Expression::Compose { .. } | crate::Expression::Splat { .. } => { self.put_possibly_const_expression( expr_handle, &context.function.expressions, context.module, context.mod_info, context, |context, expr: Handle| &context.info[expr].ty, |writer, context, expr| writer.put_expression(expr, context, true), )?; } crate::Expression::Override(_) => return Err(Error::Override), crate::Expression::Access { base, .. } | crate::Expression::AccessIndex { base, .. } => { // This is an acceptable place to generate a `ReadZeroSkipWrite` check. // Since `put_bounds_checks` and `put_access_chain` handle an entire // access chain at a time, recursing back through `put_expression` only // for index expressions and the base object, we will never see intermediate // `Access` or `AccessIndex` expressions here. let policy = context.choose_bounds_check_policy(base); if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite && self.put_bounds_checks( expr_handle, context, back::Level(0), if is_scoped { "" } else { "(" }, )? { write!(self.out, " ? ")?; self.put_access_chain(expr_handle, policy, context)?; write!(self.out, " : ")?; if context.resolve_type(base).pointer_space().is_some() { // We can't just use `DefaultConstructible` if this is a pointer. // Instead, we create a dummy local variable to serve as pointer // target if the access is out of bounds. let result_ty = context.info[expr_handle] .ty .inner_with(&context.module.types) .pointer_base_type(); let result_ty_handle = match result_ty { Some(TypeResolution::Handle(handle)) => handle, Some(TypeResolution::Value(_)) => { // As long as the result of a pointer access expression is // passed to a function or stored in a let binding, the // type will be in the arena. If additional uses of // pointers become valid, this assumption might no longer // hold. Note that the LHS of a load or store doesn't // take this path -- there is dedicated code in `put_load` // and `put_store`. unreachable!( "Expected type {result_ty:?} of access through pointer type {base:?} to be in the arena", ); } None => { unreachable!( "Expected access through pointer type {base:?} to return a pointer, but got {result_ty:?}", ) } }; let name_key = NameKey::oob_local_for_type(context.origin, result_ty_handle); self.out.write_str(&self.names[&name_key])?; } else { write!(self.out, "DefaultConstructible()")?; } if !is_scoped { write!(self.out, ")")?; } } else { self.put_access_chain(expr_handle, policy, context)?; } } crate::Expression::Swizzle { size, vector, pattern, } => { self.put_wrapped_expression_for_packed_vec3_access( vector, context, false, &Self::put_expression, )?; write!(self.out, ".")?; for &sc in pattern[..size as usize].iter() { write!(self.out, "{}", back::COMPONENTS[sc as usize])?; } } crate::Expression::FunctionArgument(index) => { let name_key = match context.origin { FunctionOrigin::Handle(handle) => NameKey::FunctionArgument(handle, index), FunctionOrigin::EntryPoint(ep_index) => { NameKey::EntryPointArgument(ep_index, index) } }; let name = &self.names[&name_key]; write!(self.out, "{name}")?; } crate::Expression::GlobalVariable(handle) => { let name = &self.names[&NameKey::GlobalVariable(handle)]; write!(self.out, "{name}")?; } crate::Expression::LocalVariable(handle) => { let name_key = NameKey::local(context.origin, handle); let name = &self.names[&name_key]; write!(self.out, "{name}")?; } crate::Expression::Load { pointer } => self.put_load(pointer, context, is_scoped)?, crate::Expression::ImageSample { coordinate, image, sampler, clamp_to_edge: true, gather: None, array_index: None, offset: None, level: crate::SampleLevel::Zero, depth_ref: None, } => { write!(self.out, "{IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}(")?; self.put_expression(image, context, true)?; write!(self.out, ", ")?; self.put_expression(sampler, context, true)?; write!(self.out, ", ")?; self.put_expression(coordinate, context, true)?; write!(self.out, ")")?; } crate::Expression::ImageSample { image, sampler, gather, coordinate, array_index, offset, level, depth_ref, clamp_to_edge, } => { if clamp_to_edge { return Err(Error::GenericValidation( "ImageSample::clamp_to_edge should have been validated out".to_string(), )); } let main_op = match gather { Some(_) => "gather", None => "sample", }; let comparison_op = match depth_ref { Some(_) => "_compare", None => "", }; self.put_expression(image, context, false)?; write!(self.out, ".{main_op}{comparison_op}(")?; self.put_expression(sampler, context, true)?; write!(self.out, ", ")?; self.put_expression(coordinate, context, true)?; if let Some(expr) = array_index { write!(self.out, ", ")?; self.put_expression(expr, context, true)?; } if let Some(dref) = depth_ref { write!(self.out, ", ")?; self.put_expression(dref, context, true)?; } self.put_image_sample_level(image, level, context)?; if let Some(offset) = offset { write!(self.out, ", ")?; self.put_expression(offset, context, true)?; } match gather { None | Some(crate::SwizzleComponent::X) => {} Some(component) => { let is_cube_map = match *context.resolve_type(image) { crate::TypeInner::Image { dim: crate::ImageDimension::Cube, .. } => true, _ => false, }; // Offset always comes before the gather, except // in cube maps where it's not applicable if offset.is_none() && !is_cube_map { write!(self.out, ", {NAMESPACE}::int2(0)")?; } let letter = back::COMPONENTS[component as usize]; write!(self.out, ", {NAMESPACE}::component::{letter}")?; } } write!(self.out, ")")?; } crate::Expression::ImageLoad { image, coordinate, array_index, sample, level, } => { let address = TexelAddress { coordinate, array_index, sample, level: level.map(LevelOfDetail::Direct), }; self.put_image_load(expr_handle, image, address, context)?; } //Note: for all the queries, the signed integers are expected, // so a conversion is needed. crate::Expression::ImageQuery { image, query } => match query { crate::ImageQuery::Size { level } => { self.put_image_size_query( image, level.map(LevelOfDetail::Direct), crate::ScalarKind::Uint, context, )?; } crate::ImageQuery::NumLevels => { self.put_expression(image, context, false)?; write!(self.out, ".get_num_mip_levels()")?; } crate::ImageQuery::NumLayers => { self.put_expression(image, context, false)?; write!(self.out, ".get_array_size()")?; } crate::ImageQuery::NumSamples => { self.put_expression(image, context, false)?; write!(self.out, ".get_num_samples()")?; } }, crate::Expression::Unary { op, expr } => { let op_str = match op { crate::UnaryOperator::Negate => { match context.resolve_type(expr).scalar_kind() { Some(crate::ScalarKind::Sint) => NEG_FUNCTION, _ => "-", } } crate::UnaryOperator::LogicalNot => "!", crate::UnaryOperator::BitwiseNot => "~", }; write!(self.out, "{op_str}(")?; self.put_expression(expr, context, false)?; write!(self.out, ")")?; } crate::Expression::Binary { op, left, right } => { let kind = context .resolve_type(left) .scalar_kind() .ok_or(Error::UnsupportedBinaryOp(op))?; if op == crate::BinaryOperator::Divide && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint) { write!(self.out, "{DIV_FUNCTION}(")?; self.put_expression(left, context, true)?; write!(self.out, ", ")?; self.put_expression(right, context, true)?; write!(self.out, ")")?; } else if op == crate::BinaryOperator::Modulo && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint) { write!(self.out, "{MOD_FUNCTION}(")?; self.put_expression(left, context, true)?; write!(self.out, ", ")?; self.put_expression(right, context, true)?; write!(self.out, ")")?; } else if op == crate::BinaryOperator::Modulo && kind == crate::ScalarKind::Float { // TODO: handle undefined behavior of BinaryOperator::Modulo // // float: // if right == 0 return ? see https://github.com/gpuweb/gpuweb/issues/2798 write!(self.out, "{NAMESPACE}::fmod(")?; self.put_expression(left, context, true)?; write!(self.out, ", ")?; self.put_expression(right, context, true)?; write!(self.out, ")")?; } else if (op == crate::BinaryOperator::Add || op == crate::BinaryOperator::Subtract || op == crate::BinaryOperator::Multiply) && kind == crate::ScalarKind::Sint { let to_unsigned = |ty: &crate::TypeInner| match *ty { crate::TypeInner::Scalar(scalar) => { Ok(crate::TypeInner::Scalar(crate::Scalar { kind: crate::ScalarKind::Uint, ..scalar })) } crate::TypeInner::Vector { size, scalar } => Ok(crate::TypeInner::Vector { size, scalar: crate::Scalar { kind: crate::ScalarKind::Uint, ..scalar }, }), _ => Err(Error::UnsupportedBitCast(ty.clone())), }; // Avoid undefined behaviour due to overflowing signed // integer arithmetic. Cast the operands to unsigned prior // to performing the operation, then cast the result back // to signed. self.put_bitcasted_expression( context.resolve_type(expr_handle), context, &|writer, context, is_scoped| { writer.put_binop( op, left, right, context, is_scoped, &|writer, expr, context, _is_scoped| { writer.put_bitcasted_expression( &to_unsigned(context.resolve_type(expr))?, context, &|writer, context, is_scoped| { writer.put_expression(expr, context, is_scoped) }, ) }, ) }, )?; } else { self.put_binop(op, left, right, context, is_scoped, &Self::put_expression)?; } } crate::Expression::Select { condition, accept, reject, } => match *context.resolve_type(condition) { crate::TypeInner::Scalar(crate::Scalar { kind: crate::ScalarKind::Bool, .. }) => { if !is_scoped { write!(self.out, "(")?; } self.put_expression(condition, context, false)?; write!(self.out, " ? ")?; self.put_expression(accept, context, false)?; write!(self.out, " : ")?; self.put_expression(reject, context, false)?; if !is_scoped { write!(self.out, ")")?; } } crate::TypeInner::Vector { scalar: crate::Scalar { kind: crate::ScalarKind::Bool, .. }, .. } => { write!(self.out, "{NAMESPACE}::select(")?; self.put_expression(reject, context, true)?; write!(self.out, ", ")?; self.put_expression(accept, context, true)?; write!(self.out, ", ")?; self.put_expression(condition, context, true)?; write!(self.out, ")")?; } ref ty => { return Err(Error::GenericValidation(format!( "Expected select condition to be a non-bool type, got {ty:?}", ))) } }, crate::Expression::Derivative { axis, expr, .. } => { use crate::DerivativeAxis as Axis; let op = match axis { Axis::X => "dfdx", Axis::Y => "dfdy", Axis::Width => "fwidth", }; write!(self.out, "{NAMESPACE}::{op}")?; self.put_call_parameters(iter::once(expr), context)?; } crate::Expression::Relational { fun, argument } => { let op = match fun { crate::RelationalFunction::Any => "any", crate::RelationalFunction::All => "all", crate::RelationalFunction::IsNan => "isnan", crate::RelationalFunction::IsInf => "isinf", }; write!(self.out, "{NAMESPACE}::{op}")?; self.put_call_parameters(iter::once(argument), context)?; } crate::Expression::Math { fun, arg, arg1, arg2, arg3, } => { use crate::MathFunction as Mf; let arg_type = context.resolve_type(arg); let scalar_argument = match arg_type { &crate::TypeInner::Scalar(_) => true, _ => false, }; let fun_name = match fun { // comparison Mf::Abs => "abs", Mf::Min => "min", Mf::Max => "max", Mf::Clamp => "clamp", Mf::Saturate => "saturate", // trigonometry Mf::Cos => "cos", Mf::Cosh => "cosh", Mf::Sin => "sin", Mf::Sinh => "sinh", Mf::Tan => "tan", Mf::Tanh => "tanh", Mf::Acos => "acos", Mf::Asin => "asin", Mf::Atan => "atan", Mf::Atan2 => "atan2", Mf::Asinh => "asinh", Mf::Acosh => "acosh", Mf::Atanh => "atanh", Mf::Radians => "", Mf::Degrees => "", // decomposition Mf::Ceil => "ceil", Mf::Floor => "floor", Mf::Round => "rint", Mf::Fract => "fract", Mf::Trunc => "trunc", Mf::Modf => MODF_FUNCTION, Mf::Frexp => FREXP_FUNCTION, Mf::Ldexp => "ldexp", // exponent Mf::Exp => "exp", Mf::Exp2 => "exp2", Mf::Log => "log", Mf::Log2 => "log2", Mf::Pow => "pow", // geometry Mf::Dot => match *context.resolve_type(arg) { crate::TypeInner::Vector { scalar: crate::Scalar { // Resolve float values to MSL's builtin dot function. kind: crate::ScalarKind::Float, .. }, .. } => "dot", crate::TypeInner::Vector { size, scalar: scalar @ crate::Scalar { kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint, .. }, } => { // Integer vector dot: call our mangled helper `dot_{type}{N}(a, b)`. let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size); write!(self.out, "{fun_name}(")?; self.put_expression(arg, context, true)?; write!(self.out, ", ")?; self.put_expression(arg1.unwrap(), context, true)?; write!(self.out, ")")?; return Ok(()); } _ => unreachable!( "Correct TypeInner for dot product should be already validated" ), }, fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => { if context.lang_version >= (2, 1) { // Write potentially optimizable code using `packed_(u?)char4`. // The two function arguments were already reinterpreted as packed (signed // or unsigned) chars in `Self::put_block`. let packed_type = match fun { Mf::Dot4I8Packed => "packed_char4", Mf::Dot4U8Packed => "packed_uchar4", _ => unreachable!(), }; return self.put_dot_product( Reinterpreted::new(packed_type, arg), Reinterpreted::new(packed_type, arg1.unwrap()), 4, |writer, arg, index| { // MSL implicitly promotes these (signed or unsigned) chars to // `int` or `uint` in the multiplication, so no overflow can occur. write!(writer.out, "{arg}[{index}]")?; Ok(()) }, ); } else { // Fall back to a polyfill since MSL < 2.1 doesn't seem to support // bitcasting from uint to `packed_char4` or `packed_uchar4`. // See . let conversion = match fun { Mf::Dot4I8Packed => "int", Mf::Dot4U8Packed => "", _ => unreachable!(), }; return self.put_dot_product( arg, arg1.unwrap(), 4, |writer, arg, index| { write!(writer.out, "({conversion}(")?; writer.put_expression(arg, context, true)?; if index == 3 { write!(writer.out, ") >> 24)")?; } else { write!(writer.out, ") << {} >> 24)", (3 - index) * 8)?; } Ok(()) }, ); } } Mf::Outer => return Err(Error::UnsupportedCall(format!("{fun:?}"))), Mf::Cross => "cross", Mf::Distance => "distance", Mf::Length if scalar_argument => "abs", Mf::Length => "length", Mf::Normalize => "normalize", Mf::FaceForward => "faceforward", Mf::Reflect => "reflect", Mf::Refract => "refract", // computational Mf::Sign => match arg_type.scalar_kind() { Some(crate::ScalarKind::Sint) => { return self.put_isign(arg, context); } _ => "sign", }, Mf::Fma => "fma", Mf::Mix => "mix", Mf::Step => "step", Mf::SmoothStep => "smoothstep", Mf::Sqrt => "sqrt", Mf::InverseSqrt => "rsqrt", Mf::Inverse => return Err(Error::UnsupportedCall(format!("{fun:?}"))), Mf::Transpose => "transpose", Mf::Determinant => "determinant", Mf::QuantizeToF16 => "", // bits Mf::CountTrailingZeros => "ctz", Mf::CountLeadingZeros => "clz", Mf::CountOneBits => "popcount", Mf::ReverseBits => "reverse_bits", Mf::ExtractBits => "", Mf::InsertBits => "", Mf::FirstTrailingBit => "", Mf::FirstLeadingBit => "", // data packing Mf::Pack4x8snorm => "pack_float_to_snorm4x8", Mf::Pack4x8unorm => "pack_float_to_unorm4x8", Mf::Pack2x16snorm => "pack_float_to_snorm2x16", Mf::Pack2x16unorm => "pack_float_to_unorm2x16", Mf::Pack2x16float => "", Mf::Pack4xI8 => "", Mf::Pack4xU8 => "", Mf::Pack4xI8Clamp => "", Mf::Pack4xU8Clamp => "", // data unpacking Mf::Unpack4x8snorm => "unpack_snorm4x8_to_float", Mf::Unpack4x8unorm => "unpack_unorm4x8_to_float", Mf::Unpack2x16snorm => "unpack_snorm2x16_to_float", Mf::Unpack2x16unorm => "unpack_unorm2x16_to_float", Mf::Unpack2x16float => "", Mf::Unpack4xI8 => "", Mf::Unpack4xU8 => "", }; match fun { Mf::ReverseBits | Mf::ExtractBits | Mf::InsertBits => { // reverse_bits is listed as requiring MSL 2.1 but that // is a copy/paste error. Looking at previous snapshots // on web.archive.org it's present in MSL 1.2. // // https://developer.apple.com/library/archive/documentation/Miscellaneous/Conceptual/MetalProgrammingGuide/WhatsNewiniOS10tvOS10andOSX1012/WhatsNewiniOS10tvOS10andOSX1012.html // also talks about MSL 1.2 adding "New integer // functions to extract, insert, and reverse bits, as // described in Integer Functions." if context.lang_version < (1, 2) { return Err(Error::UnsupportedFunction(fun_name.to_string())); } } _ => {} } match fun { Mf::Abs if arg_type.scalar_kind() == Some(crate::ScalarKind::Sint) => { write!(self.out, "{ABS_FUNCTION}(")?; self.put_expression(arg, context, true)?; write!(self.out, ")")?; } Mf::Distance if scalar_argument => { write!(self.out, "{NAMESPACE}::abs(")?; self.put_expression(arg, context, false)?; write!(self.out, " - ")?; self.put_expression(arg1.unwrap(), context, false)?; write!(self.out, ")")?; } Mf::FirstTrailingBit => { let scalar = context.resolve_type(arg).scalar().unwrap(); let constant = scalar.width * 8 + 1; write!(self.out, "((({NAMESPACE}::ctz(")?; self.put_expression(arg, context, true)?; write!(self.out, ") + 1) % {constant}) - 1)")?; } Mf::FirstLeadingBit => { let inner = context.resolve_type(arg); let scalar = inner.scalar().unwrap(); let constant = scalar.width * 8 - 1; write!( self.out, "{NAMESPACE}::select({constant} - {NAMESPACE}::clz(" )?; if scalar.kind == crate::ScalarKind::Sint { write!(self.out, "{NAMESPACE}::select(")?; self.put_expression(arg, context, true)?; write!(self.out, ", ~")?; self.put_expression(arg, context, true)?; write!(self.out, ", ")?; self.put_expression(arg, context, true)?; write!(self.out, " < 0)")?; } else { self.put_expression(arg, context, true)?; } write!(self.out, "), ")?; // or metal will complain that select is ambiguous match *inner { crate::TypeInner::Vector { size, scalar } => { let size = common::vector_size_str(size); let name = scalar.to_msl_name(); write!(self.out, "{name}{size}")?; } crate::TypeInner::Scalar(scalar) => { let name = scalar.to_msl_name(); write!(self.out, "{name}")?; } _ => (), } write!(self.out, "(-1), ")?; self.put_expression(arg, context, true)?; write!(self.out, " == 0 || ")?; self.put_expression(arg, context, true)?; write!(self.out, " == -1)")?; } Mf::Unpack2x16float => { write!(self.out, "float2(as_type(")?; self.put_expression(arg, context, false)?; write!(self.out, "))")?; } Mf::Pack2x16float => { write!(self.out, "as_type(half2(")?; self.put_expression(arg, context, false)?; write!(self.out, "))")?; } Mf::ExtractBits => { // The behavior of ExtractBits is undefined when offset + count > bit_width. We need // to first sanitize the offset and count first. If we don't do this, Apple chips // will return out-of-spec values if the extracted range is not within the bit width. // // This encodes the exact formula specified by the wgsl spec, without temporary values: // https://gpuweb.github.io/gpuweb/wgsl/#extractBits-unsigned-builtin // // w = sizeof(x) * 8 // o = min(offset, w) // tmp = w - o // c = min(count, tmp) // // bitfieldExtract(x, o, c) // // extract_bits(e, min(offset, w), min(count, w - min(offset, w)))) let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8; write!(self.out, "{NAMESPACE}::extract_bits(")?; self.put_expression(arg, context, true)?; write!(self.out, ", {NAMESPACE}::min(")?; self.put_expression(arg1.unwrap(), context, true)?; write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?; self.put_expression(arg2.unwrap(), context, true)?; write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?; self.put_expression(arg1.unwrap(), context, true)?; write!(self.out, ", {scalar_bits}u)))")?; } Mf::InsertBits => { // The behavior of InsertBits has the same issue as ExtractBits. // // insertBits(e, newBits, min(offset, w), min(count, w - min(offset, w)))) let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8; write!(self.out, "{NAMESPACE}::insert_bits(")?; self.put_expression(arg, context, true)?; write!(self.out, ", ")?; self.put_expression(arg1.unwrap(), context, true)?; write!(self.out, ", {NAMESPACE}::min(")?; self.put_expression(arg2.unwrap(), context, true)?; write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?; self.put_expression(arg3.unwrap(), context, true)?; write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?; self.put_expression(arg2.unwrap(), context, true)?; write!(self.out, ", {scalar_bits}u)))")?; } Mf::Radians => { write!(self.out, "((")?; self.put_expression(arg, context, false)?; write!(self.out, ") * 0.017453292519943295474)")?; } Mf::Degrees => { write!(self.out, "((")?; self.put_expression(arg, context, false)?; write!(self.out, ") * 57.295779513082322865)")?; } Mf::Modf | Mf::Frexp => { write!(self.out, "{fun_name}")?; self.put_call_parameters(iter::once(arg), context)?; } Mf::Pack4xI8 => self.put_pack4x8(arg, context, true, None)?, Mf::Pack4xU8 => self.put_pack4x8(arg, context, false, None)?, Mf::Pack4xI8Clamp => { self.put_pack4x8(arg, context, true, Some(("-128", "127")))? } Mf::Pack4xU8Clamp => { self.put_pack4x8(arg, context, false, Some(("0", "255")))? } fun @ (Mf::Unpack4xI8 | Mf::Unpack4xU8) => { let sign_prefix = if matches!(fun, Mf::Unpack4xU8) { "u" } else { "" }; if context.lang_version >= (2, 1) { // Metal uses little endian byte order, which matches what WGSL expects here. write!( self.out, "{sign_prefix}int4(as_type(" )?; self.put_expression(arg, context, true)?; write!(self.out, "))")?; } else { // MSL < 2.1 doesn't support `as_type` casting between packed chars and scalars. write!(self.out, "({sign_prefix}int4(")?; self.put_expression(arg, context, true)?; write!(self.out, ", ")?; self.put_expression(arg, context, true)?; write!(self.out, " >> 8, ")?; self.put_expression(arg, context, true)?; write!(self.out, " >> 16, ")?; self.put_expression(arg, context, true)?; write!(self.out, " >> 24) << 24 >> 24)")?; } } Mf::QuantizeToF16 => { match *context.resolve_type(arg) { crate::TypeInner::Scalar { .. } => write!(self.out, "float(half(")?, crate::TypeInner::Vector { size, .. } => write!( self.out, "{NAMESPACE}::float{size}({NAMESPACE}::half{size}(", size = common::vector_size_str(size), )?, _ => unreachable!( "Correct TypeInner for QuantizeToF16 should be already validated" ), }; self.put_expression(arg, context, true)?; write!(self.out, "))")?; } _ => { write!(self.out, "{NAMESPACE}::{fun_name}")?; self.put_call_parameters( iter::once(arg).chain(arg1).chain(arg2).chain(arg3), context, )?; } } } crate::Expression::As { expr, kind, convert, } => match *context.resolve_type(expr) { crate::TypeInner::Scalar(src) | crate::TypeInner::Vector { scalar: src, .. } => { if src.kind == crate::ScalarKind::Float && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint) && convert.is_some() { // Use helper functions for float to int casts in order to avoid // undefined behaviour when value is out of range for the target // type. let fun_name = match (kind, convert) { (crate::ScalarKind::Sint, Some(4)) => F2I32_FUNCTION, (crate::ScalarKind::Uint, Some(4)) => F2U32_FUNCTION, (crate::ScalarKind::Sint, Some(8)) => F2I64_FUNCTION, (crate::ScalarKind::Uint, Some(8)) => F2U64_FUNCTION, _ => unreachable!(), }; write!(self.out, "{fun_name}(")?; self.put_expression(expr, context, true)?; write!(self.out, ")")?; } else { let target_scalar = crate::Scalar { kind, width: convert.unwrap_or(src.width), }; let op = match convert { Some(_) => "static_cast", None => "as_type", }; write!(self.out, "{op}<")?; match *context.resolve_type(expr) { crate::TypeInner::Vector { size, .. } => { put_numeric_type(&mut self.out, target_scalar, &[size])? } _ => put_numeric_type(&mut self.out, target_scalar, &[])?, }; write!(self.out, ">(")?; self.put_expression(expr, context, true)?; write!(self.out, ")")?; } } crate::TypeInner::Matrix { columns, rows, scalar, } => { let target_scalar = crate::Scalar { kind, width: convert.unwrap_or(scalar.width), }; put_numeric_type(&mut self.out, target_scalar, &[rows, columns])?; write!(self.out, "(")?; self.put_expression(expr, context, true)?; write!(self.out, ")")?; } ref ty => { return Err(Error::GenericValidation(format!( "Unsupported type for As: {ty:?}" ))) } }, // has to be a named expression crate::Expression::CallResult(_) | crate::Expression::AtomicResult { .. } | crate::Expression::WorkGroupUniformLoadResult { .. } | crate::Expression::SubgroupBallotResult | crate::Expression::SubgroupOperationResult { .. } | crate::Expression::RayQueryProceedResult => { unreachable!() } crate::Expression::ArrayLength(expr) => { // Find the global to which the array belongs. let global = match context.function.expressions[expr] { crate::Expression::AccessIndex { base, .. } => { match context.function.expressions[base] { crate::Expression::GlobalVariable(handle) => handle, ref ex => { return Err(Error::GenericValidation(format!( "Expected global variable in AccessIndex, got {ex:?}" ))) } } } crate::Expression::GlobalVariable(handle) => handle, ref ex => { return Err(Error::GenericValidation(format!( "Unexpected expression in ArrayLength, got {ex:?}" ))) } }; if !is_scoped { write!(self.out, "(")?; } write!(self.out, "1 + ")?; self.put_dynamic_array_max_index(global, context)?; if !is_scoped { write!(self.out, ")")?; } } crate::Expression::RayQueryVertexPositions { .. } => { unimplemented!() } crate::Expression::RayQueryGetIntersection { query, committed: _, } => { if context.lang_version < (2, 4) { return Err(Error::UnsupportedRayTracing); } let ty = context.module.special_types.ray_intersection.unwrap(); let type_name = &self.names[&NameKey::Type(ty)]; write!(self.out, "{type_name} {{{RAY_QUERY_FUN_MAP_INTERSECTION}(")?; self.put_expression(query, context, true)?; write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.type)")?; let fields = [ "distance", "user_instance_id", // req Metal 2.4 "instance_id", "", // SBT offset "geometry_id", "primitive_id", "triangle_barycentric_coord", "triangle_front_facing", "", // padding "object_to_world_transform", // req Metal 2.4 "world_to_object_transform", // req Metal 2.4 ]; for field in fields { write!(self.out, ", ")?; if field.is_empty() { write!(self.out, "{{}}")?; } else { self.put_expression(query, context, true)?; write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.{field}")?; } } write!(self.out, "}}")?; } crate::Expression::CooperativeLoad { ref data, .. } => { if context.lang_version < (2, 3) { return Err(Error::UnsupportedCooperativeMatrix); } write!(self.out, "{COOPERATIVE_LOAD_FUNCTION}(")?; write!(self.out, "&")?; self.put_access_chain(data.pointer, context.policies.index, context)?; write!(self.out, ", ")?; self.put_expression(data.stride, context, true)?; write!(self.out, ", {})", data.row_major)?; } crate::Expression::CooperativeMultiplyAdd { a, b, c } => { if context.lang_version < (2, 3) { return Err(Error::UnsupportedCooperativeMatrix); } write!(self.out, "{COOPERATIVE_MULTIPLY_ADD_FUNCTION}(")?; self.put_expression(a, context, true)?; write!(self.out, ", ")?; self.put_expression(b, context, true)?; write!(self.out, ", ")?; self.put_expression(c, context, true)?; write!(self.out, ")")?; } } Ok(()) } /// Emits code for a binary operation, using the provided callback to emit /// the left and right operands. fn put_binop( &mut self, op: crate::BinaryOperator, left: Handle, right: Handle, context: &ExpressionContext, is_scoped: bool, put_expression: &F, ) -> BackendResult where F: Fn(&mut Self, Handle, &ExpressionContext, bool) -> BackendResult, { let op_str = back::binary_operation_str(op); if !is_scoped { write!(self.out, "(")?; } // Cast packed vector if necessary // Packed vector - matrix multiplications are not supported in MSL if op == crate::BinaryOperator::Multiply && matches!( context.resolve_type(right), &crate::TypeInner::Matrix { .. } ) { self.put_wrapped_expression_for_packed_vec3_access( left, context, false, put_expression, )?; } else { put_expression(self, left, context, false)?; } write!(self.out, " {op_str} ")?; // See comment above if op == crate::BinaryOperator::Multiply && matches!(context.resolve_type(left), &crate::TypeInner::Matrix { .. }) { self.put_wrapped_expression_for_packed_vec3_access( right, context, false, put_expression, )?; } else { put_expression(self, right, context, false)?; } if !is_scoped { write!(self.out, ")")?; } Ok(()) } /// Used by expressions like Swizzle and Binary since they need packed_vec3's to be casted to a vec3 fn put_wrapped_expression_for_packed_vec3_access( &mut self, expr_handle: Handle, context: &ExpressionContext, is_scoped: bool, put_expression: &F, ) -> BackendResult where F: Fn(&mut Self, Handle, &ExpressionContext, bool) -> BackendResult, { if let Some(scalar) = context.get_packed_vec_kind(expr_handle) { write!(self.out, "{}::{}3(", NAMESPACE, scalar.to_msl_name())?; put_expression(self, expr_handle, context, is_scoped)?; write!(self.out, ")")?; } else { put_expression(self, expr_handle, context, is_scoped)?; } Ok(()) } /// Emits code for an expression using the provided callback, wrapping the /// result in a bitcast to the type `cast_to`. fn put_bitcasted_expression( &mut self, cast_to: &crate::TypeInner, context: &ExpressionContext, put_expression: &F, ) -> BackendResult where F: Fn(&mut Self, &ExpressionContext, bool) -> BackendResult, { write!(self.out, "as_type<")?; match *cast_to { crate::TypeInner::Scalar(scalar) => put_numeric_type(&mut self.out, scalar, &[])?, crate::TypeInner::Vector { size, scalar } => { put_numeric_type(&mut self.out, scalar, &[size])? } _ => return Err(Error::UnsupportedBitCast(cast_to.clone())), }; write!(self.out, ">(")?; put_expression(self, context, true)?; write!(self.out, ")")?; Ok(()) } /// Write a `GuardedIndex` as a Metal expression. fn put_index( &mut self, index: index::GuardedIndex, context: &ExpressionContext, is_scoped: bool, ) -> BackendResult { match index { index::GuardedIndex::Expression(expr) => { self.put_expression(expr, context, is_scoped)? } index::GuardedIndex::Known(value) => write!(self.out, "{value}")?, } Ok(()) } /// Emit an index bounds check condition for `chain`, if required. /// /// `chain` is a subtree of `Access` and `AccessIndex` expressions, /// operating either on a pointer to a value, or on a value directly. If we cannot /// statically determine that all indexing operations in `chain` are within /// bounds, then write a conditional expression to check them dynamically, /// and return true. All accesses in the chain are checked by the generated /// expression. /// /// This assumes that the [`BoundsCheckPolicy`] for `chain` is [`ReadZeroSkipWrite`]. /// /// The text written is of the form: /// /// ```ignore /// {level}{prefix}uint(i) < 4 && uint(j) < 10 /// ``` /// /// where `{level}` and `{prefix}` are the arguments to this function. For [`Store`] /// statements, presumably these arguments start an indented `if` statement; for /// [`Load`] expressions, the caller is probably building up a ternary `?:` /// expression. In either case, what is written is not a complete syntactic structure /// in its own right, and the caller will have to finish it off if we return `true`. /// /// If no expression is written, return false. /// /// [`BoundsCheckPolicy`]: index::BoundsCheckPolicy /// [`ReadZeroSkipWrite`]: index::BoundsCheckPolicy::ReadZeroSkipWrite /// [`Store`]: crate::Statement::Store /// [`Load`]: crate::Expression::Load fn put_bounds_checks( &mut self, chain: Handle, context: &ExpressionContext, level: back::Level, prefix: &'static str, ) -> Result { let mut check_written = false; // Iterate over the access chain, handling each required bounds check. for item in context.bounds_check_iter(chain) { let BoundsCheck { base, index, length, } = item; if check_written { write!(self.out, " && ")?; } else { write!(self.out, "{level}{prefix}")?; check_written = true; } // Check that the index falls within bounds. Do this with a single // comparison, by casting the index to `uint` first, so that negative // indices become large positive values. write!(self.out, "uint(")?; self.put_index(index, context, true)?; self.out.write_str(") < ")?; match length { index::IndexableLength::Known(value) => write!(self.out, "{value}")?, index::IndexableLength::Dynamic => { let global = context.function.originating_global(base).ok_or_else(|| { Error::GenericValidation("Could not find originating global".into()) })?; write!(self.out, "1 + ")?; self.put_dynamic_array_max_index(global, context)? } } } Ok(check_written) } /// Write the access chain `chain`. /// /// `chain` is a subtree of [`Access`] and [`AccessIndex`] expressions, /// operating either on a pointer to a value, or on a value directly. /// /// Generate bounds checks code only if `policy` is [`Restrict`]. The /// [`ReadZeroSkipWrite`] policy requires checks before any accesses take place, so /// that must be handled in the caller. /// /// Handle the entire chain, recursing back into `put_expression` only for index /// expressions and the base expression that originates the pointer or composite value /// being accessed. This allows `put_expression` to assume that any `Access` or /// `AccessIndex` expressions it sees are the top of a chain, so it can emit /// `ReadZeroSkipWrite` checks. /// /// [`Access`]: crate::Expression::Access /// [`AccessIndex`]: crate::Expression::AccessIndex /// [`Restrict`]: crate::proc::index::BoundsCheckPolicy::Restrict /// [`ReadZeroSkipWrite`]: crate::proc::index::BoundsCheckPolicy::ReadZeroSkipWrite fn put_access_chain( &mut self, chain: Handle, policy: index::BoundsCheckPolicy, context: &ExpressionContext, ) -> BackendResult { match context.function.expressions[chain] { crate::Expression::Access { base, index } => { let mut base_ty = context.resolve_type(base); // Look through any pointers to see what we're really indexing. if let crate::TypeInner::Pointer { base, space: _ } = *base_ty { base_ty = &context.module.types[base].inner; } self.put_subscripted_access_chain( base, base_ty, index::GuardedIndex::Expression(index), policy, context, )?; } crate::Expression::AccessIndex { base, index } => { let base_resolution = &context.info[base].ty; let mut base_ty = base_resolution.inner_with(&context.module.types); let mut base_ty_handle = base_resolution.handle(); // Look through any pointers to see what we're really indexing. if let crate::TypeInner::Pointer { base, space: _ } = *base_ty { base_ty = &context.module.types[base].inner; base_ty_handle = Some(base); } // Handle structs and anything else that can use `.x` syntax here, so // `put_subscripted_access_chain` won't have to handle the absurd case of // indexing a struct with an expression. match *base_ty { crate::TypeInner::Struct { .. } => { let base_ty = base_ty_handle.unwrap(); self.put_access_chain(base, policy, context)?; let name = &self.names[&NameKey::StructMember(base_ty, index)]; write!(self.out, ".{name}")?; } crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Vector { .. } => { self.put_access_chain(base, policy, context)?; // Prior to Metal v2.1 component access for packed vectors wasn't available // however array indexing is if context.get_packed_vec_kind(base).is_some() { write!(self.out, "[{index}]")?; } else { write!(self.out, ".{}", back::COMPONENTS[index as usize])?; } } _ => { self.put_subscripted_access_chain( base, base_ty, index::GuardedIndex::Known(index), policy, context, )?; } } } _ => self.put_expression(chain, context, false)?, } Ok(()) } /// Write a `[]`-style access of `base` by `index`. /// /// If `policy` is [`Restrict`], then generate code as needed to force all index /// values within bounds. /// /// The `base_ty` argument must be the type we are actually indexing, like [`Array`] or /// [`Vector`]. In other words, it's `base`'s type with any surrounding [`Pointer`] /// removed. Our callers often already have this handy. /// /// This only emits `[]` expressions; it doesn't handle struct member accesses or /// referencing vector components by name. /// /// [`Restrict`]: crate::proc::index::BoundsCheckPolicy::Restrict /// [`Array`]: crate::TypeInner::Array /// [`Vector`]: crate::TypeInner::Vector /// [`Pointer`]: crate::TypeInner::Pointer fn put_subscripted_access_chain( &mut self, base: Handle, base_ty: &crate::TypeInner, index: index::GuardedIndex, policy: index::BoundsCheckPolicy, context: &ExpressionContext, ) -> BackendResult { let accessing_wrapped_array = match *base_ty { crate::TypeInner::Array { size: crate::ArraySize::Constant(_) | crate::ArraySize::Pending(_), .. } => true, _ => false, }; let accessing_wrapped_binding_array = matches!(*base_ty, crate::TypeInner::BindingArray { .. }); self.put_access_chain(base, policy, context)?; if accessing_wrapped_array { write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?; } write!(self.out, "[")?; // Decide whether this index needs to be clamped to fall within range. let restriction_needed = if policy == index::BoundsCheckPolicy::Restrict { context.access_needs_check(base, index) } else { None }; if let Some(limit) = restriction_needed { write!(self.out, "{NAMESPACE}::min(unsigned(")?; self.put_index(index, context, true)?; write!(self.out, "), ")?; match limit { index::IndexableLength::Known(limit) => { write!(self.out, "{}u", limit - 1)?; } index::IndexableLength::Dynamic => { let global = context.function.originating_global(base).ok_or_else(|| { Error::GenericValidation("Could not find originating global".into()) })?; self.put_dynamic_array_max_index(global, context)?; } } write!(self.out, ")")?; } else { self.put_index(index, context, true)?; } write!(self.out, "]")?; if accessing_wrapped_binding_array { write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?; } Ok(()) } fn put_load( &mut self, pointer: Handle, context: &ExpressionContext, is_scoped: bool, ) -> BackendResult { // Since access chains never cross between address spaces, we can just // check the index bounds check policy once at the top. let policy = context.choose_bounds_check_policy(pointer); if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite && self.put_bounds_checks( pointer, context, back::Level(0), if is_scoped { "" } else { "(" }, )? { write!(self.out, " ? ")?; self.put_unchecked_load(pointer, policy, context)?; write!(self.out, " : DefaultConstructible()")?; if !is_scoped { write!(self.out, ")")?; } } else { self.put_unchecked_load(pointer, policy, context)?; } Ok(()) } fn put_unchecked_load( &mut self, pointer: Handle, policy: index::BoundsCheckPolicy, context: &ExpressionContext, ) -> BackendResult { let is_atomic_pointer = context .resolve_type(pointer) .is_atomic_pointer(&context.module.types); if is_atomic_pointer { write!( self.out, "{NAMESPACE}::atomic_load_explicit({ATOMIC_REFERENCE}" )?; self.put_access_chain(pointer, policy, context)?; write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?; } else { // We don't do any dereferencing with `*` here as pointer arguments to functions // are done by `&` references and not `*` pointers. These do not need to be // dereferenced. self.put_access_chain(pointer, policy, context)?; } Ok(()) } fn put_return_value( &mut self, level: back::Level, expr_handle: Handle, result_struct: Option<&str>, context: &ExpressionContext, ) -> BackendResult { match result_struct { Some(struct_name) => { let mut has_point_size = false; let result_ty = context.function.result.as_ref().unwrap().ty; match context.module.types[result_ty].inner { crate::TypeInner::Struct { ref members, .. } => { let tmp = "_tmp"; write!(self.out, "{level}const auto {tmp} = ")?; self.put_expression(expr_handle, context, true)?; writeln!(self.out, ";")?; write!(self.out, "{level}return {struct_name} {{")?; let mut is_first = true; for (index, member) in members.iter().enumerate() { if let Some(crate::Binding::BuiltIn(crate::BuiltIn::PointSize)) = member.binding { has_point_size = true; if !context.pipeline_options.allow_and_force_point_size { continue; } } let comma = if is_first { "" } else { "," }; is_first = false; let name = &self.names[&NameKey::StructMember(result_ty, index as u32)]; // HACK: we are forcefully deduplicating the expression here // to convert from a wrapped struct to a raw array, e.g. // `float gl_ClipDistance1 [[clip_distance]] [1];`. if let crate::TypeInner::Array { size: crate::ArraySize::Constant(size), .. } = context.module.types[member.ty].inner { write!(self.out, "{comma} {{")?; for j in 0..size.get() { if j != 0 { write!(self.out, ",")?; } write!(self.out, "{tmp}.{name}.{WRAPPED_ARRAY_FIELD}[{j}]")?; } write!(self.out, "}}")?; } else { write!(self.out, "{comma} {tmp}.{name}")?; } } } _ => { write!(self.out, "{level}return {struct_name} {{ ")?; self.put_expression(expr_handle, context, true)?; } } if let FunctionOrigin::EntryPoint(ep_index) = context.origin { let stage = context.module.entry_points[ep_index as usize].stage; if context.pipeline_options.allow_and_force_point_size && stage == crate::ShaderStage::Vertex && !has_point_size { // point size was injected and comes last write!(self.out, ", 1.0")?; } } write!(self.out, " }}")?; } None => { write!(self.out, "{level}return ")?; self.put_expression(expr_handle, context, true)?; } } writeln!(self.out, ";")?; Ok(()) } /// Helper method used to find which expressions of a given function require baking /// /// # Notes /// This function overwrites the contents of `self.need_bake_expressions` fn update_expressions_to_bake( &mut self, func: &crate::Function, info: &valid::FunctionInfo, context: &ExpressionContext, ) { use crate::Expression; self.need_bake_expressions.clear(); for (expr_handle, expr) in func.expressions.iter() { // Expressions whose reference count is above the // threshold should always be stored in temporaries. let expr_info = &info[expr_handle]; let min_ref_count = func.expressions[expr_handle].bake_ref_count(); if min_ref_count <= expr_info.ref_count { self.need_bake_expressions.insert(expr_handle); } else { match expr_info.ty { // force ray desc to be baked: it's used multiple times internally TypeResolution::Handle(h) if Some(h) == context.module.special_types.ray_desc => { self.need_bake_expressions.insert(expr_handle); } _ => {} } } if let Expression::Math { fun, arg, arg1, arg2, .. } = *expr { match fun { // WGSL's `dot` function works on any `vecN` type, but Metal's only // works on floating-point vectors, so we emit inline code for // integer vector `dot` calls. But that code uses each argument `N` // times, once for each component (see `put_dot_product`), so to // avoid duplicated evaluation, we must bake integer operands. // This applies both when using the polyfill (because of the duplicate // evaluation issue) and when we don't use the polyfill (because we // need them to be emitted before casting to packed chars -- see the // comment at the call to `put_casting_to_packed_chars`). crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => { self.need_bake_expressions.insert(arg); self.need_bake_expressions.insert(arg1.unwrap()); } crate::MathFunction::FirstLeadingBit => { self.need_bake_expressions.insert(arg); } crate::MathFunction::Pack4xI8 | crate::MathFunction::Pack4xU8 | crate::MathFunction::Pack4xI8Clamp | crate::MathFunction::Pack4xU8Clamp | crate::MathFunction::Unpack4xI8 | crate::MathFunction::Unpack4xU8 => { // On MSL < 2.1, we emit a polyfill for these functions that uses the // argument multiple times. This is no longer necessary on MSL >= 2.1. if context.lang_version < (2, 1) { self.need_bake_expressions.insert(arg); } } crate::MathFunction::ExtractBits => { // Only argument 1 is re-used. self.need_bake_expressions.insert(arg1.unwrap()); } crate::MathFunction::InsertBits => { // Only argument 2 is re-used. self.need_bake_expressions.insert(arg2.unwrap()); } crate::MathFunction::Sign => { // WGSL's `sign` function works also on signed ints, but Metal's only // works on floating points, so we emit inline code for integer `sign` // calls. But that code uses each argument 2 times (see `put_isign`), // so to avoid duplicated evaluation, we must bake the argument. let inner = context.resolve_type(expr_handle); if inner.scalar_kind() == Some(crate::ScalarKind::Sint) { self.need_bake_expressions.insert(arg); } } _ => {} } } } } fn start_baking_expression( &mut self, handle: Handle, context: &ExpressionContext, name: &str, ) -> BackendResult { match context.info[handle].ty { TypeResolution::Handle(ty_handle) => { let ty_name = TypeContext { handle: ty_handle, gctx: context.module.to_ctx(), names: &self.names, access: crate::StorageAccess::empty(), first_time: false, }; write!(self.out, "{ty_name}")?; } TypeResolution::Value(crate::TypeInner::Scalar(scalar)) => { put_numeric_type(&mut self.out, scalar, &[])?; } TypeResolution::Value(crate::TypeInner::Vector { size, scalar }) => { put_numeric_type(&mut self.out, scalar, &[size])?; } TypeResolution::Value(crate::TypeInner::Matrix { columns, rows, scalar, }) => { put_numeric_type(&mut self.out, scalar, &[rows, columns])?; } TypeResolution::Value(crate::TypeInner::CooperativeMatrix { columns, rows, scalar, role: _, }) => { write!( self.out, "{}::simdgroup_{}{}x{}", NAMESPACE, scalar.to_msl_name(), columns as u32, rows as u32, )?; } TypeResolution::Value(ref other) => { log::warn!("Type {other:?} isn't a known local"); return Err(Error::FeatureNotImplemented("weird local type".to_string())); } } //TODO: figure out the naming scheme that wouldn't collide with user names. write!(self.out, " {name} = ")?; Ok(()) } /// Cache a clamped level of detail value, if necessary. /// /// [`ImageLoad`] accesses covered by [`BoundsCheckPolicy::Restrict`] use a /// properly clamped level of detail value both in the access itself, and /// for fetching the size of the requested MIP level, needed to clamp the /// coordinates. To avoid recomputing this clamped level of detail, we cache /// it in a temporary variable, as part of the [`Emit`] statement covering /// the [`ImageLoad`] expression. /// /// [`ImageLoad`]: crate::Expression::ImageLoad /// [`BoundsCheckPolicy::Restrict`]: index::BoundsCheckPolicy::Restrict /// [`Emit`]: crate::Statement::Emit fn put_cache_restricted_level( &mut self, load: Handle, image: Handle, mip_level: Option>, indent: back::Level, context: &StatementContext, ) -> BackendResult { // Does this image access actually require (or even permit) a // level-of-detail, and does the policy require us to restrict it? let level_of_detail = match mip_level { Some(level) => level, None => return Ok(()), }; if context.expression.policies.image_load != index::BoundsCheckPolicy::Restrict || !context.expression.image_needs_lod(image) { return Ok(()); } write!(self.out, "{}uint {} = ", indent, ClampedLod(load),)?; self.put_restricted_scalar_image_index( image, level_of_detail, "get_num_mip_levels", &context.expression, )?; writeln!(self.out, ";")?; Ok(()) } /// Convert the arguments of `Dot4{I, U}Packed` to `packed_(u?)char4`. /// /// Caches the results in temporary variables (whose names are derived from /// the original variable names). This caching avoids the need to redo the /// casting for each vector component when emitting the dot product. fn put_casting_to_packed_chars( &mut self, fun: crate::MathFunction, arg0: Handle, arg1: Handle, indent: back::Level, context: &StatementContext<'_>, ) -> Result<(), Error> { let packed_type = match fun { crate::MathFunction::Dot4I8Packed => "packed_char4", crate::MathFunction::Dot4U8Packed => "packed_uchar4", _ => unreachable!(), }; for arg in [arg0, arg1] { write!( self.out, "{indent}{packed_type} {0} = as_type<{packed_type}>(", Reinterpreted::new(packed_type, arg) )?; self.put_expression(arg, &context.expression, true)?; writeln!(self.out, ");")?; } Ok(()) } fn put_block( &mut self, level: back::Level, statements: &[crate::Statement], context: &StatementContext, ) -> BackendResult { // Add to the set in order to track the stack size. #[cfg(test)] self.put_block_stack_pointers .insert(ptr::from_ref(&level).cast()); for statement in statements { log::trace!("statement[{}] {:?}", level.0, statement); match *statement { crate::Statement::Emit(ref range) => { for handle in range.clone() { use crate::MathFunction as Mf; match context.expression.function.expressions[handle] { // `ImageLoad` expressions covered by the `Restrict` bounds check policy // may need to cache a clamped version of their level-of-detail argument. crate::Expression::ImageLoad { image, level: mip_level, .. } => { self.put_cache_restricted_level( handle, image, mip_level, level, context, )?; } // If we are going to write a `Dot4I8Packed` or `Dot4U8Packed` on Metal // 2.1+ then we introduce two intermediate variables that recast the two // arguments as packed (signed or unsigned) chars. The actual dot product // is implemented in `Self::put_expression`, and it uses both of these // intermediate variables multiple times. There's no danger that the // original arguments get modified between the definition of these // intermediate variables and the implementation of the actual dot // product since we require the inputs of `Dot4{I, U}Packed` to be baked. crate::Expression::Math { fun: fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed), arg, arg1, .. } if context.expression.lang_version >= (2, 1) => { self.put_casting_to_packed_chars( fun, arg, arg1.unwrap(), level, context, )?; } _ => (), } let ptr_class = context.expression.resolve_type(handle).pointer_space(); let expr_name = if ptr_class.is_some() { None // don't bake pointer expressions (just yet) } else if let Some(name) = context.expression.function.named_expressions.get(&handle) { // The `crate::Function::named_expressions` table holds // expressions that should be saved in temporaries once they // are `Emit`ted. We only add them to `self.named_expressions` // when we reach the `Emit` that covers them, so that we don't // try to use their names before we've actually initialized // the temporary that holds them. // // Don't assume the names in `named_expressions` are unique, // or even valid. Use the `Namer`. Some(self.namer.call(name)) } else { // If this expression is an index that we're going to first compare // against a limit, and then actually use as an index, then we may // want to cache it in a temporary, to avoid evaluating it twice. let bake = if context.expression.guarded_indices.contains(handle) { true } else { self.need_bake_expressions.contains(&handle) }; if bake { Some(Baked(handle).to_string()) } else { None } }; if let Some(name) = expr_name { write!(self.out, "{level}")?; self.start_baking_expression(handle, &context.expression, &name)?; self.put_expression(handle, &context.expression, true)?; self.named_expressions.insert(handle, name); writeln!(self.out, ";")?; } } } crate::Statement::Block(ref block) => { if !block.is_empty() { writeln!(self.out, "{level}{{")?; self.put_block(level.next(), block, context)?; writeln!(self.out, "{level}}}")?; } } crate::Statement::If { condition, ref accept, ref reject, } => { write!(self.out, "{level}if (")?; self.put_expression(condition, &context.expression, true)?; writeln!(self.out, ") {{")?; self.put_block(level.next(), accept, context)?; if !reject.is_empty() { writeln!(self.out, "{level}}} else {{")?; self.put_block(level.next(), reject, context)?; } writeln!(self.out, "{level}}}")?; } crate::Statement::Switch { selector, ref cases, } => { write!(self.out, "{level}switch(")?; self.put_expression(selector, &context.expression, true)?; writeln!(self.out, ") {{")?; let lcase = level.next(); for case in cases.iter() { match case.value { crate::SwitchValue::I32(value) => { write!(self.out, "{lcase}case {value}:")?; } crate::SwitchValue::U32(value) => { write!(self.out, "{lcase}case {value}u:")?; } crate::SwitchValue::Default => { write!(self.out, "{lcase}default:")?; } } let write_block_braces = !(case.fall_through && case.body.is_empty()); if write_block_braces { writeln!(self.out, " {{")?; } else { writeln!(self.out)?; } self.put_block(lcase.next(), &case.body, context)?; if !case.fall_through && case.body.last().is_none_or(|s| !s.is_terminator()) { writeln!(self.out, "{}break;", lcase.next())?; } if write_block_braces { writeln!(self.out, "{lcase}}}")?; } } writeln!(self.out, "{level}}}")?; } crate::Statement::Loop { ref body, ref continuing, break_if, } => { let force_loop_bound_statements = self.gen_force_bounded_loop_statements(level, context); let gate_name = (!continuing.is_empty() || break_if.is_some()) .then(|| self.namer.call("loop_init")); if let Some((ref decl, _)) = force_loop_bound_statements { writeln!(self.out, "{decl}")?; } if let Some(ref gate_name) = gate_name { writeln!(self.out, "{level}bool {gate_name} = true;")?; } writeln!(self.out, "{level}while(true) {{",)?; if let Some((_, ref break_and_inc)) = force_loop_bound_statements { writeln!(self.out, "{break_and_inc}")?; } if let Some(ref gate_name) = gate_name { let lif = level.next(); let lcontinuing = lif.next(); writeln!(self.out, "{lif}if (!{gate_name}) {{")?; self.put_block(lcontinuing, continuing, context)?; if let Some(condition) = break_if { write!(self.out, "{lcontinuing}if (")?; self.put_expression(condition, &context.expression, true)?; writeln!(self.out, ") {{")?; writeln!(self.out, "{}break;", lcontinuing.next())?; writeln!(self.out, "{lcontinuing}}}")?; } writeln!(self.out, "{lif}}}")?; writeln!(self.out, "{lif}{gate_name} = false;")?; } self.put_block(level.next(), body, context)?; writeln!(self.out, "{level}}}")?; } crate::Statement::Break => { writeln!(self.out, "{level}break;")?; } crate::Statement::Continue => { writeln!(self.out, "{level}continue;")?; } crate::Statement::Return { value: Some(expr_handle), } => { self.put_return_value( level, expr_handle, context.result_struct, &context.expression, )?; } crate::Statement::Return { value: None } => { writeln!(self.out, "{level}return;")?; } crate::Statement::Kill => { writeln!(self.out, "{level}{NAMESPACE}::discard_fragment();")?; } crate::Statement::ControlBarrier(flags) | crate::Statement::MemoryBarrier(flags) => { self.write_barrier(flags, level)?; } crate::Statement::Store { pointer, value } => { self.put_store(pointer, value, level, context)? } crate::Statement::ImageStore { image, coordinate, array_index, value, } => { let address = TexelAddress { coordinate, array_index, sample: None, level: None, }; self.put_image_store(level, image, &address, value, context)? } crate::Statement::Call { function, ref arguments, result, } => { write!(self.out, "{level}")?; if let Some(expr) = result { let name = Baked(expr).to_string(); self.start_baking_expression(expr, &context.expression, &name)?; self.named_expressions.insert(expr, name); } let fun_name = &self.names[&NameKey::Function(function)]; write!(self.out, "{fun_name}(")?; // first, write down the actual arguments for (i, &handle) in arguments.iter().enumerate() { if i != 0 { write!(self.out, ", ")?; } self.put_expression(handle, &context.expression, true)?; } // follow-up with any global resources used let mut separate = !arguments.is_empty(); let fun_info = &context.expression.mod_info[function]; let mut needs_buffer_sizes = false; for (handle, var) in context.expression.module.global_variables.iter() { if fun_info[handle].is_empty() { continue; } if var.space.needs_pass_through() { let name = &self.names[&NameKey::GlobalVariable(handle)]; if separate { write!(self.out, ", ")?; } else { separate = true; } write!(self.out, "{name}")?; } needs_buffer_sizes |= needs_array_length(var.ty, &context.expression.module.types); } if needs_buffer_sizes { if separate { write!(self.out, ", ")?; } write!(self.out, "_buffer_sizes")?; } // done writeln!(self.out, ");")?; } crate::Statement::Atomic { pointer, ref fun, value, result, } => { let context = &context.expression; // This backend supports `SHADER_INT64_ATOMIC_MIN_MAX` but not // `SHADER_INT64_ATOMIC_ALL_OPS`, so we can assume that if `result` is // `Some`, we are not operating on a 64-bit value, and that if we are // operating on a 64-bit value, `result` is `None`. write!(self.out, "{level}")?; let fun_key = if let Some(result) = result { let res_name = Baked(result).to_string(); self.start_baking_expression(result, context, &res_name)?; self.named_expressions.insert(result, res_name); fun.to_msl() } else if context.resolve_type(value).scalar_width() == Some(8) { fun.to_msl_64_bit()? } else { fun.to_msl() }; // If the pointer we're passing to the atomic operation needs to be conditional // for `ReadZeroSkipWrite`, the condition needs to *surround* the atomic op, and // the pointer operand should be unchecked. let policy = context.choose_bounds_check_policy(pointer); let checked = policy == index::BoundsCheckPolicy::ReadZeroSkipWrite && self.put_bounds_checks(pointer, context, back::Level(0), "")?; // If requested and successfully put bounds checks, continue the ternary expression. if checked { write!(self.out, " ? ")?; } // Put the atomic function invocation. match *fun { crate::AtomicFunction::Exchange { compare: Some(cmp) } => { write!(self.out, "{ATOMIC_COMP_EXCH_FUNCTION}({ATOMIC_REFERENCE}")?; self.put_access_chain(pointer, policy, context)?; write!(self.out, ", ")?; self.put_expression(cmp, context, true)?; write!(self.out, ", ")?; self.put_expression(value, context, true)?; write!(self.out, ")")?; } _ => { write!( self.out, "{NAMESPACE}::atomic_{fun_key}_explicit({ATOMIC_REFERENCE}" )?; self.put_access_chain(pointer, policy, context)?; write!(self.out, ", ")?; self.put_expression(value, context, true)?; write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?; } } // Finish the ternary expression. if checked { write!(self.out, " : DefaultConstructible()")?; } // Done writeln!(self.out, ";")?; } crate::Statement::ImageAtomic { image, coordinate, array_index, fun, value, } => { let address = TexelAddress { coordinate, array_index, sample: None, level: None, }; self.put_image_atomic(level, image, &address, fun, value, context)? } crate::Statement::WorkGroupUniformLoad { pointer, result } => { self.write_barrier(crate::Barrier::WORK_GROUP, level)?; write!(self.out, "{level}")?; let name = self.namer.call(""); self.start_baking_expression(result, &context.expression, &name)?; self.put_load(pointer, &context.expression, true)?; self.named_expressions.insert(result, name); writeln!(self.out, ";")?; self.write_barrier(crate::Barrier::WORK_GROUP, level)?; } crate::Statement::RayQuery { query, ref fun } => { if context.expression.lang_version < (2, 4) { return Err(Error::UnsupportedRayTracing); } match *fun { crate::RayQueryFunction::Initialize { acceleration_structure, descriptor, } => { //TODO: how to deal with winding? write!(self.out, "{level}")?; self.put_expression(query, &context.expression, true)?; writeln!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.assume_geometry_type({RT_NAMESPACE}::geometry_type::triangle);")?; { let f_opaque = back::RayFlag::CULL_OPAQUE.bits(); let f_no_opaque = back::RayFlag::CULL_NO_OPAQUE.bits(); write!(self.out, "{level}")?; self.put_expression(query, &context.expression, true)?; write!( self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.set_opacity_cull_mode((" )?; self.put_expression(descriptor, &context.expression, true)?; write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::opaque : (")?; self.put_expression(descriptor, &context.expression, true)?; write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::non_opaque : ")?; writeln!(self.out, "{RT_NAMESPACE}::opacity_cull_mode::none);")?; } { let f_opaque = back::RayFlag::OPAQUE.bits(); let f_no_opaque = back::RayFlag::NO_OPAQUE.bits(); write!(self.out, "{level}")?; self.put_expression(query, &context.expression, true)?; write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.force_opacity((")?; self.put_expression(descriptor, &context.expression, true)?; write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::opaque : (")?; self.put_expression(descriptor, &context.expression, true)?; write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::non_opaque : ")?; writeln!(self.out, "{RT_NAMESPACE}::forced_opacity::none);")?; } { let flag = back::RayFlag::TERMINATE_ON_FIRST_HIT.bits(); write!(self.out, "{level}")?; self.put_expression(query, &context.expression, true)?; write!( self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.accept_any_intersection((" )?; self.put_expression(descriptor, &context.expression, true)?; writeln!(self.out, ".flags & {flag}) != 0);")?; } write!(self.out, "{level}")?; self.put_expression(query, &context.expression, true)?; write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION} = ")?; self.put_expression(query, &context.expression, true)?; write!( self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.intersect({RT_NAMESPACE}::ray(" )?; self.put_expression(descriptor, &context.expression, true)?; write!(self.out, ".origin, ")?; self.put_expression(descriptor, &context.expression, true)?; write!(self.out, ".dir, ")?; self.put_expression(descriptor, &context.expression, true)?; write!(self.out, ".tmin, ")?; self.put_expression(descriptor, &context.expression, true)?; write!(self.out, ".tmax), ")?; self.put_expression(acceleration_structure, &context.expression, true)?; write!(self.out, ", ")?; self.put_expression(descriptor, &context.expression, true)?; write!(self.out, ".cull_mask);")?; write!(self.out, "{level}")?; self.put_expression(query, &context.expression, true)?; writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = true;")?; } crate::RayQueryFunction::Proceed { result } => { write!(self.out, "{level}")?; let name = Baked(result).to_string(); self.start_baking_expression(result, &context.expression, &name)?; self.named_expressions.insert(result, name); self.put_expression(query, &context.expression, true)?; writeln!(self.out, ".{RAY_QUERY_FIELD_READY};")?; if RAY_QUERY_MODERN_SUPPORT { write!(self.out, "{level}")?; self.put_expression(query, &context.expression, true)?; writeln!(self.out, ".?.next();")?; } } crate::RayQueryFunction::GenerateIntersection { hit_t } => { if RAY_QUERY_MODERN_SUPPORT { write!(self.out, "{level}")?; self.put_expression(query, &context.expression, true)?; write!(self.out, ".?.commit_bounding_box_intersection(")?; self.put_expression(hit_t, &context.expression, true)?; writeln!(self.out, ");")?; } else { log::warn!("Ray Query GenerateIntersection is not yet supported"); } } crate::RayQueryFunction::ConfirmIntersection => { if RAY_QUERY_MODERN_SUPPORT { write!(self.out, "{level}")?; self.put_expression(query, &context.expression, true)?; writeln!(self.out, ".?.commit_triangle_intersection();")?; } else { log::warn!("Ray Query ConfirmIntersection is not yet supported"); } } crate::RayQueryFunction::Terminate => { if RAY_QUERY_MODERN_SUPPORT { write!(self.out, "{level}")?; self.put_expression(query, &context.expression, true)?; writeln!(self.out, ".?.abort();")?; } write!(self.out, "{level}")?; self.put_expression(query, &context.expression, true)?; writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = false;")?; } } } crate::Statement::SubgroupBallot { result, predicate } => { write!(self.out, "{level}")?; let name = self.namer.call(""); self.start_baking_expression(result, &context.expression, &name)?; self.named_expressions.insert(result, name); write!( self.out, "{NAMESPACE}::uint4((uint64_t){NAMESPACE}::simd_ballot(" )?; if let Some(predicate) = predicate { self.put_expression(predicate, &context.expression, true)?; } else { write!(self.out, "true")?; } writeln!(self.out, "), 0, 0, 0);")?; } crate::Statement::SubgroupCollectiveOperation { op, collective_op, argument, result, } => { write!(self.out, "{level}")?; let name = self.namer.call(""); self.start_baking_expression(result, &context.expression, &name)?; self.named_expressions.insert(result, name); match (collective_op, op) { (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => { write!(self.out, "{NAMESPACE}::simd_all(")? } (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => { write!(self.out, "{NAMESPACE}::simd_any(")? } (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => { write!(self.out, "{NAMESPACE}::simd_sum(")? } (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => { write!(self.out, "{NAMESPACE}::simd_product(")? } (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => { write!(self.out, "{NAMESPACE}::simd_max(")? } (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => { write!(self.out, "{NAMESPACE}::simd_min(")? } (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => { write!(self.out, "{NAMESPACE}::simd_and(")? } (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => { write!(self.out, "{NAMESPACE}::simd_or(")? } (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => { write!(self.out, "{NAMESPACE}::simd_xor(")? } ( crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add, ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_sum(")?, ( crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul, ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_product(")?, ( crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add, ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_sum(")?, ( crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul, ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_product(")?, _ => unimplemented!(), } self.put_expression(argument, &context.expression, true)?; writeln!(self.out, ");")?; } crate::Statement::SubgroupGather { mode, argument, result, } => { write!(self.out, "{level}")?; let name = self.namer.call(""); self.start_baking_expression(result, &context.expression, &name)?; self.named_expressions.insert(result, name); match mode { crate::GatherMode::BroadcastFirst => { write!(self.out, "{NAMESPACE}::simd_broadcast_first(")?; } crate::GatherMode::Broadcast(_) => { write!(self.out, "{NAMESPACE}::simd_broadcast(")?; } crate::GatherMode::Shuffle(_) => { write!(self.out, "{NAMESPACE}::simd_shuffle(")?; } crate::GatherMode::ShuffleDown(_) => { write!(self.out, "{NAMESPACE}::simd_shuffle_down(")?; } crate::GatherMode::ShuffleUp(_) => { write!(self.out, "{NAMESPACE}::simd_shuffle_up(")?; } crate::GatherMode::ShuffleXor(_) => { write!(self.out, "{NAMESPACE}::simd_shuffle_xor(")?; } crate::GatherMode::QuadBroadcast(_) => { write!(self.out, "{NAMESPACE}::quad_broadcast(")?; } crate::GatherMode::QuadSwap(_) => { write!(self.out, "{NAMESPACE}::quad_shuffle_xor(")?; } } self.put_expression(argument, &context.expression, true)?; match mode { crate::GatherMode::BroadcastFirst => {} crate::GatherMode::Broadcast(index) | crate::GatherMode::Shuffle(index) | crate::GatherMode::ShuffleDown(index) | crate::GatherMode::ShuffleUp(index) | crate::GatherMode::ShuffleXor(index) | crate::GatherMode::QuadBroadcast(index) => { write!(self.out, ", ")?; self.put_expression(index, &context.expression, true)?; } crate::GatherMode::QuadSwap(direction) => { write!(self.out, ", ")?; match direction { crate::Direction::X => { write!(self.out, "1u")?; } crate::Direction::Y => { write!(self.out, "2u")?; } crate::Direction::Diagonal => { write!(self.out, "3u")?; } } } } writeln!(self.out, ");")?; } crate::Statement::CooperativeStore { target, ref data } => { write!(self.out, "{level}simdgroup_store(")?; self.put_expression(target, &context.expression, true)?; write!(self.out, ", &")?; self.put_access_chain( data.pointer, context.expression.policies.index, &context.expression, )?; write!(self.out, ", ")?; self.put_expression(data.stride, &context.expression, true)?; if data.row_major { let matrix_origin = "0"; let transpose = true; write!(self.out, ", {matrix_origin}, {transpose}")?; } writeln!(self.out, ");")?; } crate::Statement::RayPipelineFunction(_) => unreachable!(), } } // un-emit expressions //TODO: take care of loop/continuing? for statement in statements { if let crate::Statement::Emit(ref range) = *statement { for handle in range.clone() { self.named_expressions.shift_remove(&handle); } } } Ok(()) } fn put_store( &mut self, pointer: Handle, value: Handle, level: back::Level, context: &StatementContext, ) -> BackendResult { let policy = context.expression.choose_bounds_check_policy(pointer); if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite && self.put_bounds_checks(pointer, &context.expression, level, "if (")? { writeln!(self.out, ") {{")?; self.put_unchecked_store(pointer, value, policy, level.next(), context)?; writeln!(self.out, "{level}}}")?; } else { self.put_unchecked_store(pointer, value, policy, level, context)?; } Ok(()) } fn put_unchecked_store( &mut self, pointer: Handle, value: Handle, policy: index::BoundsCheckPolicy, level: back::Level, context: &StatementContext, ) -> BackendResult { let is_atomic_pointer = context .expression .resolve_type(pointer) .is_atomic_pointer(&context.expression.module.types); if is_atomic_pointer { write!( self.out, "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}" )?; self.put_access_chain(pointer, policy, &context.expression)?; write!(self.out, ", ")?; self.put_expression(value, &context.expression, true)?; writeln!(self.out, ", {NAMESPACE}::memory_order_relaxed);")?; } else { write!(self.out, "{level}")?; self.put_access_chain(pointer, policy, &context.expression)?; write!(self.out, " = ")?; self.put_expression(value, &context.expression, true)?; writeln!(self.out, ";")?; } Ok(()) } pub fn write( &mut self, module: &crate::Module, info: &valid::ModuleInfo, options: &Options, pipeline_options: &PipelineOptions, ) -> Result { self.names.clear(); self.namer.reset( module, &super::keywords::RESERVED_SET, proc::KeywordSet::empty(), proc::CaseInsensitiveKeywordSet::empty(), &[CLAMPED_LOD_LOAD_PREFIX], &mut self.names, ); self.wrapped_functions.clear(); self.struct_member_pads.clear(); writeln!( self.out, "// language: metal{}.{}", options.lang_version.0, options.lang_version.1 )?; writeln!(self.out, "#include ")?; writeln!(self.out, "#include ")?; writeln!(self.out)?; // Work around Metal bug where `uint` is not available by default writeln!(self.out, "using {NAMESPACE}::uint;")?; let mut uses_ray_query = false; for (_, ty) in module.types.iter() { match ty.inner { crate::TypeInner::AccelerationStructure { .. } => { if options.lang_version < (2, 4) { return Err(Error::UnsupportedRayTracing); } } crate::TypeInner::RayQuery { .. } => { if options.lang_version < (2, 4) { return Err(Error::UnsupportedRayTracing); } uses_ray_query = true; } _ => (), } } if module.special_types.ray_desc.is_some() || module.special_types.ray_intersection.is_some() { if options.lang_version < (2, 4) { return Err(Error::UnsupportedRayTracing); } } if uses_ray_query { self.put_ray_query_type()?; } if options .bounds_check_policies .contains(index::BoundsCheckPolicy::ReadZeroSkipWrite) { self.put_default_constructible()?; } writeln!(self.out)?; { // Make a `Vec` of all the `GlobalVariable`s that contain // runtime-sized arrays. let globals: Vec> = module .global_variables .iter() .filter(|&(_, var)| needs_array_length(var.ty, &module.types)) .map(|(handle, _)| handle) .collect(); let mut buffer_indices = vec![]; for vbm in &pipeline_options.vertex_buffer_mappings { buffer_indices.push(vbm.id); } if !globals.is_empty() || !buffer_indices.is_empty() { writeln!(self.out, "struct _mslBufferSizes {{")?; for global in globals { writeln!( self.out, "{}uint {};", back::INDENT, ArraySizeMember(global) )?; } for idx in buffer_indices { writeln!(self.out, "{}uint buffer_size{};", back::INDENT, idx)?; } writeln!(self.out, "}};")?; writeln!(self.out)?; } }; self.write_type_defs(module)?; self.write_global_constants(module, info)?; self.write_functions(module, info, options, pipeline_options) } /// Write the definition for the `DefaultConstructible` class. /// /// The [`ReadZeroSkipWrite`] bounds check policy requires us to be able to /// produce 'zero' values for any type, including structs, arrays, and so /// on. We could do this by emitting default constructor applications, but /// that would entail printing the name of the type, which is more trouble /// than you'd think. Instead, we just construct this magic C++14 class that /// can be converted to any type that can be default constructed, using /// template parameter inference to detect which type is needed, so we don't /// have to figure out the name. /// /// [`ReadZeroSkipWrite`]: index::BoundsCheckPolicy::ReadZeroSkipWrite fn put_default_constructible(&mut self) -> BackendResult { let tab = back::INDENT; writeln!(self.out, "struct DefaultConstructible {{")?; writeln!(self.out, "{tab}template")?; writeln!(self.out, "{tab}operator T() && {{")?; writeln!(self.out, "{tab}{tab}return T {{}};")?; writeln!(self.out, "{tab}}}")?; writeln!(self.out, "}};")?; Ok(()) } fn put_ray_query_type(&mut self) -> BackendResult { let tab = back::INDENT; writeln!(self.out, "struct {RAY_QUERY_TYPE} {{")?; let full_type = format!("{RT_NAMESPACE}::intersector<{RT_NAMESPACE}::instancing, {RT_NAMESPACE}::triangle_data, {RT_NAMESPACE}::world_space_data>"); writeln!(self.out, "{tab}{full_type} {RAY_QUERY_FIELD_INTERSECTOR};")?; writeln!( self.out, "{tab}{full_type}::result_type {RAY_QUERY_FIELD_INTERSECTION};" )?; writeln!(self.out, "{tab}bool {RAY_QUERY_FIELD_READY} = false;")?; writeln!(self.out, "}};")?; writeln!(self.out, "constexpr {NAMESPACE}::uint {RAY_QUERY_FUN_MAP_INTERSECTION}(const {RT_NAMESPACE}::intersection_type ty) {{")?; let v_triangle = back::RayIntersectionType::Triangle as u32; let v_bbox = back::RayIntersectionType::BoundingBox as u32; writeln!( self.out, "{tab}return ty=={RT_NAMESPACE}::intersection_type::triangle ? {v_triangle} : " )?; writeln!( self.out, "{tab}{tab}ty=={RT_NAMESPACE}::intersection_type::bounding_box ? {v_bbox} : 0;" )?; writeln!(self.out, "}}")?; Ok(()) } fn write_type_defs(&mut self, module: &crate::Module) -> BackendResult { let mut generated_argument_buffer_wrapper = false; let mut generated_external_texture_wrapper = false; for (handle, ty) in module.types.iter() { match ty.inner { crate::TypeInner::BindingArray { .. } if !generated_argument_buffer_wrapper => { writeln!(self.out, "template ")?; writeln!(self.out, "struct {ARGUMENT_BUFFER_WRAPPER_STRUCT} {{")?; writeln!(self.out, "{}T {WRAPPED_ARRAY_FIELD};", back::INDENT)?; writeln!(self.out, "}};")?; generated_argument_buffer_wrapper = true; } crate::TypeInner::Image { class: crate::ImageClass::External, .. } if !generated_external_texture_wrapper => { let params_ty_name = &self.names [&NameKey::Type(module.special_types.external_texture_params.unwrap())]; writeln!(self.out, "struct {EXTERNAL_TEXTURE_WRAPPER_STRUCT} {{")?; writeln!( self.out, "{}{NAMESPACE}::texture2d plane0;", back::INDENT )?; writeln!( self.out, "{}{NAMESPACE}::texture2d plane1;", back::INDENT )?; writeln!( self.out, "{}{NAMESPACE}::texture2d plane2;", back::INDENT )?; writeln!(self.out, "{}{params_ty_name} params;", back::INDENT)?; writeln!(self.out, "}};")?; generated_external_texture_wrapper = true; } _ => {} } if !ty.needs_alias() { continue; } let name = &self.names[&NameKey::Type(handle)]; match ty.inner { // Naga IR can pass around arrays by value, but Metal, following // C++, performs an array-to-pointer conversion (C++ [conv.array]) // on expressions of array type, so assigning the array by value // isn't possible. However, Metal *does* assign structs by // value. So in our Metal output, we wrap all array types in // synthetic struct types: // // struct type1 { // float inner[10] // }; // // Then we carefully include `.inner` (`WRAPPED_ARRAY_FIELD`) in // any expression that actually wants access to the array. crate::TypeInner::Array { base, size, stride: _, } => { let base_name = TypeContext { handle: base, gctx: module.to_ctx(), names: &self.names, access: crate::StorageAccess::empty(), first_time: false, }; match size.resolve(module.to_ctx())? { proc::IndexableLength::Known(size) => { writeln!(self.out, "struct {name} {{")?; writeln!( self.out, "{}{} {}[{}];", back::INDENT, base_name, WRAPPED_ARRAY_FIELD, size )?; writeln!(self.out, "}};")?; } proc::IndexableLength::Dynamic => { writeln!(self.out, "typedef {base_name} {name}[1];")?; } } } crate::TypeInner::Struct { ref members, span, .. } => { writeln!(self.out, "struct {name} {{")?; let mut last_offset = 0; for (index, member) in members.iter().enumerate() { if member.offset > last_offset { self.struct_member_pads.insert((handle, index as u32)); let pad = member.offset - last_offset; writeln!(self.out, "{}char _pad{}[{}];", back::INDENT, index, pad)?; } let ty_inner = &module.types[member.ty].inner; last_offset = member.offset + ty_inner.size(module.to_ctx()); let member_name = &self.names[&NameKey::StructMember(handle, index as u32)]; // If the member should be packed (as is the case for a misaligned vec3) issue a packed vector match should_pack_struct_member(members, span, index, module) { Some(scalar) => { writeln!( self.out, "{}{}::packed_{}3 {};", back::INDENT, NAMESPACE, scalar.to_msl_name(), member_name )?; } None => { let base_name = TypeContext { handle: member.ty, gctx: module.to_ctx(), names: &self.names, access: crate::StorageAccess::empty(), first_time: false, }; writeln!( self.out, "{}{} {};", back::INDENT, base_name, member_name )?; // for 3-component vectors, add one component if let crate::TypeInner::Vector { size: crate::VectorSize::Tri, scalar, } = *ty_inner { last_offset += scalar.width as u32; } } } } if last_offset < span { let pad = span - last_offset; writeln!( self.out, "{}char _pad{}[{}];", back::INDENT, members.len(), pad )?; } writeln!(self.out, "}};")?; } _ => { let ty_name = TypeContext { handle, gctx: module.to_ctx(), names: &self.names, access: crate::StorageAccess::empty(), first_time: true, }; writeln!(self.out, "typedef {ty_name} {name};")?; } } } // Write functions to create special types. for (type_key, struct_ty) in module.special_types.predeclared_types.iter() { match type_key { &crate::PredeclaredType::ModfResult { size, scalar } | &crate::PredeclaredType::FrexpResult { size, scalar } => { let arg_type_name_owner; let arg_type_name = if let Some(size) = size { arg_type_name_owner = format!( "{NAMESPACE}::{}{}", if scalar.width == 8 { "double" } else { "float" }, size as u8 ); &arg_type_name_owner } else if scalar.width == 8 { "double" } else { "float" }; let other_type_name_owner; let (defined_func_name, called_func_name, other_type_name) = if matches!(type_key, &crate::PredeclaredType::ModfResult { .. }) { (MODF_FUNCTION, "modf", arg_type_name) } else { let other_type_name = if let Some(size) = size { other_type_name_owner = format!("int{}", size as u8); &other_type_name_owner } else { "int" }; (FREXP_FUNCTION, "frexp", other_type_name) }; let struct_name = &self.names[&NameKey::Type(*struct_ty)]; writeln!(self.out)?; writeln!( self.out, "{struct_name} {defined_func_name}({arg_type_name} arg) {{ {other_type_name} other; {arg_type_name} fract = {NAMESPACE}::{called_func_name}(arg, other); return {struct_name}{{ fract, other }}; }}" )?; } &crate::PredeclaredType::AtomicCompareExchangeWeakResult(scalar) => { let arg_type_name = scalar.to_msl_name(); let called_func_name = "atomic_compare_exchange_weak_explicit"; let defined_func_name = ATOMIC_COMP_EXCH_FUNCTION; let struct_name = &self.names[&NameKey::Type(*struct_ty)]; writeln!(self.out)?; for address_space_name in ["device", "threadgroup"] { writeln!( self.out, "\ template {struct_name} {defined_func_name}( {address_space_name} A *atomic_ptr, {arg_type_name} cmp, {arg_type_name} v ) {{ bool swapped = {NAMESPACE}::{called_func_name}( atomic_ptr, &cmp, v, metal::memory_order_relaxed, metal::memory_order_relaxed ); return {struct_name}{{cmp, swapped}}; }}" )?; } } } } Ok(()) } /// Writes all named constants fn write_global_constants( &mut self, module: &crate::Module, mod_info: &valid::ModuleInfo, ) -> BackendResult { let constants = module.constants.iter().filter(|&(_, c)| c.name.is_some()); for (handle, constant) in constants { let ty_name = TypeContext { handle: constant.ty, gctx: module.to_ctx(), names: &self.names, access: crate::StorageAccess::empty(), first_time: false, }; let name = &self.names[&NameKey::Constant(handle)]; write!(self.out, "constant {ty_name} {name} = ")?; self.put_const_expression(constant.init, module, mod_info, &module.global_expressions)?; writeln!(self.out, ";")?; } Ok(()) } fn put_inline_sampler_properties( &mut self, level: back::Level, sampler: &sm::InlineSampler, ) -> BackendResult { for (&letter, address) in ['s', 't', 'r'].iter().zip(sampler.address.iter()) { writeln!( self.out, "{}{}::{}_address::{},", level, NAMESPACE, letter, address.as_str(), )?; } writeln!( self.out, "{}{}::mag_filter::{},", level, NAMESPACE, sampler.mag_filter.as_str(), )?; writeln!( self.out, "{}{}::min_filter::{},", level, NAMESPACE, sampler.min_filter.as_str(), )?; if let Some(filter) = sampler.mip_filter { writeln!( self.out, "{}{}::mip_filter::{},", level, NAMESPACE, filter.as_str(), )?; } // avoid setting it on platforms that don't support it if sampler.border_color != sm::BorderColor::TransparentBlack { writeln!( self.out, "{}{}::border_color::{},", level, NAMESPACE, sampler.border_color.as_str(), )?; } //TODO: I'm not able to feed this in a way that MSL likes: //>error: use of undeclared identifier 'lod_clamp' //>error: no member named 'max_anisotropy' in namespace 'metal' if false { if let Some(ref lod) = sampler.lod_clamp { writeln!(self.out, "{}lod_clamp({},{}),", level, lod.start, lod.end,)?; } if let Some(aniso) = sampler.max_anisotropy { writeln!(self.out, "{}max_anisotropy({}),", level, aniso.get(),)?; } } if sampler.compare_func != sm::CompareFunc::Never { writeln!( self.out, "{}{}::compare_func::{},", level, NAMESPACE, sampler.compare_func.as_str(), )?; } writeln!( self.out, "{}{}::coord::{}", level, NAMESPACE, sampler.coord.as_str() )?; Ok(()) } fn write_unpacking_function( &mut self, format: back::msl::VertexFormat, ) -> Result<(String, u32, Option, crate::Scalar), Error> { use crate::{Scalar, VectorSize}; use back::msl::VertexFormat::*; match format { Uint8 => { let name = self.namer.call("unpackUint8"); writeln!(self.out, "uint {name}(metal::uchar b0) {{")?; writeln!(self.out, "{}return uint(b0);", back::INDENT)?; writeln!(self.out, "}}")?; Ok((name, 1, None, Scalar::U32)) } Uint8x2 => { let name = self.namer.call("unpackUint8x2"); writeln!( self.out, "metal::uint2 {name}(metal::uchar b0, \ metal::uchar b1) {{" )?; writeln!(self.out, "{}return metal::uint2(b0, b1);", back::INDENT)?; writeln!(self.out, "}}")?; Ok((name, 2, Some(VectorSize::Bi), Scalar::U32)) } Uint8x4 => { let name = self.namer.call("unpackUint8x4"); writeln!( self.out, "metal::uint4 {name}(metal::uchar b0, \ metal::uchar b1, \ metal::uchar b2, \ metal::uchar b3) {{" )?; writeln!( self.out, "{}return metal::uint4(b0, b1, b2, b3);", back::INDENT )?; writeln!(self.out, "}}")?; Ok((name, 4, Some(VectorSize::Quad), Scalar::U32)) } Sint8 => { let name = self.namer.call("unpackSint8"); writeln!(self.out, "int {name}(metal::uchar b0) {{")?; writeln!(self.out, "{}return int(as_type(b0));", back::INDENT)?; writeln!(self.out, "}}")?; Ok((name, 1, None, Scalar::I32)) } Sint8x2 => { let name = self.namer.call("unpackSint8x2"); writeln!( self.out, "metal::int2 {name}(metal::uchar b0, \ metal::uchar b1) {{" )?; writeln!( self.out, "{}return metal::int2(as_type(b0), \ as_type(b1));", back::INDENT )?; writeln!(self.out, "}}")?; Ok((name, 2, Some(VectorSize::Bi), Scalar::I32)) } Sint8x4 => { let name = self.namer.call("unpackSint8x4"); writeln!( self.out, "metal::int4 {name}(metal::uchar b0, \ metal::uchar b1, \ metal::uchar b2, \ metal::uchar b3) {{" )?; writeln!( self.out, "{}return metal::int4(as_type(b0), \ as_type(b1), \ as_type(b2), \ as_type(b3));", back::INDENT )?; writeln!(self.out, "}}")?; Ok((name, 4, Some(VectorSize::Quad), Scalar::I32)) } Unorm8 => { let name = self.namer.call("unpackUnorm8"); writeln!(self.out, "float {name}(metal::uchar b0) {{")?; writeln!( self.out, "{}return float(float(b0) / 255.0f);", back::INDENT )?; writeln!(self.out, "}}")?; Ok((name, 1, None, Scalar::F32)) } Unorm8x2 => { let name = self.namer.call("unpackUnorm8x2"); writeln!( self.out, "metal::float2 {name}(metal::uchar b0, \ metal::uchar b1) {{" )?; writeln!( self.out, "{}return metal::float2(float(b0) / 255.0f, \ float(b1) / 255.0f);", back::INDENT )?; writeln!(self.out, "}}")?; Ok((name, 2, Some(VectorSize::Bi), Scalar::F32)) } Unorm8x4 => { let name = self.namer.call("unpackUnorm8x4"); writeln!( self.out, "metal::float4 {name}(metal::uchar b0, \ metal::uchar b1, \ metal::uchar b2, \ metal::uchar b3) {{" )?; writeln!( self.out, "{}return metal::float4(float(b0) / 255.0f, \ float(b1) / 255.0f, \ float(b2) / 255.0f, \ float(b3) / 255.0f);", back::INDENT )?; writeln!(self.out, "}}")?; Ok((name, 4, Some(VectorSize::Quad), Scalar::F32)) } Snorm8 => { let name = self.namer.call("unpackSnorm8"); writeln!(self.out, "float {name}(metal::uchar b0) {{")?; writeln!( self.out, "{}return float(metal::max(-1.0f, as_type(b0) / 127.0f));", back::INDENT )?; writeln!(self.out, "}}")?; Ok((name, 1, None, Scalar::F32)) } Snorm8x2 => { let name = self.namer.call("unpackSnorm8x2"); writeln!( self.out, "metal::float2 {name}(metal::uchar b0, \ metal::uchar b1) {{" )?; writeln!( self.out, "{}return metal::float2(metal::max(-1.0f, as_type(b0) / 127.0f), \ metal::max(-1.0f, as_type(b1) / 127.0f));", back::INDENT )?; writeln!(self.out, "}}")?; Ok((name, 2, Some(VectorSize::Bi), Scalar::F32)) } Snorm8x4 => { let name = self.namer.call("unpackSnorm8x4"); writeln!( self.out, "metal::float4 {name}(metal::uchar b0, \ metal::uchar b1, \ metal::uchar b2, \ metal::uchar b3) {{" )?; writeln!( self.out, "{}return metal::float4(metal::max(-1.0f, as_type(b0) / 127.0f), \ metal::max(-1.0f, as_type(b1) / 127.0f), \ metal::max(-1.0f, as_type(b2) / 127.0f), \ metal::max(-1.0f, as_type(b3) / 127.0f));", back::INDENT )?; writeln!(self.out, "}}")?; Ok((name, 4, Some(VectorSize::Quad), Scalar::F32)) } Uint16 => { let name = self.namer.call("unpackUint16"); writeln!( self.out, "metal::uint {name}(metal::uint b0, \ metal::uint b1) {{" )?; writeln!( self.out, "{}return metal::uint(b1 << 8 | b0);", back::INDENT )?; writeln!(self.out, "}}")?; Ok((name, 2, None, Scalar::U32)) } Uint16x2 => { let name = self.namer.call("unpackUint16x2"); writeln!( self.out, "metal::uint2 {name}(metal::uint b0, \ metal::uint b1, \ metal::uint b2, \ metal::uint b3) {{" )?; writeln!( self.out, "{}return metal::uint2(b1 << 8 | b0, \ b3 << 8 | b2);", back::INDENT )?; writeln!(self.out, "}}")?; Ok((name, 4, Some(VectorSize::Bi), Scalar::U32)) } Uint16x4 => { let name = self.namer.call("unpackUint16x4"); writeln!( self.out, "metal::uint4 {name}(metal::uint b0, \ metal::uint b1, \ metal::uint b2, \ metal::uint b3, \ metal::uint b4, \ metal::uint b5, \ metal::uint b6, \ metal::uint b7) {{" )?; writeln!( self.out, "{}return metal::uint4(b1 << 8 | b0, \ b3 << 8 | b2, \ b5 << 8 | b4, \ b7 << 8 | b6);", back::INDENT )?; writeln!(self.out, "}}")?; Ok((name, 8, Some(VectorSize::Quad), Scalar::U32)) } Sint16 => { let name = self.namer.call("unpackSint16"); writeln!( self.out, "int {name}(metal::ushort b0, \ metal::ushort b1) {{" )?; writeln!( self.out, "{}return int(as_type(metal::ushort(b1 << 8 | b0)));", back::INDENT )?; writeln!(self.out, "}}")?; Ok((name, 2, None, Scalar::I32)) } Sint16x2 => { let name = self.namer.call("unpackSint16x2"); writeln!( self.out, "metal::int2 {name}(metal::ushort b0, \ metal::ushort b1, \ metal::ushort b2, \ metal::ushort b3) {{" )?; writeln!( self.out, "{}return metal::int2(as_type(metal::ushort(b1 << 8 | b0)), \ as_type(metal::ushort(b3 << 8 | b2)));", back::INDENT )?; writeln!(self.out, "}}")?; Ok((name, 4, Some(VectorSize::Bi), Scalar::I32)) } Sint16x4 => { let name = self.namer.call("unpackSint16x4"); writeln!( self.out, "metal::int4 {name}(metal::ushort b0, \ metal::ushort b1, \ metal::ushort b2, \ metal::ushort b3, \ metal::ushort b4, \ metal::ushort b5, \ metal::ushort b6, \ metal::ushort b7) {{" )?; writeln!( self.out, "{}return metal::int4(as_type(metal::ushort(b1 << 8 | b0)), \ as_type(metal::ushort(b3 << 8 | b2)), \ as_type(metal::ushort(b5 << 8 | b4)), \ as_type(metal::ushort(b7 << 8 | b6)));", back::INDENT )?; writeln!(self.out, "}}")?; Ok((name, 8, Some(VectorSize::Quad), Scalar::I32)) } Unorm16 => { let name = self.namer.call("unpackUnorm16"); writeln!( self.out, "float {name}(metal::ushort b0, \ metal::ushort b1) {{" )?; writeln!( self.out, "{}return float(float(b1 << 8 | b0) / 65535.0f);", back::INDENT )?; writeln!(self.out, "}}")?; Ok((name, 2, None, Scalar::F32)) } Unorm16x2 => { let name = self.namer.call("unpackUnorm16x2"); writeln!( self.out, "metal::float2 {name}(metal::ushort b0, \ metal::ushort b1, \ metal::ushort b2, \ metal::ushort b3) {{" )?; writeln!( self.out, "{}return metal::float2(float(b1 << 8 | b0) / 65535.0f, \ float(b3 << 8 | b2) / 65535.0f);", back::INDENT )?; writeln!(self.out, "}}")?; Ok((name, 4, Some(VectorSize::Bi), Scalar::F32)) } Unorm16x4 => { let name = self.namer.call("unpackUnorm16x4"); writeln!( self.out, "metal::float4 {name}(metal::ushort b0, \ metal::ushort b1, \ metal::ushort b2, \ metal::ushort b3, \ metal::ushort b4, \ metal::ushort b5, \ metal::ushort b6, \ metal::ushort b7) {{" )?; writeln!( self.out, "{}return metal::float4(float(b1 << 8 | b0) / 65535.0f, \ float(b3 << 8 | b2) / 65535.0f, \ float(b5 << 8 | b4) / 65535.0f, \ float(b7 << 8 | b6) / 65535.0f);", back::INDENT )?; writeln!(self.out, "}}")?; Ok((name, 8, Some(VectorSize::Quad), Scalar::F32)) } Snorm16 => { let name = self.namer.call("unpackSnorm16"); writeln!( self.out, "float {name}(metal::ushort b0, \ metal::ushort b1) {{" )?; writeln!( self.out, "{}return metal::unpack_snorm2x16_to_float(b1 << 8 | b0).x;", back::INDENT )?; writeln!(self.out, "}}")?; Ok((name, 2, None, Scalar::F32)) } Snorm16x2 => { let name = self.namer.call("unpackSnorm16x2"); writeln!( self.out, "metal::float2 {name}(uint b0, \ uint b1, \ uint b2, \ uint b3) {{" )?; writeln!( self.out, "{}return metal::unpack_snorm2x16_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0);", back::INDENT )?; writeln!(self.out, "}}")?; Ok((name, 4, Some(VectorSize::Bi), Scalar::F32)) } Snorm16x4 => { let name = self.namer.call("unpackSnorm16x4"); writeln!( self.out, "metal::float4 {name}(uint b0, \ uint b1, \ uint b2, \ uint b3, \ uint b4, \ uint b5, \ uint b6, \ uint b7) {{" )?; writeln!( self.out, "{}return metal::float4(metal::unpack_snorm2x16_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0), \ metal::unpack_snorm2x16_to_float(b7 << 24 | b6 << 16 | b5 << 8 | b4));", back::INDENT )?; writeln!(self.out, "}}")?; Ok((name, 8, Some(VectorSize::Quad), Scalar::F32)) } Float16 => { let name = self.namer.call("unpackFloat16"); writeln!( self.out, "float {name}(metal::ushort b0, \ metal::ushort b1) {{" )?; writeln!( self.out, "{}return float(as_type(metal::ushort(b1 << 8 | b0)));", back::INDENT )?; writeln!(self.out, "}}")?; Ok((name, 2, None, Scalar::F32)) } Float16x2 => { let name = self.namer.call("unpackFloat16x2"); writeln!( self.out, "metal::float2 {name}(metal::ushort b0, \ metal::ushort b1, \ metal::ushort b2, \ metal::ushort b3) {{" )?; writeln!( self.out, "{}return metal::float2(as_type(metal::ushort(b1 << 8 | b0)), \ as_type(metal::ushort(b3 << 8 | b2)));", back::INDENT )?; writeln!(self.out, "}}")?; Ok((name, 4, Some(VectorSize::Bi), Scalar::F32)) } Float16x4 => { let name = self.namer.call("unpackFloat16x4"); writeln!( self.out, "metal::float4 {name}(metal::ushort b0, \ metal::ushort b1, \ metal::ushort b2, \ metal::ushort b3, \ metal::ushort b4, \ metal::ushort b5, \ metal::ushort b6, \ metal::ushort b7) {{" )?; writeln!( self.out, "{}return metal::float4(as_type(metal::ushort(b1 << 8 | b0)), \ as_type(metal::ushort(b3 << 8 | b2)), \ as_type(metal::ushort(b5 << 8 | b4)), \ as_type(metal::ushort(b7 << 8 | b6)));", back::INDENT )?; writeln!(self.out, "}}")?; Ok((name, 8, Some(VectorSize::Quad), Scalar::F32)) } Float32 => { let name = self.namer.call("unpackFloat32"); writeln!( self.out, "float {name}(uint b0, \ uint b1, \ uint b2, \ uint b3) {{" )?; writeln!( self.out, "{}return as_type(b3 << 24 | b2 << 16 | b1 << 8 | b0);", back::INDENT )?; writeln!(self.out, "}}")?; Ok((name, 4, None, Scalar::F32)) } Float32x2 => { let name = self.namer.call("unpackFloat32x2"); writeln!( self.out, "metal::float2 {name}(uint b0, \ uint b1, \ uint b2, \ uint b3, \ uint b4, \ uint b5, \ uint b6, \ uint b7) {{" )?; writeln!( self.out, "{}return metal::float2(as_type(b3 << 24 | b2 << 16 | b1 << 8 | b0), \ as_type(b7 << 24 | b6 << 16 | b5 << 8 | b4));", back::INDENT )?; writeln!(self.out, "}}")?; Ok((name, 8, Some(VectorSize::Bi), Scalar::F32)) } Float32x3 => { let name = self.namer.call("unpackFloat32x3"); writeln!( self.out, "metal::float3 {name}(uint b0, \ uint b1, \ uint b2, \ uint b3, \ uint b4, \ uint b5, \ uint b6, \ uint b7, \ uint b8, \ uint b9, \ uint b10, \ uint b11) {{" )?; writeln!( self.out, "{}return metal::float3(as_type(b3 << 24 | b2 << 16 | b1 << 8 | b0), \ as_type(b7 << 24 | b6 << 16 | b5 << 8 | b4), \ as_type(b11 << 24 | b10 << 16 | b9 << 8 | b8));", back::INDENT )?; writeln!(self.out, "}}")?; Ok((name, 12, Some(VectorSize::Tri), Scalar::F32)) } Float32x4 => { let name = self.namer.call("unpackFloat32x4"); writeln!( self.out, "metal::float4 {name}(uint b0, \ uint b1, \ uint b2, \ uint b3, \ uint b4, \ uint b5, \ uint b6, \ uint b7, \ uint b8, \ uint b9, \ uint b10, \ uint b11, \ uint b12, \ uint b13, \ uint b14, \ uint b15) {{" )?; writeln!( self.out, "{}return metal::float4(as_type(b3 << 24 | b2 << 16 | b1 << 8 | b0), \ as_type(b7 << 24 | b6 << 16 | b5 << 8 | b4), \ as_type(b11 << 24 | b10 << 16 | b9 << 8 | b8), \ as_type(b15 << 24 | b14 << 16 | b13 << 8 | b12));", back::INDENT )?; writeln!(self.out, "}}")?; Ok((name, 16, Some(VectorSize::Quad), Scalar::F32)) } Uint32 => { let name = self.namer.call("unpackUint32"); writeln!( self.out, "uint {name}(uint b0, \ uint b1, \ uint b2, \ uint b3) {{" )?; writeln!( self.out, "{}return (b3 << 24 | b2 << 16 | b1 << 8 | b0);", back::INDENT )?; writeln!(self.out, "}}")?; Ok((name, 4, None, Scalar::U32)) } Uint32x2 => { let name = self.namer.call("unpackUint32x2"); writeln!( self.out, "uint2 {name}(uint b0, \ uint b1, \ uint b2, \ uint b3, \ uint b4, \ uint b5, \ uint b6, \ uint b7) {{" )?; writeln!( self.out, "{}return uint2((b3 << 24 | b2 << 16 | b1 << 8 | b0), \ (b7 << 24 | b6 << 16 | b5 << 8 | b4));", back::INDENT )?; writeln!(self.out, "}}")?; Ok((name, 8, Some(VectorSize::Bi), Scalar::U32)) } Uint32x3 => { let name = self.namer.call("unpackUint32x3"); writeln!( self.out, "uint3 {name}(uint b0, \ uint b1, \ uint b2, \ uint b3, \ uint b4, \ uint b5, \ uint b6, \ uint b7, \ uint b8, \ uint b9, \ uint b10, \ uint b11) {{" )?; writeln!( self.out, "{}return uint3((b3 << 24 | b2 << 16 | b1 << 8 | b0), \ (b7 << 24 | b6 << 16 | b5 << 8 | b4), \ (b11 << 24 | b10 << 16 | b9 << 8 | b8));", back::INDENT )?; writeln!(self.out, "}}")?; Ok((name, 12, Some(VectorSize::Tri), Scalar::U32)) } Uint32x4 => { let name = self.namer.call("unpackUint32x4"); writeln!( self.out, "{NAMESPACE}::uint4 {name}(uint b0, \ uint b1, \ uint b2, \ uint b3, \ uint b4, \ uint b5, \ uint b6, \ uint b7, \ uint b8, \ uint b9, \ uint b10, \ uint b11, \ uint b12, \ uint b13, \ uint b14, \ uint b15) {{" )?; writeln!( self.out, "{}return {NAMESPACE}::uint4((b3 << 24 | b2 << 16 | b1 << 8 | b0), \ (b7 << 24 | b6 << 16 | b5 << 8 | b4), \ (b11 << 24 | b10 << 16 | b9 << 8 | b8), \ (b15 << 24 | b14 << 16 | b13 << 8 | b12));", back::INDENT )?; writeln!(self.out, "}}")?; Ok((name, 16, Some(VectorSize::Quad), Scalar::U32)) } Sint32 => { let name = self.namer.call("unpackSint32"); writeln!( self.out, "int {name}(uint b0, \ uint b1, \ uint b2, \ uint b3) {{" )?; writeln!( self.out, "{}return as_type(b3 << 24 | b2 << 16 | b1 << 8 | b0);", back::INDENT )?; writeln!(self.out, "}}")?; Ok((name, 4, None, Scalar::I32)) } Sint32x2 => { let name = self.namer.call("unpackSint32x2"); writeln!( self.out, "metal::int2 {name}(uint b0, \ uint b1, \ uint b2, \ uint b3, \ uint b4, \ uint b5, \ uint b6, \ uint b7) {{" )?; writeln!( self.out, "{}return metal::int2(as_type(b3 << 24 | b2 << 16 | b1 << 8 | b0), \ as_type(b7 << 24 | b6 << 16 | b5 << 8 | b4));", back::INDENT )?; writeln!(self.out, "}}")?; Ok((name, 8, Some(VectorSize::Bi), Scalar::I32)) } Sint32x3 => { let name = self.namer.call("unpackSint32x3"); writeln!( self.out, "metal::int3 {name}(uint b0, \ uint b1, \ uint b2, \ uint b3, \ uint b4, \ uint b5, \ uint b6, \ uint b7, \ uint b8, \ uint b9, \ uint b10, \ uint b11) {{" )?; writeln!( self.out, "{}return metal::int3(as_type(b3 << 24 | b2 << 16 | b1 << 8 | b0), \ as_type(b7 << 24 | b6 << 16 | b5 << 8 | b4), \ as_type(b11 << 24 | b10 << 16 | b9 << 8 | b8));", back::INDENT )?; writeln!(self.out, "}}")?; Ok((name, 12, Some(VectorSize::Tri), Scalar::I32)) } Sint32x4 => { let name = self.namer.call("unpackSint32x4"); writeln!( self.out, "metal::int4 {name}(uint b0, \ uint b1, \ uint b2, \ uint b3, \ uint b4, \ uint b5, \ uint b6, \ uint b7, \ uint b8, \ uint b9, \ uint b10, \ uint b11, \ uint b12, \ uint b13, \ uint b14, \ uint b15) {{" )?; writeln!( self.out, "{}return metal::int4(as_type(b3 << 24 | b2 << 16 | b1 << 8 | b0), \ as_type(b7 << 24 | b6 << 16 | b5 << 8 | b4), \ as_type(b11 << 24 | b10 << 16 | b9 << 8 | b8), \ as_type(b15 << 24 | b14 << 16 | b13 << 8 | b12));", back::INDENT )?; writeln!(self.out, "}}")?; Ok((name, 16, Some(VectorSize::Quad), Scalar::I32)) } Unorm10_10_10_2 => { let name = self.namer.call("unpackUnorm10_10_10_2"); writeln!( self.out, "metal::float4 {name}(uint b0, \ uint b1, \ uint b2, \ uint b3) {{" )?; writeln!( self.out, // The following is correct for RGBA packing, but our format seems to // match ABGR, which can be fed into the Metal builtin function // unpack_unorm10a2_to_float. /* "{}uint v = (b3 << 24 | b2 << 16 | b1 << 8 | b0); \ uint r = (v & 0xFFC00000) >> 22; \ uint g = (v & 0x003FF000) >> 12; \ uint b = (v & 0x00000FFC) >> 2; \ uint a = (v & 0x00000003); \ return metal::float4(float(r) / 1023.0f, float(g) / 1023.0f, float(b) / 1023.0f, float(a) / 3.0f);", */ "{}return metal::unpack_unorm10a2_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0);", back::INDENT )?; writeln!(self.out, "}}")?; Ok((name, 4, Some(VectorSize::Quad), Scalar::F32)) } Unorm8x4Bgra => { let name = self.namer.call("unpackUnorm8x4Bgra"); writeln!( self.out, "metal::float4 {name}(metal::uchar b0, \ metal::uchar b1, \ metal::uchar b2, \ metal::uchar b3) {{" )?; writeln!( self.out, "{}return metal::float4(float(b2) / 255.0f, \ float(b1) / 255.0f, \ float(b0) / 255.0f, \ float(b3) / 255.0f);", back::INDENT )?; writeln!(self.out, "}}")?; Ok((name, 4, Some(VectorSize::Quad), Scalar::F32)) } } } fn write_wrapped_unary_op( &mut self, module: &crate::Module, func_ctx: &back::FunctionCtx, op: crate::UnaryOperator, operand: Handle, ) -> BackendResult { let operand_ty = func_ctx.resolve_type(operand, &module.types); match op { // Negating the TYPE_MIN of a two's complement signed integer // type causes overflow, which is undefined behaviour in MSL. To // avoid this we bitcast the value to unsigned and negate it, // then bitcast back to signed. // This adheres to the WGSL spec in that the negative of the // type's minimum value should equal to the minimum value. crate::UnaryOperator::Negate if operand_ty.scalar_kind() == Some(crate::ScalarKind::Sint) => { let Some((vector_size, scalar)) = operand_ty.vector_size_and_scalar() else { return Ok(()); }; let wrapped = WrappedFunction::UnaryOp { op, ty: (vector_size, scalar), }; if !self.wrapped_functions.insert(wrapped) { return Ok(()); } let unsigned_scalar = crate::Scalar { kind: crate::ScalarKind::Uint, ..scalar }; let mut type_name = String::new(); let mut unsigned_type_name = String::new(); match vector_size { None => { put_numeric_type(&mut type_name, scalar, &[])?; put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[])? } Some(size) => { put_numeric_type(&mut type_name, scalar, &[size])?; put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[size])?; } }; writeln!(self.out, "{type_name} {NEG_FUNCTION}({type_name} val) {{")?; let level = back::Level(1); writeln!( self.out, "{level}return as_type<{type_name}>(-as_type<{unsigned_type_name}>(val));" )?; writeln!(self.out, "}}")?; writeln!(self.out)?; } _ => {} } Ok(()) } fn write_wrapped_binary_op( &mut self, module: &crate::Module, func_ctx: &back::FunctionCtx, expr: Handle, op: crate::BinaryOperator, left: Handle, right: Handle, ) -> BackendResult { let expr_ty = func_ctx.resolve_type(expr, &module.types); let left_ty = func_ctx.resolve_type(left, &module.types); let right_ty = func_ctx.resolve_type(right, &module.types); match (op, expr_ty.scalar_kind()) { // Signed integer division of TYPE_MIN / -1, or signed or // unsigned division by zero, gives an unspecified value in MSL. // We override the divisor to 1 in these cases. // This adheres to the WGSL spec in that: // * TYPE_MIN / -1 == TYPE_MIN // * x / 0 == x ( crate::BinaryOperator::Divide, Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint), ) => { let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else { return Ok(()); }; let Some(right_wrapped_ty) = right_ty.vector_size_and_scalar() else { return Ok(()); }; let wrapped = WrappedFunction::BinaryOp { op, left_ty: left_wrapped_ty, right_ty: right_wrapped_ty, }; if !self.wrapped_functions.insert(wrapped) { return Ok(()); } let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else { return Ok(()); }; let mut type_name = String::new(); match vector_size { None => put_numeric_type(&mut type_name, scalar, &[])?, Some(size) => put_numeric_type(&mut type_name, scalar, &[size])?, }; writeln!( self.out, "{type_name} {DIV_FUNCTION}({type_name} lhs, {type_name} rhs) {{" )?; let level = back::Level(1); match scalar.kind { crate::ScalarKind::Sint => { let min_val = match scalar.width { 4 => crate::Literal::I32(i32::MIN), 8 => crate::Literal::I64(i64::MIN), _ => { return Err(Error::GenericValidation(format!( "Unexpected width for scalar {scalar:?}" ))); } }; write!( self.out, "{level}return lhs / metal::select(rhs, 1, (lhs == " )?; self.put_literal(min_val)?; writeln!(self.out, " & rhs == -1) | (rhs == 0));")? } crate::ScalarKind::Uint => writeln!( self.out, "{level}return lhs / metal::select(rhs, 1u, rhs == 0u);" )?, _ => unreachable!(), } writeln!(self.out, "}}")?; writeln!(self.out)?; } // Integer modulo where one or both operands are negative, or the // divisor is zero, is undefined behaviour in MSL. To avoid this // we use the following equation: // // dividend - (dividend / divisor) * divisor // // overriding the divisor to 1 if either it is 0, or it is -1 // and the dividend is TYPE_MIN. // // This adheres to the WGSL spec in that: // * TYPE_MIN % -1 == 0 // * x % 0 == 0 ( crate::BinaryOperator::Modulo, Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint), ) => { let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else { return Ok(()); }; let Some((right_vector_size, right_scalar)) = right_ty.vector_size_and_scalar() else { return Ok(()); }; let wrapped = WrappedFunction::BinaryOp { op, left_ty: left_wrapped_ty, right_ty: (right_vector_size, right_scalar), }; if !self.wrapped_functions.insert(wrapped) { return Ok(()); } let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else { return Ok(()); }; let mut type_name = String::new(); match vector_size { None => put_numeric_type(&mut type_name, scalar, &[])?, Some(size) => put_numeric_type(&mut type_name, scalar, &[size])?, }; let mut rhs_type_name = String::new(); match right_vector_size { None => put_numeric_type(&mut rhs_type_name, right_scalar, &[])?, Some(size) => put_numeric_type(&mut rhs_type_name, right_scalar, &[size])?, }; writeln!( self.out, "{type_name} {MOD_FUNCTION}({type_name} lhs, {type_name} rhs) {{" )?; let level = back::Level(1); match scalar.kind { crate::ScalarKind::Sint => { let min_val = match scalar.width { 4 => crate::Literal::I32(i32::MIN), 8 => crate::Literal::I64(i64::MIN), _ => { return Err(Error::GenericValidation(format!( "Unexpected width for scalar {scalar:?}" ))); } }; write!( self.out, "{level}{rhs_type_name} divisor = metal::select(rhs, 1, (lhs == " )?; self.put_literal(min_val)?; writeln!(self.out, " & rhs == -1) | (rhs == 0));")?; writeln!(self.out, "{level}return lhs - (lhs / divisor) * divisor;")? } crate::ScalarKind::Uint => writeln!( self.out, "{level}return lhs % metal::select(rhs, 1u, rhs == 0u);" )?, _ => unreachable!(), } writeln!(self.out, "}}")?; writeln!(self.out)?; } _ => {} } Ok(()) } /// Build the mangled helper name for integer vector dot products. /// /// `scalar` must be a concrete integer scalar type. /// /// Result format: `{DOT_FUNCTION_PREFIX}_{type}{N}` (e.g., `naga_dot_int3`). fn get_dot_wrapper_function_helper_name( &self, scalar: crate::Scalar, size: crate::VectorSize, ) -> String { // Check for consistency with [`super::keywords::RESERVED_SET`] debug_assert!(concrete_int_scalars().any(|s| s == scalar)); let type_name = scalar.to_msl_name(); let size_suffix = common::vector_size_str(size); format!("{DOT_FUNCTION_PREFIX}_{type_name}{size_suffix}") } #[allow(clippy::too_many_arguments)] fn write_wrapped_math_function( &mut self, module: &crate::Module, func_ctx: &back::FunctionCtx, fun: crate::MathFunction, arg: Handle, _arg1: Option>, _arg2: Option>, _arg3: Option>, ) -> BackendResult { let arg_ty = func_ctx.resolve_type(arg, &module.types); match fun { // Taking the absolute value of the TYPE_MIN of a two's // complement signed integer type causes overflow, which is // undefined behaviour in MSL. To avoid this, when the value is // negative we bitcast the value to unsigned and negate it, then // bitcast back to signed. // This adheres to the WGSL spec in that the absolute of the // type's minimum value should equal to the minimum value. crate::MathFunction::Abs if arg_ty.scalar_kind() == Some(crate::ScalarKind::Sint) => { let Some((vector_size, scalar)) = arg_ty.vector_size_and_scalar() else { return Ok(()); }; let wrapped = WrappedFunction::Math { fun, arg_ty: (vector_size, scalar), }; if !self.wrapped_functions.insert(wrapped) { return Ok(()); } let unsigned_scalar = crate::Scalar { kind: crate::ScalarKind::Uint, ..scalar }; let mut type_name = String::new(); let mut unsigned_type_name = String::new(); match vector_size { None => { put_numeric_type(&mut type_name, scalar, &[])?; put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[])? } Some(size) => { put_numeric_type(&mut type_name, scalar, &[size])?; put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[size])?; } }; writeln!(self.out, "{type_name} {ABS_FUNCTION}({type_name} val) {{")?; let level = back::Level(1); writeln!(self.out, "{level}return metal::select(as_type<{type_name}>(-as_type<{unsigned_type_name}>(val)), val, val >= 0);")?; writeln!(self.out, "}}")?; writeln!(self.out)?; } crate::MathFunction::Dot => match *arg_ty { crate::TypeInner::Vector { size, scalar } if matches!( scalar.kind, crate::ScalarKind::Sint | crate::ScalarKind::Uint ) => { // De-duplicate per (fun, arg type) like other wrapped math functions let wrapped = WrappedFunction::Math { fun, arg_ty: (Some(size), scalar), }; if !self.wrapped_functions.insert(wrapped) { return Ok(()); } let mut vec_ty = String::new(); put_numeric_type(&mut vec_ty, scalar, &[size])?; let mut ret_ty = String::new(); put_numeric_type(&mut ret_ty, scalar, &[])?; let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size); // Emit function signature and body using put_dot_product for the expression writeln!(self.out, "{ret_ty} {fun_name}({vec_ty} a, {vec_ty} b) {{")?; let level = back::Level(1); write!(self.out, "{level}return ")?; self.put_dot_product("a", "b", size as usize, |writer, name, index| { write!(writer.out, "{name}.{}", back::COMPONENTS[index])?; Ok(()) })?; writeln!(self.out, ";")?; writeln!(self.out, "}}")?; writeln!(self.out)?; } _ => {} }, _ => {} } Ok(()) } fn write_wrapped_cast( &mut self, module: &crate::Module, func_ctx: &back::FunctionCtx, expr: Handle, kind: crate::ScalarKind, convert: Option, ) -> BackendResult { // Avoid undefined behaviour when casting from a float to integer // when the value is out of range for the target type. Additionally // ensure we clamp to the correct value as per the WGSL spec. // // https://www.w3.org/TR/WGSL/#floating-point-conversion: // * If X is exactly representable in the target type T, then the // result is that value. // * Otherwise, the result is the value in T closest to // truncate(X) and also exactly representable in the original // floating point type. let src_ty = func_ctx.resolve_type(expr, &module.types); let Some(width) = convert else { return Ok(()); }; let Some((vector_size, src_scalar)) = src_ty.vector_size_and_scalar() else { return Ok(()); }; let dst_scalar = crate::Scalar { kind, width }; if src_scalar.kind != crate::ScalarKind::Float || (dst_scalar.kind != crate::ScalarKind::Sint && dst_scalar.kind != crate::ScalarKind::Uint) { return Ok(()); } let wrapped = WrappedFunction::Cast { src_scalar, vector_size, dst_scalar, }; if !self.wrapped_functions.insert(wrapped) { return Ok(()); } let (min, max) = proc::min_max_float_representable_by(src_scalar, dst_scalar); let mut src_type_name = String::new(); match vector_size { None => put_numeric_type(&mut src_type_name, src_scalar, &[])?, Some(size) => put_numeric_type(&mut src_type_name, src_scalar, &[size])?, }; let mut dst_type_name = String::new(); match vector_size { None => put_numeric_type(&mut dst_type_name, dst_scalar, &[])?, Some(size) => put_numeric_type(&mut dst_type_name, dst_scalar, &[size])?, }; let fun_name = match dst_scalar { crate::Scalar::I32 => F2I32_FUNCTION, crate::Scalar::U32 => F2U32_FUNCTION, crate::Scalar::I64 => F2I64_FUNCTION, crate::Scalar::U64 => F2U64_FUNCTION, _ => unreachable!(), }; writeln!( self.out, "{dst_type_name} {fun_name}({src_type_name} value) {{" )?; let level = back::Level(1); write!( self.out, "{level}return static_cast<{dst_type_name}>({NAMESPACE}::clamp(value, " )?; self.put_literal(min)?; write!(self.out, ", ")?; self.put_literal(max)?; writeln!(self.out, "));")?; writeln!(self.out, "}}")?; writeln!(self.out)?; Ok(()) } /// Helper function used by [`Self::write_wrapped_image_load`] and /// [`Self::write_wrapped_image_sample`] to write the shared YUV to RGB /// conversion code for external textures. Expects the preceding code to /// declare the Y component as a `float` variable of name `y`, the UV /// components as a `float2` variable of name `uv`, and the external /// texture params as a variable of name `params`. The emitted code will /// return the result. fn write_convert_yuv_to_rgb_and_return( &mut self, level: back::Level, y: &str, uv: &str, params: &str, ) -> BackendResult { let l1 = level; let l2 = l1.next(); // Convert from YUV to non-linear RGB in the source color space. writeln!( self.out, "{l1}float3 srcGammaRgb = ({params}.yuv_conversion_matrix * float4({y}, {uv}, 1.0)).rgb;" )?; // Apply the inverse of the source transfer function to convert to // linear RGB in the source color space. writeln!(self.out, "{l1}float3 srcLinearRgb = {NAMESPACE}::select(")?; writeln!(self.out, "{l2}{NAMESPACE}::pow((srcGammaRgb + {params}.src_tf.a - 1.0) / {params}.src_tf.a, {params}.src_tf.g),")?; writeln!(self.out, "{l2}srcGammaRgb / {params}.src_tf.k,")?; writeln!( self.out, "{l2}srcGammaRgb < {params}.src_tf.k * {params}.src_tf.b);" )?; // Multiply by the gamut conversion matrix to convert to linear RGB in // the destination color space. writeln!( self.out, "{l1}float3 dstLinearRgb = {params}.gamut_conversion_matrix * srcLinearRgb;" )?; // Finally, apply the dest transfer function to convert to non-linear // RGB in the destination color space, and return the result. writeln!(self.out, "{l1}float3 dstGammaRgb = {NAMESPACE}::select(")?; writeln!(self.out, "{l2}{params}.dst_tf.a * {NAMESPACE}::pow(dstLinearRgb, 1.0 / {params}.dst_tf.g) - ({params}.dst_tf.a - 1),")?; writeln!(self.out, "{l2}{params}.dst_tf.k * dstLinearRgb,")?; writeln!(self.out, "{l2}dstLinearRgb < {params}.dst_tf.b);")?; writeln!(self.out, "{l1}return float4(dstGammaRgb, 1.0);")?; Ok(()) } #[allow(clippy::too_many_arguments)] fn write_wrapped_image_load( &mut self, module: &crate::Module, func_ctx: &back::FunctionCtx, image: Handle, _coordinate: Handle, _array_index: Option>, _sample: Option>, _level: Option>, ) -> BackendResult { // We currently only need to wrap image loads for external textures let class = match *func_ctx.resolve_type(image, &module.types) { crate::TypeInner::Image { class, .. } => class, _ => unreachable!(), }; if class != crate::ImageClass::External { return Ok(()); } let wrapped = WrappedFunction::ImageLoad { class }; if !self.wrapped_functions.insert(wrapped) { return Ok(()); } writeln!(self.out, "float4 {IMAGE_LOAD_EXTERNAL_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex, uint2 coords) {{")?; let l1 = back::Level(1); let l2 = l1.next(); let l3 = l2.next(); writeln!( self.out, "{l1}uint2 plane0_size = uint2(tex.plane0.get_width(), tex.plane0.get_height());" )?; // Clamp coords to provided size of external texture to prevent OOB // read. If params.size is zero then clamp to the actual size of the // texture. writeln!( self.out, "{l1}uint2 cropped_size = {NAMESPACE}::any(tex.params.size != 0) ? tex.params.size : plane0_size;" )?; writeln!( self.out, "{l1}coords = {NAMESPACE}::min(coords, cropped_size - 1);" )?; // Apply load transformation writeln!(self.out, "{l1}uint2 plane0_coords = uint2({NAMESPACE}::round(tex.params.load_transform * float3(float2(coords), 1.0)));")?; writeln!(self.out, "{l1}if (tex.params.num_planes == 1u) {{")?; // For single plane, simply read from plane0 writeln!(self.out, "{l2}return tex.plane0.read(plane0_coords);")?; writeln!(self.out, "{l1}}} else {{")?; // Chroma planes may be subsampled so we must scale the coords accordingly. writeln!( self.out, "{l2}uint2 plane1_size = uint2(tex.plane1.get_width(), tex.plane1.get_height());" )?; writeln!(self.out, "{l2}uint2 plane1_coords = uint2({NAMESPACE}::floor(float2(plane0_coords) * float2(plane1_size) / float2(plane0_size)));")?; // For multi-plane, read the Y value from plane 0 writeln!(self.out, "{l2}float y = tex.plane0.read(plane0_coords).x;")?; writeln!(self.out, "{l2}float2 uv;")?; writeln!(self.out, "{l2}if (tex.params.num_planes == 2u) {{")?; // For 2 planes, read UV from interleaved plane 1 writeln!(self.out, "{l3}uv = tex.plane1.read(plane1_coords).xy;")?; writeln!(self.out, "{l2}}} else {{")?; // For 3 planes, read U and V from planes 1 and 2 respectively writeln!( self.out, "{l2}uint2 plane2_size = uint2(tex.plane2.get_width(), tex.plane2.get_height());" )?; writeln!(self.out, "{l2}uint2 plane2_coords = uint2({NAMESPACE}::floor(float2(plane0_coords) * float2(plane2_size) / float2(plane0_size)));")?; writeln!( self.out, "{l3}uv = float2(tex.plane1.read(plane1_coords).x, tex.plane2.read(plane2_coords).x);" )?; writeln!(self.out, "{l2}}}")?; self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "tex.params")?; writeln!(self.out, "{l1}}}")?; writeln!(self.out, "}}")?; writeln!(self.out)?; Ok(()) } #[allow(clippy::too_many_arguments)] fn write_wrapped_image_sample( &mut self, module: &crate::Module, func_ctx: &back::FunctionCtx, image: Handle, _sampler: Handle, _gather: Option, _coordinate: Handle, _array_index: Option>, _offset: Option>, _level: crate::SampleLevel, _depth_ref: Option>, clamp_to_edge: bool, ) -> BackendResult { // We currently only need to wrap textureSampleBaseClampToEdge, for // both sampled and external textures. if !clamp_to_edge { return Ok(()); } let class = match *func_ctx.resolve_type(image, &module.types) { crate::TypeInner::Image { class, .. } => class, _ => unreachable!(), }; let wrapped = WrappedFunction::ImageSample { class, clamp_to_edge: true, }; if !self.wrapped_functions.insert(wrapped) { return Ok(()); } match class { crate::ImageClass::External => { writeln!(self.out, "float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex, {NAMESPACE}::sampler samp, float2 coords) {{")?; let l1 = back::Level(1); let l2 = l1.next(); let l3 = l2.next(); writeln!(self.out, "{l1}uint2 plane0_size = uint2(tex.plane0.get_width(), tex.plane0.get_height());")?; writeln!( self.out, "{l1}coords = tex.params.sample_transform * float3(coords, 1.0);" )?; // Calculate the sample bounds. The purported size of the texture // (params.size) is irrelevant here as we are dealing with normalized // coordinates. Usually we would clamp to (0,0)..(1,1). However, we must // apply the sample transformation to that, also bearing in mind that it // may contain a flip on either axis. We calculate and adjust for the // half-texel separately for each plane as it depends on the actual // texture size which may vary between planes. writeln!( self.out, "{l1}float2 bounds_min = tex.params.sample_transform * float3(0.0, 0.0, 1.0);" )?; writeln!( self.out, "{l1}float2 bounds_max = tex.params.sample_transform * float3(1.0, 1.0, 1.0);" )?; writeln!(self.out, "{l1}float4 bounds = float4({NAMESPACE}::min(bounds_min, bounds_max), {NAMESPACE}::max(bounds_min, bounds_max));")?; writeln!( self.out, "{l1}float2 plane0_half_texel = float2(0.5, 0.5) / float2(plane0_size);" )?; writeln!( self.out, "{l1}float2 plane0_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane0_half_texel, bounds.zw - plane0_half_texel);" )?; writeln!(self.out, "{l1}if (tex.params.num_planes == 1u) {{")?; // For single plane, simply sample from plane0 writeln!( self.out, "{l2}return tex.plane0.sample(samp, plane0_coords, {NAMESPACE}::level(0.0f));" )?; writeln!(self.out, "{l1}}} else {{")?; writeln!(self.out, "{l2}uint2 plane1_size = uint2(tex.plane1.get_width(), tex.plane1.get_height());")?; writeln!( self.out, "{l2}float2 plane1_half_texel = float2(0.5, 0.5) / float2(plane1_size);" )?; writeln!( self.out, "{l2}float2 plane1_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane1_half_texel, bounds.zw - plane1_half_texel);" )?; // For multi-plane, sample the Y value from plane 0 writeln!( self.out, "{l2}float y = tex.plane0.sample(samp, plane0_coords, {NAMESPACE}::level(0.0f)).r;" )?; writeln!(self.out, "{l2}float2 uv = float2(0.0, 0.0);")?; writeln!(self.out, "{l2}if (tex.params.num_planes == 2u) {{")?; // For 2 planes, sample UV from interleaved plane 1 writeln!( self.out, "{l3}uv = tex.plane1.sample(samp, plane1_coords, {NAMESPACE}::level(0.0f)).xy;" )?; writeln!(self.out, "{l2}}} else {{")?; // For 3 planes, sample U and V from planes 1 and 2 respectively writeln!(self.out, "{l3}uint2 plane2_size = uint2(tex.plane2.get_width(), tex.plane2.get_height());")?; writeln!( self.out, "{l3}float2 plane2_half_texel = float2(0.5, 0.5) / float2(plane2_size);" )?; writeln!( self.out, "{l3}float2 plane2_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane2_half_texel, bounds.zw - plane1_half_texel);" )?; writeln!(self.out, "{l3}uv.x = tex.plane1.sample(samp, plane1_coords, {NAMESPACE}::level(0.0f)).x;")?; writeln!(self.out, "{l3}uv.y = tex.plane2.sample(samp, plane2_coords, {NAMESPACE}::level(0.0f)).x;")?; writeln!(self.out, "{l2}}}")?; self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "tex.params")?; writeln!(self.out, "{l1}}}")?; writeln!(self.out, "}}")?; writeln!(self.out)?; } _ => { writeln!(self.out, "{NAMESPACE}::float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}({NAMESPACE}::texture2d tex, {NAMESPACE}::sampler samp, {NAMESPACE}::float2 coords) {{")?; let l1 = back::Level(1); writeln!(self.out, "{l1}{NAMESPACE}::float2 half_texel = 0.5 / {NAMESPACE}::float2(tex.get_width(0u), tex.get_height(0u));")?; writeln!( self.out, "{l1}return tex.sample(samp, {NAMESPACE}::clamp(coords, half_texel, 1.0 - half_texel), {NAMESPACE}::level(0.0));" )?; writeln!(self.out, "}}")?; writeln!(self.out)?; } } Ok(()) } fn write_wrapped_image_query( &mut self, module: &crate::Module, func_ctx: &back::FunctionCtx, image: Handle, query: crate::ImageQuery, ) -> BackendResult { // We currently only need to wrap size image queries for external textures if !matches!(query, crate::ImageQuery::Size { .. }) { return Ok(()); } let class = match *func_ctx.resolve_type(image, &module.types) { crate::TypeInner::Image { class, .. } => class, _ => unreachable!(), }; if class != crate::ImageClass::External { return Ok(()); } let wrapped = WrappedFunction::ImageQuerySize { class }; if !self.wrapped_functions.insert(wrapped) { return Ok(()); } writeln!( self.out, "uint2 {IMAGE_SIZE_EXTERNAL_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex) {{" )?; let l1 = back::Level(1); let l2 = l1.next(); writeln!( self.out, "{l1}if ({NAMESPACE}::any(tex.params.size != uint2(0u))) {{" )?; writeln!(self.out, "{l2}return tex.params.size;")?; writeln!(self.out, "{l1}}} else {{")?; // params.size == (0, 0) indicates to query and return plane 0's actual size writeln!( self.out, "{l2}return uint2(tex.plane0.get_width(), tex.plane0.get_height());" )?; writeln!(self.out, "{l1}}}")?; writeln!(self.out, "}}")?; writeln!(self.out)?; Ok(()) } fn write_wrapped_cooperative_load( &mut self, module: &crate::Module, func_ctx: &back::FunctionCtx, columns: crate::CooperativeSize, rows: crate::CooperativeSize, pointer: Handle, ) -> BackendResult { let ptr_ty = func_ctx.resolve_type(pointer, &module.types); let space = ptr_ty.pointer_space().unwrap(); let space_name = space.to_msl_name().unwrap_or_default(); let scalar = ptr_ty .pointer_base_type() .unwrap() .inner_with(&module.types) .scalar() .unwrap(); let wrapped = WrappedFunction::CooperativeLoad { space_name, columns, rows, scalar, }; if !self.wrapped_functions.insert(wrapped) { return Ok(()); } let scalar_name = scalar.to_msl_name(); writeln!( self.out, "{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_LOAD_FUNCTION}(const {space_name} {scalar_name}* ptr, int stride, bool is_row_major) {{", columns as u32, rows as u32, )?; let l1 = back::Level(1); writeln!( self.out, "{l1}{NAMESPACE}::simdgroup_{scalar_name}{}x{} m;", columns as u32, rows as u32 )?; let matrix_origin = "0"; writeln!( self.out, "{l1}simdgroup_load(m, ptr, stride, {matrix_origin}, is_row_major);" )?; writeln!(self.out, "{l1}return m;")?; writeln!(self.out, "}}")?; writeln!(self.out)?; Ok(()) } fn write_wrapped_cooperative_multiply_add( &mut self, module: &crate::Module, func_ctx: &back::FunctionCtx, space: crate::AddressSpace, a: Handle, b: Handle, ) -> BackendResult { let space_name = space.to_msl_name().unwrap_or_default(); let (a_c, a_r, scalar) = match *func_ctx.resolve_type(a, &module.types) { crate::TypeInner::CooperativeMatrix { columns, rows, scalar, .. } => (columns, rows, scalar), _ => unreachable!(), }; let (b_c, b_r) = match *func_ctx.resolve_type(b, &module.types) { crate::TypeInner::CooperativeMatrix { columns, rows, .. } => (columns, rows), _ => unreachable!(), }; let wrapped = WrappedFunction::CooperativeMultiplyAdd { space_name, columns: b_c, rows: a_r, intermediate: a_c, scalar, }; if !self.wrapped_functions.insert(wrapped) { return Ok(()); } let scalar_name = scalar.to_msl_name(); writeln!( self.out, "{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_MULTIPLY_ADD_FUNCTION}(const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& a, const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& b, const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& c) {{", b_c as u32, a_r as u32, a_c as u32, a_r as u32, b_c as u32, b_r as u32, b_c as u32, a_r as u32, )?; let l1 = back::Level(1); writeln!( self.out, "{l1}{NAMESPACE}::simdgroup_{scalar_name}{}x{} d;", b_c as u32, a_r as u32 )?; writeln!(self.out, "{l1}simdgroup_multiply_accumulate(d,a,b,c);")?; writeln!(self.out, "{l1}return d;")?; writeln!(self.out, "}}")?; writeln!(self.out)?; Ok(()) } pub(super) fn write_wrapped_functions( &mut self, module: &crate::Module, func_ctx: &back::FunctionCtx, ) -> BackendResult { for (expr_handle, expr) in func_ctx.expressions.iter() { match *expr { crate::Expression::Unary { op, expr: operand } => { self.write_wrapped_unary_op(module, func_ctx, op, operand)?; } crate::Expression::Binary { op, left, right } => { self.write_wrapped_binary_op(module, func_ctx, expr_handle, op, left, right)?; } crate::Expression::Math { fun, arg, arg1, arg2, arg3, } => { self.write_wrapped_math_function(module, func_ctx, fun, arg, arg1, arg2, arg3)?; } crate::Expression::As { expr, kind, convert, } => { self.write_wrapped_cast(module, func_ctx, expr, kind, convert)?; } crate::Expression::ImageLoad { image, coordinate, array_index, sample, level, } => { self.write_wrapped_image_load( module, func_ctx, image, coordinate, array_index, sample, level, )?; } crate::Expression::ImageSample { image, sampler, gather, coordinate, array_index, offset, level, depth_ref, clamp_to_edge, } => { self.write_wrapped_image_sample( module, func_ctx, image, sampler, gather, coordinate, array_index, offset, level, depth_ref, clamp_to_edge, )?; } crate::Expression::ImageQuery { image, query } => { self.write_wrapped_image_query(module, func_ctx, image, query)?; } crate::Expression::CooperativeLoad { columns, rows, role: _, ref data, } => { self.write_wrapped_cooperative_load( module, func_ctx, columns, rows, data.pointer, )?; } crate::Expression::CooperativeMultiplyAdd { a, b, c: _ } => { let space = crate::AddressSpace::Private; self.write_wrapped_cooperative_multiply_add(module, func_ctx, space, a, b)?; } _ => {} } } Ok(()) } // Returns the array of mapped entry point names. fn write_functions( &mut self, module: &crate::Module, mod_info: &valid::ModuleInfo, options: &Options, pipeline_options: &PipelineOptions, ) -> Result { use back::msl::VertexFormat; // Define structs to hold resolved/generated data for vertex buffers and // their attributes. struct AttributeMappingResolved { ty_name: String, dimension: Option, scalar: crate::Scalar, name: String, } let mut am_resolved = FastHashMap::::default(); struct VertexBufferMappingResolved<'a> { id: u32, stride: u32, step_mode: back::msl::VertexBufferStepMode, ty_name: String, param_name: String, elem_name: String, attributes: &'a Vec, } let mut vbm_resolved = Vec::::new(); // Define a struct to hold a named reference to a byte-unpacking function. struct UnpackingFunction { name: String, byte_count: u32, dimension: Option, scalar: crate::Scalar, } let mut unpacking_functions = FastHashMap::::default(); // Check if we are attempting vertex pulling. If we are, generate some // names we'll need, and iterate the vertex buffer mappings to output // all the conversion functions we'll need to unpack the attribute data. // We can re-use these names for all entry points that need them, since // those entry points also use self.namer. let mut needs_vertex_id = false; let v_id = self.namer.call("v_id"); let mut needs_instance_id = false; let i_id = self.namer.call("i_id"); if pipeline_options.vertex_pulling_transform { for vbm in &pipeline_options.vertex_buffer_mappings { let buffer_id = vbm.id; let buffer_stride = vbm.stride; assert!( buffer_stride > 0, "Vertex pulling requires a non-zero buffer stride." ); match vbm.step_mode { back::msl::VertexBufferStepMode::Constant => {} back::msl::VertexBufferStepMode::ByVertex => { needs_vertex_id = true; } back::msl::VertexBufferStepMode::ByInstance => { needs_instance_id = true; } } let buffer_ty = self.namer.call(format!("vb_{buffer_id}_type").as_str()); let buffer_param = self.namer.call(format!("vb_{buffer_id}_in").as_str()); let buffer_elem = self.namer.call(format!("vb_{buffer_id}_elem").as_str()); vbm_resolved.push(VertexBufferMappingResolved { id: buffer_id, stride: buffer_stride, step_mode: vbm.step_mode, ty_name: buffer_ty, param_name: buffer_param, elem_name: buffer_elem, attributes: &vbm.attributes, }); // Iterate the attributes and generate needed unpacking functions. for attribute in &vbm.attributes { if unpacking_functions.contains_key(&attribute.format) { continue; } let (name, byte_count, dimension, scalar) = match self.write_unpacking_function(attribute.format) { Ok((name, byte_count, dimension, scalar)) => { (name, byte_count, dimension, scalar) } _ => { continue; } }; unpacking_functions.insert( attribute.format, UnpackingFunction { name, byte_count, dimension, scalar, }, ); } } } let mut pass_through_globals = Vec::new(); for (fun_handle, fun) in module.functions.iter() { log::trace!( "function {:?}, handle {:?}", fun.name.as_deref().unwrap_or("(anonymous)"), fun_handle ); let ctx = back::FunctionCtx { ty: back::FunctionType::Function(fun_handle), info: &mod_info[fun_handle], expressions: &fun.expressions, named_expressions: &fun.named_expressions, }; writeln!(self.out)?; self.write_wrapped_functions(module, &ctx)?; let fun_info = &mod_info[fun_handle]; pass_through_globals.clear(); let mut needs_buffer_sizes = false; for (handle, var) in module.global_variables.iter() { if !fun_info[handle].is_empty() { if var.space.needs_pass_through() { pass_through_globals.push(handle); } needs_buffer_sizes |= needs_array_length(var.ty, &module.types); } } let fun_name = &self.names[&NameKey::Function(fun_handle)]; match fun.result { Some(ref result) => { let ty_name = TypeContext { handle: result.ty, gctx: module.to_ctx(), names: &self.names, access: crate::StorageAccess::empty(), first_time: false, }; write!(self.out, "{ty_name}")?; } None => { write!(self.out, "void")?; } } writeln!(self.out, " {fun_name}(")?; for (index, arg) in fun.arguments.iter().enumerate() { let name = &self.names[&NameKey::FunctionArgument(fun_handle, index as u32)]; let param_type_name = TypeContext { handle: arg.ty, gctx: module.to_ctx(), names: &self.names, access: crate::StorageAccess::empty(), first_time: false, }; let separator = separate( !pass_through_globals.is_empty() || index + 1 != fun.arguments.len() || needs_buffer_sizes, ); writeln!( self.out, "{}{} {}{}", back::INDENT, param_type_name, name, separator )?; } for (index, &handle) in pass_through_globals.iter().enumerate() { let tyvar = TypedGlobalVariable { module, names: &self.names, handle, usage: fun_info[handle], reference: true, }; let separator = separate(index + 1 != pass_through_globals.len() || needs_buffer_sizes); write!(self.out, "{}", back::INDENT)?; tyvar.try_fmt(&mut self.out)?; writeln!(self.out, "{separator}")?; } if needs_buffer_sizes { writeln!( self.out, "{}constant _mslBufferSizes& _buffer_sizes", back::INDENT )?; } writeln!(self.out, ") {{")?; let guarded_indices = index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies); let context = StatementContext { expression: ExpressionContext { function: fun, origin: FunctionOrigin::Handle(fun_handle), info: fun_info, lang_version: options.lang_version, policies: options.bounds_check_policies, guarded_indices, module, mod_info, pipeline_options, force_loop_bounding: options.force_loop_bounding, }, result_struct: None, }; self.put_locals(&context.expression)?; self.update_expressions_to_bake(fun, fun_info, &context.expression); self.put_block(back::Level(1), &fun.body, &context)?; writeln!(self.out, "}}")?; self.named_expressions.clear(); } let ep_range = get_entry_points(module, pipeline_options.entry_point.as_ref()) .map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?; let mut info = TranslationInfo { entry_point_names: Vec::with_capacity(ep_range.len()), }; for ep_index in ep_range { let ep = &module.entry_points[ep_index]; let fun = &ep.function; let fun_info = mod_info.get_entry_point(ep_index); let mut ep_error = None; // For vertex_id and instance_id arguments, presume that we'll // use our generated names, but switch to the name of an // existing @builtin param, if we find one. let mut v_existing_id = None; let mut i_existing_id = None; log::trace!( "entry point {:?}, index {:?}", fun.name.as_deref().unwrap_or("(anonymous)"), ep_index ); let ctx = back::FunctionCtx { ty: back::FunctionType::EntryPoint(ep_index as u16), info: fun_info, expressions: &fun.expressions, named_expressions: &fun.named_expressions, }; self.write_wrapped_functions(module, &ctx)?; let (em_str, in_mode, out_mode, can_vertex_pull) = match ep.stage { crate::ShaderStage::Vertex => ( "vertex", LocationMode::VertexInput, LocationMode::VertexOutput, true, ), crate::ShaderStage::Fragment => ( "fragment", LocationMode::FragmentInput, LocationMode::FragmentOutput, false, ), crate::ShaderStage::Compute => ( "kernel", LocationMode::Uniform, LocationMode::Uniform, false, ), crate::ShaderStage::Task | crate::ShaderStage::Mesh => unimplemented!(), crate::ShaderStage::RayGeneration | crate::ShaderStage::AnyHit | crate::ShaderStage::ClosestHit | crate::ShaderStage::Miss => unimplemented!(), }; // Should this entry point be modified to do vertex pulling? let do_vertex_pulling = can_vertex_pull && pipeline_options.vertex_pulling_transform && !pipeline_options.vertex_buffer_mappings.is_empty(); // Is any global variable used by this entry point dynamically sized? let needs_buffer_sizes = do_vertex_pulling || module .global_variables .iter() .filter(|&(handle, _)| !fun_info[handle].is_empty()) .any(|(_, var)| needs_array_length(var.ty, &module.types)); // skip this entry point if any global bindings are missing, // or their types are incompatible. if !options.fake_missing_bindings { for (var_handle, var) in module.global_variables.iter() { if fun_info[var_handle].is_empty() { continue; } match var.space { crate::AddressSpace::Uniform | crate::AddressSpace::Storage { .. } | crate::AddressSpace::Handle => { let br = match var.binding { Some(ref br) => br, None => { let var_name = var.name.clone().unwrap_or_default(); ep_error = Some(super::EntryPointError::MissingBinding(var_name)); break; } }; let target = options.get_resource_binding_target(ep, br); let good = match target { Some(target) => { // We intentionally don't dereference binding_arrays here, // so that binding arrays fall to the buffer location. match module.types[var.ty].inner { crate::TypeInner::Image { class: crate::ImageClass::External, .. } => target.external_texture.is_some(), crate::TypeInner::Image { .. } => target.texture.is_some(), crate::TypeInner::Sampler { .. } => { target.sampler.is_some() } _ => target.buffer.is_some(), } } None => false, }; if !good { ep_error = Some(super::EntryPointError::MissingBindTarget(*br)); break; } } crate::AddressSpace::Immediate => { if let Err(e) = options.resolve_immediates(ep) { ep_error = Some(e); break; } } crate::AddressSpace::TaskPayload => { unimplemented!() } crate::AddressSpace::Function | crate::AddressSpace::Private | crate::AddressSpace::WorkGroup => {} crate::AddressSpace::RayPayload | crate::AddressSpace::IncomingRayPayload => unimplemented!(), } } if needs_buffer_sizes { if let Err(err) = options.resolve_sizes_buffer(ep) { ep_error = Some(err); } } } if let Some(err) = ep_error { info.entry_point_names.push(Err(err)); continue; } let fun_name = &self.names[&NameKey::EntryPoint(ep_index as _)]; info.entry_point_names.push(Ok(fun_name.clone())); writeln!(self.out)?; // Since `Namer.reset` wasn't expecting struct members to be // suddenly injected into another namespace like this, // `self.names` doesn't keep them distinct from other variables. // Generate fresh names for these arguments, and remember the // mapping. let mut flattened_member_names = FastHashMap::default(); // Varyings' members get their own namespace let mut varyings_namer = proc::Namer::default(); // List all the Naga `EntryPoint`'s `Function`'s arguments, // flattening structs into their members. In Metal, we will pass // each of these values to the entry point as a separate argument— // except for the varyings, handled next. let mut flattened_arguments = Vec::new(); for (arg_index, arg) in fun.arguments.iter().enumerate() { match module.types[arg.ty].inner { crate::TypeInner::Struct { ref members, .. } => { for (member_index, member) in members.iter().enumerate() { let member_index = member_index as u32; flattened_arguments.push(( NameKey::StructMember(arg.ty, member_index), member.ty, member.binding.as_ref(), )); let name_key = NameKey::StructMember(arg.ty, member_index); let name = match member.binding { Some(crate::Binding::Location { .. }) => { if do_vertex_pulling { self.namer.call(&self.names[&name_key]) } else { varyings_namer.call(&self.names[&name_key]) } } _ => self.namer.call(&self.names[&name_key]), }; flattened_member_names.insert(name_key, name); } } _ => flattened_arguments.push(( NameKey::EntryPointArgument(ep_index as _, arg_index as u32), arg.ty, arg.binding.as_ref(), )), } } // Identify the varyings among the argument values, and maybe emit // a struct type named `Input` to hold them. If we are doing // vertex pulling, we instead update our attribute mapping to // note the types, names, and zero values of the attributes. let stage_in_name = self.namer.call(&format!("{fun_name}Input")); let varyings_member_name = self.namer.call("varyings"); let mut has_varyings = false; if !flattened_arguments.is_empty() { if !do_vertex_pulling { writeln!(self.out, "struct {stage_in_name} {{")?; } for &(ref name_key, ty, binding) in flattened_arguments.iter() { let Some(binding) = binding else { continue; }; let name = match *name_key { NameKey::StructMember(..) => &flattened_member_names[name_key], _ => &self.names[name_key], }; let ty_name = TypeContext { handle: ty, gctx: module.to_ctx(), names: &self.names, access: crate::StorageAccess::empty(), first_time: false, }; let resolved = options.resolve_local_binding(binding, in_mode)?; let location = match *binding { crate::Binding::Location { location, .. } => Some(location), crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { .. }) => None, crate::Binding::BuiltIn(_) => continue, }; if do_vertex_pulling { let Some(location) = location else { continue; }; // Update our attribute mapping. am_resolved.insert( location, AttributeMappingResolved { ty_name: ty_name.to_string(), dimension: ty_name.vector_size(), scalar: ty_name.scalar().unwrap(), name: name.to_string(), }, ); } else { has_varyings = true; write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?; resolved.try_fmt(&mut self.out)?; writeln!(self.out, ";")?; } } if !do_vertex_pulling { writeln!(self.out, "}};")?; } } // Define a struct type named for the return value, if any, named // `Output`. let stage_out_name = self.namer.call(&format!("{fun_name}Output")); let result_member_name = self.namer.call("member"); let result_type_name = match fun.result { Some(ref result) => { let mut result_members = Vec::new(); if let crate::TypeInner::Struct { ref members, .. } = module.types[result.ty].inner { for (member_index, member) in members.iter().enumerate() { result_members.push(( &self.names[&NameKey::StructMember(result.ty, member_index as u32)], member.ty, member.binding.as_ref(), )); } } else { result_members.push(( &result_member_name, result.ty, result.binding.as_ref(), )); } writeln!(self.out, "struct {stage_out_name} {{")?; let mut has_point_size = false; for (name, ty, binding) in result_members { let ty_name = TypeContext { handle: ty, gctx: module.to_ctx(), names: &self.names, access: crate::StorageAccess::empty(), first_time: true, }; let binding = binding.ok_or_else(|| { Error::GenericValidation("Expected binding, got None".into()) })?; if let crate::Binding::BuiltIn(crate::BuiltIn::PointSize) = *binding { has_point_size = true; if !pipeline_options.allow_and_force_point_size { continue; } } let array_len = match module.types[ty].inner { crate::TypeInner::Array { size: crate::ArraySize::Constant(size), .. } => Some(size), _ => None, }; let resolved = options.resolve_local_binding(binding, out_mode)?; write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?; if let Some(array_len) = array_len { write!(self.out, " [{array_len}]")?; } resolved.try_fmt(&mut self.out)?; writeln!(self.out, ";")?; } if pipeline_options.allow_and_force_point_size && ep.stage == crate::ShaderStage::Vertex && !has_point_size { // inject the point size output last writeln!( self.out, "{}float _point_size [[point_size]];", back::INDENT )?; } writeln!(self.out, "}};")?; &stage_out_name } None => "void", }; // If we're doing a vertex pulling transform, define the buffer // structure types. if do_vertex_pulling { for vbm in &vbm_resolved { let buffer_stride = vbm.stride; let buffer_ty = &vbm.ty_name; // Define a structure of bytes of the appropriate size. // When we access the attributes, we'll be unpacking these // bytes at some offset. writeln!( self.out, "struct {buffer_ty} {{ metal::uchar data[{buffer_stride}]; }};" )?; } } // Write the entry point function's name, and begin its argument list. writeln!(self.out, "{em_str} {result_type_name} {fun_name}(")?; let mut is_first_argument = true; let mut separator = || { if is_first_argument { is_first_argument = false; ' ' } else { ',' } }; // If we have produced a struct holding the `EntryPoint`'s // `Function`'s arguments' varyings, pass that struct first. if has_varyings { writeln!( self.out, "{} {stage_in_name} {varyings_member_name} [[stage_in]]", separator() )?; } let mut local_invocation_id = None; // Then pass the remaining arguments not included in the varyings // struct. for &(ref name_key, ty, binding) in flattened_arguments.iter() { let binding = match binding { Some(&crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { .. })) => continue, Some(binding @ &crate::Binding::BuiltIn { .. }) => binding, _ => continue, }; let name = match *name_key { NameKey::StructMember(..) => &flattened_member_names[name_key], _ => &self.names[name_key], }; if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationId) { local_invocation_id = Some(name_key); } let ty_name = TypeContext { handle: ty, gctx: module.to_ctx(), names: &self.names, access: crate::StorageAccess::empty(), first_time: false, }; match *binding { crate::Binding::BuiltIn(crate::BuiltIn::VertexIndex) => { v_existing_id = Some(name.clone()); } crate::Binding::BuiltIn(crate::BuiltIn::InstanceIndex) => { i_existing_id = Some(name.clone()); } _ => {} }; let resolved = options.resolve_local_binding(binding, in_mode)?; write!(self.out, "{} {ty_name} {name}", separator())?; resolved.try_fmt(&mut self.out)?; writeln!(self.out)?; } let need_workgroup_variables_initialization = self.need_workgroup_variables_initialization(options, ep, module, fun_info); if need_workgroup_variables_initialization && local_invocation_id.is_none() { writeln!( self.out, "{} {NAMESPACE}::uint3 __local_invocation_id [[thread_position_in_threadgroup]]", separator() )?; } // Those global variables used by this entry point and its callees // get passed as arguments. `Private` globals are an exception, they // don't outlive this invocation, so we declare them below as locals // within the entry point. for (handle, var) in module.global_variables.iter() { let usage = fun_info[handle]; if usage.is_empty() || var.space == crate::AddressSpace::Private { continue; } if options.lang_version < (1, 2) { match var.space { // This restriction is not documented in the MSL spec // but validation will fail if it is not upheld. // // We infer the required version from the "Function // Buffer Read-Writes" section of [what's new], where // the feature sets listed correspond with the ones // supporting MSL 1.2. // // [what's new]: https://developer.apple.com/library/archive/documentation/Miscellaneous/Conceptual/MetalProgrammingGuide/WhatsNewiniOS10tvOS10andOSX1012/WhatsNewiniOS10tvOS10andOSX1012.html crate::AddressSpace::Storage { access } if access.contains(crate::StorageAccess::STORE) && ep.stage == crate::ShaderStage::Fragment => { return Err(Error::UnsupportedWriteableStorageBuffer) } crate::AddressSpace::Handle => { match module.types[var.ty].inner { crate::TypeInner::Image { class: crate::ImageClass::Storage { access, .. }, .. } => { // This restriction is not documented in the MSL spec // but validation will fail if it is not upheld. // // We infer the required version from the "Function // Texture Read-Writes" section of [what's new], where // the feature sets listed correspond with the ones // supporting MSL 1.2. // // [what's new]: https://developer.apple.com/library/archive/documentation/Miscellaneous/Conceptual/MetalProgrammingGuide/WhatsNewiniOS10tvOS10andOSX1012/WhatsNewiniOS10tvOS10andOSX1012.html if access.contains(crate::StorageAccess::STORE) && (ep.stage == crate::ShaderStage::Vertex || ep.stage == crate::ShaderStage::Fragment) { return Err(Error::UnsupportedWriteableStorageTexture( ep.stage, )); } if access.contains( crate::StorageAccess::LOAD | crate::StorageAccess::STORE, ) { return Err(Error::UnsupportedRWStorageTexture); } } _ => {} } } _ => {} } } // Check min MSL version for binding arrays match var.space { crate::AddressSpace::Handle => match module.types[var.ty].inner { crate::TypeInner::BindingArray { base, .. } => { match module.types[base].inner { crate::TypeInner::Sampler { .. } => { if options.lang_version < (2, 0) { return Err(Error::UnsupportedArrayOf( "samplers".to_string(), )); } } crate::TypeInner::Image { class, .. } => match class { crate::ImageClass::Sampled { .. } | crate::ImageClass::Depth { .. } | crate::ImageClass::Storage { access: crate::StorageAccess::LOAD, .. } => { // Array of textures since: // - iOS: Metal 1.2 (check depends on https://github.com/gfx-rs/naga/issues/2164) // - macOS: Metal 2 if options.lang_version < (2, 0) { return Err(Error::UnsupportedArrayOf( "textures".to_string(), )); } } crate::ImageClass::Storage { access: crate::StorageAccess::STORE, .. } => { // Array of write-only textures since: // - iOS: Metal 2.2 (check depends on https://github.com/gfx-rs/naga/issues/2164) // - macOS: Metal 2 if options.lang_version < (2, 0) { return Err(Error::UnsupportedArrayOf( "write-only textures".to_string(), )); } } crate::ImageClass::Storage { .. } => { if options.lang_version < (3, 0) { return Err(Error::UnsupportedArrayOf( "read-write textures".to_string(), )); } } crate::ImageClass::External => { return Err(Error::UnsupportedArrayOf( "external textures".to_string(), )); } }, _ => { return Err(Error::UnsupportedArrayOfType(base)); } } } _ => {} }, _ => {} } // the resolves have already been checked for `!fake_missing_bindings` case let resolved = match var.space { crate::AddressSpace::Immediate => options.resolve_immediates(ep).ok(), crate::AddressSpace::WorkGroup => None, _ => options .resolve_resource_binding(ep, var.binding.as_ref().unwrap()) .ok(), }; if let Some(ref resolved) = resolved { // Inline samplers are be defined in the EP body if resolved.as_inline_sampler(options).is_some() { continue; } } match module.types[var.ty].inner { crate::TypeInner::Image { class: crate::ImageClass::External, .. } => { // External texture global variables get lowered to 3 textures // and a constant buffer. We must emit a separate argument for // each of these. let target = match resolved { Some(back::msl::ResolvedBinding::Resource(target)) => { target.external_texture } _ => None, }; for i in 0..3 { write!(self.out, "{} ", separator())?; let plane_name = &self.names[&NameKey::ExternalTextureGlobalVariable( handle, ExternalTextureNameKey::Plane(i), )]; write!( self.out, "{NAMESPACE}::texture2d {plane_name}" )?; if let Some(ref target) = target { write!(self.out, " [[texture({})]]", target.planes[i])?; } writeln!(self.out)?; } let params_ty_name = &self.names [&NameKey::Type(module.special_types.external_texture_params.unwrap())]; let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable( handle, ExternalTextureNameKey::Params, )]; write!(self.out, "{} ", separator())?; write!(self.out, "constant {params_ty_name}& {params_name}")?; if let Some(ref target) = target { write!(self.out, " [[buffer({})]]", target.params)?; } } _ => { let tyvar = TypedGlobalVariable { module, names: &self.names, handle, usage, reference: true, }; write!(self.out, "{} ", separator())?; tyvar.try_fmt(&mut self.out)?; if let Some(resolved) = resolved { resolved.try_fmt(&mut self.out)?; } if let Some(value) = var.init { write!(self.out, " = ")?; self.put_const_expression( value, module, mod_info, &module.global_expressions, )?; } } } writeln!(self.out)?; } if do_vertex_pulling { if needs_vertex_id && v_existing_id.is_none() { // Write the [[vertex_id]] argument. writeln!(self.out, "{} uint {v_id} [[vertex_id]]", separator())?; } if needs_instance_id && i_existing_id.is_none() { writeln!(self.out, "{} uint {i_id} [[instance_id]]", separator())?; } // Iterate vbm_resolved, output one argument for every vertex buffer, // using the names we generated earlier. for vbm in &vbm_resolved { let id = &vbm.id; let ty_name = &vbm.ty_name; let param_name = &vbm.param_name; writeln!( self.out, "{} const device {ty_name}* {param_name} [[buffer({id})]]", separator() )?; } } // If this entry uses any variable-length arrays, their sizes are // passed as a final struct-typed argument. if needs_buffer_sizes { // this is checked earlier let resolved = options.resolve_sizes_buffer(ep).unwrap(); write!( self.out, "{} constant _mslBufferSizes& _buffer_sizes", separator() )?; resolved.try_fmt(&mut self.out)?; writeln!(self.out)?; } // end of the entry point argument list writeln!(self.out, ") {{")?; // Starting the function body. if do_vertex_pulling { // Provide zero values for all the attributes, which we will overwrite with // real data from the vertex attribute buffers, if the indices are in-bounds. for vbm in &vbm_resolved { for attribute in vbm.attributes { let location = attribute.shader_location; let am_option = am_resolved.get(&location); if am_option.is_none() { // This bound attribute isn't used in this entry point, so // don't bother zero-initializing it. continue; } let am = am_option.unwrap(); let attribute_ty_name = &am.ty_name; let attribute_name = &am.name; writeln!( self.out, "{}{attribute_ty_name} {attribute_name} = {{}};", back::Level(1) )?; } // Output a bounds check block that will set real values for the // attributes, if the bounds are satisfied. write!(self.out, "{}if (", back::Level(1))?; let idx = &vbm.id; let stride = &vbm.stride; let index_name = match vbm.step_mode { back::msl::VertexBufferStepMode::Constant => "0", back::msl::VertexBufferStepMode::ByVertex => { if let Some(ref name) = v_existing_id { name } else { &v_id } } back::msl::VertexBufferStepMode::ByInstance => { if let Some(ref name) = i_existing_id { name } else { &i_id } } }; write!( self.out, "{index_name} < (_buffer_sizes.buffer_size{idx} / {stride})" )?; writeln!(self.out, ") {{")?; // Pull the bytes out of the vertex buffer. let ty_name = &vbm.ty_name; let elem_name = &vbm.elem_name; let param_name = &vbm.param_name; writeln!( self.out, "{}const {ty_name} {elem_name} = {param_name}[{index_name}];", back::Level(2), )?; // Now set real values for each of the attributes, by unpacking the data // from the buffer elements. for attribute in vbm.attributes { let location = attribute.shader_location; let Some(am) = am_resolved.get(&location) else { // This bound attribute isn't used in this entry point, so // don't bother extracting the data. Too bad we emitted the // unpacking function earlier -- it might not get used. continue; }; let attribute_name = &am.name; let attribute_ty_name = &am.ty_name; let offset = attribute.offset; let func = unpacking_functions .get(&attribute.format) .expect("Should have generated this unpacking function earlier."); let func_name = &func.name; // Check dimensionality of the attribute compared to the unpacking // function. If attribute dimension > unpack dimension, we have to // pad out the unpack value from a vec4(0, 0, 0, 1) of matching // scalar type. Otherwise, if attribute dimension is < unpack // dimension, then we need to explicitly truncate the result. let needs_padding_or_truncation = am.dimension.cmp(&func.dimension); // We need an extra type conversion if the shader type does not // match the type returned from the unpacking function. let needs_conversion = am.scalar != func.scalar; if needs_padding_or_truncation != Ordering::Equal { // Emit a comment flagging that a conversion is happening, // since the actual logic can be at the end of a long line. writeln!( self.out, "{}// {attribute_ty_name} <- {:?}", back::Level(2), attribute.format )?; } write!(self.out, "{}{attribute_name} = ", back::Level(2),)?; if needs_padding_or_truncation == Ordering::Greater { // Needs padding: emit constructor call for wider type write!(self.out, "{attribute_ty_name}(")?; } // Emit call to unpacking function if needs_conversion { put_numeric_type(&mut self.out, am.scalar, func.dimension.as_slice())?; write!(self.out, "(")?; } write!(self.out, "{func_name}({elem_name}.data[{offset}]")?; for i in (offset + 1)..(offset + func.byte_count) { write!(self.out, ", {elem_name}.data[{i}]")?; } write!(self.out, ")")?; if needs_conversion { write!(self.out, ")")?; } match needs_padding_or_truncation { Ordering::Greater => { // Padding let ty_is_int = scalar_is_int(am.scalar); let zero_value = if ty_is_int { "0" } else { "0.0" }; let one_value = if ty_is_int { "1" } else { "1.0" }; for i in func.dimension.map_or(1, u8::from) ..am.dimension.map_or(1, u8::from) { write!( self.out, ", {}", if i == 3 { one_value } else { zero_value } )?; } } Ordering::Less => { // Truncate to the first `am.dimension` components write!( self.out, ".{}", &"xyzw"[0..usize::from(am.dimension.map_or(1, u8::from))] )?; } Ordering::Equal => {} } if needs_padding_or_truncation == Ordering::Greater { write!(self.out, ")")?; } writeln!(self.out, ";")?; } // End the bounds check / attribute setting block. writeln!(self.out, "{}}}", back::Level(1))?; } } if need_workgroup_variables_initialization { self.write_workgroup_variables_initialization( module, mod_info, fun_info, local_invocation_id, )?; } // Metal doesn't support private mutable variables outside of functions, // so we put them here, just like the locals. for (handle, var) in module.global_variables.iter() { let usage = fun_info[handle]; if usage.is_empty() { continue; } if var.space == crate::AddressSpace::Private { let tyvar = TypedGlobalVariable { module, names: &self.names, handle, usage, reference: false, }; write!(self.out, "{}", back::INDENT)?; tyvar.try_fmt(&mut self.out)?; match var.init { Some(value) => { write!(self.out, " = ")?; self.put_const_expression( value, module, mod_info, &module.global_expressions, )?; writeln!(self.out, ";")?; } None => { writeln!(self.out, " = {{}};")?; } }; } else if let Some(ref binding) = var.binding { let resolved = options.resolve_resource_binding(ep, binding).unwrap(); if let Some(sampler) = resolved.as_inline_sampler(options) { // write an inline sampler let name = &self.names[&NameKey::GlobalVariable(handle)]; writeln!( self.out, "{}constexpr {}::sampler {}(", back::INDENT, NAMESPACE, name )?; self.put_inline_sampler_properties(back::Level(2), sampler)?; writeln!(self.out, "{});", back::INDENT)?; } else if let crate::TypeInner::Image { class: crate::ImageClass::External, .. } = module.types[var.ty].inner { // Wrap the individual arguments for each external texture global // in a struct which can be easily passed around. let wrapper_name = &self.names[&NameKey::GlobalVariable(handle)]; let l1 = back::Level(1); let l2 = l1.next(); writeln!( self.out, "{l1}const {EXTERNAL_TEXTURE_WRAPPER_STRUCT} {wrapper_name} {{" )?; for i in 0..3 { let plane_name = &self.names[&NameKey::ExternalTextureGlobalVariable( handle, ExternalTextureNameKey::Plane(i), )]; writeln!(self.out, "{l2}.plane{i} = {plane_name},")?; } let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable( handle, ExternalTextureNameKey::Params, )]; writeln!(self.out, "{l2}.params = {params_name},")?; writeln!(self.out, "{l1}}};")?; } } } // Now take the arguments that we gathered into structs, and the // structs that we flattened into arguments, and emit local // variables with initializers that put everything back the way the // body code expects. // // If we had to generate fresh names for struct members passed as // arguments, be sure to use those names when rebuilding the struct. // // "Each day, I change some zeros to ones, and some ones to zeros. // The rest, I leave alone." for (arg_index, arg) in fun.arguments.iter().enumerate() { let arg_name = &self.names[&NameKey::EntryPointArgument(ep_index as _, arg_index as u32)]; match module.types[arg.ty].inner { crate::TypeInner::Struct { ref members, .. } => { let struct_name = &self.names[&NameKey::Type(arg.ty)]; write!( self.out, "{}const {} {} = {{ ", back::INDENT, struct_name, arg_name )?; for (member_index, member) in members.iter().enumerate() { let key = NameKey::StructMember(arg.ty, member_index as u32); let name = &flattened_member_names[&key]; if member_index != 0 { write!(self.out, ", ")?; } // insert padding initialization, if needed if self .struct_member_pads .contains(&(arg.ty, member_index as u32)) { write!(self.out, "{{}}, ")?; } if let Some(crate::Binding::Location { .. }) = member.binding { if has_varyings { write!(self.out, "{varyings_member_name}.")?; } } write!(self.out, "{name}")?; } writeln!(self.out, " }};")?; } _ => match arg.binding { Some(crate::Binding::Location { .. }) | Some(crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { .. })) => { if has_varyings { writeln!( self.out, "{}const auto {} = {}.{};", back::INDENT, arg_name, varyings_member_name, arg_name )?; } } _ => {} }, } } let guarded_indices = index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies); let context = StatementContext { expression: ExpressionContext { function: fun, origin: FunctionOrigin::EntryPoint(ep_index as _), info: fun_info, lang_version: options.lang_version, policies: options.bounds_check_policies, guarded_indices, module, mod_info, pipeline_options, force_loop_bounding: options.force_loop_bounding, }, result_struct: Some(&stage_out_name), }; // Finally, declare all the local variables that we need //TODO: we can postpone this till the relevant expressions are emitted self.put_locals(&context.expression)?; self.update_expressions_to_bake(fun, fun_info, &context.expression); self.put_block(back::Level(1), &fun.body, &context)?; writeln!(self.out, "}}")?; if ep_index + 1 != module.entry_points.len() { writeln!(self.out)?; } self.named_expressions.clear(); } Ok(info) } fn write_barrier(&mut self, flags: crate::Barrier, level: back::Level) -> BackendResult { // Note: OR-ring bitflags requires `__HAVE_MEMFLAG_OPERATORS__`, // so we try to avoid it here. if flags.is_empty() { writeln!( self.out, "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_none);", )?; } if flags.contains(crate::Barrier::STORAGE) { writeln!( self.out, "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_device);", )?; } if flags.contains(crate::Barrier::WORK_GROUP) { writeln!( self.out, "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);", )?; } if flags.contains(crate::Barrier::SUB_GROUP) { writeln!( self.out, "{level}{NAMESPACE}::simdgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);", )?; } if flags.contains(crate::Barrier::TEXTURE) { writeln!( self.out, "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_texture);", )?; } Ok(()) } } /// Initializing workgroup variables is more tricky for Metal because we have to deal /// with atomics at the type-level (which don't have a copy constructor). mod workgroup_mem_init { use crate::EntryPoint; use super::*; enum Access { GlobalVariable(Handle), StructMember(Handle, u32), Array(usize), } impl Access { fn write( &self, writer: &mut W, names: &FastHashMap, ) -> Result<(), core::fmt::Error> { match *self { Access::GlobalVariable(handle) => { write!(writer, "{}", &names[&NameKey::GlobalVariable(handle)]) } Access::StructMember(handle, index) => { write!(writer, ".{}", &names[&NameKey::StructMember(handle, index)]) } Access::Array(depth) => write!(writer, ".{WRAPPED_ARRAY_FIELD}[__i{depth}]"), } } } struct AccessStack { stack: Vec, array_depth: usize, } impl AccessStack { const fn new() -> Self { Self { stack: Vec::new(), array_depth: 0, } } fn enter_array(&mut self, cb: impl FnOnce(&mut Self, usize) -> R) -> R { let array_depth = self.array_depth; self.stack.push(Access::Array(array_depth)); self.array_depth += 1; let res = cb(self, array_depth); self.stack.pop(); self.array_depth -= 1; res } fn enter(&mut self, new: Access, cb: impl FnOnce(&mut Self) -> R) -> R { self.stack.push(new); let res = cb(self); self.stack.pop(); res } fn write( &self, writer: &mut W, names: &FastHashMap, ) -> Result<(), core::fmt::Error> { for next in self.stack.iter() { next.write(writer, names)?; } Ok(()) } } impl Writer { pub(super) fn need_workgroup_variables_initialization( &mut self, options: &Options, ep: &EntryPoint, module: &crate::Module, fun_info: &valid::FunctionInfo, ) -> bool { options.zero_initialize_workgroup_memory && ep.stage.compute_like() && module.global_variables.iter().any(|(handle, var)| { !fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup }) } pub(super) fn write_workgroup_variables_initialization( &mut self, module: &crate::Module, module_info: &valid::ModuleInfo, fun_info: &valid::FunctionInfo, local_invocation_id: Option<&NameKey>, ) -> BackendResult { let level = back::Level(1); writeln!( self.out, "{}if ({}::all({} == {}::uint3(0u))) {{", level, NAMESPACE, local_invocation_id .map(|name_key| self.names[name_key].as_str()) .unwrap_or("__local_invocation_id"), NAMESPACE, )?; let mut access_stack = AccessStack::new(); let vars = module.global_variables.iter().filter(|&(handle, var)| { !fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup }); for (handle, var) in vars { access_stack.enter(Access::GlobalVariable(handle), |access_stack| { self.write_workgroup_variable_initialization( module, module_info, var.ty, access_stack, level.next(), ) })?; } writeln!(self.out, "{level}}}")?; self.write_barrier(crate::Barrier::WORK_GROUP, level) } fn write_workgroup_variable_initialization( &mut self, module: &crate::Module, module_info: &valid::ModuleInfo, ty: Handle, access_stack: &mut AccessStack, level: back::Level, ) -> BackendResult { if module_info[ty].contains(valid::TypeFlags::CONSTRUCTIBLE) { write!(self.out, "{level}")?; access_stack.write(&mut self.out, &self.names)?; writeln!(self.out, " = {{}};")?; } else { match module.types[ty].inner { crate::TypeInner::Atomic { .. } => { write!( self.out, "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}" )?; access_stack.write(&mut self.out, &self.names)?; writeln!(self.out, ", 0, {NAMESPACE}::memory_order_relaxed);")?; } crate::TypeInner::Array { base, size, .. } => { let count = match size.resolve(module.to_ctx())? { proc::IndexableLength::Known(count) => count, proc::IndexableLength::Dynamic => unreachable!(), }; access_stack.enter_array(|access_stack, array_depth| { writeln!( self.out, "{level}for (int __i{array_depth} = 0; __i{array_depth} < {count}; __i{array_depth}++) {{" )?; self.write_workgroup_variable_initialization( module, module_info, base, access_stack, level.next(), )?; writeln!(self.out, "{level}}}")?; BackendResult::Ok(()) })?; } crate::TypeInner::Struct { ref members, .. } => { for (index, member) in members.iter().enumerate() { access_stack.enter( Access::StructMember(ty, index as u32), |access_stack| { self.write_workgroup_variable_initialization( module, module_info, member.ty, access_stack, level, ) }, )?; } } _ => unreachable!(), } } Ok(()) } } } impl crate::AtomicFunction { const fn to_msl(self) -> &'static str { match self { Self::Add => "fetch_add", Self::Subtract => "fetch_sub", Self::And => "fetch_and", Self::InclusiveOr => "fetch_or", Self::ExclusiveOr => "fetch_xor", Self::Min => "fetch_min", Self::Max => "fetch_max", Self::Exchange { compare: None } => "exchange", Self::Exchange { compare: Some(_) } => ATOMIC_COMP_EXCH_FUNCTION, } } fn to_msl_64_bit(self) -> Result<&'static str, Error> { Ok(match self { Self::Min => "min", Self::Max => "max", _ => Err(Error::FeatureNotImplemented( "64-bit atomic operation other than min/max".to_string(), ))?, }) } } naga-29.0.3/src/back/pipeline_constants.rs000064400000000000000000001205521046102023000165700ustar 00000000000000use alloc::{ borrow::Cow, string::{String, ToString}, vec::Vec, }; use core::mem; use hashbrown::HashSet; use thiserror::Error; use super::PipelineConstants; use crate::{ arena::HandleVec, compact::{compact, KeepUnused}, ir, proc::{ConstantEvaluator, ConstantEvaluatorError, Emitter}, valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags, Validator}, Arena, Block, Constant, Expression, Function, Handle, Literal, Module, Override, Range, Scalar, Span, Statement, TypeInner, WithSpan, }; // Possibly unused if not compiled with no_std #[allow(unused_imports)] use num_traits::float::FloatCore as _; #[derive(Error, Debug, Clone)] #[cfg_attr(test, derive(PartialEq))] pub enum PipelineConstantError { #[error("Missing value for pipeline-overridable constant with identifier string: '{0}'")] MissingValue(String), #[error("pipeline-overridable constant '{0}' not found in the shader")] NotFound(String), #[error( "Source f64 value needs to be finite ({}) for number destinations", "NaNs and Inifinites are not allowed" )] SrcNeedsToBeFinite, #[error("Source f64 value doesn't fit in destination")] DstRangeTooSmall, #[error(transparent)] ConstantEvaluatorError(#[from] ConstantEvaluatorError), #[error(transparent)] ValidationError(#[from] WithSpan), #[error("workgroup_size override isn't strictly positive")] NegativeWorkgroupSize, #[error("max vertices or max primitives is negative")] NegativeMeshOutputMax, } /// Compact `module` and replace all overrides with constants. /// /// `module` must be valid. Both compaction and constant evaluation may produce /// invalid results (e.g. replace an invalid expression with a constant) for /// invalid modules. /// /// If no changes are needed, this just returns `Cow::Borrowed` references to /// `module` and `module_info`. Otherwise, it clones `module`, retains only the /// selected entry point, compacts the module, edits its [`global_expressions`] /// arena to contain only fully-evaluated expressions, and returns the /// simplified module and its validation results. /// /// The module returned has an empty `overrides` arena, and the /// `global_expressions` arena contains only fully-evaluated expressions. /// /// [`global_expressions`]: Module::global_expressions pub fn process_overrides<'a>( module: &'a Module, module_info: &'a ModuleInfo, entry_point: Option<(ir::ShaderStage, &str)>, pipeline_constants: &PipelineConstants, ) -> Result<(Cow<'a, Module>, Cow<'a, ModuleInfo>), PipelineConstantError> { let mut handles = module .overrides .iter() .map(|(handle, _)| handle) .collect::>(); for c in pipeline_constants.keys() { let c_id = c.parse().ok(); if let Some((i, _)) = handles.iter().enumerate().find(|&(_, handle)| { let o = &module.overrides[*handle]; if o.id.is_some() { o.id == c_id } else { o.name.as_deref() == Some(c.as_str()) } }) { handles.swap_remove(i); } else { return Err(PipelineConstantError::NotFound(c.clone())); } } if (entry_point.is_none() || module.entry_points.len() <= 1) && module.overrides.is_empty() { // We skip compacting the module here mostly to reduce the risk of // hitting corner cases like https://github.com/gfx-rs/wgpu/issues/7793. // Compaction doesn't cost very much [1], so it would also be reasonable // to do it unconditionally. Even when there is a single entry point or // when no entry point is specified, it is still possible that there // are unreferenced items in the module that would be removed by this // compaction. // // [1]: https://github.com/gfx-rs/wgpu/pull/7703#issuecomment-2902153760 return Ok((Cow::Borrowed(module), Cow::Borrowed(module_info))); } let mut module = module.clone(); if let Some((ep_stage, ep_name)) = entry_point { module .entry_points .retain(|ep| ep.stage == ep_stage && ep.name == ep_name); } // Compact the module to remove anything not reachable from an entry point. // This is necessary because we may not have values for overrides that are // not reachable from the/an entry point. compact(&mut module, KeepUnused::No); // If there are no overrides in the module, then we can skip the rest. if module.overrides.is_empty() { return revalidate(module); } // A map from override handles to the handles of the constants // we've replaced them with. let mut override_map = HandleVec::with_capacity(module.overrides.len()); // A map from `module`'s original global expression handles to // handles in the new, simplified global expression arena. let mut adjusted_global_expressions = HandleVec::with_capacity(module.global_expressions.len()); // The set of constants whose initializer handles we've already // updated to refer to the newly built global expression arena. // // All constants in `module` must have their `init` handles // updated to point into the new, simplified global expression // arena. Some of these we can most easily handle as a side effect // during the simplification process, but we must handle the rest // in a final fixup pass, guided by `adjusted_global_expressions`. We // add their handles to this set, so that the final fixup step can // leave them alone. let mut adjusted_constant_initializers = HashSet::with_capacity(module.constants.len()); let mut global_expression_kind_tracker = crate::proc::ExpressionKindTracker::new(); let mut layouter = crate::proc::Layouter::default(); // An iterator through the original overrides table, consumed in // approximate tandem with the global expressions. let mut overrides = mem::take(&mut module.overrides); let mut override_iter = overrides.iter_mut_span(); // Do two things in tandem: // // - Rebuild the global expression arena from scratch, fully // evaluating all expressions, and replacing each `Override` // expression in `module.global_expressions` with a `Constant` // expression. // // - Build a new `Constant` in `module.constants` to take the // place of each `Override`. // // Build a map from old global expression handles to their // fully-evaluated counterparts in `adjusted_global_expressions` as we // go. // // Why in tandem? Overrides refer to expressions, and expressions // refer to overrides, so we can't disentangle the two into // separate phases. However, we can take advantage of the fact // that the overrides and expressions must form a DAG, and work // our way from the leaves to the roots, replacing and evaluating // as we go. // // Although the two loops are nested, this is really two // alternating phases: we adjust and evaluate constant expressions // until we hit an `Override` expression, at which point we switch // to building `Constant`s for `Overrides` until we've handled the // one used by the expression. Then we switch back to processing // expressions. Because we know they form a DAG, we know the // `Override` expressions we encounter can only have initializers // referring to global expressions we've already simplified. for (old_h, expr, span) in module.global_expressions.drain() { let mut expr = match expr { Expression::Override(h) => { let c_h = if let Some(new_h) = override_map.get(h) { *new_h } else { let mut new_h = None; for entry in override_iter.by_ref() { let stop = entry.0 == h; new_h = Some(process_override( entry, pipeline_constants, &mut module, &mut override_map, &adjusted_global_expressions, &mut adjusted_constant_initializers, &mut global_expression_kind_tracker, )?); if stop { break; } } new_h.unwrap() }; Expression::Constant(c_h) } Expression::Constant(c_h) => { if adjusted_constant_initializers.insert(c_h) { let init = &mut module.constants[c_h].init; *init = adjusted_global_expressions[*init]; } expr } expr => expr, }; let mut evaluator = ConstantEvaluator::for_wgsl_module( &mut module, &mut global_expression_kind_tracker, &mut layouter, false, ); adjust_expr(&adjusted_global_expressions, &mut expr); let h = evaluator.try_eval_and_append(expr, span)?; adjusted_global_expressions.insert(old_h, h); } // Finish processing any overrides we didn't visit in the loop above. for entry in override_iter { match *entry.1 { Override { name: Some(_), .. } | Override { id: Some(_), .. } => { process_override( entry, pipeline_constants, &mut module, &mut override_map, &adjusted_global_expressions, &mut adjusted_constant_initializers, &mut global_expression_kind_tracker, )?; } Override { init: Some(ref mut init), .. } => { *init = adjusted_global_expressions[*init]; } _ => {} } } // Update the initialization expression handles of all `Constant`s // and `GlobalVariable`s. Skip `Constant`s we'd already updated en // passant. for (_, c) in module .constants .iter_mut() .filter(|&(c_h, _)| !adjusted_constant_initializers.contains(&c_h)) { c.init = adjusted_global_expressions[c.init]; } for (_, v) in module.global_variables.iter_mut() { if let Some(ref mut init) = v.init { *init = adjusted_global_expressions[*init]; } } let mut functions = mem::take(&mut module.functions); for (_, function) in functions.iter_mut() { process_function(&mut module, &override_map, &mut layouter, function)?; } module.functions = functions; let mut entry_points = mem::take(&mut module.entry_points); for ep in entry_points.iter_mut() { process_function(&mut module, &override_map, &mut layouter, &mut ep.function)?; process_workgroup_size_override(&mut module, &adjusted_global_expressions, ep)?; process_mesh_shader_overrides(&mut module, &adjusted_global_expressions, ep)?; } module.entry_points = entry_points; module.overrides = overrides; // Now that we've rewritten all the expressions, we need to // recompute their types and other metadata. For the time being, // do a full re-validation. revalidate(module) } fn revalidate( module: Module, ) -> Result<(Cow<'static, Module>, Cow<'static, ModuleInfo>), PipelineConstantError> { let mut validator = Validator::new(ValidationFlags::all(), Capabilities::all()); let module_info = validator.validate_resolved_overrides(&module)?; Ok((Cow::Owned(module), Cow::Owned(module_info))) } fn process_workgroup_size_override( module: &mut Module, adjusted_global_expressions: &HandleVec>, ep: &mut crate::EntryPoint, ) -> Result<(), PipelineConstantError> { match ep.workgroup_size_overrides { None => {} Some(overrides) => { overrides.iter().enumerate().try_for_each( |(i, overridden)| -> Result<(), PipelineConstantError> { match *overridden { None => Ok(()), Some(h) => { ep.workgroup_size[i] = module .to_ctx() .get_const_val(adjusted_global_expressions[h]) .map(|n| { if n == 0 { Err(PipelineConstantError::NegativeWorkgroupSize) } else { Ok(n) } }) .map_err(|_| PipelineConstantError::NegativeWorkgroupSize)??; Ok(()) } } }, )?; ep.workgroup_size_overrides = None; } } Ok(()) } fn process_mesh_shader_overrides( module: &mut Module, adjusted_global_expressions: &HandleVec>, ep: &mut crate::EntryPoint, ) -> Result<(), PipelineConstantError> { if let Some(ref mut mesh_info) = ep.mesh_info { if let Some(r#override) = mesh_info.max_vertices_override { mesh_info.max_vertices = module .to_ctx() .get_const_val(adjusted_global_expressions[r#override]) .map_err(|_| PipelineConstantError::NegativeMeshOutputMax)?; } if let Some(r#override) = mesh_info.max_primitives_override { mesh_info.max_primitives = module .to_ctx() .get_const_val(adjusted_global_expressions[r#override]) .map_err(|_| PipelineConstantError::NegativeMeshOutputMax)?; } } Ok(()) } /// Add a [`Constant`] to `module` for the override `old_h`. /// /// Add the new `Constant` to `override_map` and `adjusted_constant_initializers`. fn process_override( (old_h, r#override, span): (Handle, &mut Override, &Span), pipeline_constants: &PipelineConstants, module: &mut Module, override_map: &mut HandleVec>, adjusted_global_expressions: &HandleVec>, adjusted_constant_initializers: &mut HashSet>, global_expression_kind_tracker: &mut crate::proc::ExpressionKindTracker, ) -> Result, PipelineConstantError> { // Determine which key to use for `r#override` in `pipeline_constants`. let key = if let Some(id) = r#override.id { Cow::Owned(id.to_string()) } else if let Some(ref name) = r#override.name { Cow::Borrowed(name) } else { unreachable!(); }; // Generate a global expression for `r#override`'s value, either // from the provided `pipeline_constants` table or its initializer // in the module. let init = if let Some(value) = pipeline_constants.get::(&key) { let literal = match module.types[r#override.ty].inner { TypeInner::Scalar(scalar) => map_value_to_literal(*value, scalar)?, _ => unreachable!(), }; let expr = module .global_expressions .append(Expression::Literal(literal), Span::UNDEFINED); global_expression_kind_tracker.insert(expr, crate::proc::ExpressionKind::Const); expr } else if let Some(init) = r#override.init { adjusted_global_expressions[init] } else { return Err(PipelineConstantError::MissingValue(key.to_string())); }; // Generate a new `Constant` to represent the override's value. let constant = Constant { name: r#override.name.clone(), ty: r#override.ty, init, }; let h = module.constants.append(constant, *span); override_map.insert(old_h, h); adjusted_constant_initializers.insert(h); r#override.init = Some(init); Ok(h) } /// Replace all override expressions in `function` with fully-evaluated constants. /// /// Replace all `Expression::Override`s in `function`'s expression arena with /// the corresponding `Expression::Constant`s, as given in `override_map`. /// Replace any expressions whose values are now known with their fully /// evaluated form. /// /// If `h` is a `Handle`, then `override_map[h]` is the /// `Handle` for the override's final value. fn process_function( module: &mut Module, override_map: &HandleVec>, layouter: &mut crate::proc::Layouter, function: &mut Function, ) -> Result<(), ConstantEvaluatorError> { // A map from original local expression handles to // handles in the new, local expression arena. let mut adjusted_local_expressions = HandleVec::with_capacity(function.expressions.len()); let mut local_expression_kind_tracker = crate::proc::ExpressionKindTracker::new(); let mut expressions = mem::take(&mut function.expressions); // Dummy `emitter` and `block` for the constant evaluator. // We can ignore the concept of emitting expressions here since // expressions have already been covered by a `Statement::Emit` // in the frontend. // The only thing we might have to do is remove some expressions // that have been covered by a `Statement::Emit`. See the docs of // `filter_emits_in_block` for the reasoning. let mut emitter = Emitter::default(); let mut block = Block::new(); let mut evaluator = ConstantEvaluator::for_wgsl_function( module, &mut function.expressions, &mut local_expression_kind_tracker, layouter, &mut emitter, &mut block, false, ); for (old_h, mut expr, span) in expressions.drain() { if let Expression::Override(h) = expr { expr = Expression::Constant(override_map[h]); } adjust_expr(&adjusted_local_expressions, &mut expr); let h = evaluator.try_eval_and_append(expr, span)?; adjusted_local_expressions.insert(old_h, h); } adjust_block(&adjusted_local_expressions, &mut function.body); filter_emits_in_block(&mut function.body, &function.expressions); // Update local expression initializers. for (_, local) in function.local_variables.iter_mut() { if let &mut Some(ref mut init) = &mut local.init { *init = adjusted_local_expressions[*init]; } } // We've changed the keys of `function.named_expression`, so we have to // rebuild it from scratch. let named_expressions = mem::take(&mut function.named_expressions); for (expr_h, name) in named_expressions { function .named_expressions .insert(adjusted_local_expressions[expr_h], name); } Ok(()) } /// Replace every expression handle in `expr` with its counterpart /// given by `new_pos`. fn adjust_expr(new_pos: &HandleVec>, expr: &mut Expression) { let adjust = |expr: &mut Handle| { *expr = new_pos[*expr]; }; match *expr { Expression::Compose { ref mut components, ty: _, } => { for c in components.iter_mut() { adjust(c); } } Expression::Access { ref mut base, ref mut index, } => { adjust(base); adjust(index); } Expression::AccessIndex { ref mut base, index: _, } => { adjust(base); } Expression::Splat { ref mut value, size: _, } => { adjust(value); } Expression::Swizzle { ref mut vector, size: _, pattern: _, } => { adjust(vector); } Expression::Load { ref mut pointer } => { adjust(pointer); } Expression::ImageSample { ref mut image, ref mut sampler, ref mut coordinate, ref mut array_index, ref mut offset, ref mut level, ref mut depth_ref, gather: _, clamp_to_edge: _, } => { adjust(image); adjust(sampler); adjust(coordinate); if let Some(e) = array_index.as_mut() { adjust(e); } if let Some(e) = offset.as_mut() { adjust(e); } match *level { crate::SampleLevel::Exact(ref mut expr) | crate::SampleLevel::Bias(ref mut expr) => { adjust(expr); } crate::SampleLevel::Gradient { ref mut x, ref mut y, } => { adjust(x); adjust(y); } _ => {} } if let Some(e) = depth_ref.as_mut() { adjust(e); } } Expression::ImageLoad { ref mut image, ref mut coordinate, ref mut array_index, ref mut sample, ref mut level, } => { adjust(image); adjust(coordinate); if let Some(e) = array_index.as_mut() { adjust(e); } if let Some(e) = sample.as_mut() { adjust(e); } if let Some(e) = level.as_mut() { adjust(e); } } Expression::ImageQuery { ref mut image, ref mut query, } => { adjust(image); match *query { crate::ImageQuery::Size { ref mut level } => { if let Some(e) = level.as_mut() { adjust(e); } } crate::ImageQuery::NumLevels | crate::ImageQuery::NumLayers | crate::ImageQuery::NumSamples => {} } } Expression::Unary { ref mut expr, op: _, } => { adjust(expr); } Expression::Binary { ref mut left, ref mut right, op: _, } => { adjust(left); adjust(right); } Expression::Select { ref mut condition, ref mut accept, ref mut reject, } => { adjust(condition); adjust(accept); adjust(reject); } Expression::Derivative { ref mut expr, axis: _, ctrl: _, } => { adjust(expr); } Expression::Relational { ref mut argument, fun: _, } => { adjust(argument); } Expression::Math { ref mut arg, ref mut arg1, ref mut arg2, ref mut arg3, fun: _, } => { adjust(arg); if let Some(e) = arg1.as_mut() { adjust(e); } if let Some(e) = arg2.as_mut() { adjust(e); } if let Some(e) = arg3.as_mut() { adjust(e); } } Expression::As { ref mut expr, kind: _, convert: _, } => { adjust(expr); } Expression::ArrayLength(ref mut expr) => { adjust(expr); } Expression::RayQueryGetIntersection { ref mut query, committed: _, } => { adjust(query); } Expression::Literal(_) | Expression::FunctionArgument(_) | Expression::GlobalVariable(_) | Expression::LocalVariable(_) | Expression::CallResult(_) | Expression::RayQueryProceedResult | Expression::Constant(_) | Expression::Override(_) | Expression::ZeroValue(_) | Expression::AtomicResult { ty: _, comparison: _, } | Expression::WorkGroupUniformLoadResult { ty: _ } | Expression::SubgroupBallotResult | Expression::SubgroupOperationResult { .. } => {} Expression::RayQueryVertexPositions { ref mut query, committed: _, } => { adjust(query); } Expression::CooperativeLoad { ref mut data, .. } => { adjust(&mut data.pointer); adjust(&mut data.stride); } Expression::CooperativeMultiplyAdd { ref mut a, ref mut b, ref mut c, } => { adjust(a); adjust(b); adjust(c); } } } /// Replace every expression handle in `block` with its counterpart /// given by `new_pos`. fn adjust_block(new_pos: &HandleVec>, block: &mut Block) { for stmt in block.iter_mut() { adjust_stmt(new_pos, stmt); } } /// Replace every expression handle in `stmt` with its counterpart /// given by `new_pos`. fn adjust_stmt(new_pos: &HandleVec>, stmt: &mut Statement) { let adjust = |expr: &mut Handle| { *expr = new_pos[*expr]; }; match *stmt { Statement::Emit(ref mut range) => { if let Some((mut first, mut last)) = range.first_and_last() { adjust(&mut first); adjust(&mut last); *range = Range::new_from_bounds(first, last); } } Statement::Block(ref mut block) => { adjust_block(new_pos, block); } Statement::If { ref mut condition, ref mut accept, ref mut reject, } => { adjust(condition); adjust_block(new_pos, accept); adjust_block(new_pos, reject); } Statement::Switch { ref mut selector, ref mut cases, } => { adjust(selector); for case in cases.iter_mut() { adjust_block(new_pos, &mut case.body); } } Statement::Loop { ref mut body, ref mut continuing, ref mut break_if, } => { adjust_block(new_pos, body); adjust_block(new_pos, continuing); if let Some(e) = break_if.as_mut() { adjust(e); } } Statement::Return { ref mut value } => { if let Some(e) = value.as_mut() { adjust(e); } } Statement::Store { ref mut pointer, ref mut value, } => { adjust(pointer); adjust(value); } Statement::ImageStore { ref mut image, ref mut coordinate, ref mut array_index, ref mut value, } => { adjust(image); adjust(coordinate); if let Some(e) = array_index.as_mut() { adjust(e); } adjust(value); } Statement::Atomic { ref mut pointer, ref mut value, ref mut result, ref mut fun, } => { adjust(pointer); adjust(value); if let Some(ref mut result) = *result { adjust(result); } match *fun { crate::AtomicFunction::Exchange { compare: Some(ref mut compare), } => { adjust(compare); } crate::AtomicFunction::Add | crate::AtomicFunction::Subtract | crate::AtomicFunction::And | crate::AtomicFunction::ExclusiveOr | crate::AtomicFunction::InclusiveOr | crate::AtomicFunction::Min | crate::AtomicFunction::Max | crate::AtomicFunction::Exchange { compare: None } => {} } } Statement::ImageAtomic { ref mut image, ref mut coordinate, ref mut array_index, fun: _, ref mut value, } => { adjust(image); adjust(coordinate); if let Some(ref mut array_index) = *array_index { adjust(array_index); } adjust(value); } Statement::WorkGroupUniformLoad { ref mut pointer, ref mut result, } => { adjust(pointer); adjust(result); } Statement::SubgroupBallot { ref mut result, ref mut predicate, } => { if let Some(ref mut predicate) = *predicate { adjust(predicate); } adjust(result); } Statement::SubgroupCollectiveOperation { ref mut argument, ref mut result, .. } => { adjust(argument); adjust(result); } Statement::SubgroupGather { ref mut mode, ref mut argument, ref mut result, } => { match *mode { crate::GatherMode::BroadcastFirst => {} crate::GatherMode::Broadcast(ref mut index) | crate::GatherMode::Shuffle(ref mut index) | crate::GatherMode::ShuffleDown(ref mut index) | crate::GatherMode::ShuffleUp(ref mut index) | crate::GatherMode::ShuffleXor(ref mut index) | crate::GatherMode::QuadBroadcast(ref mut index) => { adjust(index); } crate::GatherMode::QuadSwap(_) => {} } adjust(argument); adjust(result) } Statement::Call { ref mut arguments, ref mut result, function: _, } => { for argument in arguments.iter_mut() { adjust(argument); } if let Some(e) = result.as_mut() { adjust(e); } } Statement::RayQuery { ref mut query, ref mut fun, } => { adjust(query); match *fun { crate::RayQueryFunction::Initialize { ref mut acceleration_structure, ref mut descriptor, } => { adjust(acceleration_structure); adjust(descriptor); } crate::RayQueryFunction::Proceed { ref mut result } => { adjust(result); } crate::RayQueryFunction::GenerateIntersection { ref mut hit_t } => { adjust(hit_t); } crate::RayQueryFunction::ConfirmIntersection => {} crate::RayQueryFunction::Terminate => {} } } Statement::CooperativeStore { ref mut target, ref mut data, } => { adjust(target); adjust(&mut data.pointer); adjust(&mut data.stride); } Statement::RayPipelineFunction(ref mut func) => match *func { crate::RayPipelineFunction::TraceRay { ref mut acceleration_structure, ref mut descriptor, ref mut payload, } => { adjust(acceleration_structure); adjust(descriptor); adjust(payload); } }, Statement::Break | Statement::Continue | Statement::Kill | Statement::ControlBarrier(_) | Statement::MemoryBarrier(_) => {} } } /// Adjust [`Emit`] statements in `block` to skip [`needs_pre_emit`] expressions we have introduced. /// /// According to validation, [`Emit`] statements must not cover any expressions /// for which [`Expression::needs_pre_emit`] returns true. All expressions built /// by successful constant evaluation fall into that category, meaning that /// `process_function` will usually rewrite [`Override`] expressions and those /// that use their values into pre-emitted expressions, leaving any [`Emit`] /// statements that cover them invalid. /// /// This function rewrites all [`Emit`] statements into zero or more new /// [`Emit`] statements covering only those expressions in the original range /// that are not pre-emitted. /// /// [`Emit`]: Statement::Emit /// [`needs_pre_emit`]: Expression::needs_pre_emit /// [`Override`]: Expression::Override fn filter_emits_in_block(block: &mut Block, expressions: &Arena) { let original = mem::replace(block, Block::with_capacity(block.len())); for (stmt, span) in original.span_into_iter() { match stmt { Statement::Emit(range) => { let mut current = None; for expr_h in range { if expressions[expr_h].needs_pre_emit() { if let Some((first, last)) = current { block.push(Statement::Emit(Range::new_from_bounds(first, last)), span); } current = None; } else if let Some((_, ref mut last)) = current { *last = expr_h; } else { current = Some((expr_h, expr_h)); } } if let Some((first, last)) = current { block.push(Statement::Emit(Range::new_from_bounds(first, last)), span); } } Statement::Block(mut child) => { filter_emits_in_block(&mut child, expressions); block.push(Statement::Block(child), span); } Statement::If { condition, mut accept, mut reject, } => { filter_emits_in_block(&mut accept, expressions); filter_emits_in_block(&mut reject, expressions); block.push( Statement::If { condition, accept, reject, }, span, ); } Statement::Switch { selector, mut cases, } => { for case in &mut cases { filter_emits_in_block(&mut case.body, expressions); } block.push(Statement::Switch { selector, cases }, span); } Statement::Loop { mut body, mut continuing, break_if, } => { filter_emits_in_block(&mut body, expressions); filter_emits_in_block(&mut continuing, expressions); block.push( Statement::Loop { body, continuing, break_if, }, span, ); } stmt => block.push(stmt.clone(), span), } } } fn map_value_to_literal(value: f64, scalar: Scalar) -> Result { // note that in rust 0.0 == -0.0 match scalar { Scalar::BOOL => { // https://webidl.spec.whatwg.org/#js-boolean let value = value != 0.0 && !value.is_nan(); Ok(Literal::Bool(value)) } Scalar::I32 => { // https://webidl.spec.whatwg.org/#js-long if !value.is_finite() { return Err(PipelineConstantError::SrcNeedsToBeFinite); } let value = value.trunc(); if value < f64::from(i32::MIN) || value > f64::from(i32::MAX) { return Err(PipelineConstantError::DstRangeTooSmall); } let value = value as i32; Ok(Literal::I32(value)) } Scalar::U32 => { // https://webidl.spec.whatwg.org/#js-unsigned-long if !value.is_finite() { return Err(PipelineConstantError::SrcNeedsToBeFinite); } let value = value.trunc(); if value < f64::from(u32::MIN) || value > f64::from(u32::MAX) { return Err(PipelineConstantError::DstRangeTooSmall); } let value = value as u32; Ok(Literal::U32(value)) } Scalar::F16 => { // https://webidl.spec.whatwg.org/#js-float if !value.is_finite() { return Err(PipelineConstantError::SrcNeedsToBeFinite); } let value = half::f16::from_f64(value); if !value.is_finite() { return Err(PipelineConstantError::DstRangeTooSmall); } Ok(Literal::F16(value)) } Scalar::F32 => { // https://webidl.spec.whatwg.org/#js-float if !value.is_finite() { return Err(PipelineConstantError::SrcNeedsToBeFinite); } let value = value as f32; if !value.is_finite() { return Err(PipelineConstantError::DstRangeTooSmall); } Ok(Literal::F32(value)) } Scalar::F64 => { // https://webidl.spec.whatwg.org/#js-double if !value.is_finite() { return Err(PipelineConstantError::SrcNeedsToBeFinite); } Ok(Literal::F64(value)) } Scalar::ABSTRACT_FLOAT | Scalar::ABSTRACT_INT => { unreachable!("abstract values should not be validated out of override processing") } _ => unreachable!("unrecognized scalar type for override"), } } #[test] fn test_map_value_to_literal() { let bool_test_cases = [ (0.0, false), (-0.0, false), (f64::NAN, false), (1.0, true), (f64::INFINITY, true), (f64::NEG_INFINITY, true), ]; for (value, out) in bool_test_cases { let res = Ok(Literal::Bool(out)); assert_eq!(map_value_to_literal(value, Scalar::BOOL), res); } for scalar in [Scalar::I32, Scalar::U32, Scalar::F32, Scalar::F64] { for value in [f64::NAN, f64::INFINITY, f64::NEG_INFINITY] { let res = Err(PipelineConstantError::SrcNeedsToBeFinite); assert_eq!(map_value_to_literal(value, scalar), res); } } // i32 assert_eq!( map_value_to_literal(f64::from(i32::MIN), Scalar::I32), Ok(Literal::I32(i32::MIN)) ); assert_eq!( map_value_to_literal(f64::from(i32::MAX), Scalar::I32), Ok(Literal::I32(i32::MAX)) ); assert_eq!( map_value_to_literal(f64::from(i32::MIN) - 1.0, Scalar::I32), Err(PipelineConstantError::DstRangeTooSmall) ); assert_eq!( map_value_to_literal(f64::from(i32::MAX) + 1.0, Scalar::I32), Err(PipelineConstantError::DstRangeTooSmall) ); // u32 assert_eq!( map_value_to_literal(f64::from(u32::MIN), Scalar::U32), Ok(Literal::U32(u32::MIN)) ); assert_eq!( map_value_to_literal(f64::from(u32::MAX), Scalar::U32), Ok(Literal::U32(u32::MAX)) ); assert_eq!( map_value_to_literal(f64::from(u32::MIN) - 1.0, Scalar::U32), Err(PipelineConstantError::DstRangeTooSmall) ); assert_eq!( map_value_to_literal(f64::from(u32::MAX) + 1.0, Scalar::U32), Err(PipelineConstantError::DstRangeTooSmall) ); // f32 assert_eq!( map_value_to_literal(f64::from(f32::MIN), Scalar::F32), Ok(Literal::F32(f32::MIN)) ); assert_eq!( map_value_to_literal(f64::from(f32::MAX), Scalar::F32), Ok(Literal::F32(f32::MAX)) ); assert_eq!( map_value_to_literal(-f64::from_bits(0x47efffffefffffff), Scalar::F32), Ok(Literal::F32(f32::MIN)) ); assert_eq!( map_value_to_literal(f64::from_bits(0x47efffffefffffff), Scalar::F32), Ok(Literal::F32(f32::MAX)) ); assert_eq!( map_value_to_literal(-f64::from_bits(0x47effffff0000000), Scalar::F32), Err(PipelineConstantError::DstRangeTooSmall) ); assert_eq!( map_value_to_literal(f64::from_bits(0x47effffff0000000), Scalar::F32), Err(PipelineConstantError::DstRangeTooSmall) ); // f64 assert_eq!( map_value_to_literal(f64::MIN, Scalar::F64), Ok(Literal::F64(f64::MIN)) ); assert_eq!( map_value_to_literal(f64::MAX, Scalar::F64), Ok(Literal::F64(f64::MAX)) ); } naga-29.0.3/src/back/spv/block.rs000064400000000000000000005571111046102023000145760ustar 00000000000000/*! Implementations for `BlockContext` methods. */ use alloc::vec::Vec; use arrayvec::ArrayVec; use spirv::Word; use super::{ helpers::map_storage_class, index::BoundsCheckResult, selection::Selection, Block, BlockContext, Dimension, Error, IdGenerator, Instruction, LocalType, LookupType, NumericType, ResultMember, WrappedFunction, Writer, WriterFlags, }; use crate::{ arena::Handle, back::spv::helpers::is_uniform_matcx2_struct_member_access, proc::index::GuardedIndex, Statement, }; fn get_dimension(type_inner: &crate::TypeInner) -> Dimension { match *type_inner { crate::TypeInner::Scalar(_) => Dimension::Scalar, crate::TypeInner::Vector { .. } => Dimension::Vector, crate::TypeInner::Matrix { .. } => Dimension::Matrix, crate::TypeInner::CooperativeMatrix { .. } => Dimension::CooperativeMatrix, _ => unreachable!(), } } /// How to derive the type of `OpAccessChain` instructions from Naga IR. /// /// Most of the time, we compile Naga IR to SPIR-V instructions whose result /// types are simply the direct SPIR-V analog of the Naga IR's. But in some /// cases, the Naga IR and SPIR-V types need to diverge. /// /// This enum specifies how [`BlockContext::write_access_chain`] should /// choose a SPIR-V result type for the `OpAccessChain` it generates, based on /// the type of the given Naga IR [`Expression`] it's generating code for. /// /// [`Expression`]: crate::Expression #[derive(Copy, Clone)] enum AccessTypeAdjustment { /// No adjustment needed: the SPIR-V type should be the direct /// analog of the Naga IR expression type. /// /// For most access chains, this is the right thing: the Naga IR access /// expression produces a [`Pointer`] to the element / component, and the /// SPIR-V `OpAccessChain` instruction does the same. /// /// [`Pointer`]: crate::TypeInner::Pointer None, /// The SPIR-V type should be an `OpPointer` to the direct analog of the /// Naga IR expression's type. /// /// This is necessary for indexing binding arrays in the [`Handle`] address /// space: /// /// - In Naga IR, referencing a binding array [`GlobalVariable`] in the /// [`Handle`] address space produces a value of type [`BindingArray`], /// not a pointer to such. And [`Access`] and [`AccessIndex`] expressions /// operate on handle binding arrays by value, and produce handle values, /// not pointers. /// /// - In SPIR-V, a binding array `OpVariable` produces a pointer to an /// array, and `OpAccessChain` instructions operate on pointers, /// regardless of whether the elements are opaque types or not. /// /// See also the documentation for [`BindingArray`]. /// /// [`Handle`]: crate::AddressSpace::Handle /// [`GlobalVariable`]: crate::GlobalVariable /// [`BindingArray`]: crate::TypeInner::BindingArray /// [`Access`]: crate::Expression::Access /// [`AccessIndex`]: crate::Expression::AccessIndex IntroducePointer(spirv::StorageClass), /// The SPIR-V type should be an `OpPointer` to the std140 layout /// compatible variant of the Naga IR expression's base type. /// /// This is used when accessing a type through an [`AddressSpace::Uniform`] /// pointer in cases where the original type is incompatible with std140 /// layout requirements and we have therefore declared the uniform to be of /// an alternative std140 compliant type. /// /// [`AddressSpace::Uniform`]: crate::AddressSpace::Uniform UseStd140CompatType, } /// The results of emitting code for a left-hand-side expression. /// /// On success, `write_access_chain` returns one of these. enum ExpressionPointer { /// The pointer to the expression's value is available, as the value of the /// expression with the given id. Ready { pointer_id: Word }, /// The access expression must be conditional on the value of `condition`, a boolean /// expression that is true if all indices are in bounds. If `condition` is true, then /// `access` is an `OpAccessChain` instruction that will compute a pointer to the /// expression's value. If `condition` is false, then executing `access` would be /// undefined behavior. Conditional { condition: Word, access: Instruction, }, } /// The termination statement to be added to the end of the block enum BlockExit { /// Generates an OpReturn (void return) Return, /// Generates an OpBranch to the specified block Branch { /// The branch target block target: Word, }, /// Translates a loop `break if` into an `OpBranchConditional` to the /// merge block if true (the merge block is passed through [`LoopContext::break_id`] /// or else to the loop header (passed through [`preamble_id`]) /// /// [`preamble_id`]: Self::BreakIf::preamble_id BreakIf { /// The condition of the `break if` condition: Handle, /// The loop header block id preamble_id: Word, }, } /// What code generation did with a provided [`BlockExit`] value. /// /// A function that accepts a [`BlockExit`] argument should return a value of /// this type, to indicate whether the code it generated ended up using the /// provided exit, or ignored it and did a non-local exit of some other kind /// (say, [`Break`] or [`Continue`]). Some callers must use this information to /// decide whether to generate the target block at all. /// /// [`Break`]: Statement::Break /// [`Continue`]: Statement::Continue #[must_use] enum BlockExitDisposition { /// The generated code used the provided `BlockExit` value. If it included a /// block label, the caller should be sure to actually emit the block it /// refers to. Used, /// The generated code did not use the provided `BlockExit` value. If it /// included a block label, the caller should not bother to actually emit /// the block it refers to, unless it knows the block is needed for /// something else. Discarded, } #[derive(Clone, Copy, Default)] struct LoopContext { continuing_id: Option, break_id: Option, } #[derive(Debug)] pub(crate) struct DebugInfoInner<'a> { pub source_code: &'a str, pub source_file_id: Word, } impl Writer { // Flip Y coordinate to adjust for coordinate space difference // between SPIR-V and our IR. // The `position_id` argument is a pointer to a `vecN`, // whose `y` component we will negate. fn write_epilogue_position_y_flip( &mut self, position_id: Word, body: &mut Vec, ) -> Result<(), Error> { let float_ptr_type_id = self.get_f32_pointer_type_id(spirv::StorageClass::Output); let index_y_id = self.get_index_constant(1); let access_id = self.id_gen.next(); body.push(Instruction::access_chain( float_ptr_type_id, access_id, position_id, &[index_y_id], )); let float_type_id = self.get_f32_type_id(); let load_id = self.id_gen.next(); body.push(Instruction::load(float_type_id, load_id, access_id, None)); let neg_id = self.id_gen.next(); body.push(Instruction::unary( spirv::Op::FNegate, float_type_id, neg_id, load_id, )); body.push(Instruction::store(access_id, neg_id, None)); Ok(()) } // Clamp fragment depth between 0 and 1. fn write_epilogue_frag_depth_clamp( &mut self, frag_depth_id: Word, body: &mut Vec, ) -> Result<(), Error> { let float_type_id = self.get_f32_type_id(); let zero_scalar_id = self.get_constant_scalar(crate::Literal::F32(0.0)); let one_scalar_id = self.get_constant_scalar(crate::Literal::F32(1.0)); let original_id = self.id_gen.next(); body.push(Instruction::load( float_type_id, original_id, frag_depth_id, None, )); let clamp_id = self.id_gen.next(); body.push(Instruction::ext_inst_gl_op( self.gl450_ext_inst_id, spirv::GlslStd450Op::FClamp, float_type_id, clamp_id, &[original_id, zero_scalar_id, one_scalar_id], )); body.push(Instruction::store(frag_depth_id, clamp_id, None)); Ok(()) } fn write_entry_point_return( &mut self, value_id: Word, ir_result: &crate::FunctionResult, result_members: &[ResultMember], body: &mut Vec, ) -> Result { for (index, res_member) in result_members.iter().enumerate() { // This isn't a real builtin, and is handled elsewhere if res_member.built_in == Some(crate::BuiltIn::MeshTaskSize) { return Ok(Instruction::return_value(value_id)); } let member_value_id = match ir_result.binding { Some(_) => value_id, None => { let member_value_id = self.id_gen.next(); body.push(Instruction::composite_extract( res_member.type_id, member_value_id, value_id, &[index as u32], )); member_value_id } }; self.store_io_with_f16_polyfill(body, res_member.id, member_value_id); match res_member.built_in { Some(crate::BuiltIn::Position { .. }) if self.flags.contains(WriterFlags::ADJUST_COORDINATE_SPACE) => { self.write_epilogue_position_y_flip(res_member.id, body)?; } Some(crate::BuiltIn::FragDepth) if self.flags.contains(WriterFlags::CLAMP_FRAG_DEPTH) => { self.write_epilogue_frag_depth_clamp(res_member.id, body)?; } _ => {} } } Ok(Instruction::return_void()) } } impl BlockContext<'_> { /// Generates code to ensure that a loop is bounded. Should be called immediately /// after adding the OpLoopMerge instruction to `block`. This function will /// [`consume()`](crate::back::spv::Function::consume) `block` and append its /// instructions to a new [`Block`], which will be returned to the caller for it to /// consumed prior to writing the loop body. /// /// Additionally this function will populate [`force_loop_bounding_vars`](crate::back::spv::Function::force_loop_bounding_vars), /// ensuring that [`Function::to_words()`](crate::back::spv::Function::to_words) will /// declare the required variables. /// /// See [`crate::back::msl::Writer::gen_force_bounded_loop_statements`] for details /// of why this is required. fn write_force_bounded_loop_instructions(&mut self, mut block: Block, merge_id: Word) -> Block { let uint_type_id = self.writer.get_u32_type_id(); let uint2_type_id = self.writer.get_vec2u_type_id(); let uint2_ptr_type_id = self .writer .get_vec2u_pointer_type_id(spirv::StorageClass::Function); let bool_type_id = self.writer.get_bool_type_id(); let bool2_type_id = self.writer.get_vec2_bool_type_id(); let zero_uint_const_id = self.writer.get_constant_scalar(crate::Literal::U32(0)); let zero_uint2_const_id = self.writer.get_constant_composite( LookupType::Local(LocalType::Numeric(NumericType::Vector { size: crate::VectorSize::Bi, scalar: crate::Scalar::U32, })), &[zero_uint_const_id, zero_uint_const_id], ); let one_uint_const_id = self.writer.get_constant_scalar(crate::Literal::U32(1)); let max_uint_const_id = self .writer .get_constant_scalar(crate::Literal::U32(u32::MAX)); let max_uint2_const_id = self.writer.get_constant_composite( LookupType::Local(LocalType::Numeric(NumericType::Vector { size: crate::VectorSize::Bi, scalar: crate::Scalar::U32, })), &[max_uint_const_id, max_uint_const_id], ); let loop_counter_var_id = self.gen_id(); if self.writer.flags.contains(WriterFlags::DEBUG) { self.writer .debugs .push(Instruction::name(loop_counter_var_id, "loop_bound")); } let var = super::LocalVariable { id: loop_counter_var_id, instruction: Instruction::variable( uint2_ptr_type_id, loop_counter_var_id, spirv::StorageClass::Function, Some(max_uint2_const_id), ), }; self.function.force_loop_bounding_vars.push(var); let break_if_block = self.gen_id(); self.function .consume(block, Instruction::branch(break_if_block)); block = Block::new(break_if_block); // Load the current loop counter value from its variable. We use a vec2 to // simulate a 64-bit counter. let load_id = self.gen_id(); block.body.push(Instruction::load( uint2_type_id, load_id, loop_counter_var_id, None, )); // If both the high and low u32s have reached 0 then break. ie // if (all(eq(loop_counter, vec2(0)))) { break; } let eq_id = self.gen_id(); block.body.push(Instruction::binary( spirv::Op::IEqual, bool2_type_id, eq_id, zero_uint2_const_id, load_id, )); let all_eq_id = self.gen_id(); block.body.push(Instruction::relational( spirv::Op::All, bool_type_id, all_eq_id, eq_id, )); let inc_counter_block_id = self.gen_id(); block.body.push(Instruction::selection_merge( inc_counter_block_id, spirv::SelectionControl::empty(), )); self.function.consume( block, Instruction::branch_conditional(all_eq_id, merge_id, inc_counter_block_id), ); block = Block::new(inc_counter_block_id); // To simulate a 64-bit counter we always decrement the low u32, and decrement // the high u32 when the low u32 overflows. ie // counter -= vec2(select(0u, 1u, counter.y == 0), 1u); // Count down from u32::MAX rather than up from 0 to avoid hang on // certain Intel drivers. See . let low_id = self.gen_id(); block.body.push(Instruction::composite_extract( uint_type_id, low_id, load_id, &[1], )); let low_overflow_id = self.gen_id(); block.body.push(Instruction::binary( spirv::Op::IEqual, bool_type_id, low_overflow_id, low_id, zero_uint_const_id, )); let carry_bit_id = self.gen_id(); block.body.push(Instruction::select( uint_type_id, carry_bit_id, low_overflow_id, one_uint_const_id, zero_uint_const_id, )); let decrement_id = self.gen_id(); block.body.push(Instruction::composite_construct( uint2_type_id, decrement_id, &[carry_bit_id, one_uint_const_id], )); let result_id = self.gen_id(); block.body.push(Instruction::binary( spirv::Op::ISub, uint2_type_id, result_id, load_id, decrement_id, )); block .body .push(Instruction::store(loop_counter_var_id, result_id, None)); block } /// If `pointer` refers to an access chain that contains a dynamic indexing /// of a two-row matrix in the [`Uniform`] address space, write code to /// access the value returning the ID of the result. Else return None. /// /// Two-row matrices in the uniform address space will have been declared /// using a alternative std140 layout compatible type, where each column is /// a member of a containing struct. As a result, SPIR-V is unable to access /// its columns with a non-constant index. To work around this limitation /// this function will call [`Self::write_checked_load()`] to load the /// matrix itself, which handles conversion from the std140 compatible type /// to the real matrix type. It then calls a [`wrapper function`] to obtain /// the correct column from the matrix, and possibly extracts a component /// from the vector too. /// /// [`Uniform`]: crate::AddressSpace::Uniform /// [`wrapper function`]: super::Writer::write_wrapped_matcx2_get_column fn maybe_write_uniform_matcx2_dynamic_access( &mut self, pointer: Handle, block: &mut Block, ) -> Result, Error> { // If this access chain contains a dynamic matrix access, `pointer` is // either a pointer to a vector (the column) or a scalar (a component // within the column). In either case grab the pointer to the column, // and remember the component index if there is one. If `pointer` // points to any other type we're not interested. let (column_pointer, component_index) = match self.fun_info[pointer] .ty .inner_with(&self.ir_module.types) .pointer_base_type() { Some(resolution) => match *resolution.inner_with(&self.ir_module.types) { crate::TypeInner::Scalar(_) => match self.ir_function.expressions[pointer] { crate::Expression::Access { base, index } => { (base, Some(GuardedIndex::Expression(index))) } crate::Expression::AccessIndex { base, index } => { (base, Some(GuardedIndex::Known(index))) } _ => return Ok(None), }, crate::TypeInner::Vector { .. } => (pointer, None), _ => return Ok(None), }, None => return Ok(None), }; // Ensure the column is accessed with a dynamic index (i.e. // `Expression::Access`), and grab the pointer to the matrix. let crate::Expression::Access { base: matrix_pointer, index: column_index, } = self.ir_function.expressions[column_pointer] else { return Ok(None); }; // Ensure the matrix pointer is in the uniform address space. let crate::TypeInner::Pointer { base: matrix_pointer_base_type, space: crate::AddressSpace::Uniform, } = *self.fun_info[matrix_pointer] .ty .inner_with(&self.ir_module.types) else { return Ok(None); }; // Ensure the matrix pointer actually points to a Cx2 matrix. let crate::TypeInner::Matrix { columns, rows: rows @ crate::VectorSize::Bi, scalar, } = self.ir_module.types[matrix_pointer_base_type].inner else { return Ok(None); }; let matrix_type_id = self.get_numeric_type_id(NumericType::Matrix { columns, rows, scalar, }); let column_type_id = self.get_numeric_type_id(NumericType::Vector { size: rows, scalar }); let component_type_id = self.get_numeric_type_id(NumericType::Scalar(scalar)); let get_column_function_id = self.writer.wrapped_functions [&WrappedFunction::MatCx2GetColumn { r#type: matrix_pointer_base_type, }]; let matrix_load_id = self.write_checked_load( matrix_pointer, block, AccessTypeAdjustment::None, matrix_type_id, )?; // Naga IR allows the index to be either an I32 or U32 but our wrapper // function expects a U32 argument, so convert it if required. let column_index_id = match *self.fun_info[column_index] .ty .inner_with(&self.ir_module.types) { crate::TypeInner::Scalar(crate::Scalar { kind: crate::ScalarKind::Uint, .. }) => self.cached[column_index], crate::TypeInner::Scalar(crate::Scalar { kind: crate::ScalarKind::Sint, .. }) => { let cast_id = self.gen_id(); let u32_type_id = self.writer.get_u32_type_id(); block.body.push(Instruction::unary( spirv::Op::Bitcast, u32_type_id, cast_id, self.cached[column_index], )); cast_id } _ => return Err(Error::Validation("Matrix access index must be u32 or i32")), }; let column_id = self.gen_id(); block.body.push(Instruction::function_call( column_type_id, column_id, get_column_function_id, &[matrix_load_id, column_index_id], )); let result_id = match component_index { Some(index) => self.write_vector_access( component_type_id, column_pointer, Some(column_id), index, block, )?, None => column_id, }; Ok(Some(result_id)) } /// If `pointer` refers to two-row matrix that is a member of a struct in /// the [`Uniform`] address space, write code to load the matrix returning /// the ID of the result. Else return None. /// /// Two-row matrices that are struct members in the uniform address space /// will have been decomposed such that the struct contains a separate /// vector member for each column of the matrix. This function will load /// each column separately from the containing struct, then composite them /// into the real matrix type. /// /// [`Uniform`]: crate::AddressSpace::Uniform fn maybe_write_load_uniform_matcx2_struct_member( &mut self, pointer: Handle, block: &mut Block, ) -> Result, Error> { // Check this is a uniform address space pointer to a two-row matrix. let crate::TypeInner::Pointer { base: matrix_type, space: space @ crate::AddressSpace::Uniform, } = *self.fun_info[pointer].ty.inner_with(&self.ir_module.types) else { return Ok(None); }; let crate::TypeInner::Matrix { columns, rows: rows @ crate::VectorSize::Bi, scalar, } = self.ir_module.types[matrix_type].inner else { return Ok(None); }; // Check this is a struct member. Note struct members can only be // accessed with `AccessIndex`. let crate::Expression::AccessIndex { base: struct_pointer, index: member_index, } = self.ir_function.expressions[pointer] else { return Ok(None); }; let crate::TypeInner::Pointer { base: struct_type, .. } = *self.fun_info[struct_pointer] .ty .inner_with(&self.ir_module.types) else { return Ok(None); }; let crate::TypeInner::Struct { .. } = self.ir_module.types[struct_type].inner else { return Ok(None); }; let matrix_type_id = self.get_numeric_type_id(NumericType::Matrix { columns, rows, scalar, }); let column_type_id = self.get_numeric_type_id(NumericType::Vector { size: rows, scalar }); let column_pointer_type_id = self.get_pointer_type_id(column_type_id, map_storage_class(space)); let column0_index = self.writer.std140_compat_uniform_types[&struct_type].member_indices [member_index as usize]; let column_indices = (0..columns as u32) .map(|c| self.get_index_constant(column0_index + c)) .collect::>(); // Load each column from the struct, then composite into the real // matrix type. let load_mat_from_struct = |struct_pointer_id: Word, id_gen: &mut IdGenerator, block: &mut Block| -> Word { let mut column_ids: ArrayVec = ArrayVec::new(); for index in &column_indices { let column_pointer_id = id_gen.next(); block.body.push(Instruction::access_chain( column_pointer_type_id, column_pointer_id, struct_pointer_id, &[*index], )); let column_id = id_gen.next(); block.body.push(Instruction::load( column_type_id, column_id, column_pointer_id, None, )); column_ids.push(column_id); } let result_id = id_gen.next(); block.body.push(Instruction::composite_construct( matrix_type_id, result_id, &column_ids, )); result_id }; let result_id = match self.write_access_chain( struct_pointer, block, AccessTypeAdjustment::UseStd140CompatType, )? { ExpressionPointer::Ready { pointer_id } => { load_mat_from_struct(pointer_id, &mut self.writer.id_gen, block) } ExpressionPointer::Conditional { condition, access } => self .write_conditional_indexed_load( matrix_type_id, condition, block, |id_gen, block| { let pointer_id = access.result_id.unwrap(); block.body.push(access); load_mat_from_struct(pointer_id, id_gen, block) }, ), }; Ok(Some(result_id)) } /// Cache an expression for a value. pub(super) fn cache_expression_value( &mut self, expr_handle: Handle, block: &mut Block, ) -> Result<(), Error> { let is_named_expression = self .ir_function .named_expressions .contains_key(&expr_handle); if self.fun_info[expr_handle].ref_count == 0 && !is_named_expression { return Ok(()); } let result_type_id = self.get_expression_type_id(&self.fun_info[expr_handle].ty); let id = match self.ir_function.expressions[expr_handle] { crate::Expression::Literal(literal) => self.writer.get_constant_scalar(literal), crate::Expression::Constant(handle) => { let init = self.ir_module.constants[handle].init; self.writer.constant_ids[init] } crate::Expression::Override(_) => return Err(Error::Override), crate::Expression::ZeroValue(_) => self.writer.get_constant_null(result_type_id), crate::Expression::Compose { ty, ref components } => { self.temp_list.clear(); if self.expression_constness.is_const(expr_handle) { self.temp_list.extend( crate::proc::flatten_compose( ty, components, &self.ir_function.expressions, &self.ir_module.types, ) .map(|component| self.cached[component]), ); self.writer .get_constant_composite(LookupType::Handle(ty), &self.temp_list) } else { self.temp_list .extend(components.iter().map(|&component| self.cached[component])); let id = self.gen_id(); block.body.push(Instruction::composite_construct( result_type_id, id, &self.temp_list, )); id } } crate::Expression::Splat { size, value } => { let value_id = self.cached[value]; let components = &[value_id; 4][..size as usize]; if self.expression_constness.is_const(expr_handle) { let ty = self .writer .get_expression_lookup_type(&self.fun_info[expr_handle].ty); self.writer.get_constant_composite(ty, components) } else { let id = self.gen_id(); block.body.push(Instruction::composite_construct( result_type_id, id, components, )); id } } crate::Expression::Access { base, index } => { let base_ty_inner = self.fun_info[base].ty.inner_with(&self.ir_module.types); match *base_ty_inner { crate::TypeInner::Pointer { .. } | crate::TypeInner::ValuePointer { .. } => { // When we have a chain of `Access` and `AccessIndex` expressions // operating on pointers, we want to generate a single // `OpAccessChain` instruction for the whole chain. Put off // generating any code for this until we find the `Expression` // that actually dereferences the pointer. 0 } _ if self.function.spilled_accesses.contains(base) => { // As far as Naga IR is concerned, this expression does not yield // a pointer (we just checked, above), but this backend spilled it // to a temporary variable, so SPIR-V thinks we're accessing it // via a pointer. // Since the base expression was spilled, mark this access to it // as spilled, too. self.function.spilled_accesses.insert(expr_handle); self.maybe_access_spilled_composite(expr_handle, block, result_type_id)? } crate::TypeInner::Vector { .. } => self.write_vector_access( result_type_id, base, None, GuardedIndex::Expression(index), block, )?, crate::TypeInner::Array { .. } | crate::TypeInner::Matrix { .. } => { // See if `index` is known at compile time. match GuardedIndex::from_expression( index, &self.ir_function.expressions, self.ir_module, ) { GuardedIndex::Known(value) => { // If `index` is known and in bounds, we can just use // `OpCompositeExtract`. // // At the moment, validation rejects programs if this // index is out of bounds, so we don't need bounds checks. // However, that rejection is incorrect, since WGSL says // that `let` bindings are not constant expressions // (#6396). So eventually we will need to emulate bounds // checks here. let id = self.gen_id(); let base_id = self.cached[base]; block.body.push(Instruction::composite_extract( result_type_id, id, base_id, &[value], )); id } GuardedIndex::Expression(_) => { // We are subscripting an array or matrix that is not // behind a pointer, using an index computed at runtime. // SPIR-V has no instructions that do this, so the best we // can do is spill the value to a new temporary variable, // at which point we can get a pointer to that and just // use `OpAccessChain` in the usual way. self.spill_to_internal_variable(base, block); // Since the base was spilled, mark this access to it as // spilled, too. self.function.spilled_accesses.insert(expr_handle); self.maybe_access_spilled_composite( expr_handle, block, result_type_id, )? } } } crate::TypeInner::BindingArray { base: binding_type, .. } => { // Only binding arrays in the `Handle` address space will take // this path, since we handled the `Pointer` case above. let result_id = match self.write_access_chain( expr_handle, block, AccessTypeAdjustment::IntroducePointer( spirv::StorageClass::UniformConstant, ), )? { ExpressionPointer::Ready { pointer_id } => pointer_id, ExpressionPointer::Conditional { .. } => { return Err(Error::FeatureNotImplemented( "Texture array out-of-bounds handling", )); } }; let binding_type_id = self.get_handle_type_id(binding_type); let load_id = self.gen_id(); block.body.push(Instruction::load( binding_type_id, load_id, result_id, None, )); // Subsequent image operations require the image/sampler to be decorated as NonUniform // if the image/sampler binding array was accessed with a non-uniform index // see VUID-RuntimeSpirv-NonUniform-06274 if self.fun_info[index].uniformity.non_uniform_result.is_some() { self.writer .decorate_non_uniform_binding_array_access(load_id)?; } load_id } ref other => { log::error!( "Unable to access base {:?} of type {:?}", self.ir_function.expressions[base], other ); return Err(Error::Validation( "only vectors and arrays may be dynamically indexed by value", )); } } } crate::Expression::AccessIndex { base, index } => { match *self.fun_info[base].ty.inner_with(&self.ir_module.types) { crate::TypeInner::Pointer { .. } | crate::TypeInner::ValuePointer { .. } => { // When we have a chain of `Access` and `AccessIndex` expressions // operating on pointers, we want to generate a single // `OpAccessChain` instruction for the whole chain. Put off // generating any code for this until we find the `Expression` // that actually dereferences the pointer. 0 } _ if self.function.spilled_accesses.contains(base) => { // As far as Naga IR is concerned, this expression does not yield // a pointer (we just checked, above), but this backend spilled it // to a temporary variable, so SPIR-V thinks we're accessing it // via a pointer. // Since the base expression was spilled, mark this access to it // as spilled, too. self.function.spilled_accesses.insert(expr_handle); self.maybe_access_spilled_composite(expr_handle, block, result_type_id)? } crate::TypeInner::Vector { .. } | crate::TypeInner::Matrix { .. } | crate::TypeInner::Array { .. } | crate::TypeInner::Struct { .. } => { // We never need bounds checks here: dynamically sized arrays can // only appear behind pointers, and are thus handled by the // `is_intermediate` case above. Everything else's size is // statically known and checked in validation. let id = self.gen_id(); let base_id = self.cached[base]; block.body.push(Instruction::composite_extract( result_type_id, id, base_id, &[index], )); id } crate::TypeInner::BindingArray { base: binding_type, .. } => { // Only binding arrays in the `Handle` address space will take // this path, since we handled the `Pointer` case above. let result_id = match self.write_access_chain( expr_handle, block, AccessTypeAdjustment::IntroducePointer( spirv::StorageClass::UniformConstant, ), )? { ExpressionPointer::Ready { pointer_id } => pointer_id, ExpressionPointer::Conditional { .. } => { return Err(Error::FeatureNotImplemented( "Texture array out-of-bounds handling", )); } }; let binding_type_id = self.get_handle_type_id(binding_type); let load_id = self.gen_id(); block.body.push(Instruction::load( binding_type_id, load_id, result_id, None, )); load_id } ref other => { log::error!("Unable to access index of {other:?}"); return Err(Error::FeatureNotImplemented("access index for type")); } } } crate::Expression::GlobalVariable(handle) => { self.writer.global_variables[handle].access_id } crate::Expression::Swizzle { size, vector, pattern, } => { let vector_id = self.cached[vector]; self.temp_list.clear(); for &sc in pattern[..size as usize].iter() { self.temp_list.push(sc as Word); } let id = self.gen_id(); block.body.push(Instruction::vector_shuffle( result_type_id, id, vector_id, vector_id, &self.temp_list, )); id } crate::Expression::Unary { op, expr } => { let id = self.gen_id(); let expr_id = self.cached[expr]; let expr_ty_inner = self.fun_info[expr].ty.inner_with(&self.ir_module.types); let spirv_op = match op { crate::UnaryOperator::Negate => match expr_ty_inner.scalar_kind() { Some(crate::ScalarKind::Float) => spirv::Op::FNegate, Some(crate::ScalarKind::Sint) => spirv::Op::SNegate, _ => return Err(Error::Validation("Unexpected kind for negation")), }, crate::UnaryOperator::LogicalNot => spirv::Op::LogicalNot, crate::UnaryOperator::BitwiseNot => spirv::Op::Not, }; block .body .push(Instruction::unary(spirv_op, result_type_id, id, expr_id)); id } crate::Expression::Binary { op, left, right } => { let id = self.gen_id(); let left_id = self.cached[left]; let right_id = self.cached[right]; let left_type_id = self.get_expression_type_id(&self.fun_info[left].ty); let right_type_id = self.get_expression_type_id(&self.fun_info[right].ty); if let Some(function_id) = self.writer .wrapped_functions .get(&WrappedFunction::BinaryOp { op, left_type_id, right_type_id, }) { block.body.push(Instruction::function_call( result_type_id, id, *function_id, &[left_id, right_id], )); } else { let left_ty_inner = self.fun_info[left].ty.inner_with(&self.ir_module.types); let right_ty_inner = self.fun_info[right].ty.inner_with(&self.ir_module.types); let left_dimension = get_dimension(left_ty_inner); let right_dimension = get_dimension(right_ty_inner); let mut reverse_operands = false; let spirv_op = match op { crate::BinaryOperator::Add => match *left_ty_inner { crate::TypeInner::Scalar(scalar) | crate::TypeInner::Vector { scalar, .. } => match scalar.kind { crate::ScalarKind::Float => spirv::Op::FAdd, _ => spirv::Op::IAdd, }, crate::TypeInner::Matrix { columns, rows, scalar, } => { //TODO: why not just rely on `Fadd` for matrices? self.write_matrix_matrix_column_op( block, id, result_type_id, left_id, right_id, columns, rows, scalar.width, spirv::Op::FAdd, ); self.cached[expr_handle] = id; return Ok(()); } crate::TypeInner::CooperativeMatrix { .. } => spirv::Op::FAdd, _ => unimplemented!(), }, crate::BinaryOperator::Subtract => match *left_ty_inner { crate::TypeInner::Scalar(scalar) | crate::TypeInner::Vector { scalar, .. } => match scalar.kind { crate::ScalarKind::Float => spirv::Op::FSub, _ => spirv::Op::ISub, }, crate::TypeInner::Matrix { columns, rows, scalar, } => { self.write_matrix_matrix_column_op( block, id, result_type_id, left_id, right_id, columns, rows, scalar.width, spirv::Op::FSub, ); self.cached[expr_handle] = id; return Ok(()); } crate::TypeInner::CooperativeMatrix { .. } => spirv::Op::FSub, _ => unimplemented!(), }, crate::BinaryOperator::Multiply => { match (left_dimension, right_dimension) { (Dimension::Scalar, Dimension::Vector) => { self.write_vector_scalar_mult( block, id, result_type_id, right_id, left_id, right_ty_inner, ); self.cached[expr_handle] = id; return Ok(()); } (Dimension::Vector, Dimension::Scalar) => { self.write_vector_scalar_mult( block, id, result_type_id, left_id, right_id, left_ty_inner, ); self.cached[expr_handle] = id; return Ok(()); } (Dimension::Vector, Dimension::Matrix) => { spirv::Op::VectorTimesMatrix } (Dimension::Matrix, Dimension::Scalar) | (Dimension::CooperativeMatrix, Dimension::Scalar) => { spirv::Op::MatrixTimesScalar } (Dimension::Scalar, Dimension::Matrix) | (Dimension::Scalar, Dimension::CooperativeMatrix) => { reverse_operands = true; spirv::Op::MatrixTimesScalar } (Dimension::Matrix, Dimension::Vector) => { spirv::Op::MatrixTimesVector } (Dimension::Matrix, Dimension::Matrix) => { spirv::Op::MatrixTimesMatrix } (Dimension::Vector, Dimension::Vector) | (Dimension::Scalar, Dimension::Scalar) if left_ty_inner.scalar_kind() == Some(crate::ScalarKind::Float) => { spirv::Op::FMul } (Dimension::Vector, Dimension::Vector) | (Dimension::Scalar, Dimension::Scalar) => spirv::Op::IMul, (Dimension::CooperativeMatrix, Dimension::CooperativeMatrix) //Note: technically can do `FMul` but IR doesn't have matrix per-component multiplication | (Dimension::CooperativeMatrix, _) | (_, Dimension::CooperativeMatrix) => { unimplemented!() } } } crate::BinaryOperator::Divide => match left_ty_inner.scalar_kind() { Some(crate::ScalarKind::Sint) => spirv::Op::SDiv, Some(crate::ScalarKind::Uint) => spirv::Op::UDiv, Some(crate::ScalarKind::Float) => spirv::Op::FDiv, _ => unimplemented!(), }, crate::BinaryOperator::Modulo => match left_ty_inner.scalar_kind() { // TODO: handle undefined behavior // if right == 0 return ? see https://github.com/gpuweb/gpuweb/issues/2798 Some(crate::ScalarKind::Float) => spirv::Op::FRem, Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => { unreachable!("Should have been handled by wrapped function") } _ => unimplemented!(), }, crate::BinaryOperator::Equal => match left_ty_inner.scalar_kind() { Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => { spirv::Op::IEqual } Some(crate::ScalarKind::Float) => spirv::Op::FOrdEqual, Some(crate::ScalarKind::Bool) => spirv::Op::LogicalEqual, _ => unimplemented!(), }, crate::BinaryOperator::NotEqual => match left_ty_inner.scalar_kind() { Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => { spirv::Op::INotEqual } Some(crate::ScalarKind::Float) => spirv::Op::FOrdNotEqual, Some(crate::ScalarKind::Bool) => spirv::Op::LogicalNotEqual, _ => unimplemented!(), }, crate::BinaryOperator::Less => match left_ty_inner.scalar_kind() { Some(crate::ScalarKind::Sint) => spirv::Op::SLessThan, Some(crate::ScalarKind::Uint) => spirv::Op::ULessThan, Some(crate::ScalarKind::Float) => spirv::Op::FOrdLessThan, _ => unimplemented!(), }, crate::BinaryOperator::LessEqual => match left_ty_inner.scalar_kind() { Some(crate::ScalarKind::Sint) => spirv::Op::SLessThanEqual, Some(crate::ScalarKind::Uint) => spirv::Op::ULessThanEqual, Some(crate::ScalarKind::Float) => spirv::Op::FOrdLessThanEqual, _ => unimplemented!(), }, crate::BinaryOperator::Greater => match left_ty_inner.scalar_kind() { Some(crate::ScalarKind::Sint) => spirv::Op::SGreaterThan, Some(crate::ScalarKind::Uint) => spirv::Op::UGreaterThan, Some(crate::ScalarKind::Float) => spirv::Op::FOrdGreaterThan, _ => unimplemented!(), }, crate::BinaryOperator::GreaterEqual => match left_ty_inner.scalar_kind() { Some(crate::ScalarKind::Sint) => spirv::Op::SGreaterThanEqual, Some(crate::ScalarKind::Uint) => spirv::Op::UGreaterThanEqual, Some(crate::ScalarKind::Float) => spirv::Op::FOrdGreaterThanEqual, _ => unimplemented!(), }, crate::BinaryOperator::And => match left_ty_inner.scalar_kind() { Some(crate::ScalarKind::Bool) => spirv::Op::LogicalAnd, _ => spirv::Op::BitwiseAnd, }, crate::BinaryOperator::ExclusiveOr => spirv::Op::BitwiseXor, crate::BinaryOperator::InclusiveOr => match left_ty_inner.scalar_kind() { Some(crate::ScalarKind::Bool) => spirv::Op::LogicalOr, _ => spirv::Op::BitwiseOr, }, crate::BinaryOperator::LogicalAnd => spirv::Op::LogicalAnd, crate::BinaryOperator::LogicalOr => spirv::Op::LogicalOr, crate::BinaryOperator::ShiftLeft => spirv::Op::ShiftLeftLogical, crate::BinaryOperator::ShiftRight => match left_ty_inner.scalar_kind() { Some(crate::ScalarKind::Sint) => spirv::Op::ShiftRightArithmetic, Some(crate::ScalarKind::Uint) => spirv::Op::ShiftRightLogical, _ => unimplemented!(), }, }; block.body.push(Instruction::binary( spirv_op, result_type_id, id, if reverse_operands { right_id } else { left_id }, if reverse_operands { left_id } else { right_id }, )); } id } crate::Expression::Math { fun, arg, arg1, arg2, arg3, } => { use crate::MathFunction as Mf; enum MathOp { Ext(spirv::GlslStd450Op), Custom(Instruction), } let arg0_id = self.cached[arg]; let arg_ty = self.fun_info[arg].ty.inner_with(&self.ir_module.types); let arg_scalar_kind = arg_ty.scalar_kind(); let arg1_id = match arg1 { Some(handle) => self.cached[handle], None => 0, }; let arg2_id = match arg2 { Some(handle) => self.cached[handle], None => 0, }; let arg3_id = match arg3 { Some(handle) => self.cached[handle], None => 0, }; let id = self.gen_id(); let math_op = match fun { // comparison Mf::Abs => { match arg_scalar_kind { Some(crate::ScalarKind::Float) => { MathOp::Ext(spirv::GlslStd450Op::FAbs) } Some(crate::ScalarKind::Sint) => MathOp::Ext(spirv::GlslStd450Op::SAbs), Some(crate::ScalarKind::Uint) => { MathOp::Custom(Instruction::unary( spirv::Op::CopyObject, // do nothing result_type_id, id, arg0_id, )) } other => unimplemented!("Unexpected abs({:?})", other), } } Mf::Min => MathOp::Ext(match arg_scalar_kind { Some(crate::ScalarKind::Float) => spirv::GlslStd450Op::FMin, Some(crate::ScalarKind::Sint) => spirv::GlslStd450Op::SMin, Some(crate::ScalarKind::Uint) => spirv::GlslStd450Op::UMin, other => unimplemented!("Unexpected min({:?})", other), }), Mf::Max => MathOp::Ext(match arg_scalar_kind { Some(crate::ScalarKind::Float) => spirv::GlslStd450Op::FMax, Some(crate::ScalarKind::Sint) => spirv::GlslStd450Op::SMax, Some(crate::ScalarKind::Uint) => spirv::GlslStd450Op::UMax, other => unimplemented!("Unexpected max({:?})", other), }), Mf::Clamp => match arg_scalar_kind { // Clamp is undefined if min > max. In practice this means it can use a median-of-three // instruction to determine the value. This is fine according to the WGSL spec for float // clamp, but integer clamp _must_ use min-max. As such we write out min/max. Some(crate::ScalarKind::Float) => MathOp::Ext(spirv::GlslStd450Op::FClamp), Some(_) => { let (min_op, max_op) = match arg_scalar_kind { Some(crate::ScalarKind::Sint) => { (spirv::GlslStd450Op::SMin, spirv::GlslStd450Op::SMax) } Some(crate::ScalarKind::Uint) => { (spirv::GlslStd450Op::UMin, spirv::GlslStd450Op::UMax) } _ => unreachable!(), }; let max_id = self.gen_id(); block.body.push(Instruction::ext_inst_gl_op( self.writer.gl450_ext_inst_id, max_op, result_type_id, max_id, &[arg0_id, arg1_id], )); MathOp::Custom(Instruction::ext_inst_gl_op( self.writer.gl450_ext_inst_id, min_op, result_type_id, id, &[max_id, arg2_id], )) } other => unimplemented!("Unexpected max({:?})", other), }, Mf::Saturate => { let (maybe_size, scalar) = match *arg_ty { crate::TypeInner::Vector { size, scalar } => (Some(size), scalar), crate::TypeInner::Scalar(scalar) => (None, scalar), ref other => unimplemented!("Unexpected saturate({:?})", other), }; let scalar = crate::Scalar::float(scalar.width); let mut arg1_id = self.writer.get_constant_scalar_with(0, scalar)?; let mut arg2_id = self.writer.get_constant_scalar_with(1, scalar)?; if let Some(size) = maybe_size { let ty = LocalType::Numeric(NumericType::Vector { size, scalar }).into(); self.temp_list.clear(); self.temp_list.resize(size as _, arg1_id); arg1_id = self.writer.get_constant_composite(ty, &self.temp_list); self.temp_list.fill(arg2_id); arg2_id = self.writer.get_constant_composite(ty, &self.temp_list); } MathOp::Custom(Instruction::ext_inst_gl_op( self.writer.gl450_ext_inst_id, spirv::GlslStd450Op::FClamp, result_type_id, id, &[arg0_id, arg1_id, arg2_id], )) } // trigonometry Mf::Sin => MathOp::Ext(spirv::GlslStd450Op::Sin), Mf::Sinh => MathOp::Ext(spirv::GlslStd450Op::Sinh), Mf::Asin => MathOp::Ext(spirv::GlslStd450Op::Asin), Mf::Cos => MathOp::Ext(spirv::GlslStd450Op::Cos), Mf::Cosh => MathOp::Ext(spirv::GlslStd450Op::Cosh), Mf::Acos => MathOp::Ext(spirv::GlslStd450Op::Acos), Mf::Tan => MathOp::Ext(spirv::GlslStd450Op::Tan), Mf::Tanh => MathOp::Ext(spirv::GlslStd450Op::Tanh), Mf::Atan => MathOp::Ext(spirv::GlslStd450Op::Atan), Mf::Atan2 => MathOp::Ext(spirv::GlslStd450Op::Atan2), Mf::Asinh => MathOp::Ext(spirv::GlslStd450Op::Asinh), Mf::Acosh => MathOp::Ext(spirv::GlslStd450Op::Acosh), Mf::Atanh => MathOp::Ext(spirv::GlslStd450Op::Atanh), Mf::Radians => MathOp::Ext(spirv::GlslStd450Op::Radians), Mf::Degrees => MathOp::Ext(spirv::GlslStd450Op::Degrees), // decomposition Mf::Ceil => MathOp::Ext(spirv::GlslStd450Op::Ceil), Mf::Round => MathOp::Ext(spirv::GlslStd450Op::RoundEven), Mf::Floor => MathOp::Ext(spirv::GlslStd450Op::Floor), Mf::Fract => MathOp::Ext(spirv::GlslStd450Op::Fract), Mf::Trunc => MathOp::Ext(spirv::GlslStd450Op::Trunc), Mf::Modf => MathOp::Ext(spirv::GlslStd450Op::ModfStruct), Mf::Frexp => MathOp::Ext(spirv::GlslStd450Op::FrexpStruct), Mf::Ldexp => MathOp::Ext(spirv::GlslStd450Op::Ldexp), // geometry Mf::Dot => match *self.fun_info[arg].ty.inner_with(&self.ir_module.types) { crate::TypeInner::Vector { scalar: crate::Scalar { kind: crate::ScalarKind::Float, .. }, .. } => MathOp::Custom(Instruction::binary( spirv::Op::Dot, result_type_id, id, arg0_id, arg1_id, )), // TODO: consider using integer dot product if VK_KHR_shader_integer_dot_product is available crate::TypeInner::Vector { size, .. } => { self.write_dot_product( id, result_type_id, arg0_id, arg1_id, size as u32, block, |result_id, composite_id, index| { Instruction::composite_extract( result_type_id, result_id, composite_id, &[index], ) }, ); self.cached[expr_handle] = id; return Ok(()); } _ => unreachable!( "Correct TypeInner for dot product should be already validated" ), }, fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => { if self .writer .require_all(&[ spirv::Capability::DotProduct, spirv::Capability::DotProductInput4x8BitPacked, ]) .is_ok() { // Write optimized code using `PackedVectorFormat4x8Bit`. if self.writer.lang_version() < (1, 6) { // SPIR-V 1.6 supports the required capabilities natively, so the extension // is only required for earlier versions. See right column of // . self.writer.use_extension("SPV_KHR_integer_dot_product"); } let op = match fun { Mf::Dot4I8Packed => spirv::Op::SDot, Mf::Dot4U8Packed => spirv::Op::UDot, _ => unreachable!(), }; block.body.push(Instruction::ternary( op, result_type_id, id, arg0_id, arg1_id, spirv::PackedVectorFormat::PackedVectorFormat4x8Bit as Word, )); } else { // Fall back to a polyfill since `PackedVectorFormat4x8Bit` is not available. let (extract_op, arg0_id, arg1_id) = match fun { Mf::Dot4U8Packed => (spirv::Op::BitFieldUExtract, arg0_id, arg1_id), Mf::Dot4I8Packed => { // Convert both packed arguments to signed integers so that we can apply the // `BitFieldSExtract` operation on them in `write_dot_product` below. let new_arg0_id = self.gen_id(); block.body.push(Instruction::unary( spirv::Op::Bitcast, result_type_id, new_arg0_id, arg0_id, )); let new_arg1_id = self.gen_id(); block.body.push(Instruction::unary( spirv::Op::Bitcast, result_type_id, new_arg1_id, arg1_id, )); (spirv::Op::BitFieldSExtract, new_arg0_id, new_arg1_id) } _ => unreachable!(), }; let eight = self.writer.get_constant_scalar(crate::Literal::U32(8)); const VEC_LENGTH: u8 = 4; let bit_shifts: [_; VEC_LENGTH as usize] = core::array::from_fn(|index| { self.writer .get_constant_scalar(crate::Literal::U32(index as u32 * 8)) }); self.write_dot_product( id, result_type_id, arg0_id, arg1_id, VEC_LENGTH as Word, block, |result_id, composite_id, index| { Instruction::ternary( extract_op, result_type_id, result_id, composite_id, bit_shifts[index as usize], eight, ) }, ); } self.cached[expr_handle] = id; return Ok(()); } Mf::Outer => MathOp::Custom(Instruction::binary( spirv::Op::OuterProduct, result_type_id, id, arg0_id, arg1_id, )), Mf::Cross => MathOp::Ext(spirv::GlslStd450Op::Cross), Mf::Distance => MathOp::Ext(spirv::GlslStd450Op::Distance), Mf::Length => MathOp::Ext(spirv::GlslStd450Op::Length), Mf::Normalize => MathOp::Ext(spirv::GlslStd450Op::Normalize), Mf::FaceForward => MathOp::Ext(spirv::GlslStd450Op::FaceForward), Mf::Reflect => MathOp::Ext(spirv::GlslStd450Op::Reflect), Mf::Refract => MathOp::Ext(spirv::GlslStd450Op::Refract), // exponent Mf::Exp => MathOp::Ext(spirv::GlslStd450Op::Exp), Mf::Exp2 => MathOp::Ext(spirv::GlslStd450Op::Exp2), Mf::Log => MathOp::Ext(spirv::GlslStd450Op::Log), Mf::Log2 => MathOp::Ext(spirv::GlslStd450Op::Log2), Mf::Pow => MathOp::Ext(spirv::GlslStd450Op::Pow), // computational Mf::Sign => MathOp::Ext(match arg_scalar_kind { Some(crate::ScalarKind::Float) => spirv::GlslStd450Op::FSign, Some(crate::ScalarKind::Sint) => spirv::GlslStd450Op::SSign, other => unimplemented!("Unexpected sign({:?})", other), }), Mf::Fma => MathOp::Ext(spirv::GlslStd450Op::Fma), Mf::Mix => { let selector = arg2.unwrap(); let selector_ty = self.fun_info[selector].ty.inner_with(&self.ir_module.types); match (arg_ty, selector_ty) { // if the selector is a scalar, we need to splat it ( &crate::TypeInner::Vector { size, .. }, &crate::TypeInner::Scalar(scalar), ) => { let selector_type_id = self.get_numeric_type_id(NumericType::Vector { size, scalar }); self.temp_list.clear(); self.temp_list.resize(size as usize, arg2_id); let selector_id = self.gen_id(); block.body.push(Instruction::composite_construct( selector_type_id, selector_id, &self.temp_list, )); MathOp::Custom(Instruction::ext_inst_gl_op( self.writer.gl450_ext_inst_id, spirv::GlslStd450Op::FMix, result_type_id, id, &[arg0_id, arg1_id, selector_id], )) } _ => MathOp::Ext(spirv::GlslStd450Op::FMix), } } Mf::Step => MathOp::Ext(spirv::GlslStd450Op::Step), Mf::SmoothStep => MathOp::Ext(spirv::GlslStd450Op::SmoothStep), Mf::Sqrt => MathOp::Ext(spirv::GlslStd450Op::Sqrt), Mf::InverseSqrt => MathOp::Ext(spirv::GlslStd450Op::InverseSqrt), Mf::Inverse => MathOp::Ext(spirv::GlslStd450Op::MatrixInverse), Mf::Transpose => MathOp::Custom(Instruction::unary( spirv::Op::Transpose, result_type_id, id, arg0_id, )), Mf::Determinant => MathOp::Ext(spirv::GlslStd450Op::Determinant), Mf::QuantizeToF16 => MathOp::Custom(Instruction::unary( spirv::Op::QuantizeToF16, result_type_id, id, arg0_id, )), Mf::ReverseBits => MathOp::Custom(Instruction::unary( spirv::Op::BitReverse, result_type_id, id, arg0_id, )), Mf::CountTrailingZeros => { let uint_id = match *arg_ty { crate::TypeInner::Vector { size, scalar } => { let ty = LocalType::Numeric(NumericType::Vector { size, scalar }).into(); self.temp_list.clear(); self.temp_list.resize( size as _, self.writer .get_constant_scalar_with(scalar.width * 8, scalar)?, ); self.writer.get_constant_composite(ty, &self.temp_list) } crate::TypeInner::Scalar(scalar) => self .writer .get_constant_scalar_with(scalar.width * 8, scalar)?, _ => unreachable!(), }; let lsb_id = self.gen_id(); block.body.push(Instruction::ext_inst_gl_op( self.writer.gl450_ext_inst_id, spirv::GlslStd450Op::FindILsb, result_type_id, lsb_id, &[arg0_id], )); MathOp::Custom(Instruction::ext_inst_gl_op( self.writer.gl450_ext_inst_id, spirv::GlslStd450Op::UMin, result_type_id, id, &[uint_id, lsb_id], )) } Mf::CountLeadingZeros => { let (int_type_id, int_id, width) = match *arg_ty { crate::TypeInner::Vector { size, scalar } => { let ty = LocalType::Numeric(NumericType::Vector { size, scalar }).into(); self.temp_list.clear(); self.temp_list.resize( size as _, self.writer .get_constant_scalar_with(scalar.width * 8 - 1, scalar)?, ); ( self.get_type_id(ty), self.writer.get_constant_composite(ty, &self.temp_list), scalar.width, ) } crate::TypeInner::Scalar(scalar) => ( self.get_numeric_type_id(NumericType::Scalar(scalar)), self.writer .get_constant_scalar_with(scalar.width * 8 - 1, scalar)?, scalar.width, ), _ => unreachable!(), }; if width != 4 { unreachable!("This is validated out until a polyfill is implemented. https://github.com/gfx-rs/wgpu/issues/5276"); }; let msb_id = self.gen_id(); block.body.push(Instruction::ext_inst_gl_op( self.writer.gl450_ext_inst_id, if width != 4 { spirv::GlslStd450Op::FindILsb } else { spirv::GlslStd450Op::FindUMsb }, int_type_id, msb_id, &[arg0_id], )); MathOp::Custom(Instruction::binary( spirv::Op::ISub, result_type_id, id, int_id, msb_id, )) } Mf::CountOneBits => MathOp::Custom(Instruction::unary( spirv::Op::BitCount, result_type_id, id, arg0_id, )), Mf::ExtractBits => { let op = match arg_scalar_kind { Some(crate::ScalarKind::Uint) => spirv::Op::BitFieldUExtract, Some(crate::ScalarKind::Sint) => spirv::Op::BitFieldSExtract, other => unimplemented!("Unexpected sign({:?})", other), }; // The behavior of ExtractBits is undefined when offset + count > bit_width. We need // to first sanitize the offset and count first. If we don't do this, AMD and Intel // will return out-of-spec values if the extracted range is not within the bit width. // // This encodes the exact formula specified by the wgsl spec: // https://gpuweb.github.io/gpuweb/wgsl/#extractBits-unsigned-builtin // // w = sizeof(x) * 8 // o = min(offset, w) // tmp = w - o // c = min(count, tmp) // // bitfieldExtract(x, o, c) let bit_width = arg_ty.scalar_width().unwrap() * 8; let width_constant = self .writer .get_constant_scalar(crate::Literal::U32(bit_width as u32)); let u32_type = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32)); // o = min(offset, w) let offset_id = self.gen_id(); block.body.push(Instruction::ext_inst_gl_op( self.writer.gl450_ext_inst_id, spirv::GlslStd450Op::UMin, u32_type, offset_id, &[arg1_id, width_constant], )); // tmp = w - o let max_count_id = self.gen_id(); block.body.push(Instruction::binary( spirv::Op::ISub, u32_type, max_count_id, width_constant, offset_id, )); // c = min(count, tmp) let count_id = self.gen_id(); block.body.push(Instruction::ext_inst_gl_op( self.writer.gl450_ext_inst_id, spirv::GlslStd450Op::UMin, u32_type, count_id, &[arg2_id, max_count_id], )); MathOp::Custom(Instruction::ternary( op, result_type_id, id, arg0_id, offset_id, count_id, )) } Mf::InsertBits => { // The behavior of InsertBits has the same undefined behavior as ExtractBits. let bit_width = arg_ty.scalar_width().unwrap() * 8; let width_constant = self .writer .get_constant_scalar(crate::Literal::U32(bit_width as u32)); let u32_type = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32)); // o = min(offset, w) let offset_id = self.gen_id(); block.body.push(Instruction::ext_inst_gl_op( self.writer.gl450_ext_inst_id, spirv::GlslStd450Op::UMin, u32_type, offset_id, &[arg2_id, width_constant], )); // tmp = w - o let max_count_id = self.gen_id(); block.body.push(Instruction::binary( spirv::Op::ISub, u32_type, max_count_id, width_constant, offset_id, )); // c = min(count, tmp) let count_id = self.gen_id(); block.body.push(Instruction::ext_inst_gl_op( self.writer.gl450_ext_inst_id, spirv::GlslStd450Op::UMin, u32_type, count_id, &[arg3_id, max_count_id], )); MathOp::Custom(Instruction::quaternary( spirv::Op::BitFieldInsert, result_type_id, id, arg0_id, arg1_id, offset_id, count_id, )) } Mf::FirstTrailingBit => MathOp::Ext(spirv::GlslStd450Op::FindILsb), Mf::FirstLeadingBit => { if arg_ty.scalar_width() == Some(4) { let thing = match arg_scalar_kind { Some(crate::ScalarKind::Uint) => spirv::GlslStd450Op::FindUMsb, Some(crate::ScalarKind::Sint) => spirv::GlslStd450Op::FindSMsb, other => unimplemented!("Unexpected firstLeadingBit({:?})", other), }; MathOp::Ext(thing) } else { unreachable!("This is validated out until a polyfill is implemented. https://github.com/gfx-rs/wgpu/issues/5276"); } } Mf::Pack4x8unorm => MathOp::Ext(spirv::GlslStd450Op::PackUnorm4x8), Mf::Pack4x8snorm => MathOp::Ext(spirv::GlslStd450Op::PackSnorm4x8), Mf::Pack2x16float => MathOp::Ext(spirv::GlslStd450Op::PackHalf2x16), Mf::Pack2x16unorm => MathOp::Ext(spirv::GlslStd450Op::PackUnorm2x16), Mf::Pack2x16snorm => MathOp::Ext(spirv::GlslStd450Op::PackSnorm2x16), fun @ (Mf::Pack4xI8 | Mf::Pack4xU8 | Mf::Pack4xI8Clamp | Mf::Pack4xU8Clamp) => { let is_signed = matches!(fun, Mf::Pack4xI8 | Mf::Pack4xI8Clamp); let should_clamp = matches!(fun, Mf::Pack4xI8Clamp | Mf::Pack4xU8Clamp); let last_instruction = if self.writer.require_all(&[spirv::Capability::Int8]).is_ok() { self.write_pack4x8_optimized( block, result_type_id, arg0_id, id, is_signed, should_clamp, ) } else { self.write_pack4x8_polyfill( block, result_type_id, arg0_id, id, is_signed, should_clamp, ) }; MathOp::Custom(last_instruction) } Mf::Unpack4x8unorm => MathOp::Ext(spirv::GlslStd450Op::UnpackUnorm4x8), Mf::Unpack4x8snorm => MathOp::Ext(spirv::GlslStd450Op::UnpackSnorm4x8), Mf::Unpack2x16float => MathOp::Ext(spirv::GlslStd450Op::UnpackHalf2x16), Mf::Unpack2x16unorm => MathOp::Ext(spirv::GlslStd450Op::UnpackUnorm2x16), Mf::Unpack2x16snorm => MathOp::Ext(spirv::GlslStd450Op::UnpackSnorm2x16), fun @ (Mf::Unpack4xI8 | Mf::Unpack4xU8) => { let is_signed = matches!(fun, Mf::Unpack4xI8); let last_instruction = if self.writer.require_all(&[spirv::Capability::Int8]).is_ok() { self.write_unpack4x8_optimized( block, result_type_id, arg0_id, id, is_signed, ) } else { self.write_unpack4x8_polyfill( block, result_type_id, arg0_id, id, is_signed, ) }; MathOp::Custom(last_instruction) } }; block.body.push(match math_op { MathOp::Ext(op) => Instruction::ext_inst_gl_op( self.writer.gl450_ext_inst_id, op, result_type_id, id, &[arg0_id, arg1_id, arg2_id, arg3_id][..fun.argument_count()], ), MathOp::Custom(inst) => inst, }); id } crate::Expression::LocalVariable(variable) => { if let Some(rq_tracker) = self .function .ray_query_initialization_tracker_variables .get(&variable) { self.ray_query_tracker_expr.insert( expr_handle, super::RayQueryTrackers { initialized_tracker: rq_tracker.id, t_max_tracker: self .function .ray_query_t_max_tracker_variables .get(&variable) .expect("Both trackers are set at the same time.") .id, }, ); } self.function.variables[&variable].id } crate::Expression::Load { pointer } => { self.write_checked_load(pointer, block, AccessTypeAdjustment::None, result_type_id)? } crate::Expression::FunctionArgument(index) => self.function.parameter_id(index), crate::Expression::CallResult(_) | crate::Expression::AtomicResult { .. } | crate::Expression::WorkGroupUniformLoadResult { .. } | crate::Expression::RayQueryProceedResult | crate::Expression::SubgroupBallotResult | crate::Expression::SubgroupOperationResult { .. } => self.cached[expr_handle], crate::Expression::As { expr, kind, convert, } => self.write_as_expression(expr, convert, kind, block, result_type_id)?, crate::Expression::ImageLoad { image, coordinate, array_index, sample, level, } => self.write_image_load( result_type_id, image, coordinate, array_index, level, sample, block, )?, crate::Expression::ImageSample { image, sampler, gather, coordinate, array_index, offset, level, depth_ref, clamp_to_edge, } => self.write_image_sample( result_type_id, image, sampler, gather, coordinate, array_index, offset, level, depth_ref, clamp_to_edge, block, )?, crate::Expression::Select { condition, accept, reject, } => { let id = self.gen_id(); let mut condition_id = self.cached[condition]; let accept_id = self.cached[accept]; let reject_id = self.cached[reject]; let condition_ty = self.fun_info[condition] .ty .inner_with(&self.ir_module.types); let object_ty = self.fun_info[accept].ty.inner_with(&self.ir_module.types); if let ( &crate::TypeInner::Scalar( condition_scalar @ crate::Scalar { kind: crate::ScalarKind::Bool, .. }, ), &crate::TypeInner::Vector { size, .. }, ) = (condition_ty, object_ty) { self.temp_list.clear(); self.temp_list.resize(size as usize, condition_id); let bool_vector_type_id = self.get_numeric_type_id(NumericType::Vector { size, scalar: condition_scalar, }); let id = self.gen_id(); block.body.push(Instruction::composite_construct( bool_vector_type_id, id, &self.temp_list, )); condition_id = id } let instruction = Instruction::select(result_type_id, id, condition_id, accept_id, reject_id); block.body.push(instruction); id } crate::Expression::Derivative { axis, ctrl, expr } => { use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl}; match ctrl { Ctrl::Coarse | Ctrl::Fine => { self.writer.require_any( "DerivativeControl", &[spirv::Capability::DerivativeControl], )?; } Ctrl::None => {} } let id = self.gen_id(); let expr_id = self.cached[expr]; let op = match (axis, ctrl) { (Axis::X, Ctrl::Coarse) => spirv::Op::DPdxCoarse, (Axis::X, Ctrl::Fine) => spirv::Op::DPdxFine, (Axis::X, Ctrl::None) => spirv::Op::DPdx, (Axis::Y, Ctrl::Coarse) => spirv::Op::DPdyCoarse, (Axis::Y, Ctrl::Fine) => spirv::Op::DPdyFine, (Axis::Y, Ctrl::None) => spirv::Op::DPdy, (Axis::Width, Ctrl::Coarse) => spirv::Op::FwidthCoarse, (Axis::Width, Ctrl::Fine) => spirv::Op::FwidthFine, (Axis::Width, Ctrl::None) => spirv::Op::Fwidth, }; block .body .push(Instruction::derivative(op, result_type_id, id, expr_id)); id } crate::Expression::ImageQuery { image, query } => { self.write_image_query(result_type_id, image, query, block)? } crate::Expression::Relational { fun, argument } => { use crate::RelationalFunction as Rf; let arg_id = self.cached[argument]; let op = match fun { Rf::All => spirv::Op::All, Rf::Any => spirv::Op::Any, Rf::IsNan => spirv::Op::IsNan, Rf::IsInf => spirv::Op::IsInf, }; let id = self.gen_id(); block .body .push(Instruction::relational(op, result_type_id, id, arg_id)); id } crate::Expression::ArrayLength(expr) => self.write_runtime_array_length(expr, block)?, crate::Expression::RayQueryGetIntersection { query, committed } => { let query_id = self.cached[query]; let init_tracker_id = *self .ray_query_tracker_expr .get(&query) .expect("not a cached ray query"); let func_id = self .writer .write_ray_query_get_intersection_function(committed, self.ir_module); let ray_intersection = self.ir_module.special_types.ray_intersection.unwrap(); let intersection_type_id = self.get_handle_type_id(ray_intersection); let id = self.gen_id(); block.body.push(Instruction::function_call( intersection_type_id, id, func_id, &[query_id, init_tracker_id.initialized_tracker], )); id } crate::Expression::RayQueryVertexPositions { query, committed } => { self.writer.require_any( "RayQueryVertexPositions", &[spirv::Capability::RayQueryPositionFetchKHR], )?; self.write_ray_query_return_vertex_position(query, block, committed) } crate::Expression::CooperativeLoad { ref data, .. } => { self.writer.require_any( "CooperativeMatrix", &[spirv::Capability::CooperativeMatrixKHR], )?; let layout = if data.row_major { spirv::CooperativeMatrixLayout::RowMajorKHR } else { spirv::CooperativeMatrixLayout::ColumnMajorKHR }; let layout_id = self.get_index_constant(layout as u32); let stride_id = self.cached[data.stride]; match self.write_access_chain(data.pointer, block, AccessTypeAdjustment::None)? { ExpressionPointer::Ready { pointer_id } => { let id = self.gen_id(); block.body.push(Instruction::coop_load( result_type_id, id, pointer_id, layout_id, stride_id, )); id } ExpressionPointer::Conditional { condition, access } => self .write_conditional_indexed_load( result_type_id, condition, block, |id_gen, block| { let pointer_id = access.result_id.unwrap(); block.body.push(access); let id = id_gen.next(); block.body.push(Instruction::coop_load( result_type_id, id, pointer_id, layout_id, stride_id, )); id }, ), } } crate::Expression::CooperativeMultiplyAdd { a, b, c } => { self.writer.require_any( "CooperativeMatrix", &[spirv::Capability::CooperativeMatrixKHR], )?; let a_id = self.cached[a]; let b_id = self.cached[b]; let c_id = self.cached[c]; let id = self.gen_id(); block.body.push(Instruction::coop_mul_add( result_type_id, id, a_id, b_id, c_id, )); id } }; self.cached[expr_handle] = id; Ok(()) } /// Helper which focuses on generating the `As` expressions and the various conversions /// that need to happen because of that. fn write_as_expression( &mut self, expr: Handle, convert: Option, kind: crate::ScalarKind, block: &mut Block, result_type_id: u32, ) -> Result { use crate::ScalarKind as Sk; let expr_id = self.cached[expr]; let ty = self.fun_info[expr].ty.inner_with(&self.ir_module.types); // Matrix casts needs special treatment in SPIR-V, as the cast functions // can take vectors or scalars, but not matrices. In order to cast a matrix // we need to cast each column of the matrix individually and construct a new // matrix from the converted columns. if let crate::TypeInner::Matrix { columns, rows, scalar, } = *ty { let Some(convert) = convert else { // No conversion needs to be done, passes through. return Ok(expr_id); }; if convert == scalar.width { // No conversion needs to be done, passes through. return Ok(expr_id); } if kind != Sk::Float { // Only float conversions are supported for matrices. return Err(Error::Validation("Matrices must be floats")); } // Type of each extracted column let column_src_ty = self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Vector { size: rows, scalar, }))); // Type of the column after conversion let column_dst_ty = self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Vector { size: rows, scalar: crate::Scalar { kind, width: convert, }, }))); let mut components = ArrayVec::::new(); for column in 0..columns as usize { let column_id = self.gen_id(); block.body.push(Instruction::composite_extract( column_src_ty, column_id, expr_id, &[column as u32], )); let column_conv_id = self.gen_id(); block.body.push(Instruction::unary( spirv::Op::FConvert, column_dst_ty, column_conv_id, column_id, )); components.push(column_conv_id); } let construct_id = self.gen_id(); block.body.push(Instruction::composite_construct( result_type_id, construct_id, &components, )); return Ok(construct_id); } let (src_scalar, src_size) = match *ty { crate::TypeInner::Scalar(scalar) => (scalar, None), crate::TypeInner::Vector { scalar, size } => (scalar, Some(size)), ref other => { log::error!("As source {other:?}"); return Err(Error::Validation("Unexpected Expression::As source")); } }; enum Cast { Identity(Word), Unary(spirv::Op, Word), Binary(spirv::Op, Word, Word), Ternary(spirv::Op, Word, Word, Word), } let cast = match (src_scalar.kind, kind, convert) { // Filter out identity casts. Some Adreno drivers are // confused by no-op OpBitCast instructions. (src_kind, kind, convert) if src_kind == kind && convert.filter(|&width| width != src_scalar.width).is_none() => { Cast::Identity(expr_id) } (Sk::Bool, Sk::Bool, _) => Cast::Unary(spirv::Op::CopyObject, expr_id), (_, _, None) => Cast::Unary(spirv::Op::Bitcast, expr_id), // casting to a bool - generate `OpXxxNotEqual` (_, Sk::Bool, Some(_)) => { let op = match src_scalar.kind { Sk::Sint | Sk::Uint => spirv::Op::INotEqual, Sk::Float => spirv::Op::FUnordNotEqual, Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => unreachable!(), }; let zero_scalar_id = self.writer.get_constant_scalar_with(0, src_scalar)?; let zero_id = match src_size { Some(size) => { let ty = LocalType::Numeric(NumericType::Vector { size, scalar: src_scalar, }) .into(); self.temp_list.clear(); self.temp_list.resize(size as _, zero_scalar_id); self.writer.get_constant_composite(ty, &self.temp_list) } None => zero_scalar_id, }; Cast::Binary(op, expr_id, zero_id) } // casting from a bool - generate `OpSelect` (Sk::Bool, _, Some(dst_width)) => { let dst_scalar = crate::Scalar { kind, width: dst_width, }; let zero_scalar_id = self.writer.get_constant_scalar_with(0, dst_scalar)?; let one_scalar_id = self.writer.get_constant_scalar_with(1, dst_scalar)?; let (accept_id, reject_id) = match src_size { Some(size) => { let ty = LocalType::Numeric(NumericType::Vector { size, scalar: dst_scalar, }) .into(); self.temp_list.clear(); self.temp_list.resize(size as _, zero_scalar_id); let vec0_id = self.writer.get_constant_composite(ty, &self.temp_list); self.temp_list.fill(one_scalar_id); let vec1_id = self.writer.get_constant_composite(ty, &self.temp_list); (vec1_id, vec0_id) } None => (one_scalar_id, zero_scalar_id), }; Cast::Ternary(spirv::Op::Select, expr_id, accept_id, reject_id) } // Avoid undefined behaviour when casting from a float to integer // when the value is out of range for the target type. Additionally // ensure we clamp to the correct value as per the WGSL spec. // // https://www.w3.org/TR/WGSL/#floating-point-conversion: // * If X is exactly representable in the target type T, then the // result is that value. // * Otherwise, the result is the value in T closest to // truncate(X) and also exactly representable in the original // floating point type. (Sk::Float, Sk::Sint | Sk::Uint, Some(width)) => { let dst_scalar = crate::Scalar { kind, width }; let (min, max) = crate::proc::min_max_float_representable_by(src_scalar, dst_scalar); let expr_type_id = self.get_expression_type_id(&self.fun_info[expr].ty); let maybe_splat_const = |writer: &mut Writer, const_id| match src_size { None => const_id, Some(size) => { let constituent_ids = [const_id; crate::VectorSize::MAX]; writer.get_constant_composite( LookupType::Local(LocalType::Numeric(NumericType::Vector { size, scalar: src_scalar, })), &constituent_ids[..size as usize], ) } }; let min_const_id = self.writer.get_constant_scalar(min); let min_const_id = maybe_splat_const(self.writer, min_const_id); let max_const_id = self.writer.get_constant_scalar(max); let max_const_id = maybe_splat_const(self.writer, max_const_id); let clamp_id = self.gen_id(); block.body.push(Instruction::ext_inst_gl_op( self.writer.gl450_ext_inst_id, spirv::GlslStd450Op::FClamp, expr_type_id, clamp_id, &[expr_id, min_const_id, max_const_id], )); let op = match dst_scalar.kind { crate::ScalarKind::Sint => spirv::Op::ConvertFToS, crate::ScalarKind::Uint => spirv::Op::ConvertFToU, _ => unreachable!(), }; Cast::Unary(op, clamp_id) } (Sk::Float, Sk::Float, Some(dst_width)) if src_scalar.width != dst_width => { Cast::Unary(spirv::Op::FConvert, expr_id) } (Sk::Sint, Sk::Float, Some(_)) => Cast::Unary(spirv::Op::ConvertSToF, expr_id), (Sk::Sint, Sk::Sint, Some(dst_width)) if src_scalar.width != dst_width => { Cast::Unary(spirv::Op::SConvert, expr_id) } (Sk::Uint, Sk::Float, Some(_)) => Cast::Unary(spirv::Op::ConvertUToF, expr_id), (Sk::Uint, Sk::Uint, Some(dst_width)) if src_scalar.width != dst_width => { Cast::Unary(spirv::Op::UConvert, expr_id) } (Sk::Uint, Sk::Sint, Some(dst_width)) if src_scalar.width != dst_width => { Cast::Unary(spirv::Op::SConvert, expr_id) } (Sk::Sint, Sk::Uint, Some(dst_width)) if src_scalar.width != dst_width => { Cast::Unary(spirv::Op::UConvert, expr_id) } // We assume it's either an identity cast, or int-uint. _ => Cast::Unary(spirv::Op::Bitcast, expr_id), }; Ok(match cast { Cast::Identity(expr) => expr, Cast::Unary(op, op1) => { let id = self.gen_id(); block .body .push(Instruction::unary(op, result_type_id, id, op1)); id } Cast::Binary(op, op1, op2) => { let id = self.gen_id(); block .body .push(Instruction::binary(op, result_type_id, id, op1, op2)); id } Cast::Ternary(op, op1, op2, op3) => { let id = self.gen_id(); block .body .push(Instruction::ternary(op, result_type_id, id, op1, op2, op3)); id } }) } /// Build an `OpAccessChain` instruction. /// /// Emit any needed bounds-checking expressions to `block`. /// /// Give the `OpAccessChain` a result type based on `expr_handle`, adjusted /// according to `type_adjustment`; see the documentation for /// [`AccessTypeAdjustment`] for details. /// /// On success, the return value is an [`ExpressionPointer`] value; see the /// documentation for that type. fn write_access_chain( &mut self, mut expr_handle: Handle, block: &mut Block, type_adjustment: AccessTypeAdjustment, ) -> Result { let result_type_id = { let resolution = &self.fun_info[expr_handle].ty; match type_adjustment { AccessTypeAdjustment::None => self.writer.get_expression_type_id(resolution), AccessTypeAdjustment::IntroducePointer(class) => { self.writer.get_resolution_pointer_id(resolution, class) } AccessTypeAdjustment::UseStd140CompatType => { match *resolution.inner_with(&self.ir_module.types) { crate::TypeInner::Pointer { base, space: space @ crate::AddressSpace::Uniform, } => self.writer.get_pointer_type_id( self.writer.std140_compat_uniform_types[&base].type_id, map_storage_class(space), ), _ => unreachable!( "`UseStd140CompatType` must only be used with uniform pointer types" ), } } } }; // The id of the boolean `and` of all dynamic bounds checks up to this point. // // See `extend_bounds_check_condition_chain` for a full explanation. let mut accumulated_checks = None; // Is true if we are accessing into a binding array with a non-uniform index. let mut is_non_uniform_binding_array = false; // The index value if the previously encountered expression was an // `AccessIndex` of a matrix which has been decomposed into individual // column vectors directly in the containing struct. The subsequent // iteration will append the correct index to the list for accessing // said column from the containing struct. let mut prev_decomposed_matrix_index = None; self.temp_list.clear(); let root_id = loop { // If `expr_handle` was spilled, then the temporary variable has exactly // the value we want to start from. if let Some(spilled) = self.function.spilled_composites.get(&expr_handle) { // The root id of the `OpAccessChain` instruction is the temporary // variable we spilled the composite to. break spilled.id; } expr_handle = match self.ir_function.expressions[expr_handle] { crate::Expression::Access { base, index } => { is_non_uniform_binding_array |= self.is_nonuniform_binding_array_access(base, index); let index = GuardedIndex::Expression(index); let index_id = self.write_access_chain_index(base, index, &mut accumulated_checks, block)?; self.temp_list.push(index_id); base } crate::Expression::AccessIndex { base, index } => { // Decide whether we're indexing a struct (bounds checks // forbidden) or anything else (bounds checks required). let mut base_ty = self.fun_info[base].ty.inner_with(&self.ir_module.types); let mut base_ty_handle = self.fun_info[base].ty.handle(); let mut pointer_space = None; if let crate::TypeInner::Pointer { base, space } = *base_ty { base_ty = &self.ir_module.types[base].inner; base_ty_handle = Some(base); pointer_space = Some(space); } match *base_ty { // When indexing a struct bounds checks are forbidden. If accessing the // struct through a uniform address space pointer, where the struct has // been declared with an alternative std140 compatible layout, we must use // the remapped member index. Additionally if the previous iteration was // accessing a column of a matrix member which has been decomposed directly // into the struct, we must ensure we access the correct column. crate::TypeInner::Struct { .. } => { let index = match base_ty_handle.and_then(|handle| { self.writer.std140_compat_uniform_types.get(&handle) }) { Some(std140_type_info) if pointer_space == Some(crate::AddressSpace::Uniform) => { std140_type_info.member_indices[index as usize] + prev_decomposed_matrix_index.take().unwrap_or(0) } _ => index, }; let index_id = self.get_index_constant(index); self.temp_list.push(index_id); } // Bounds checks are not required when indexing a matrix. If indexing a // two-row matrix contained within a struct through a uniform address space // pointer then the matrix' columns will have been decomposed directly into // the containing struct. We skip adding an index to the list on this // iteration and instead adjust the index on the next iteration when // accessing the struct member. _ if is_uniform_matcx2_struct_member_access( self.ir_function, self.fun_info, self.ir_module, base, ) => { assert!(prev_decomposed_matrix_index.is_none()); prev_decomposed_matrix_index = Some(index); } _ => { // `index` is constant, so this can't possibly require // setting `is_nonuniform_binding_array_access`. // Even though the index value is statically known, `base` // may be a runtime-sized array, so we still need to go // through the bounds check process. let index_id = self.write_access_chain_index( base, GuardedIndex::Known(index), &mut accumulated_checks, block, )?; self.temp_list.push(index_id); } } base } crate::Expression::GlobalVariable(handle) => { let gv = &self.writer.global_variables[handle]; break gv.access_id; } crate::Expression::LocalVariable(variable) => { let local_var = &self.function.variables[&variable]; break local_var.id; } crate::Expression::FunctionArgument(index) => { break self.function.parameter_id(index); } ref other => unimplemented!("Unexpected pointer expression {:?}", other), } }; let (pointer_id, expr_pointer) = if self.temp_list.is_empty() { ( root_id, ExpressionPointer::Ready { pointer_id: root_id, }, ) } else { self.temp_list.reverse(); let pointer_id = self.gen_id(); let access = Instruction::access_chain(result_type_id, pointer_id, root_id, &self.temp_list); // If we generated some bounds checks, we need to leave it to our // caller to generate the branch, the access, the load or store, and // the zero value (for loads). Otherwise, we can emit the access // ourselves, and just hand them the id of the pointer. let expr_pointer = match accumulated_checks { Some(condition) => ExpressionPointer::Conditional { condition, access }, None => { block.body.push(access); ExpressionPointer::Ready { pointer_id } } }; (pointer_id, expr_pointer) }; // Subsequent load, store and atomic operations require the pointer to be decorated as NonUniform // if the binding array was accessed with a non-uniform index // see VUID-RuntimeSpirv-NonUniform-06274 if is_non_uniform_binding_array { self.writer .decorate_non_uniform_binding_array_access(pointer_id)?; } Ok(expr_pointer) } fn is_nonuniform_binding_array_access( &mut self, base: Handle, index: Handle, ) -> bool { let crate::Expression::GlobalVariable(var_handle) = self.ir_function.expressions[base] else { return false; }; // The access chain needs to be decorated as NonUniform // see VUID-RuntimeSpirv-NonUniform-06274 let gvar = &self.ir_module.global_variables[var_handle]; let crate::TypeInner::BindingArray { .. } = self.ir_module.types[gvar.ty].inner else { return false; }; self.fun_info[index].uniformity.non_uniform_result.is_some() } /// Compute a single index operand to an `OpAccessChain` instruction. /// /// Given that we are indexing `base` with `index`, apply the appropriate /// bounds check policies, emitting code to `block` to clamp `index` or /// determine whether it's in bounds. Return the SPIR-V instruction id of /// the index value we should actually use. /// /// Extend `accumulated_checks` to include the results of any needed bounds /// checks. See [`BlockContext::extend_bounds_check_condition_chain`]. fn write_access_chain_index( &mut self, base: Handle, index: GuardedIndex, accumulated_checks: &mut Option, block: &mut Block, ) -> Result { match self.write_bounds_check(base, index, block)? { BoundsCheckResult::KnownInBounds(known_index) => { // Even if the index is known, `OpAccessChain` // requires expression operands, not literals. let scalar = crate::Literal::U32(known_index); Ok(self.writer.get_constant_scalar(scalar)) } BoundsCheckResult::Computed(computed_index_id) => Ok(computed_index_id), BoundsCheckResult::Conditional { condition_id: condition, index_id: index, } => { self.extend_bounds_check_condition_chain(accumulated_checks, condition, block); // Use the index from the `Access` expression unchanged. Ok(index) } } } /// Add a condition to a chain of bounds checks. /// /// As we build an `OpAccessChain` instruction govered by /// [`BoundsCheckPolicy::ReadZeroSkipWrite`], we accumulate a chain of /// dynamic bounds checks, one for each index in the chain, which must all /// be true for that `OpAccessChain`'s execution to be well-defined. This /// function adds the boolean instruction id `comparison_id` to `chain`. /// /// If `chain` is `None`, that means there are no bounds checks in the chain /// yet. If chain is `Some(id)`, then `id` is the conjunction of all the /// bounds checks in the chain. /// /// When we have multiple bounds checks, we combine them with /// `OpLogicalAnd`, not a short-circuit branch. This means we might do /// comparisons we don't need to, but we expect these checks to almost /// always succeed, and keeping branches to a minimum is essential. /// /// [`BoundsCheckPolicy::ReadZeroSkipWrite`]: crate::proc::BoundsCheckPolicy fn extend_bounds_check_condition_chain( &mut self, chain: &mut Option, comparison_id: Word, block: &mut Block, ) { match *chain { Some(ref mut prior_checks) => { let combined = self.gen_id(); block.body.push(Instruction::binary( spirv::Op::LogicalAnd, self.writer.get_bool_type_id(), combined, *prior_checks, comparison_id, )); *prior_checks = combined; } None => { // Start a fresh chain of checks. *chain = Some(comparison_id); } } } fn write_checked_load( &mut self, pointer: Handle, block: &mut Block, access_type_adjustment: AccessTypeAdjustment, result_type_id: Word, ) -> Result { if let Some(result_id) = self.maybe_write_uniform_matcx2_dynamic_access(pointer, block)? { Ok(result_id) } else if let Some(result_id) = self.maybe_write_load_uniform_matcx2_struct_member(pointer, block)? { Ok(result_id) } else { // If `pointer` refers to a uniform address space pointer to a type // which was declared using a std140 compatible type variant (i.e. // is a two-row matrix, or a struct or array containing such a // matrix) we must ensure the access chain and the type of the load // instruction use the std140 compatible type variant. struct WrappedLoad { access_type_adjustment: AccessTypeAdjustment, r#type: Handle, } let mut wrapped_load = None; if let crate::TypeInner::Pointer { base: pointer_base_type, space: crate::AddressSpace::Uniform, } = *self.fun_info[pointer].ty.inner_with(&self.ir_module.types) { if self .writer .std140_compat_uniform_types .contains_key(&pointer_base_type) { wrapped_load = Some(WrappedLoad { access_type_adjustment: AccessTypeAdjustment::UseStd140CompatType, r#type: pointer_base_type, }); }; }; let (load_type_id, access_type_adjustment) = match wrapped_load { Some(ref wrapped_load) => ( self.writer.std140_compat_uniform_types[&wrapped_load.r#type].type_id, wrapped_load.access_type_adjustment, ), None => (result_type_id, access_type_adjustment), }; let load_id = match self.write_access_chain(pointer, block, access_type_adjustment)? { ExpressionPointer::Ready { pointer_id } => { let id = self.gen_id(); let atomic_space = match *self.fun_info[pointer].ty.inner_with(&self.ir_module.types) { crate::TypeInner::Pointer { base, space } => { match self.ir_module.types[base].inner { crate::TypeInner::Atomic { .. } => Some(space), _ => None, } } _ => None, }; let instruction = if let Some(space) = atomic_space { let (semantics, scope) = space.to_spirv_semantics_and_scope(); let scope_constant_id = self.get_scope_constant(scope as u32); let semantics_id = self.get_index_constant(semantics.bits()); Instruction::atomic_load( result_type_id, id, pointer_id, scope_constant_id, semantics_id, ) } else { Instruction::load(load_type_id, id, pointer_id, None) }; block.body.push(instruction); id } ExpressionPointer::Conditional { condition, access } => { //TODO: support atomics? self.write_conditional_indexed_load( load_type_id, condition, block, move |id_gen, block| { // The in-bounds path. Perform the access and the load. let pointer_id = access.result_id.unwrap(); let value_id = id_gen.next(); block.body.push(access); block.body.push(Instruction::load( load_type_id, value_id, pointer_id, None, )); value_id }, ) } }; match wrapped_load { Some(ref wrapped_load) => { // If we loaded a std140 compat type then we must call the // function to convert the loaded value to the regular type. let result_id = self.gen_id(); let function_id = self.writer.wrapped_functions [&WrappedFunction::ConvertFromStd140CompatType { r#type: wrapped_load.r#type, }]; block.body.push(Instruction::function_call( result_type_id, result_id, function_id, &[load_id], )); Ok(result_id) } None => Ok(load_id), } } } fn spill_to_internal_variable(&mut self, base: Handle, block: &mut Block) { use indexmap::map::Entry; // Make sure we have an internal variable to spill `base` to. let spill_variable_id = match self.function.spilled_composites.entry(base) { Entry::Occupied(preexisting) => preexisting.get().id, Entry::Vacant(vacant) => { // Generate a new internal variable of the appropriate // type for `base`. let pointer_type_id = self.writer.get_resolution_pointer_id( &self.fun_info[base].ty, spirv::StorageClass::Function, ); let id = self.writer.id_gen.next(); vacant.insert(super::LocalVariable { id, instruction: Instruction::variable( pointer_type_id, id, spirv::StorageClass::Function, None, ), }); id } }; // Perform the store even if we already had a spill variable for `base`. // Consider this code: // // var x = ...; // var y = ...; // var z = ...; // for (i = 0; i<2; i++) { // let a = array(i, i, i); // if (i == 0) { // x += a[y]; // } else [ // x += a[z]; // } // } // // The value of `a` needs to be spilled so we can subscript it with `y` and `z`. // // When we generate SPIR-V for `a[y]`, we will create the spill // variable, and store `a`'s value in it. // // When we generate SPIR-V for `a[z]`, we will notice that the spill // variable for `a` has already been declared, but it is still essential // that we store `a` into it, so that `a[z]` sees this iteration's value // of `a`. let base_id = self.cached[base]; block .body .push(Instruction::store(spill_variable_id, base_id, None)); } /// Generate an access to a spilled temporary, if necessary. /// /// Given `access`, an [`Access`] or [`AccessIndex`] expression that refers /// to a component of a composite value that has been spilled to a temporary /// variable, determine whether other expressions are going to use /// `access`'s value: /// /// - If so, perform the access and cache that as the value of `access`. /// /// - Otherwise, generate no code and cache no value for `access`. /// /// Return `Ok(0)` if no value was fetched, or `Ok(id)` if we loaded it into /// the instruction given by `id`. /// /// [`Access`]: crate::Expression::Access /// [`AccessIndex`]: crate::Expression::AccessIndex fn maybe_access_spilled_composite( &mut self, access: Handle, block: &mut Block, result_type_id: Word, ) -> Result { let access_uses = self.function.access_uses.get(&access).map_or(0, |r| *r); if access_uses == self.fun_info[access].ref_count { // This expression is only used by other `Access` and // `AccessIndex` expressions, so we don't need to cache a // value for it yet. Ok(0) } else { // There are other expressions that are going to expect this // expression's value to be cached, not just other `Access` or // `AccessIndex` expressions. We must actually perform the // access on the spill variable now. self.write_checked_load( access, block, AccessTypeAdjustment::IntroducePointer(spirv::StorageClass::Function), result_type_id, ) } } /// Build the instructions for matrix - matrix column operations #[allow(clippy::too_many_arguments)] fn write_matrix_matrix_column_op( &mut self, block: &mut Block, result_id: Word, result_type_id: Word, left_id: Word, right_id: Word, columns: crate::VectorSize, rows: crate::VectorSize, width: u8, op: spirv::Op, ) { self.temp_list.clear(); let vector_type_id = self.get_numeric_type_id(NumericType::Vector { size: rows, scalar: crate::Scalar::float(width), }); for index in 0..columns as u32 { let column_id_left = self.gen_id(); let column_id_right = self.gen_id(); let column_id_res = self.gen_id(); block.body.push(Instruction::composite_extract( vector_type_id, column_id_left, left_id, &[index], )); block.body.push(Instruction::composite_extract( vector_type_id, column_id_right, right_id, &[index], )); block.body.push(Instruction::binary( op, vector_type_id, column_id_res, column_id_left, column_id_right, )); self.temp_list.push(column_id_res); } block.body.push(Instruction::composite_construct( result_type_id, result_id, &self.temp_list, )); } /// Build the instructions for vector - scalar multiplication fn write_vector_scalar_mult( &mut self, block: &mut Block, result_id: Word, result_type_id: Word, vector_id: Word, scalar_id: Word, vector: &crate::TypeInner, ) { let (size, kind) = match *vector { crate::TypeInner::Vector { size, scalar: crate::Scalar { kind, .. }, } => (size, kind), _ => unreachable!(), }; let (op, operand_id) = match kind { crate::ScalarKind::Float => (spirv::Op::VectorTimesScalar, scalar_id), _ => { let operand_id = self.gen_id(); self.temp_list.clear(); self.temp_list.resize(size as usize, scalar_id); block.body.push(Instruction::composite_construct( result_type_id, operand_id, &self.temp_list, )); (spirv::Op::IMul, operand_id) } }; block.body.push(Instruction::binary( op, result_type_id, result_id, vector_id, operand_id, )); } /// Build the instructions for the arithmetic expression of a dot product /// /// The argument `extractor` is a function that maps `(result_id, /// composite_id, index)` to an instruction that extracts the `index`th /// entry of the value with ID `composite_id` and assigns it to the slot /// with id `result_id` (which must have type `result_type_id`). #[expect(clippy::too_many_arguments)] fn write_dot_product( &mut self, result_id: Word, result_type_id: Word, arg0_id: Word, arg1_id: Word, size: u32, block: &mut Block, extractor: impl Fn(Word, Word, Word) -> Instruction, ) { let mut partial_sum = self.writer.get_constant_null(result_type_id); let last_component = size - 1; for index in 0..=last_component { // compute the product of the current components let a_id = self.gen_id(); block.body.push(extractor(a_id, arg0_id, index)); let b_id = self.gen_id(); block.body.push(extractor(b_id, arg1_id, index)); let prod_id = self.gen_id(); block.body.push(Instruction::binary( spirv::Op::IMul, result_type_id, prod_id, a_id, b_id, )); // choose the id for the next sum, depending on current index let id = if index == last_component { result_id } else { self.gen_id() }; // sum the computed product with the partial sum block.body.push(Instruction::binary( spirv::Op::IAdd, result_type_id, id, partial_sum, prod_id, )); // set the id of the result as the previous partial sum partial_sum = id; } } /// Emit code for `pack4x{I,U}8[Clamp]` if capability "Int8" is available. fn write_pack4x8_optimized( &mut self, block: &mut Block, result_type_id: u32, arg0_id: u32, id: u32, is_signed: bool, should_clamp: bool, ) -> Instruction { let int_type = if is_signed { crate::ScalarKind::Sint } else { crate::ScalarKind::Uint }; let wide_vector_type = NumericType::Vector { size: crate::VectorSize::Quad, scalar: crate::Scalar { kind: int_type, width: 4, }, }; let wide_vector_type_id = self.get_numeric_type_id(wide_vector_type); let packed_vector_type_id = self.get_numeric_type_id(NumericType::Vector { size: crate::VectorSize::Quad, scalar: crate::Scalar { kind: crate::ScalarKind::Uint, width: 1, }, }); let mut wide_vector = arg0_id; if should_clamp { let (min, max, clamp_op) = if is_signed { ( crate::Literal::I32(-128), crate::Literal::I32(127), spirv::GlslStd450Op::SClamp, ) } else { ( crate::Literal::U32(0), crate::Literal::U32(255), spirv::GlslStd450Op::UClamp, ) }; let [min, max] = [min, max].map(|lit| { let scalar = self.writer.get_constant_scalar(lit); self.writer.get_constant_composite( LookupType::Local(LocalType::Numeric(wide_vector_type)), &[scalar; 4], ) }); let clamp_id = self.gen_id(); block.body.push(Instruction::ext_inst_gl_op( self.writer.gl450_ext_inst_id, clamp_op, wide_vector_type_id, clamp_id, &[wide_vector, min, max], )); wide_vector = clamp_id; } let packed_vector = self.gen_id(); block.body.push(Instruction::unary( spirv::Op::UConvert, // We truncate, so `UConvert` and `SConvert` behave identically. packed_vector_type_id, packed_vector, wide_vector, )); // The SPIR-V spec [1] defines the bit order for bit casting between a vector // and a scalar precisely as required by the WGSL spec [2]. // [1]: https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpBitcast // [2]: https://www.w3.org/TR/WGSL/#pack4xI8-builtin Instruction::unary(spirv::Op::Bitcast, result_type_id, id, packed_vector) } /// Emit code for `pack4x{I,U}8[Clamp]` if capability "Int8" is not available. fn write_pack4x8_polyfill( &mut self, block: &mut Block, result_type_id: u32, arg0_id: u32, id: u32, is_signed: bool, should_clamp: bool, ) -> Instruction { let int_type = if is_signed { crate::ScalarKind::Sint } else { crate::ScalarKind::Uint }; let uint_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32)); let int_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar { kind: int_type, width: 4, })); let mut last_instruction = Instruction::new(spirv::Op::Nop); let zero = self.writer.get_constant_scalar(crate::Literal::U32(0)); let mut preresult = zero; block .body .reserve(usize::from(VEC_LENGTH) * (2 + usize::from(is_signed))); let eight = self.writer.get_constant_scalar(crate::Literal::U32(8)); const VEC_LENGTH: u8 = 4; for i in 0..u32::from(VEC_LENGTH) { let offset = self.writer.get_constant_scalar(crate::Literal::U32(i * 8)); let mut extracted = self.gen_id(); block.body.push(Instruction::binary( spirv::Op::CompositeExtract, int_type_id, extracted, arg0_id, i, )); if is_signed { let casted = self.gen_id(); block.body.push(Instruction::unary( spirv::Op::Bitcast, uint_type_id, casted, extracted, )); extracted = casted; } if should_clamp { let (min, max, clamp_op) = if is_signed { ( crate::Literal::I32(-128), crate::Literal::I32(127), spirv::GlslStd450Op::SClamp, ) } else { ( crate::Literal::U32(0), crate::Literal::U32(255), spirv::GlslStd450Op::UClamp, ) }; let [min, max] = [min, max].map(|lit| self.writer.get_constant_scalar(lit)); let clamp_id = self.gen_id(); block.body.push(Instruction::ext_inst_gl_op( self.writer.gl450_ext_inst_id, clamp_op, result_type_id, clamp_id, &[extracted, min, max], )); extracted = clamp_id; } let is_last = i == u32::from(VEC_LENGTH - 1); if is_last { last_instruction = Instruction::quaternary( spirv::Op::BitFieldInsert, result_type_id, id, preresult, extracted, offset, eight, ) } else { let new_preresult = self.gen_id(); block.body.push(Instruction::quaternary( spirv::Op::BitFieldInsert, result_type_id, new_preresult, preresult, extracted, offset, eight, )); preresult = new_preresult; } } last_instruction } /// Emit code for `unpack4x{I,U}8` if capability "Int8" is available. fn write_unpack4x8_optimized( &mut self, block: &mut Block, result_type_id: u32, arg0_id: u32, id: u32, is_signed: bool, ) -> Instruction { let (int_type, convert_op) = if is_signed { (crate::ScalarKind::Sint, spirv::Op::SConvert) } else { (crate::ScalarKind::Uint, spirv::Op::UConvert) }; let packed_vector_type_id = self.get_numeric_type_id(NumericType::Vector { size: crate::VectorSize::Quad, scalar: crate::Scalar { kind: int_type, width: 1, }, }); // The SPIR-V spec [1] defines the bit order for bit casting between a vector // and a scalar precisely as required by the WGSL spec [2]. // [1]: https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpBitcast // [2]: https://www.w3.org/TR/WGSL/#pack4xI8-builtin let packed_vector = self.gen_id(); block.body.push(Instruction::unary( spirv::Op::Bitcast, packed_vector_type_id, packed_vector, arg0_id, )); Instruction::unary(convert_op, result_type_id, id, packed_vector) } /// Emit code for `unpack4x{I,U}8` if capability "Int8" is not available. fn write_unpack4x8_polyfill( &mut self, block: &mut Block, result_type_id: u32, arg0_id: u32, id: u32, is_signed: bool, ) -> Instruction { let (int_type, extract_op) = if is_signed { (crate::ScalarKind::Sint, spirv::Op::BitFieldSExtract) } else { (crate::ScalarKind::Uint, spirv::Op::BitFieldUExtract) }; let sint_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::I32)); let eight = self.writer.get_constant_scalar(crate::Literal::U32(8)); let int_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar { kind: int_type, width: 4, })); block .body .reserve(usize::from(VEC_LENGTH) * 2 + usize::from(is_signed)); let arg_id = if is_signed { let new_arg_id = self.gen_id(); block.body.push(Instruction::unary( spirv::Op::Bitcast, sint_type_id, new_arg_id, arg0_id, )); new_arg_id } else { arg0_id }; const VEC_LENGTH: u8 = 4; let parts: [_; VEC_LENGTH as usize] = core::array::from_fn(|_| self.gen_id()); for (i, part_id) in parts.into_iter().enumerate() { let index = self .writer .get_constant_scalar(crate::Literal::U32(i as u32 * 8)); block.body.push(Instruction::ternary( extract_op, int_type_id, part_id, arg_id, index, eight, )); } Instruction::composite_construct(result_type_id, id, &parts) } /// Generate one or more SPIR-V blocks for `naga_block`. /// /// Use `label_id` as the label for the SPIR-V entry point block. /// /// If control reaches the end of the SPIR-V block, terminate it according /// to `exit`. This function's return value indicates whether it acted on /// this parameter or not; see [`BlockExitDisposition`]. /// /// If the block contains [`Break`] or [`Continue`] statements, /// `loop_context` supplies the labels of the SPIR-V blocks to jump to. If /// either of these labels are `None`, then it should have been a Naga /// validation error for the corresponding statement to occur in this /// context. /// /// [`Break`]: Statement::Break /// [`Continue`]: Statement::Continue fn write_block( &mut self, label_id: Word, naga_block: &crate::Block, exit: BlockExit, loop_context: LoopContext, debug_info: Option<&DebugInfoInner>, ) -> Result { let mut block = Block::new(label_id); for (statement, span) in naga_block.span_iter() { if let (Some(debug_info), false) = ( debug_info, matches!( statement, &(Statement::Block(..) | Statement::Break | Statement::Continue | Statement::Kill | Statement::Return { .. } | Statement::Loop { .. }) ), ) { let loc: crate::SourceLocation = span.location(debug_info.source_code); block.body.push(Instruction::line( debug_info.source_file_id, loc.line_number, loc.line_position, )); }; match *statement { Statement::Emit(ref range) => { for handle in range.clone() { // omit const expressions as we've already cached those if !self.expression_constness.is_const(handle) { self.cache_expression_value(handle, &mut block)?; } } } Statement::Block(ref block_statements) => { let scope_id = self.gen_id(); self.function.consume(block, Instruction::branch(scope_id)); let merge_id = self.gen_id(); let merge_used = self.write_block( scope_id, block_statements, BlockExit::Branch { target: merge_id }, loop_context, debug_info, )?; match merge_used { BlockExitDisposition::Used => { block = Block::new(merge_id); } BlockExitDisposition::Discarded => { return Ok(BlockExitDisposition::Discarded); } } } Statement::If { condition, ref accept, ref reject, } => { // In spirv 1.6, in a conditional branch the two block ids // of the branches can't have the same label. If `accept` // and `reject` are both empty (e.g. in `if (condition) {}`) // merge id will be both labels. Because both branches are // empty, we can skip the if statement. if !(accept.is_empty() && reject.is_empty()) { let condition_id = self.cached[condition]; let merge_id = self.gen_id(); block.body.push(Instruction::selection_merge( merge_id, spirv::SelectionControl::NONE, )); let accept_id = if accept.is_empty() { None } else { Some(self.gen_id()) }; let reject_id = if reject.is_empty() { None } else { Some(self.gen_id()) }; self.function.consume( block, Instruction::branch_conditional( condition_id, accept_id.unwrap_or(merge_id), reject_id.unwrap_or(merge_id), ), ); if let Some(block_id) = accept_id { // We can ignore the `BlockExitDisposition` returned here because, // even if `merge_id` is not actually reachable, it is always // referred to by the `OpSelectionMerge` instruction we emitted // earlier. let _ = self.write_block( block_id, accept, BlockExit::Branch { target: merge_id }, loop_context, debug_info, )?; } if let Some(block_id) = reject_id { // We can ignore the `BlockExitDisposition` returned here because, // even if `merge_id` is not actually reachable, it is always // referred to by the `OpSelectionMerge` instruction we emitted // earlier. let _ = self.write_block( block_id, reject, BlockExit::Branch { target: merge_id }, loop_context, debug_info, )?; } block = Block::new(merge_id); } } Statement::Switch { selector, ref cases, } => { let selector_id = self.cached[selector]; let merge_id = self.gen_id(); block.body.push(Instruction::selection_merge( merge_id, spirv::SelectionControl::NONE, )); let mut default_id = None; // id of previous empty fall-through case let mut last_id = None; let mut raw_cases = Vec::with_capacity(cases.len()); let mut case_ids = Vec::with_capacity(cases.len()); for case in cases.iter() { // take id of previous empty fall-through case or generate a new one let label_id = last_id.take().unwrap_or_else(|| self.gen_id()); if case.fall_through && case.body.is_empty() { last_id = Some(label_id); } case_ids.push(label_id); match case.value { crate::SwitchValue::I32(value) => { raw_cases.push(super::instructions::Case { value: value as Word, label_id, }); } crate::SwitchValue::U32(value) => { raw_cases.push(super::instructions::Case { value, label_id }); } crate::SwitchValue::Default => { default_id = Some(label_id); } } } let default_id = default_id.unwrap(); self.function.consume( block, Instruction::switch(selector_id, default_id, &raw_cases), ); let inner_context = LoopContext { break_id: Some(merge_id), ..loop_context }; for (i, (case, label_id)) in cases .iter() .zip(case_ids.iter()) .filter(|&(case, _)| !(case.fall_through && case.body.is_empty())) .enumerate() { let case_finish_id = if case.fall_through { case_ids[i + 1] } else { merge_id }; // We can ignore the `BlockExitDisposition` returned here because // `case_finish_id` is always referred to by either: // // - the `OpSwitch`, if it's the next case's label for a // fall-through, or // // - the `OpSelectionMerge`, if it's the switch's overall merge // block because there's no fall-through. let _ = self.write_block( *label_id, &case.body, BlockExit::Branch { target: case_finish_id, }, inner_context, debug_info, )?; } block = Block::new(merge_id); } Statement::Loop { ref body, ref continuing, break_if, } => { let preamble_id = self.gen_id(); self.function .consume(block, Instruction::branch(preamble_id)); let merge_id = self.gen_id(); let body_id = self.gen_id(); let continuing_id = self.gen_id(); // SPIR-V requires the continuing to the `OpLoopMerge`, // so we have to start a new block with it. block = Block::new(preamble_id); // HACK the loop statement is begin with branch instruction, // so we need to put `OpLine` debug info before merge instruction if let Some(debug_info) = debug_info { let loc: crate::SourceLocation = span.location(debug_info.source_code); block.body.push(Instruction::line( debug_info.source_file_id, loc.line_number, loc.line_position, )) } block.body.push(Instruction::loop_merge( merge_id, continuing_id, spirv::SelectionControl::NONE, )); if self.force_loop_bounding { block = self.write_force_bounded_loop_instructions(block, merge_id); } self.function.consume(block, Instruction::branch(body_id)); // We can ignore the `BlockExitDisposition` returned here because, // even if `continuing_id` is not actually reachable, it is always // referred to by the `OpLoopMerge` instruction we emitted earlier. let _ = self.write_block( body_id, body, BlockExit::Branch { target: continuing_id, }, LoopContext { continuing_id: Some(continuing_id), break_id: Some(merge_id), }, debug_info, )?; let exit = match break_if { Some(condition) => BlockExit::BreakIf { condition, preamble_id, }, None => BlockExit::Branch { target: preamble_id, }, }; // We can ignore the `BlockExitDisposition` returned here because, // even if `merge_id` is not actually reachable, it is always referred // to by the `OpLoopMerge` instruction we emitted earlier. let _ = self.write_block( continuing_id, continuing, exit, LoopContext { continuing_id: None, break_id: Some(merge_id), }, debug_info, )?; block = Block::new(merge_id); } Statement::Break => { self.function .consume(block, Instruction::branch(loop_context.break_id.unwrap())); return Ok(BlockExitDisposition::Discarded); } Statement::Continue => { self.function.consume( block, Instruction::branch(loop_context.continuing_id.unwrap()), ); return Ok(BlockExitDisposition::Discarded); } Statement::Return { value: Some(value) } => { let value_id = self.cached[value]; let instruction = match self.function.entry_point_context { // If this is an entry point, and we need to return anything, // let's instead store the output variables and return `void`. Some(ref context) => self.writer.write_entry_point_return( value_id, self.ir_function.result.as_ref().unwrap(), &context.results, &mut block.body, )?, None => Instruction::return_value(value_id), }; self.function.consume(block, instruction); return Ok(BlockExitDisposition::Discarded); } Statement::Return { value: None } => { self.function.consume(block, Instruction::return_void()); return Ok(BlockExitDisposition::Discarded); } Statement::Kill => { self.function.consume(block, Instruction::kill()); return Ok(BlockExitDisposition::Discarded); } Statement::ControlBarrier(flags) => { self.writer.write_control_barrier(flags, &mut block.body); } Statement::MemoryBarrier(flags) => { self.writer.write_memory_barrier(flags, &mut block); } Statement::Store { pointer, value } => { let value_id = self.cached[value]; match self.write_access_chain( pointer, &mut block, AccessTypeAdjustment::None, )? { ExpressionPointer::Ready { pointer_id } => { let atomic_space = match *self.fun_info[pointer] .ty .inner_with(&self.ir_module.types) { crate::TypeInner::Pointer { base, space } => { match self.ir_module.types[base].inner { crate::TypeInner::Atomic { .. } => Some(space), _ => None, } } _ => None, }; let instruction = if let Some(space) = atomic_space { let (semantics, scope) = space.to_spirv_semantics_and_scope(); let scope_constant_id = self.get_scope_constant(scope as u32); let semantics_id = self.get_index_constant(semantics.bits()); Instruction::atomic_store( pointer_id, scope_constant_id, semantics_id, value_id, ) } else { Instruction::store(pointer_id, value_id, None) }; block.body.push(instruction); } ExpressionPointer::Conditional { condition, access } => { let mut selection = Selection::start(&mut block, ()); selection.if_true(self, condition, ()); // The in-bounds path. Perform the access and the store. let pointer_id = access.result_id.unwrap(); selection.block().body.push(access); selection .block() .body .push(Instruction::store(pointer_id, value_id, None)); // Finish the in-bounds block and start the merge block. This // is the block we'll leave current on return. selection.finish(self, ()); } }; } Statement::ImageStore { image, coordinate, array_index, value, } => self.write_image_store(image, coordinate, array_index, value, &mut block)?, Statement::Call { function: local_function, ref arguments, result, } => { let id = self.gen_id(); self.temp_list.clear(); for &argument in arguments { self.temp_list.push(self.cached[argument]); } let type_id = match result { Some(expr) => { self.cached[expr] = id; self.get_expression_type_id(&self.fun_info[expr].ty) } None => self.writer.void_type, }; block.body.push(Instruction::function_call( type_id, id, self.writer.lookup_function[&local_function], &self.temp_list, )); } Statement::Atomic { pointer, ref fun, value, result, } => { let id = self.gen_id(); // Compare-and-exchange operations produce a struct result, // so use `result`'s type if it is available. For no-result // operations, fall back to `value`'s type. let result_type_id = self.get_expression_type_id(&self.fun_info[result.unwrap_or(value)].ty); if let Some(result) = result { self.cached[result] = id; } let pointer_id = match self.write_access_chain( pointer, &mut block, AccessTypeAdjustment::None, )? { ExpressionPointer::Ready { pointer_id } => pointer_id, ExpressionPointer::Conditional { .. } => { return Err(Error::FeatureNotImplemented( "Atomics out-of-bounds handling", )); } }; let space = self.fun_info[pointer] .ty .inner_with(&self.ir_module.types) .pointer_space() .unwrap(); let (semantics, scope) = space.to_spirv_semantics_and_scope(); let scope_constant_id = self.get_scope_constant(scope as u32); let semantics_id = self.get_index_constant(semantics.bits()); let value_id = self.cached[value]; let value_inner = self.fun_info[value].ty.inner_with(&self.ir_module.types); let crate::TypeInner::Scalar(scalar) = *value_inner else { return Err(Error::FeatureNotImplemented( "Atomics with non-scalar values", )); }; let instruction = match *fun { crate::AtomicFunction::Add => { let spirv_op = match scalar.kind { crate::ScalarKind::Sint | crate::ScalarKind::Uint => { spirv::Op::AtomicIAdd } crate::ScalarKind::Float => spirv::Op::AtomicFAddEXT, _ => unimplemented!(), }; Instruction::atomic_binary( spirv_op, result_type_id, id, pointer_id, scope_constant_id, semantics_id, value_id, ) } crate::AtomicFunction::Subtract => { let (spirv_op, value_id) = match scalar.kind { crate::ScalarKind::Sint | crate::ScalarKind::Uint => { (spirv::Op::AtomicISub, value_id) } crate::ScalarKind::Float => { // HACK: SPIR-V doesn't have a atomic subtraction, // so we add the negated value instead. let neg_result_id = self.gen_id(); block.body.push(Instruction::unary( spirv::Op::FNegate, result_type_id, neg_result_id, value_id, )); (spirv::Op::AtomicFAddEXT, neg_result_id) } _ => unimplemented!(), }; Instruction::atomic_binary( spirv_op, result_type_id, id, pointer_id, scope_constant_id, semantics_id, value_id, ) } crate::AtomicFunction::And => { let spirv_op = match scalar.kind { crate::ScalarKind::Sint | crate::ScalarKind::Uint => { spirv::Op::AtomicAnd } _ => unimplemented!(), }; Instruction::atomic_binary( spirv_op, result_type_id, id, pointer_id, scope_constant_id, semantics_id, value_id, ) } crate::AtomicFunction::InclusiveOr => { let spirv_op = match scalar.kind { crate::ScalarKind::Sint | crate::ScalarKind::Uint => { spirv::Op::AtomicOr } _ => unimplemented!(), }; Instruction::atomic_binary( spirv_op, result_type_id, id, pointer_id, scope_constant_id, semantics_id, value_id, ) } crate::AtomicFunction::ExclusiveOr => { let spirv_op = match scalar.kind { crate::ScalarKind::Sint | crate::ScalarKind::Uint => { spirv::Op::AtomicXor } _ => unimplemented!(), }; Instruction::atomic_binary( spirv_op, result_type_id, id, pointer_id, scope_constant_id, semantics_id, value_id, ) } crate::AtomicFunction::Min => { let spirv_op = match scalar.kind { crate::ScalarKind::Sint => spirv::Op::AtomicSMin, crate::ScalarKind::Uint => spirv::Op::AtomicUMin, _ => unimplemented!(), }; Instruction::atomic_binary( spirv_op, result_type_id, id, pointer_id, scope_constant_id, semantics_id, value_id, ) } crate::AtomicFunction::Max => { let spirv_op = match scalar.kind { crate::ScalarKind::Sint => spirv::Op::AtomicSMax, crate::ScalarKind::Uint => spirv::Op::AtomicUMax, _ => unimplemented!(), }; Instruction::atomic_binary( spirv_op, result_type_id, id, pointer_id, scope_constant_id, semantics_id, value_id, ) } crate::AtomicFunction::Exchange { compare: None } => { Instruction::atomic_binary( spirv::Op::AtomicExchange, result_type_id, id, pointer_id, scope_constant_id, semantics_id, value_id, ) } crate::AtomicFunction::Exchange { compare: Some(cmp) } => { let scalar_type_id = self.get_numeric_type_id(NumericType::Scalar(scalar)); let bool_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::BOOL)); let cas_result_id = self.gen_id(); let equality_result_id = self.gen_id(); let equality_operator = match scalar.kind { crate::ScalarKind::Sint | crate::ScalarKind::Uint => { spirv::Op::IEqual } _ => unimplemented!(), }; let mut cas_instr = Instruction::new(spirv::Op::AtomicCompareExchange); cas_instr.set_type(scalar_type_id); cas_instr.set_result(cas_result_id); cas_instr.add_operand(pointer_id); cas_instr.add_operand(scope_constant_id); cas_instr.add_operand(semantics_id); // semantics if equal cas_instr.add_operand(semantics_id); // semantics if not equal cas_instr.add_operand(value_id); cas_instr.add_operand(self.cached[cmp]); block.body.push(cas_instr); block.body.push(Instruction::binary( equality_operator, bool_type_id, equality_result_id, cas_result_id, self.cached[cmp], )); Instruction::composite_construct( result_type_id, id, &[cas_result_id, equality_result_id], ) } }; block.body.push(instruction); } Statement::ImageAtomic { image, coordinate, array_index, fun, value, } => { self.write_image_atomic( image, coordinate, array_index, fun, value, &mut block, )?; } Statement::WorkGroupUniformLoad { pointer, result } => { self.writer .write_control_barrier(crate::Barrier::WORK_GROUP, &mut block.body); let result_type_id = self.get_expression_type_id(&self.fun_info[result].ty); // Match `Expression::Load` behavior, including `OpAtomicLoad` when // loading from a pointer to `atomic`. let id = self.write_checked_load( pointer, &mut block, AccessTypeAdjustment::None, result_type_id, )?; self.cached[result] = id; self.writer .write_control_barrier(crate::Barrier::WORK_GROUP, &mut block.body); } Statement::RayQuery { query, ref fun } => { self.write_ray_query_function(query, fun, &mut block); } Statement::SubgroupBallot { result, ref predicate, } => { self.write_subgroup_ballot(predicate, result, &mut block)?; } Statement::SubgroupCollectiveOperation { ref op, ref collective_op, argument, result, } => { self.write_subgroup_operation(op, collective_op, argument, result, &mut block)?; } Statement::SubgroupGather { ref mode, argument, result, } => { self.write_subgroup_gather(mode, argument, result, &mut block)?; } Statement::CooperativeStore { target, ref data } => { let target_id = self.cached[target]; let layout = if data.row_major { spirv::CooperativeMatrixLayout::RowMajorKHR } else { spirv::CooperativeMatrixLayout::ColumnMajorKHR }; let layout_id = self.get_index_constant(layout as u32); let stride_id = self.cached[data.stride]; match self.write_access_chain( data.pointer, &mut block, AccessTypeAdjustment::None, )? { ExpressionPointer::Ready { pointer_id } => { block.body.push(Instruction::coop_store( target_id, pointer_id, layout_id, stride_id, )); } ExpressionPointer::Conditional { condition, access } => { let mut selection = Selection::start(&mut block, ()); selection.if_true(self, condition, ()); // The in-bounds path. Perform the access and the store. let pointer_id = access.result_id.unwrap(); selection.block().body.push(access); selection.block().body.push(Instruction::coop_store( target_id, pointer_id, layout_id, stride_id, )); // Finish the in-bounds block and start the merge block. This // is the block we'll leave current on return. selection.finish(self, ()); } }; } Statement::RayPipelineFunction(_) => unreachable!(), } } let termination = match exit { // We're generating code for the top-level Block of the function, so we // need to end it with some kind of return instruction. BlockExit::Return => match self.ir_function.result { Some(ref result) if self.function.entry_point_context.is_none() => { let type_id = self.get_handle_type_id(result.ty); let null_id = self.writer.get_constant_null(type_id); Instruction::return_value(null_id) } _ => Instruction::return_void(), }, BlockExit::Branch { target } => Instruction::branch(target), BlockExit::BreakIf { condition, preamble_id, } => { let condition_id = self.cached[condition]; Instruction::branch_conditional( condition_id, loop_context.break_id.unwrap(), preamble_id, ) } }; self.function.consume(block, termination); Ok(BlockExitDisposition::Used) } pub(super) fn write_function_body( &mut self, entry_id: Word, debug_info: Option<&DebugInfoInner>, ) -> Result<(), Error> { // We can ignore the `BlockExitDisposition` returned here because // `BlockExit::Return` doesn't refer to a block. let _ = self.write_block( entry_id, &self.ir_function.body, BlockExit::Return, LoopContext::default(), debug_info, )?; Ok(()) } } naga-29.0.3/src/back/spv/f16_polyfill.rs000064400000000000000000000061551046102023000160070ustar 00000000000000/*! This module provides functionality for polyfilling `f16` input/output variables when the `StorageInputOutput16` capability is not available or disabled. It works by: 1. Declaring `f16` I/O variables as `f32` in SPIR-V 2. Converting between `f16` and `f32` at runtime using `OpFConvert` 3. Maintaining mappings to track which variables need conversion */ use crate::back::spv::{Instruction, LocalType, NumericType, Word}; use alloc::vec::Vec; /// Manages `f16` I/O polyfill state and operations. #[derive(Default)] pub(in crate::back::spv) struct F16IoPolyfill { use_native: bool, io_var_to_f32_type: crate::FastHashMap, } impl F16IoPolyfill { pub fn new(use_storage_input_output_16: bool) -> Self { Self { use_native: use_storage_input_output_16, io_var_to_f32_type: crate::FastHashMap::default(), } } pub fn needs_polyfill(&self, ty_inner: &crate::TypeInner) -> bool { use crate::{ScalarKind as Sk, TypeInner}; !self.use_native && match *ty_inner { TypeInner::Scalar(ref s) if s.kind == Sk::Float && s.width == 2 => true, TypeInner::Vector { scalar, .. } if scalar.kind == Sk::Float && scalar.width == 2 => { true } _ => false, } } pub fn register_io_var(&mut self, variable_id: Word, f32_type_id: Word) { self.io_var_to_f32_type.insert(variable_id, f32_type_id); } pub fn get_f32_io_type(&self, variable_id: Word) -> Option { self.io_var_to_f32_type.get(&variable_id).copied() } pub fn emit_f16_to_f32_conversion( f16_value_id: Word, f32_type_id: Word, converted_id: Word, body: &mut Vec, ) { body.push(Instruction::unary( spirv::Op::FConvert, f32_type_id, converted_id, f16_value_id, )); } pub fn emit_f32_to_f16_conversion( f32_value_id: Word, f16_type_id: Word, converted_id: Word, body: &mut Vec, ) { body.push(Instruction::unary( spirv::Op::FConvert, f16_type_id, converted_id, f32_value_id, )); } pub fn create_polyfill_type(ty_inner: &crate::TypeInner) -> Option { use crate::{ScalarKind as Sk, TypeInner}; match *ty_inner { TypeInner::Scalar(ref s) if s.kind == Sk::Float && s.width == 2 => { Some(LocalType::Numeric(NumericType::Scalar(crate::Scalar::F32))) } TypeInner::Vector { size, scalar } if scalar.kind == Sk::Float && scalar.width == 2 => { Some(LocalType::Numeric(NumericType::Vector { size, scalar: crate::Scalar::F32, })) } _ => None, } } } impl crate::back::spv::reclaimable::Reclaimable for F16IoPolyfill { fn reclaim(mut self) -> Self { self.io_var_to_f32_type = self.io_var_to_f32_type.reclaim(); self } } naga-29.0.3/src/back/spv/helpers.rs000064400000000000000000000155731046102023000151470ustar 00000000000000use alloc::{vec, vec::Vec}; use arrayvec::ArrayVec; use spirv::Word; use crate::{Handle, UniqueArena}; pub(super) fn bytes_to_words(bytes: &[u8]) -> Vec { bytes .chunks(4) .map(|chars| chars.iter().rev().fold(0u32, |u, c| (u << 8) | *c as u32)) .collect() } pub(super) fn string_to_words(input: &str) -> Vec { let bytes = input.as_bytes(); str_bytes_to_words(bytes) } pub(super) fn str_bytes_to_words(bytes: &[u8]) -> Vec { let mut words = bytes_to_words(bytes); if bytes.len().is_multiple_of(4) { // nul-termination words.push(0x0u32); } words } /// split a string into chunks and keep utf8 valid #[allow(unstable_name_collisions)] pub(super) fn string_to_byte_chunks(input: &str, limit: usize) -> Vec<&[u8]> { let mut offset: usize = 0; let mut start: usize = 0; let mut words = vec![]; while offset < input.len() { offset = input.floor_char_boundary_polyfill(offset + limit); // Clippy wants us to call as_bytes() first to avoid the UTF-8 check, // but we want to assert the output is valid UTF-8. #[allow(clippy::sliced_string_as_bytes)] words.push(input[start..offset].as_bytes()); start = offset; } words } pub(super) const fn map_storage_class(space: crate::AddressSpace) -> spirv::StorageClass { match space { crate::AddressSpace::Handle => spirv::StorageClass::UniformConstant, crate::AddressSpace::Function => spirv::StorageClass::Function, crate::AddressSpace::Private => spirv::StorageClass::Private, crate::AddressSpace::Storage { .. } => spirv::StorageClass::StorageBuffer, crate::AddressSpace::Uniform => spirv::StorageClass::Uniform, crate::AddressSpace::WorkGroup => spirv::StorageClass::Workgroup, crate::AddressSpace::Immediate => spirv::StorageClass::PushConstant, crate::AddressSpace::TaskPayload => spirv::StorageClass::TaskPayloadWorkgroupEXT, crate::AddressSpace::IncomingRayPayload | crate::AddressSpace::RayPayload => unreachable!(), } } pub(super) fn contains_builtin( binding: Option<&crate::Binding>, ty: Handle, arena: &UniqueArena, built_in: crate::BuiltIn, ) -> bool { if let Some(&crate::Binding::BuiltIn(bi)) = binding { bi == built_in } else if let crate::TypeInner::Struct { ref members, .. } = arena[ty].inner { members .iter() .any(|member| contains_builtin(member.binding.as_ref(), member.ty, arena, built_in)) } else { false // unreachable } } impl crate::AddressSpace { pub(super) const fn to_spirv_semantics_and_scope( self, ) -> (spirv::MemorySemantics, spirv::Scope) { match self { Self::Storage { .. } => (spirv::MemorySemantics::empty(), spirv::Scope::Device), Self::WorkGroup => (spirv::MemorySemantics::empty(), spirv::Scope::Workgroup), Self::Uniform => (spirv::MemorySemantics::empty(), spirv::Scope::Device), Self::Handle => (spirv::MemorySemantics::empty(), spirv::Scope::Device), _ => (spirv::MemorySemantics::empty(), spirv::Scope::Invocation), } } } /// Return true if the global requires a type decorated with `Block`. /// /// See [`back::spv::GlobalVariable`] for details. /// /// [`back::spv::GlobalVariable`]: super::GlobalVariable pub fn global_needs_wrapper(ir_module: &crate::Module, var: &crate::GlobalVariable) -> bool { match var.space { crate::AddressSpace::Uniform | crate::AddressSpace::Storage { .. } | crate::AddressSpace::Immediate => {} _ => return false, }; match ir_module.types[var.ty].inner { crate::TypeInner::Struct { ref members, span: _, } => match members.last() { Some(member) => match ir_module.types[member.ty].inner { // Structs with dynamically sized arrays can't be copied and can't be wrapped. crate::TypeInner::Array { size: crate::ArraySize::Dynamic, .. } => false, _ => true, }, None => false, }, crate::TypeInner::BindingArray { .. } => false, // if it's not a structure or a binding array, let's wrap it to be able to put "Block" _ => true, } } /// Returns true if `pointer` refers to two-row matrix which is a member of a /// struct in the [`crate::AddressSpace::Uniform`] address space. pub fn is_uniform_matcx2_struct_member_access( ir_function: &crate::Function, fun_info: &crate::valid::FunctionInfo, ir_module: &crate::Module, pointer: Handle, ) -> bool { if let crate::TypeInner::Pointer { base: pointer_base_type, space: crate::AddressSpace::Uniform, } = *fun_info[pointer].ty.inner_with(&ir_module.types) { if let crate::TypeInner::Matrix { rows: crate::VectorSize::Bi, .. } = ir_module.types[pointer_base_type].inner { if let crate::Expression::AccessIndex { base: parent_pointer, .. } = ir_function.expressions[pointer] { if let crate::TypeInner::Pointer { base: parent_type, .. } = *fun_info[parent_pointer].ty.inner_with(&ir_module.types) { if let crate::TypeInner::Struct { .. } = ir_module.types[parent_type].inner { return true; } } } } } false } ///HACK: this is taken from std unstable, remove it when std's floor_char_boundary is stable /// and available in our msrv. trait U8Internal { fn is_utf8_char_boundary_polyfill(&self) -> bool; } impl U8Internal for u8 { fn is_utf8_char_boundary_polyfill(&self) -> bool { // This is bit magic equivalent to: b < 128 || b >= 192 (*self as i8) >= -0x40 } } trait StrUnstable { fn floor_char_boundary_polyfill(&self, index: usize) -> usize; } impl StrUnstable for str { fn floor_char_boundary_polyfill(&self, index: usize) -> usize { if index >= self.len() { self.len() } else { let lower_bound = index.saturating_sub(3); let new_index = self.as_bytes()[lower_bound..=index] .iter() .rposition(|b| b.is_utf8_char_boundary_polyfill()); // We know that the character boundary will be within four bytes. lower_bound + new_index.unwrap() } } } pub enum BindingDecorations { BuiltIn(spirv::BuiltIn, ArrayVec), Location { location: u32, others: ArrayVec, /// If this is `Some`, use Decoration::Index with blend_src as an operand blend_src: Option, }, None, } naga-29.0.3/src/back/spv/image.rs000064400000000000000000001450231046102023000145610ustar 00000000000000/*! Generating SPIR-V for image operations. */ use spirv::Word; use super::{ selection::{MergeTuple, Selection}, Block, BlockContext, Error, IdGenerator, Instruction, LocalType, LookupType, NumericType, }; use crate::arena::Handle; /// Information about a vector of coordinates. /// /// The coordinate vectors expected by SPIR-V `OpImageRead` and `OpImageFetch` /// supply the array index for arrayed images as an additional component at /// the end, whereas Naga's `ImageLoad`, `ImageStore`, and `ImageSample` carry /// the array index as a separate field. /// /// In the process of generating code to compute the combined vector, we also /// produce SPIR-V types and vector lengths that are useful elsewhere. This /// struct gathers that information into one place, with standard names. struct ImageCoordinates { /// The SPIR-V id of the combined coordinate/index vector value. /// /// Note: when indexing a non-arrayed 1D image, this will be a scalar. value_id: Word, /// The SPIR-V id of the type of `value`. type_id: Word, /// The number of components in `value`, if it is a vector, or `None` if it /// is a scalar. size: Option, } /// A trait for image access (load or store) code generators. /// /// Types implementing this trait hold information about an `ImageStore` or /// `ImageLoad` operation that is not affected by the bounds check policy. The /// `generate` method emits code for the access, given the results of bounds /// checking. /// /// The [`image`] bounds checks policy affects access coordinates, level of /// detail, and sample index, but never the image id, result type (if any), or /// the specific SPIR-V instruction used. Types that implement this trait gather /// together the latter category, so we don't have to plumb them through the /// bounds-checking code. /// /// [`image`]: crate::proc::BoundsCheckPolicies::index trait Access { /// The Rust type that represents SPIR-V values and types for this access. /// /// For operations like loads, this is `Word`. For operations like stores, /// this is `()`. /// /// For `ReadZeroSkipWrite`, this will be the type of the selection /// construct that performs the bounds checks, so it must implement /// `MergeTuple`. type Output: MergeTuple + Copy + Clone; /// Write an image access to `block`. /// /// Access the texel at `coordinates_id`. The optional `level_id` indicates /// the level of detail, and `sample_id` is the index of the sample to /// access in a multisampled texel. /// /// This method assumes that `coordinates_id` has already had the image array /// index, if any, folded in, as done by `write_image_coordinates`. /// /// Return the value id produced by the instruction, if any. /// /// Use `id_gen` to generate SPIR-V ids as necessary. fn generate( &self, id_gen: &mut IdGenerator, coordinates_id: Word, level_id: Option, sample_id: Option, block: &mut Block, ) -> Self::Output; /// Return the SPIR-V type of the value produced by the code written by /// `generate`. If the access does not produce a value, `Self::Output` /// should be `()`. fn result_type(&self) -> Self::Output; /// Construct the SPIR-V 'zero' value to be returned for an out-of-bounds /// access under the `ReadZeroSkipWrite` policy. If the access does not /// produce a value, `Self::Output` should be `()`. fn out_of_bounds_value(&self, ctx: &mut BlockContext<'_>) -> Self::Output; } /// Texel access information for an [`ImageLoad`] expression. /// /// [`ImageLoad`]: crate::Expression::ImageLoad struct Load { /// The specific opcode we'll use to perform the fetch. Storage images /// require `OpImageRead`, while sampled images require `OpImageFetch`. opcode: spirv::Op, /// The type id produced by the actual image access instruction. type_id: Word, /// The id of the image being accessed. image_id: Word, } impl Load { fn from_image_expr( ctx: &mut BlockContext<'_>, image_id: Word, image_class: crate::ImageClass, result_type_id: Word, ) -> Result { let opcode = match image_class { crate::ImageClass::Storage { .. } => spirv::Op::ImageRead, crate::ImageClass::Depth { .. } | crate::ImageClass::Sampled { .. } => { spirv::Op::ImageFetch } crate::ImageClass::External => unimplemented!(), }; // `OpImageRead` and `OpImageFetch` instructions produce vec4 // values. Most of the time, we can just use `result_type_id` for // this. The exception is that `Expression::ImageLoad` from a depth // image produces a scalar `f32`, so in that case we need to find // the right SPIR-V type for the access instruction here. let type_id = match image_class { crate::ImageClass::Depth { .. } => ctx.get_numeric_type_id(NumericType::Vector { size: crate::VectorSize::Quad, scalar: crate::Scalar::F32, }), _ => result_type_id, }; Ok(Load { opcode, type_id, image_id, }) } } impl Access for Load { type Output = Word; /// Write an instruction to access a given texel of this image. fn generate( &self, id_gen: &mut IdGenerator, coordinates_id: Word, level_id: Option, sample_id: Option, block: &mut Block, ) -> Word { let texel_id = id_gen.next(); let mut instruction = Instruction::image_fetch_or_read( self.opcode, self.type_id, texel_id, self.image_id, coordinates_id, ); match (level_id, sample_id) { (None, None) => {} (Some(level_id), None) => { instruction.add_operand(spirv::ImageOperands::LOD.bits()); instruction.add_operand(level_id); } (None, Some(sample_id)) => { instruction.add_operand(spirv::ImageOperands::SAMPLE.bits()); instruction.add_operand(sample_id); } // There's no such thing as a multi-sampled mipmap. (Some(_), Some(_)) => unreachable!(), } block.body.push(instruction); texel_id } fn result_type(&self) -> Word { self.type_id } fn out_of_bounds_value(&self, ctx: &mut BlockContext<'_>) -> Word { ctx.writer.get_constant_null(self.type_id) } } /// Texel access information for a [`Store`] statement. /// /// [`Store`]: crate::Statement::Store struct Store { /// The id of the image being written to. image_id: Word, /// The value we're going to write to the texel. value_id: Word, } impl Access for Store { /// Stores don't generate any value. type Output = (); fn generate( &self, _id_gen: &mut IdGenerator, coordinates_id: Word, _level_id: Option, _sample_id: Option, block: &mut Block, ) { block.body.push(Instruction::image_write( self.image_id, coordinates_id, self.value_id, )); } /// Stores don't generate any value, so this just returns `()`. fn result_type(&self) {} /// Stores don't generate any value, so this just returns `()`. fn out_of_bounds_value(&self, _ctx: &mut BlockContext<'_>) {} } impl BlockContext<'_> { /// Extend image coordinates with an array index, if necessary. /// /// Whereas [`Expression::ImageLoad`] and [`ImageSample`] treat the array /// index as a separate operand from the coordinates, SPIR-V image access /// instructions include the array index in the `coordinates` operand. This /// function builds a SPIR-V coordinate vector from a Naga coordinate vector /// and array index, if one is supplied, and returns a `ImageCoordinates` /// struct describing what it built. /// /// If `array_index` is `Some(expr)`, then this function constructs a new /// vector that is `coordinates` with `array_index` concatenated onto the /// end: a `vec2` becomes a `vec3`, a scalar becomes a `vec2`, and so on. /// /// If `array_index` is `None`, then the return value uses `coordinates` /// unchanged. Note that, when indexing a non-arrayed 1D image, this will be /// a scalar value. /// /// If needed, this function generates code to convert the array index, /// always an integer scalar, to match the component type of `coordinates`. /// Naga's `ImageLoad` and SPIR-V's `OpImageRead`, `OpImageFetch`, and /// `OpImageWrite` all use integer coordinates, while Naga's `ImageSample` /// and SPIR-V's `OpImageSample...` instructions all take floating-point /// coordinate vectors. /// /// [`Expression::ImageLoad`]: crate::Expression::ImageLoad /// [`ImageSample`]: crate::Expression::ImageSample fn write_image_coordinates( &mut self, coordinates: Handle, array_index: Option>, block: &mut Block, ) -> Result { use crate::TypeInner as Ti; use crate::VectorSize as Vs; let coordinates_id = self.cached[coordinates]; let ty = &self.fun_info[coordinates].ty; let inner_ty = ty.inner_with(&self.ir_module.types); // If there's no array index, the image coordinates are exactly the // `coordinate` field of the `Expression::ImageLoad`. No work is needed. let array_index = match array_index { None => { let value_id = coordinates_id; let type_id = self.get_expression_type_id(ty); let size = match *inner_ty { Ti::Scalar { .. } => None, Ti::Vector { size, .. } => Some(size), _ => return Err(Error::Validation("coordinate type")), }; return Ok(ImageCoordinates { value_id, type_id, size, }); } Some(ix) => ix, }; // Find the component type of `coordinates`, and figure out the size the // combined coordinate vector will have. let (component_scalar, size) = match *inner_ty { Ti::Scalar(scalar @ crate::Scalar { width: 4, .. }) => (scalar, Vs::Bi), Ti::Vector { scalar: scalar @ crate::Scalar { width: 4, .. }, size: Vs::Bi, } => (scalar, Vs::Tri), Ti::Vector { scalar: scalar @ crate::Scalar { width: 4, .. }, size: Vs::Tri, } => (scalar, Vs::Quad), Ti::Vector { size: Vs::Quad, .. } => { return Err(Error::Validation("extending vec4 coordinate")); } ref other => { log::error!("wrong coordinate type {other:?}"); return Err(Error::Validation("coordinate type")); } }; // Convert the index to the coordinate component type, if necessary. let array_index_id = self.cached[array_index]; let ty = &self.fun_info[array_index].ty; let inner_ty = ty.inner_with(&self.ir_module.types); let array_index_scalar = match *inner_ty { Ti::Scalar( scalar @ crate::Scalar { kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint, width: 4, }, ) => scalar, _ => unreachable!("we only allow i32 and u32"), }; let cast = match (component_scalar.kind, array_index_scalar.kind) { (crate::ScalarKind::Sint, crate::ScalarKind::Sint) | (crate::ScalarKind::Uint, crate::ScalarKind::Uint) => None, (crate::ScalarKind::Sint, crate::ScalarKind::Uint) | (crate::ScalarKind::Uint, crate::ScalarKind::Sint) => Some(spirv::Op::Bitcast), (crate::ScalarKind::Float, crate::ScalarKind::Sint) => Some(spirv::Op::ConvertSToF), (crate::ScalarKind::Float, crate::ScalarKind::Uint) => Some(spirv::Op::ConvertUToF), (crate::ScalarKind::Bool, _) => unreachable!("we don't allow bool for component"), (_, crate::ScalarKind::Bool | crate::ScalarKind::Float) => { unreachable!("we don't allow bool or float for array index") } (crate::ScalarKind::AbstractInt | crate::ScalarKind::AbstractFloat, _) | (_, crate::ScalarKind::AbstractInt | crate::ScalarKind::AbstractFloat) => { unreachable!("abstract types should never reach backends") } }; let reconciled_array_index_id = if let Some(cast) = cast { let component_ty_id = self.get_numeric_type_id(NumericType::Scalar(component_scalar)); let reconciled_id = self.gen_id(); block.body.push(Instruction::unary( cast, component_ty_id, reconciled_id, array_index_id, )); reconciled_id } else { array_index_id }; // Find the SPIR-V type for the combined coordinates/index vector. let type_id = self.get_numeric_type_id(NumericType::Vector { size, scalar: component_scalar, }); // Schmear the coordinates and index together. let value_id = self.gen_id(); block.body.push(Instruction::composite_construct( type_id, value_id, &[coordinates_id, reconciled_array_index_id], )); Ok(ImageCoordinates { value_id, type_id, size: Some(size), }) } pub(super) fn get_handle_id(&mut self, expr_handle: Handle) -> Word { let id = match self.ir_function.expressions[expr_handle] { crate::Expression::GlobalVariable(handle) => { self.writer.global_variables[handle].handle_id } crate::Expression::FunctionArgument(i) => { self.function.parameters[i as usize].handle_id } crate::Expression::Access { .. } | crate::Expression::AccessIndex { .. } => { self.cached[expr_handle] } ref other => unreachable!("Unexpected image expression {:?}", other), }; if id == 0 { unreachable!( "Image expression {:?} doesn't have a handle ID", expr_handle ); } id } /// Generate a vector or scalar 'one' for arithmetic on `coordinates`. /// /// If `coordinates` is a scalar, return a scalar one. Otherwise, return /// a vector of ones. fn write_coordinate_one(&mut self, coordinates: &ImageCoordinates) -> Result { let one = self.get_scope_constant(1); match coordinates.size { None => Ok(one), Some(vector_size) => { let ones = [one; 4]; let id = self.gen_id(); Instruction::constant_composite( coordinates.type_id, id, &ones[..vector_size as usize], ) .to_words(&mut self.writer.logical_layout.declarations); Ok(id) } } } /// Generate code to restrict `input` to fall between zero and one less than /// `size_id`. /// /// Both must be 32-bit scalar integer values, whose type is given by /// `type_id`. The computed value is also of type `type_id`. fn restrict_scalar( &mut self, type_id: Word, input_id: Word, size_id: Word, block: &mut Block, ) -> Result { let i32_one_id = self.get_scope_constant(1); // Subtract one from `size` to get the largest valid value. let limit_id = self.gen_id(); block.body.push(Instruction::binary( spirv::Op::ISub, type_id, limit_id, size_id, i32_one_id, )); // Use an unsigned minimum, to handle both positive out-of-range values // and negative values in a single instruction: negative values of // `input_id` get treated as very large positive values. let restricted_id = self.gen_id(); block.body.push(Instruction::ext_inst_gl_op( self.writer.gl450_ext_inst_id, spirv::GlslStd450Op::UMin, type_id, restricted_id, &[input_id, limit_id], )); Ok(restricted_id) } /// Write instructions to query the size of an image. /// /// This takes care of selecting the right instruction depending on whether /// a level of detail parameter is present. fn write_coordinate_bounds( &mut self, type_id: Word, image_id: Word, level_id: Option, block: &mut Block, ) -> Word { let coordinate_bounds_id = self.gen_id(); match level_id { Some(level_id) => { // A level of detail was provided, so fetch the image size for // that level. let mut inst = Instruction::image_query( spirv::Op::ImageQuerySizeLod, type_id, coordinate_bounds_id, image_id, ); inst.add_operand(level_id); block.body.push(inst); } _ => { // No level of detail was given. block.body.push(Instruction::image_query( spirv::Op::ImageQuerySize, type_id, coordinate_bounds_id, image_id, )); } } coordinate_bounds_id } /// Write code to restrict coordinates for an image reference. /// /// First, clamp the level of detail or sample index to fall within bounds. /// Then, obtain the image size, possibly using the clamped level of detail. /// Finally, use an unsigned minimum instruction to force all coordinates /// into range. /// /// Return a triple `(COORDS, LEVEL, SAMPLE)`, where `COORDS` is a coordinate /// vector (including the array index, if any), `LEVEL` is an optional level /// of detail, and `SAMPLE` is an optional sample index, all guaranteed to /// be in-bounds for `image_id`. /// /// The result is usually a vector, but it is a scalar when indexing /// non-arrayed 1D images. fn write_restricted_coordinates( &mut self, image_id: Word, coordinates: ImageCoordinates, level_id: Option, sample_id: Option, block: &mut Block, ) -> Result<(Word, Option, Option), Error> { self.writer.require_any( "the `Restrict` image bounds check policy", &[spirv::Capability::ImageQuery], )?; let i32_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::I32)); // If `level` is `Some`, clamp it to fall within bounds. This must // happen first, because we'll use it to query the image size for // clamping the actual coordinates. let level_id = level_id .map(|level_id| { // Find the number of mipmap levels in this image. let num_levels_id = self.gen_id(); block.body.push(Instruction::image_query( spirv::Op::ImageQueryLevels, i32_type_id, num_levels_id, image_id, )); self.restrict_scalar(i32_type_id, level_id, num_levels_id, block) }) .transpose()?; // If `sample_id` is `Some`, clamp it to fall within bounds. let sample_id = sample_id .map(|sample_id| { // Find the number of samples per texel. let num_samples_id = self.gen_id(); block.body.push(Instruction::image_query( spirv::Op::ImageQuerySamples, i32_type_id, num_samples_id, image_id, )); self.restrict_scalar(i32_type_id, sample_id, num_samples_id, block) }) .transpose()?; // Obtain the image bounds, including the array element count. let coordinate_bounds_id = self.write_coordinate_bounds(coordinates.type_id, image_id, level_id, block); // Compute maximum valid values from the bounds. let ones = self.write_coordinate_one(&coordinates)?; let coordinate_limit_id = self.gen_id(); block.body.push(Instruction::binary( spirv::Op::ISub, coordinates.type_id, coordinate_limit_id, coordinate_bounds_id, ones, )); // Restrict the coordinates to fall within those bounds. // // Use an unsigned minimum, to handle both positive out-of-range values // and negative values in a single instruction: negative values of // `coordinates` get treated as very large positive values. let restricted_coordinates_id = self.gen_id(); block.body.push(Instruction::ext_inst_gl_op( self.writer.gl450_ext_inst_id, spirv::GlslStd450Op::UMin, coordinates.type_id, restricted_coordinates_id, &[coordinates.value_id, coordinate_limit_id], )); Ok((restricted_coordinates_id, level_id, sample_id)) } fn write_conditional_image_access( &mut self, image_id: Word, coordinates: ImageCoordinates, level_id: Option, sample_id: Option, block: &mut Block, access: &A, ) -> Result { self.writer.require_any( "the `ReadZeroSkipWrite` image bounds check policy", &[spirv::Capability::ImageQuery], )?; let bool_type_id = self.writer.get_bool_type_id(); let i32_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::I32)); let null_id = access.out_of_bounds_value(self); let mut selection = Selection::start(block, access.result_type()); // If `level_id` is `Some`, check whether it is within bounds. This must // happen first, because we'll be supplying this as an argument when we // query the image size. if let Some(level_id) = level_id { // Find the number of mipmap levels in this image. let num_levels_id = self.gen_id(); selection.block().body.push(Instruction::image_query( spirv::Op::ImageQueryLevels, i32_type_id, num_levels_id, image_id, )); let lod_cond_id = self.gen_id(); selection.block().body.push(Instruction::binary( spirv::Op::ULessThan, bool_type_id, lod_cond_id, level_id, num_levels_id, )); selection.if_true(self, lod_cond_id, null_id); } // If `sample_id` is `Some`, check whether it is in bounds. if let Some(sample_id) = sample_id { // Find the number of samples per texel. let num_samples_id = self.gen_id(); selection.block().body.push(Instruction::image_query( spirv::Op::ImageQuerySamples, i32_type_id, num_samples_id, image_id, )); let samples_cond_id = self.gen_id(); selection.block().body.push(Instruction::binary( spirv::Op::ULessThan, bool_type_id, samples_cond_id, sample_id, num_samples_id, )); selection.if_true(self, samples_cond_id, null_id); } // Obtain the image bounds, including any array element count. let coordinate_bounds_id = self.write_coordinate_bounds( coordinates.type_id, image_id, level_id, selection.block(), ); // Compare the coordinates against the bounds. let coords_numeric_type = match coordinates.size { Some(size) => NumericType::Vector { size, scalar: crate::Scalar::BOOL, }, None => NumericType::Scalar(crate::Scalar::BOOL), }; let coords_bool_type_id = self.get_numeric_type_id(coords_numeric_type); let coords_conds_id = self.gen_id(); selection.block().body.push(Instruction::binary( spirv::Op::ULessThan, coords_bool_type_id, coords_conds_id, coordinates.value_id, coordinate_bounds_id, )); // If the comparison above was a vector comparison, then we need to // check that all components of the comparison are true. let coords_cond_id = if coords_bool_type_id != bool_type_id { let id = self.gen_id(); selection.block().body.push(Instruction::relational( spirv::Op::All, bool_type_id, id, coords_conds_id, )); id } else { coords_conds_id }; selection.if_true(self, coords_cond_id, null_id); // All conditions are met. We can carry out the access. let texel_id = access.generate( &mut self.writer.id_gen, coordinates.value_id, level_id, sample_id, selection.block(), ); // This, then, is the value of the 'true' branch. Ok(selection.finish(self, texel_id)) } /// Generate code for an `ImageLoad` expression. /// /// The arguments are the components of an `Expression::ImageLoad` variant. #[allow(clippy::too_many_arguments)] pub(super) fn write_image_load( &mut self, result_type_id: Word, image: Handle, coordinate: Handle, array_index: Option>, level: Option>, sample: Option>, block: &mut Block, ) -> Result { let image_id = self.get_handle_id(image); let image_type = self.fun_info[image].ty.inner_with(&self.ir_module.types); let image_class = match *image_type { crate::TypeInner::Image { class, .. } => class, _ => return Err(Error::Validation("image type")), }; let access = Load::from_image_expr(self, image_id, image_class, result_type_id)?; let coordinates = self.write_image_coordinates(coordinate, array_index, block)?; let level_id = level.map(|expr| self.cached[expr]); let sample_id = sample.map(|expr| self.cached[expr]); // Perform the access, according to the bounds check policy. let access_id = match self.writer.bounds_check_policies.image_load { crate::proc::BoundsCheckPolicy::Restrict => { let (coords, level_id, sample_id) = self.write_restricted_coordinates( image_id, coordinates, level_id, sample_id, block, )?; access.generate(&mut self.writer.id_gen, coords, level_id, sample_id, block) } crate::proc::BoundsCheckPolicy::ReadZeroSkipWrite => self .write_conditional_image_access( image_id, coordinates, level_id, sample_id, block, &access, )?, crate::proc::BoundsCheckPolicy::Unchecked => access.generate( &mut self.writer.id_gen, coordinates.value_id, level_id, sample_id, block, ), }; // For depth images, `ImageLoad` expressions produce a single f32, // whereas the SPIR-V instructions always produce a vec4. So we may have // to pull out the component we need. let result_id = if result_type_id == access.result_type() { // The instruction produced the type we expected. We can use // its result as-is. access_id } else { // For `ImageClass::Depth` images, SPIR-V gave us four components, // but we only want the first one. let component_id = self.gen_id(); block.body.push(Instruction::composite_extract( result_type_id, component_id, access_id, &[0], )); component_id }; Ok(result_id) } /// Generate code for an `ImageSample` expression. /// /// The arguments are the components of an `Expression::ImageSample` variant. #[allow(clippy::too_many_arguments)] pub(super) fn write_image_sample( &mut self, result_type_id: Word, image: Handle, sampler: Handle, gather: Option, coordinate: Handle, array_index: Option>, offset: Option>, level: crate::SampleLevel, depth_ref: Option>, clamp_to_edge: bool, block: &mut Block, ) -> Result { use super::instructions::SampleLod; // image let image_id = self.get_handle_id(image); let image_type = self.fun_info[image].ty.handle().unwrap(); // SPIR-V doesn't know about our `Depth` class, and it returns // `vec4`, so we need to grab the first component out of it. let needs_sub_access = match self.ir_module.types[image_type].inner { crate::TypeInner::Image { class: crate::ImageClass::Depth { .. }, .. } => depth_ref.is_none() && gather.is_none(), _ => false, }; let sample_result_type_id = if needs_sub_access { self.get_numeric_type_id(NumericType::Vector { size: crate::VectorSize::Quad, scalar: crate::Scalar::F32, }) } else { result_type_id }; // OpTypeSampledImage let image_type_id = self.get_handle_type_id(image_type); let sampled_image_type_id = self.get_type_id(LookupType::Local(LocalType::SampledImage { image_type_id })); let sampler_id = self.get_handle_id(sampler); let coordinates = self.write_image_coordinates(coordinate, array_index, block)?; let coordinates_id = if clamp_to_edge { self.writer.require_any( "clamp sample coordinates to edge", &[spirv::Capability::ImageQuery], )?; // clamp_to_edge can only be used with Level 0, and no array offset, offset, // depth_ref or gather. This should have been caught by validation. Rather // than entirely duplicate validation code here just ensure the level is // zero, as we rely on that to query the texture size in order to calculate // the clamped coordinates. if level != crate::SampleLevel::Zero { return Err(Error::Validation( "ImageSample::clamp_to_edge requires SampleLevel::Zero", )); } // Query the size of level 0 of the texture. let image_size_id = self.gen_id(); let vec2u_type_id = self.writer.get_vec2u_type_id(); let const_zero_uint_id = self.writer.get_constant_scalar(crate::Literal::U32(0)); let mut query_inst = Instruction::image_query( spirv::Op::ImageQuerySizeLod, vec2u_type_id, image_size_id, image_id, ); query_inst.add_operand(const_zero_uint_id); block.body.push(query_inst); let image_size_f_id = self.gen_id(); let vec2f_type_id = self.writer.get_vec2f_type_id(); block.body.push(Instruction::unary( spirv::Op::ConvertUToF, vec2f_type_id, image_size_f_id, image_size_id, )); // Calculate the top-left and bottom-right margin for clamping to. I.e. a // half-texel from each side. let const_0_5_f32_id = self.writer.get_constant_scalar(crate::Literal::F32(0.5)); let const_0_5_vec2f_id = self.writer.get_constant_composite( LookupType::Local(LocalType::Numeric(NumericType::Vector { size: crate::VectorSize::Bi, scalar: crate::Scalar::F32, })), &[const_0_5_f32_id, const_0_5_f32_id], ); let margin_left_id = self.gen_id(); block.body.push(Instruction::binary( spirv::Op::FDiv, vec2f_type_id, margin_left_id, const_0_5_vec2f_id, image_size_f_id, )); let const_1_f32_id = self.writer.get_constant_scalar(crate::Literal::F32(1.0)); let const_1_vec2f_id = self.writer.get_constant_composite( LookupType::Local(LocalType::Numeric(NumericType::Vector { size: crate::VectorSize::Bi, scalar: crate::Scalar::F32, })), &[const_1_f32_id, const_1_f32_id], ); let margin_right_id = self.gen_id(); block.body.push(Instruction::binary( spirv::Op::FSub, vec2f_type_id, margin_right_id, const_1_vec2f_id, margin_left_id, )); // Clamp the coords to the calculated margins let clamped_coords_id = self.gen_id(); block.body.push(Instruction::ext_inst_gl_op( self.writer.gl450_ext_inst_id, spirv::GlslStd450Op::NClamp, vec2f_type_id, clamped_coords_id, &[coordinates.value_id, margin_left_id, margin_right_id], )); clamped_coords_id } else { coordinates.value_id }; let sampled_image_id = self.gen_id(); block.body.push(Instruction::sampled_image( sampled_image_type_id, sampled_image_id, image_id, sampler_id, )); let id = self.gen_id(); let depth_id = depth_ref.map(|handle| self.cached[handle]); let mut mask = spirv::ImageOperands::empty(); mask.set(spirv::ImageOperands::CONST_OFFSET, offset.is_some()); let mut main_instruction = match (level, gather) { (_, Some(component)) => { let component_id = self.get_index_constant(component as u32); let mut inst = Instruction::image_gather( sample_result_type_id, id, sampled_image_id, coordinates_id, component_id, depth_id, ); if !mask.is_empty() { inst.add_operand(mask.bits()); } inst } (crate::SampleLevel::Zero, None) => { let mut inst = Instruction::image_sample( sample_result_type_id, id, SampleLod::Explicit, sampled_image_id, coordinates_id, depth_id, ); let zero_id = self.writer.get_constant_scalar(crate::Literal::F32(0.0)); mask |= spirv::ImageOperands::LOD; inst.add_operand(mask.bits()); inst.add_operand(zero_id); inst } (crate::SampleLevel::Auto, None) => { let mut inst = Instruction::image_sample( sample_result_type_id, id, SampleLod::Implicit, sampled_image_id, coordinates_id, depth_id, ); if !mask.is_empty() { inst.add_operand(mask.bits()); } inst } (crate::SampleLevel::Exact(lod_handle), None) => { let mut inst = Instruction::image_sample( sample_result_type_id, id, SampleLod::Explicit, sampled_image_id, coordinates_id, depth_id, ); let mut lod_id = self.cached[lod_handle]; // SPIR-V expects the LOD to be a float for all image classes. // lod_id, however, will be an integer for depth images, // therefore we must do a conversion. if matches!( self.ir_module.types[image_type].inner, crate::TypeInner::Image { class: crate::ImageClass::Depth { .. }, .. } ) { let lod_f32_id = self.gen_id(); let f32_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::F32)); let convert_op = match *self.fun_info[lod_handle] .ty .inner_with(&self.ir_module.types) { crate::TypeInner::Scalar(crate::Scalar { kind: crate::ScalarKind::Uint, width: 4, }) => spirv::Op::ConvertUToF, crate::TypeInner::Scalar(crate::Scalar { kind: crate::ScalarKind::Sint, width: 4, }) => spirv::Op::ConvertSToF, _ => unreachable!(), }; block.body.push(Instruction::unary( convert_op, f32_type_id, lod_f32_id, lod_id, )); lod_id = lod_f32_id; } mask |= spirv::ImageOperands::LOD; inst.add_operand(mask.bits()); inst.add_operand(lod_id); inst } (crate::SampleLevel::Bias(bias_handle), None) => { let mut inst = Instruction::image_sample( sample_result_type_id, id, SampleLod::Implicit, sampled_image_id, coordinates_id, depth_id, ); let bias_id = self.cached[bias_handle]; mask |= spirv::ImageOperands::BIAS; inst.add_operand(mask.bits()); inst.add_operand(bias_id); inst } (crate::SampleLevel::Gradient { x, y }, None) => { let mut inst = Instruction::image_sample( sample_result_type_id, id, SampleLod::Explicit, sampled_image_id, coordinates_id, depth_id, ); let x_id = self.cached[x]; let y_id = self.cached[y]; mask |= spirv::ImageOperands::GRAD; inst.add_operand(mask.bits()); inst.add_operand(x_id); inst.add_operand(y_id); inst } }; if let Some(offset_const) = offset { let offset_id = self.cached[offset_const]; main_instruction.add_operand(offset_id); } block.body.push(main_instruction); let id = if needs_sub_access { let sub_id = self.gen_id(); block.body.push(Instruction::composite_extract( result_type_id, sub_id, id, &[0], )); sub_id } else { id }; Ok(id) } /// Generate code for an `ImageQuery` expression. /// /// The arguments are the components of an `Expression::ImageQuery` variant. pub(super) fn write_image_query( &mut self, result_type_id: Word, image: Handle, query: crate::ImageQuery, block: &mut Block, ) -> Result { use crate::{ImageClass as Ic, ImageDimension as Id, ImageQuery as Iq}; let image_id = self.get_handle_id(image); let image_type = self.fun_info[image].ty.handle().unwrap(); let (dim, arrayed, class) = match self.ir_module.types[image_type].inner { crate::TypeInner::Image { dim, arrayed, class, } => (dim, arrayed, class), _ => { return Err(Error::Validation("image type")); } }; self.writer .require_any("image queries", &[spirv::Capability::ImageQuery])?; let id = match query { Iq::Size { level } => { let dim_coords = match dim { Id::D1 => 1, Id::D2 | Id::Cube => 2, Id::D3 => 3, }; let array_coords = usize::from(arrayed); let vector_size = match dim_coords + array_coords { 2 => Some(crate::VectorSize::Bi), 3 => Some(crate::VectorSize::Tri), 4 => Some(crate::VectorSize::Quad), _ => None, }; let vector_numeric_type = match vector_size { Some(size) => NumericType::Vector { size, scalar: crate::Scalar::U32, }, None => NumericType::Scalar(crate::Scalar::U32), }; let extended_size_type_id = self.get_numeric_type_id(vector_numeric_type); let (query_op, level_id) = match class { Ic::Sampled { multi: true, .. } | Ic::Depth { multi: true } | Ic::Storage { .. } => (spirv::Op::ImageQuerySize, None), _ => { let level_id = match level { Some(expr) => self.cached[expr], None => self.get_index_constant(0), }; (spirv::Op::ImageQuerySizeLod, Some(level_id)) } }; // The ID of the vector returned by SPIR-V, which contains the dimensions // as well as the layer count. let id_extended = self.gen_id(); let mut inst = Instruction::image_query( query_op, extended_size_type_id, id_extended, image_id, ); if let Some(expr_id) = level_id { inst.add_operand(expr_id); } block.body.push(inst); if result_type_id != extended_size_type_id { let id = self.gen_id(); let components = match dim { // always pick the first component, and duplicate it for all 3 dimensions Id::Cube => &[0u32, 0][..], _ => &[0u32, 1, 2, 3][..dim_coords], }; block.body.push(Instruction::vector_shuffle( result_type_id, id, id_extended, id_extended, components, )); id } else { id_extended } } Iq::NumLevels => { let query_id = self.gen_id(); block.body.push(Instruction::image_query( spirv::Op::ImageQueryLevels, result_type_id, query_id, image_id, )); query_id } Iq::NumLayers => { let vec_size = match dim { Id::D1 => crate::VectorSize::Bi, Id::D2 | Id::Cube => crate::VectorSize::Tri, Id::D3 => crate::VectorSize::Quad, }; let extended_size_type_id = self.get_numeric_type_id(NumericType::Vector { size: vec_size, scalar: crate::Scalar::U32, }); let id_extended = self.gen_id(); let mut inst = Instruction::image_query( spirv::Op::ImageQuerySizeLod, extended_size_type_id, id_extended, image_id, ); inst.add_operand(self.get_index_constant(0)); block.body.push(inst); let extract_id = self.gen_id(); block.body.push(Instruction::composite_extract( result_type_id, extract_id, id_extended, &[vec_size as u32 - 1], )); extract_id } Iq::NumSamples => { let query_id = self.gen_id(); block.body.push(Instruction::image_query( spirv::Op::ImageQuerySamples, result_type_id, query_id, image_id, )); query_id } }; Ok(id) } pub(super) fn write_image_store( &mut self, image: Handle, coordinate: Handle, array_index: Option>, value: Handle, block: &mut Block, ) -> Result<(), Error> { let image_id = self.get_handle_id(image); let coordinates = self.write_image_coordinates(coordinate, array_index, block)?; let value_id = self.cached[value]; let write = Store { image_id, value_id }; match *self.fun_info[image].ty.inner_with(&self.ir_module.types) { crate::TypeInner::Image { class: crate::ImageClass::Storage { format: crate::StorageFormat::Bgra8Unorm, .. }, .. } => self.writer.require_any( "Bgra8Unorm storage write", &[spirv::Capability::StorageImageWriteWithoutFormat], )?, _ => {} } write.generate( &mut self.writer.id_gen, coordinates.value_id, None, None, block, ); Ok(()) } pub(super) fn write_image_atomic( &mut self, image: Handle, coordinate: Handle, array_index: Option>, fun: crate::AtomicFunction, value: Handle, block: &mut Block, ) -> Result<(), Error> { let image_id = match self.ir_function.originating_global(image) { Some(handle) => self.writer.global_variables[handle].var_id, _ => return Err(Error::Validation("Unexpected image type")), }; let crate::TypeInner::Image { class, .. } = *self.fun_info[image].ty.inner_with(&self.ir_module.types) else { return Err(Error::Validation("Invalid image type")); }; let crate::ImageClass::Storage { format, .. } = class else { return Err(Error::Validation("Invalid image class")); }; let scalar = format.into(); let scalar_type_id = self.get_numeric_type_id(NumericType::Scalar(scalar)); let pointer_type_id = self.get_pointer_type_id(scalar_type_id, spirv::StorageClass::Image); let signed = scalar.kind == crate::ScalarKind::Sint; if scalar.width == 8 { self.writer .require_any("64 bit image atomics", &[spirv::Capability::Int64Atomics])?; } let pointer_id = self.gen_id(); let coordinates = self.write_image_coordinates(coordinate, array_index, block)?; let sample_id = self.writer.get_constant_scalar(crate::Literal::U32(0)); block.body.push(Instruction::image_texel_pointer( pointer_type_id, pointer_id, image_id, coordinates.value_id, sample_id, )); let op = match fun { crate::AtomicFunction::Add => spirv::Op::AtomicIAdd, crate::AtomicFunction::Subtract => spirv::Op::AtomicISub, crate::AtomicFunction::And => spirv::Op::AtomicAnd, crate::AtomicFunction::ExclusiveOr => spirv::Op::AtomicXor, crate::AtomicFunction::InclusiveOr => spirv::Op::AtomicOr, crate::AtomicFunction::Min if signed => spirv::Op::AtomicSMin, crate::AtomicFunction::Min => spirv::Op::AtomicUMin, crate::AtomicFunction::Max if signed => spirv::Op::AtomicSMax, crate::AtomicFunction::Max => spirv::Op::AtomicUMax, crate::AtomicFunction::Exchange { .. } => { return Err(Error::Validation("Exchange atomics are not supported yet")) } }; let result_type_id = self.get_expression_type_id(&self.fun_info[value].ty); let id = self.gen_id(); let space = crate::AddressSpace::Handle; let (semantics, scope) = space.to_spirv_semantics_and_scope(); let scope_constant_id = self.get_scope_constant(scope as u32); let semantics_id = self.get_index_constant(semantics.bits()); let value_id = self.cached[value]; block.body.push(Instruction::image_atomic( op, result_type_id, id, pointer_id, scope_constant_id, semantics_id, value_id, )); Ok(()) } } naga-29.0.3/src/back/spv/index.rs000064400000000000000000000610151046102023000146040ustar 00000000000000/*! Bounds-checking for SPIR-V output. */ use super::{ helpers::{global_needs_wrapper, map_storage_class}, selection::Selection, Block, BlockContext, Error, IdGenerator, Instruction, Word, }; use crate::{ arena::Handle, proc::{index::GuardedIndex, BoundsCheckPolicy}, }; /// The results of performing a bounds check. /// /// On success, [`write_bounds_check`](BlockContext::write_bounds_check) /// returns a value of this type. The caller can assume that the right /// policy has been applied, and simply do what the variant says. #[derive(Debug)] pub(super) enum BoundsCheckResult { /// The index is statically known and in bounds, with the given value. KnownInBounds(u32), /// The given instruction computes the index to be used. /// /// When [`BoundsCheckPolicy::Restrict`] is in force, this is a /// clamped version of the index the user supplied. /// /// When [`BoundsCheckPolicy::Unchecked`] is in force, this is /// simply the index the user supplied. This variant indicates /// that we couldn't prove statically that the index was in /// bounds; otherwise we would have returned [`KnownInBounds`]. /// /// [`KnownInBounds`]: BoundsCheckResult::KnownInBounds Computed(Word), /// The given instruction computes a boolean condition which is true /// if the index is in bounds. /// /// This is returned when [`BoundsCheckPolicy::ReadZeroSkipWrite`] /// is in force. Conditional { /// The access should only be permitted if this value is true. condition_id: Word, /// The access should use this index value. index_id: Word, }, } /// A value that we either know at translation time, or need to compute at runtime. #[derive(Copy, Clone)] pub(super) enum MaybeKnown { /// The value is known at shader translation time. Known(T), /// The value is computed by the instruction with the given id. Computed(Word), } impl BlockContext<'_> { /// Emit code to compute the length of a run-time array. /// /// Given `array`, an expression referring a runtime-sized array, return the /// instruction id for the array's length. /// /// Runtime-sized arrays may only appear in the values of global /// variables, which must have one of the following Naga types: /// /// 1. A runtime-sized array. /// 2. A struct whose last member is a runtime-sized array. /// 3. A binding array of 2. /// /// Thus, the expression `array` has the form of: /// /// - An optional [`AccessIndex`], for case 2, applied to... /// - An optional [`Access`] or [`AccessIndex`], for case 3, applied to... /// - A [`GlobalVariable`]. /// /// The generated SPIR-V takes into account wrapped globals; see /// [`back::spv::GlobalVariable`] for details. /// /// [`GlobalVariable`]: crate::Expression::GlobalVariable /// [`AccessIndex`]: crate::Expression::AccessIndex /// [`Access`]: crate::Expression::Access /// [`base`]: crate::Expression::Access::base /// [`back::spv::GlobalVariable`]: super::GlobalVariable pub(super) fn write_runtime_array_length( &mut self, array: Handle, block: &mut Block, ) -> Result { // The index into the binding array, if any. let binding_array_index_id: Option; // The handle to the Naga IR global we're referring to. let global_handle: Handle; // At the Naga type level, if the runtime-sized array is the final member of a // struct, this is that member's index. // // This does not cover wrappers: if this backend wrapped the Naga global's // type in a synthetic SPIR-V struct (see `global_needs_wrapper`), this is // `None`. let opt_last_member_index: Option; // Inspect `array` and decide whether we have a binding array and/or an // enclosing struct. match self.ir_function.expressions[array] { crate::Expression::AccessIndex { base, index } => { match self.ir_function.expressions[base] { crate::Expression::AccessIndex { base: base_outer, index: index_outer, } => match self.ir_function.expressions[base_outer] { // An `AccessIndex` of an `AccessIndex` must be a // binding array holding structs whose last members are // runtime-sized arrays. crate::Expression::GlobalVariable(handle) => { let index_id = self.get_index_constant(index_outer); binding_array_index_id = Some(index_id); global_handle = handle; opt_last_member_index = Some(index); } _ => { return Err(Error::Validation( "array length expression: AccessIndex(AccessIndex(Global))", )) } }, crate::Expression::Access { base: base_outer, index: index_outer, } => match self.ir_function.expressions[base_outer] { // Similarly, an `AccessIndex` of an `Access` must be a // binding array holding structs whose last members are // runtime-sized arrays. crate::Expression::GlobalVariable(handle) => { let index_id = self.cached[index_outer]; binding_array_index_id = Some(index_id); global_handle = handle; opt_last_member_index = Some(index); } _ => { return Err(Error::Validation( "array length expression: AccessIndex(Access(Global))", )) } }, crate::Expression::GlobalVariable(handle) => { // An outer `AccessIndex` applied directly to a // `GlobalVariable`. Since binding arrays can only contain // structs, this must be referring to the last member of a // struct that is a runtime-sized array. binding_array_index_id = None; global_handle = handle; opt_last_member_index = Some(index); } _ => { return Err(Error::Validation( "array length expression: AccessIndex()", )) } } } crate::Expression::GlobalVariable(handle) => { // A direct reference to a global variable. This must hold the // runtime-sized array directly. binding_array_index_id = None; global_handle = handle; opt_last_member_index = None; } _ => return Err(Error::Validation("array length expression case-4")), }; // The verifier should have checked this, but make sure the inspection above // agrees with the type about whether a binding array is involved. // // Eventually we do want to support `binding_array>`. This check // ensures that whoever relaxes the validator will get an error message from // us, not just bogus SPIR-V. let global = &self.ir_module.global_variables[global_handle]; match ( &self.ir_module.types[global.ty].inner, binding_array_index_id, ) { (&crate::TypeInner::BindingArray { .. }, Some(_)) => {} (_, None) => {} _ => { return Err(Error::Validation( "array length expression: bad binding array inference", )) } } // SPIR-V allows runtime-sized arrays to appear only as the last member of a // struct. Determine this member's index. let gvar = self.writer.global_variables[global_handle].clone(); let global = &self.ir_module.global_variables[global_handle]; let needs_wrapper = global_needs_wrapper(self.ir_module, global); let (last_member_index, gvar_id) = match (opt_last_member_index, needs_wrapper) { (Some(index), false) => { // At the Naga type level, the runtime-sized array appears as the // final member of a struct, whose index is `index`. We didn't need to // wrap this, since the Naga type meets SPIR-V's requirements already. (index, gvar.access_id) } (None, true) => { // At the Naga type level, the runtime-sized array does not appear // within a struct. We wrapped this in an OpTypeStruct with nothing // else in it, so the index is zero. OpArrayLength wants the pointer // to the wrapper struct, so use `gvar.var_id`. (0, gvar.var_id) } _ => { return Err(Error::Validation( "array length expression: bad SPIR-V wrapper struct inference", )); } }; let structure_id = match binding_array_index_id { // We are indexing inside a binding array, generate the access op. Some(index_id) => { let element_type_id = match self.ir_module.types[global.ty].inner { crate::TypeInner::BindingArray { base, size: _ } => { let base_id = self.get_handle_type_id(base); let class = map_storage_class(global.space); self.get_pointer_type_id(base_id, class) } _ => return Err(Error::Validation("array length expression case-5")), }; let structure_id = self.gen_id(); block.body.push(Instruction::access_chain( element_type_id, structure_id, gvar_id, &[index_id], )); structure_id } None => gvar_id, }; let length_id = self.gen_id(); block.body.push(Instruction::array_length( self.writer.get_u32_type_id(), length_id, structure_id, last_member_index, )); Ok(length_id) } /// Compute the length of a subscriptable value. /// /// Given `sequence`, an expression referring to some indexable type, return /// its length. The result may either be computed by SPIR-V instructions, or /// known at shader translation time. /// /// `sequence` may be a `Vector`, `Matrix`, or `Array`, a `Pointer` to any /// of those, or a `ValuePointer`. An array may be fixed-size, dynamically /// sized, or use a specializable constant as its length. fn write_sequence_length( &mut self, sequence: Handle, block: &mut Block, ) -> Result, Error> { let sequence_ty = self.fun_info[sequence].ty.inner_with(&self.ir_module.types); match sequence_ty.indexable_length_resolved(self.ir_module) { Ok(crate::proc::IndexableLength::Known(known_length)) => { Ok(MaybeKnown::Known(known_length)) } Ok(crate::proc::IndexableLength::Dynamic) => { let length_id = self.write_runtime_array_length(sequence, block)?; Ok(MaybeKnown::Computed(length_id)) } Err(err) => { log::error!("Sequence length for {sequence:?} failed: {err}"); Err(Error::Validation("indexable length")) } } } /// Compute the maximum valid index of a subscriptable value. /// /// Given `sequence`, an expression referring to some indexable type, return /// its maximum valid index - one less than its length. The result may /// either be computed, or known at shader translation time. /// /// `sequence` may be a `Vector`, `Matrix`, or `Array`, a `Pointer` to any /// of those, or a `ValuePointer`. An array may be fixed-size, dynamically /// sized, or use a specializable constant as its length. fn write_sequence_max_index( &mut self, sequence: Handle, block: &mut Block, ) -> Result, Error> { match self.write_sequence_length(sequence, block)? { MaybeKnown::Known(known_length) => { // We should have thrown out all attempts to subscript zero-length // sequences during validation, so the following subtraction should never // underflow. assert!(known_length > 0); // Compute the max index from the length now. Ok(MaybeKnown::Known(known_length - 1)) } MaybeKnown::Computed(length_id) => { // Emit code to compute the max index from the length. let const_one_id = self.get_index_constant(1); let max_index_id = self.gen_id(); block.body.push(Instruction::binary( spirv::Op::ISub, self.writer.get_u32_type_id(), max_index_id, length_id, const_one_id, )); Ok(MaybeKnown::Computed(max_index_id)) } } } /// Restrict an index to be in range for a vector, matrix, or array. /// /// This is used to implement `BoundsCheckPolicy::Restrict`. An in-bounds /// index is left unchanged. An out-of-bounds index is replaced with some /// arbitrary in-bounds index. Note,this is not necessarily clamping; for /// example, negative indices might be changed to refer to the last element /// of the sequence, not the first, as clamping would do. /// /// Either return the restricted index value, if known, or add instructions /// to `block` to compute it, and return the id of the result. See the /// documentation for `BoundsCheckResult` for details. /// /// The `sequence` expression may be a `Vector`, `Matrix`, or `Array`, a /// `Pointer` to any of those, or a `ValuePointer`. An array may be /// fixed-size, dynamically sized, or use a specializable constant as its /// length. pub(super) fn write_restricted_index( &mut self, sequence: Handle, index: GuardedIndex, block: &mut Block, ) -> Result { let max_index = self.write_sequence_max_index(sequence, block)?; // If both are known, we can compute the index to be used // right now. if let (GuardedIndex::Known(index), MaybeKnown::Known(max_index)) = (index, max_index) { let restricted = core::cmp::min(index, max_index); return Ok(BoundsCheckResult::KnownInBounds(restricted)); } let index_id = match index { GuardedIndex::Known(value) => self.get_index_constant(value), GuardedIndex::Expression(expr) => self.cached[expr], }; let max_index_id = match max_index { MaybeKnown::Known(value) => self.get_index_constant(value), MaybeKnown::Computed(id) => id, }; // One or the other of the index or length is dynamic, so emit code for // BoundsCheckPolicy::Restrict. let restricted_index_id = self.gen_id(); block.body.push(Instruction::ext_inst_gl_op( self.writer.gl450_ext_inst_id, spirv::GlslStd450Op::UMin, self.writer.get_u32_type_id(), restricted_index_id, &[index_id, max_index_id], )); Ok(BoundsCheckResult::Computed(restricted_index_id)) } /// Write an index bounds comparison to `block`, if needed. /// /// This is used to implement [`BoundsCheckPolicy::ReadZeroSkipWrite`]. /// /// If we're able to determine statically that `index` is in bounds for /// `sequence`, return `KnownInBounds(value)`, where `value` is the actual /// value of the index. (In principle, one could know that the index is in /// bounds without knowing its specific value, but in our simple-minded /// situation, we always know it.) /// /// If instead we must generate code to perform the comparison at run time, /// return `Conditional(comparison_id)`, where `comparison_id` is an /// instruction producing a boolean value that is true if `index` is in /// bounds for `sequence`. /// /// The `sequence` expression may be a `Vector`, `Matrix`, or `Array`, a /// `Pointer` to any of those, or a `ValuePointer`. An array may be /// fixed-size, dynamically sized, or use a specializable constant as its /// length. fn write_index_comparison( &mut self, sequence: Handle, index: GuardedIndex, block: &mut Block, ) -> Result { let length = self.write_sequence_length(sequence, block)?; // If both are known, we can decide whether the index is in // bounds right now. if let (GuardedIndex::Known(index), MaybeKnown::Known(length)) = (index, length) { if index < length { return Ok(BoundsCheckResult::KnownInBounds(index)); } // In theory, when `index` is bad, we could return a new // `KnownOutOfBounds` variant here. But it's simpler just to fall // through and let the bounds check take place. The shader is broken // anyway, so it doesn't make sense to invest in emitting the ideal // code for it. } let index_id = match index { GuardedIndex::Known(value) => self.get_index_constant(value), GuardedIndex::Expression(expr) => self.cached[expr], }; let length_id = match length { MaybeKnown::Known(value) => self.get_index_constant(value), MaybeKnown::Computed(id) => id, }; // Compare the index against the length. let condition_id = self.gen_id(); block.body.push(Instruction::binary( spirv::Op::ULessThan, self.writer.get_bool_type_id(), condition_id, index_id, length_id, )); // Indicate that we did generate the check. Ok(BoundsCheckResult::Conditional { condition_id, index_id, }) } /// Emit a conditional load for `BoundsCheckPolicy::ReadZeroSkipWrite`. /// /// Generate code to load a value of `result_type` if `condition` is true, /// and generate a null value of that type if it is false. Call `emit_load` /// to emit the instructions to perform the load. Return the id of the /// merged value of the two branches. pub(super) fn write_conditional_indexed_load( &mut self, result_type: Word, condition: Word, block: &mut Block, emit_load: F, ) -> Word where F: FnOnce(&mut IdGenerator, &mut Block) -> Word, { // For the out-of-bounds case, we produce a zero value. let null_id = self.writer.get_constant_null(result_type); let mut selection = Selection::start(block, result_type); // As it turns out, we don't actually need a full 'if-then-else' // structure for this: SPIR-V constants are declared up front, so the // 'else' block would have no instructions. Instead we emit something // like this: // // result = zero; // if in_bounds { // result = do the load; // } // use result; // Continue only if the index was in bounds. Otherwise, branch to the // merge block. selection.if_true(self, condition, null_id); // The in-bounds path. Perform the access and the load. let loaded_value = emit_load(&mut self.writer.id_gen, selection.block()); selection.finish(self, loaded_value) } /// Emit code for bounds checks for an array, vector, or matrix access. /// /// This tries to handle all the critical steps for bounds checks: /// /// - First, select the appropriate bounds check policy for `base`, /// depending on its address space. /// /// - Next, analyze `index` to see if its value is known at /// compile time, in which case we can decide statically whether /// the index is in bounds. /// /// - If the index's value is not known at compile time, emit code to: /// /// - restrict its value (for [`BoundsCheckPolicy::Restrict`]), or /// /// - check whether it's in bounds (for /// [`BoundsCheckPolicy::ReadZeroSkipWrite`]). /// /// Return a [`BoundsCheckResult`] indicating how the index should be /// consumed. See that type's documentation for details. pub(super) fn write_bounds_check( &mut self, base: Handle, mut index: GuardedIndex, block: &mut Block, ) -> Result { // If the value of `index` is known at compile time, find it now. index.try_resolve_to_constant(&self.ir_function.expressions, self.ir_module); let policy = self.writer.bounds_check_policies.choose_policy( base, &self.ir_module.types, self.fun_info, ); Ok(match policy { BoundsCheckPolicy::Restrict => self.write_restricted_index(base, index, block)?, BoundsCheckPolicy::ReadZeroSkipWrite => { self.write_index_comparison(base, index, block)? } BoundsCheckPolicy::Unchecked => match index { GuardedIndex::Known(value) => BoundsCheckResult::KnownInBounds(value), GuardedIndex::Expression(expr) => BoundsCheckResult::Computed(self.cached[expr]), }, }) } /// Emit code to subscript a vector by value with a computed index. /// /// Return the id of the element value. /// /// If `base_id_override` is provided, it is used as the vector expression /// to be subscripted into, rather than the cached value of `base`. pub(super) fn write_vector_access( &mut self, result_type_id: Word, base: Handle, base_id_override: Option, index: GuardedIndex, block: &mut Block, ) -> Result { let base_id = base_id_override.unwrap_or_else(|| self.cached[base]); let result_id = match self.write_bounds_check(base, index, block)? { BoundsCheckResult::KnownInBounds(known_index) => { let result_id = self.gen_id(); block.body.push(Instruction::composite_extract( result_type_id, result_id, base_id, &[known_index], )); result_id } BoundsCheckResult::Computed(computed_index_id) => { let result_id = self.gen_id(); block.body.push(Instruction::vector_extract_dynamic( result_type_id, result_id, base_id, computed_index_id, )); result_id } BoundsCheckResult::Conditional { condition_id, index_id, } => { // Run-time bounds checks were required. Emit // conditional load. self.write_conditional_indexed_load( result_type_id, condition_id, block, |id_gen, block| { // The in-bounds path. Generate the access. let element_id = id_gen.next(); block.body.push(Instruction::vector_extract_dynamic( result_type_id, element_id, base_id, index_id, )); element_id }, ) } }; Ok(result_id) } } naga-29.0.3/src/back/spv/instructions.rs000064400000000000000000001207711046102023000162460ustar 00000000000000use alloc::{vec, vec::Vec}; use spirv::{Op, Word}; use super::{block::DebugInfoInner, helpers}; pub(super) enum Signedness { Unsigned = 0, Signed = 1, } pub(super) enum SampleLod { Explicit, Implicit, } pub(super) struct Case { pub value: Word, pub label_id: Word, } impl super::Instruction { // // Debug Instructions // pub(super) fn string(name: &str, id: Word) -> Self { let mut instruction = Self::new(Op::String); instruction.set_result(id); instruction.add_operands(helpers::string_to_words(name)); instruction } pub(super) fn source( source_language: spirv::SourceLanguage, version: u32, source: &Option, ) -> Self { let mut instruction = Self::new(Op::Source); instruction.add_operand(source_language as u32); instruction.add_operands(helpers::bytes_to_words(&version.to_le_bytes())); if let Some(source) = source.as_ref() { instruction.add_operand(source.source_file_id); instruction.add_operands(helpers::string_to_words(source.source_code)); } instruction } pub(super) fn source_continued(source: &[u8]) -> Self { let mut instruction = Self::new(Op::SourceContinued); instruction.add_operands(helpers::str_bytes_to_words(source)); instruction } pub(super) fn source_auto_continued( source_language: spirv::SourceLanguage, version: u32, source: &Option, ) -> Vec { let mut instructions = vec![]; let with_continue = source.as_ref().and_then(|debug_info| { (debug_info.source_code.len() > u16::MAX as usize).then_some(debug_info) }); if let Some(debug_info) = with_continue { let mut instruction = Self::new(Op::Source); instruction.add_operand(source_language as u32); instruction.add_operands(helpers::bytes_to_words(&version.to_le_bytes())); let words = helpers::string_to_byte_chunks(debug_info.source_code, u16::MAX as usize); instruction.add_operand(debug_info.source_file_id); instruction.add_operands(helpers::str_bytes_to_words(words[0])); instructions.push(instruction); for word_bytes in words[1..].iter() { let instruction_continue = Self::source_continued(word_bytes); instructions.push(instruction_continue); } } else { let instruction = Self::source(source_language, version, source); instructions.push(instruction); } instructions } pub(super) fn name(target_id: Word, name: &str) -> Self { let mut instruction = Self::new(Op::Name); instruction.add_operand(target_id); instruction.add_operands(helpers::string_to_words(name)); instruction } pub(super) fn member_name(target_id: Word, member: Word, name: &str) -> Self { let mut instruction = Self::new(Op::MemberName); instruction.add_operand(target_id); instruction.add_operand(member); instruction.add_operands(helpers::string_to_words(name)); instruction } pub(super) fn line(file: Word, line: Word, column: Word) -> Self { let mut instruction = Self::new(Op::Line); instruction.add_operand(file); instruction.add_operand(line); instruction.add_operand(column); instruction } // // Annotation Instructions // pub(super) fn decorate( target_id: Word, decoration: spirv::Decoration, operands: &[Word], ) -> Self { let mut instruction = Self::new(Op::Decorate); instruction.add_operand(target_id); instruction.add_operand(decoration as u32); for operand in operands { instruction.add_operand(*operand) } instruction } pub(super) fn member_decorate( target_id: Word, member_index: Word, decoration: spirv::Decoration, operands: &[Word], ) -> Self { let mut instruction = Self::new(Op::MemberDecorate); instruction.add_operand(target_id); instruction.add_operand(member_index); instruction.add_operand(decoration as u32); for operand in operands { instruction.add_operand(*operand) } instruction } // // Extension Instructions // pub(super) fn extension(name: &str) -> Self { let mut instruction = Self::new(Op::Extension); instruction.add_operands(helpers::string_to_words(name)); instruction } pub(super) fn ext_inst_import(id: Word, name: &str) -> Self { let mut instruction = Self::new(Op::ExtInstImport); instruction.set_result(id); instruction.add_operands(helpers::string_to_words(name)); instruction } pub(super) fn ext_inst_gl_op( set_id: Word, op: spirv::GlslStd450Op, result_type_id: Word, id: Word, operands: &[Word], ) -> Self { Self::ext_inst(set_id, op as u32, result_type_id, id, operands) } pub(super) fn ext_inst( set_id: Word, op: u32, result_type_id: Word, id: Word, operands: &[Word], ) -> Self { let mut instruction = Self::new(Op::ExtInst); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(set_id); instruction.add_operand(op); for operand in operands { instruction.add_operand(*operand) } instruction } // // Mode-Setting Instructions // pub(super) fn memory_model( addressing_model: spirv::AddressingModel, memory_model: spirv::MemoryModel, ) -> Self { let mut instruction = Self::new(Op::MemoryModel); instruction.add_operand(addressing_model as u32); instruction.add_operand(memory_model as u32); instruction } pub(super) fn entry_point( execution_model: spirv::ExecutionModel, entry_point_id: Word, name: &str, interface_ids: &[Word], ) -> Self { let mut instruction = Self::new(Op::EntryPoint); instruction.add_operand(execution_model as u32); instruction.add_operand(entry_point_id); instruction.add_operands(helpers::string_to_words(name)); for interface_id in interface_ids { instruction.add_operand(*interface_id); } instruction } pub(super) fn execution_mode( entry_point_id: Word, execution_mode: spirv::ExecutionMode, args: &[Word], ) -> Self { let mut instruction = Self::new(Op::ExecutionMode); instruction.add_operand(entry_point_id); instruction.add_operand(execution_mode as u32); for arg in args { instruction.add_operand(*arg); } instruction } pub(super) fn capability(capability: spirv::Capability) -> Self { let mut instruction = Self::new(Op::Capability); instruction.add_operand(capability as u32); instruction } // // Type-Declaration Instructions // pub(super) fn type_void(id: Word) -> Self { let mut instruction = Self::new(Op::TypeVoid); instruction.set_result(id); instruction } pub(super) fn type_bool(id: Word) -> Self { let mut instruction = Self::new(Op::TypeBool); instruction.set_result(id); instruction } pub(super) fn type_int(id: Word, width: Word, signedness: Signedness) -> Self { let mut instruction = Self::new(Op::TypeInt); instruction.set_result(id); instruction.add_operand(width); instruction.add_operand(signedness as u32); instruction } pub(super) fn type_float(id: Word, width: Word) -> Self { let mut instruction = Self::new(Op::TypeFloat); instruction.set_result(id); instruction.add_operand(width); instruction } pub(super) fn type_vector( id: Word, component_type_id: Word, component_count: crate::VectorSize, ) -> Self { let mut instruction = Self::new(Op::TypeVector); instruction.set_result(id); instruction.add_operand(component_type_id); instruction.add_operand(component_count as u32); instruction } pub(super) fn type_matrix( id: Word, column_type_id: Word, column_count: crate::VectorSize, ) -> Self { let mut instruction = Self::new(Op::TypeMatrix); instruction.set_result(id); instruction.add_operand(column_type_id); instruction.add_operand(column_count as u32); instruction } pub(super) fn type_coop_matrix( id: Word, scalar_type_id: Word, scope_id: Word, row_count_id: Word, column_count_id: Word, matrix_use_id: Word, ) -> Self { let mut instruction = Self::new(Op::TypeCooperativeMatrixKHR); instruction.set_result(id); instruction.add_operand(scalar_type_id); instruction.add_operand(scope_id); instruction.add_operand(row_count_id); instruction.add_operand(column_count_id); instruction.add_operand(matrix_use_id); instruction } pub(super) fn type_image( id: Word, sampled_type_id: Word, dim: spirv::Dim, flags: super::ImageTypeFlags, image_format: spirv::ImageFormat, ) -> Self { let mut instruction = Self::new(Op::TypeImage); instruction.set_result(id); instruction.add_operand(sampled_type_id); instruction.add_operand(dim as u32); instruction.add_operand(flags.contains(super::ImageTypeFlags::DEPTH) as u32); instruction.add_operand(flags.contains(super::ImageTypeFlags::ARRAYED) as u32); instruction.add_operand(flags.contains(super::ImageTypeFlags::MULTISAMPLED) as u32); instruction.add_operand(if flags.contains(super::ImageTypeFlags::SAMPLED) { 1 } else { 2 }); instruction.add_operand(image_format as u32); instruction } pub(super) fn type_sampler(id: Word) -> Self { let mut instruction = Self::new(Op::TypeSampler); instruction.set_result(id); instruction } pub(super) fn type_acceleration_structure(id: Word) -> Self { let mut instruction = Self::new(Op::TypeAccelerationStructureKHR); instruction.set_result(id); instruction } pub(super) fn type_ray_query(id: Word) -> Self { let mut instruction = Self::new(Op::TypeRayQueryKHR); instruction.set_result(id); instruction } pub(super) fn type_sampled_image(id: Word, image_type_id: Word) -> Self { let mut instruction = Self::new(Op::TypeSampledImage); instruction.set_result(id); instruction.add_operand(image_type_id); instruction } pub(super) fn type_array(id: Word, element_type_id: Word, length_id: Word) -> Self { let mut instruction = Self::new(Op::TypeArray); instruction.set_result(id); instruction.add_operand(element_type_id); instruction.add_operand(length_id); instruction } pub(super) fn type_runtime_array(id: Word, element_type_id: Word) -> Self { let mut instruction = Self::new(Op::TypeRuntimeArray); instruction.set_result(id); instruction.add_operand(element_type_id); instruction } pub(super) fn type_struct(id: Word, member_ids: &[Word]) -> Self { let mut instruction = Self::new(Op::TypeStruct); instruction.set_result(id); for member_id in member_ids { instruction.add_operand(*member_id) } instruction } pub(super) fn type_pointer( id: Word, storage_class: spirv::StorageClass, type_id: Word, ) -> Self { let mut instruction = Self::new(Op::TypePointer); instruction.set_result(id); instruction.add_operand(storage_class as u32); instruction.add_operand(type_id); instruction } pub(super) fn type_function(id: Word, return_type_id: Word, parameter_ids: &[Word]) -> Self { let mut instruction = Self::new(Op::TypeFunction); instruction.set_result(id); instruction.add_operand(return_type_id); for parameter_id in parameter_ids { instruction.add_operand(*parameter_id); } instruction } // // Constant-Creation Instructions // pub(super) fn constant_null(result_type_id: Word, id: Word) -> Self { let mut instruction = Self::new(Op::ConstantNull); instruction.set_type(result_type_id); instruction.set_result(id); instruction } pub(super) fn constant_true(result_type_id: Word, id: Word) -> Self { let mut instruction = Self::new(Op::ConstantTrue); instruction.set_type(result_type_id); instruction.set_result(id); instruction } pub(super) fn constant_false(result_type_id: Word, id: Word) -> Self { let mut instruction = Self::new(Op::ConstantFalse); instruction.set_type(result_type_id); instruction.set_result(id); instruction } pub(super) fn constant_16bit(result_type_id: Word, id: Word, low: Word) -> Self { Self::constant(result_type_id, id, &[low]) } pub(super) fn constant_32bit(result_type_id: Word, id: Word, value: Word) -> Self { Self::constant(result_type_id, id, &[value]) } pub(super) fn constant_64bit(result_type_id: Word, id: Word, low: Word, high: Word) -> Self { Self::constant(result_type_id, id, &[low, high]) } pub(super) fn constant(result_type_id: Word, id: Word, values: &[Word]) -> Self { let mut instruction = Self::new(Op::Constant); instruction.set_type(result_type_id); instruction.set_result(id); for value in values { instruction.add_operand(*value); } instruction } pub(super) fn constant_composite( result_type_id: Word, id: Word, constituent_ids: &[Word], ) -> Self { let mut instruction = Self::new(Op::ConstantComposite); instruction.set_type(result_type_id); instruction.set_result(id); for constituent_id in constituent_ids { instruction.add_operand(*constituent_id); } instruction } // // Memory Instructions // pub(super) fn variable( result_type_id: Word, id: Word, storage_class: spirv::StorageClass, initializer_id: Option, ) -> Self { let mut instruction = Self::new(Op::Variable); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(storage_class as u32); if let Some(initializer_id) = initializer_id { instruction.add_operand(initializer_id); } instruction } pub(super) fn load( result_type_id: Word, id: Word, pointer_id: Word, memory_access: Option, ) -> Self { let mut instruction = Self::new(Op::Load); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(pointer_id); if let Some(memory_access) = memory_access { instruction.add_operand(memory_access.bits()); } instruction } pub(super) fn atomic_load( result_type_id: Word, id: Word, pointer_id: Word, scope_id: Word, semantics_id: Word, ) -> Self { let mut instruction = Self::new(Op::AtomicLoad); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(pointer_id); instruction.add_operand(scope_id); instruction.add_operand(semantics_id); instruction } pub(super) fn store( pointer_id: Word, value_id: Word, memory_access: Option, ) -> Self { let mut instruction = Self::new(Op::Store); instruction.add_operand(pointer_id); instruction.add_operand(value_id); if let Some(memory_access) = memory_access { instruction.add_operand(memory_access.bits()); } instruction } pub(super) fn atomic_store( pointer_id: Word, scope_id: Word, semantics_id: Word, value_id: Word, ) -> Self { let mut instruction = Self::new(Op::AtomicStore); instruction.add_operand(pointer_id); instruction.add_operand(scope_id); instruction.add_operand(semantics_id); instruction.add_operand(value_id); instruction } pub(super) fn access_chain( result_type_id: Word, id: Word, base_id: Word, index_ids: &[Word], ) -> Self { let mut instruction = Self::new(Op::AccessChain); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(base_id); for index_id in index_ids { instruction.add_operand(*index_id); } instruction } pub(super) fn array_length( result_type_id: Word, id: Word, structure_id: Word, array_member: Word, ) -> Self { let mut instruction = Self::new(Op::ArrayLength); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(structure_id); instruction.add_operand(array_member); instruction } // // Function Instructions // pub(super) fn function( return_type_id: Word, id: Word, function_control: spirv::FunctionControl, function_type_id: Word, ) -> Self { let mut instruction = Self::new(Op::Function); instruction.set_type(return_type_id); instruction.set_result(id); instruction.add_operand(function_control.bits()); instruction.add_operand(function_type_id); instruction } pub(super) fn function_parameter(result_type_id: Word, id: Word) -> Self { let mut instruction = Self::new(Op::FunctionParameter); instruction.set_type(result_type_id); instruction.set_result(id); instruction } pub(super) const fn function_end() -> Self { Self::new(Op::FunctionEnd) } pub(super) fn function_call( result_type_id: Word, id: Word, function_id: Word, argument_ids: &[Word], ) -> Self { let mut instruction = Self::new(Op::FunctionCall); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(function_id); for argument_id in argument_ids { instruction.add_operand(*argument_id); } instruction } // // Image Instructions // pub(super) fn sampled_image( result_type_id: Word, id: Word, image: Word, sampler: Word, ) -> Self { let mut instruction = Self::new(Op::SampledImage); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(image); instruction.add_operand(sampler); instruction } pub(super) fn image_sample( result_type_id: Word, id: Word, lod: SampleLod, sampled_image: Word, coordinates: Word, depth_ref: Option, ) -> Self { let op = match (lod, depth_ref) { (SampleLod::Explicit, None) => Op::ImageSampleExplicitLod, (SampleLod::Implicit, None) => Op::ImageSampleImplicitLod, (SampleLod::Explicit, Some(_)) => Op::ImageSampleDrefExplicitLod, (SampleLod::Implicit, Some(_)) => Op::ImageSampleDrefImplicitLod, }; let mut instruction = Self::new(op); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(sampled_image); instruction.add_operand(coordinates); if let Some(dref) = depth_ref { instruction.add_operand(dref); } instruction } pub(super) fn image_gather( result_type_id: Word, id: Word, sampled_image: Word, coordinates: Word, component_id: Word, depth_ref: Option, ) -> Self { let op = match depth_ref { None => Op::ImageGather, Some(_) => Op::ImageDrefGather, }; let mut instruction = Self::new(op); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(sampled_image); instruction.add_operand(coordinates); if let Some(dref) = depth_ref { instruction.add_operand(dref); } else { instruction.add_operand(component_id); } instruction } pub(super) fn image_fetch_or_read( op: Op, result_type_id: Word, id: Word, image: Word, coordinates: Word, ) -> Self { let mut instruction = Self::new(op); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(image); instruction.add_operand(coordinates); instruction } pub(super) fn image_write(image: Word, coordinates: Word, value: Word) -> Self { let mut instruction = Self::new(Op::ImageWrite); instruction.add_operand(image); instruction.add_operand(coordinates); instruction.add_operand(value); instruction } pub(super) fn image_texel_pointer( result_type_id: Word, id: Word, image: Word, coordinates: Word, sample: Word, ) -> Self { let mut instruction = Self::new(Op::ImageTexelPointer); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(image); instruction.add_operand(coordinates); instruction.add_operand(sample); instruction } pub(super) fn image_atomic( op: Op, result_type_id: Word, id: Word, pointer: Word, scope_id: Word, semantics_id: Word, value: Word, ) -> Self { let mut instruction = Self::new(op); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(pointer); instruction.add_operand(scope_id); instruction.add_operand(semantics_id); instruction.add_operand(value); instruction } pub(super) fn image_query(op: Op, result_type_id: Word, id: Word, image: Word) -> Self { let mut instruction = Self::new(op); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(image); instruction } // // Ray Query Instructions // #[allow(clippy::too_many_arguments)] pub(super) fn ray_query_initialize( query: Word, acceleration_structure: Word, ray_flags: Word, cull_mask: Word, ray_origin: Word, ray_tmin: Word, ray_dir: Word, ray_tmax: Word, ) -> Self { let mut instruction = Self::new(Op::RayQueryInitializeKHR); instruction.add_operand(query); instruction.add_operand(acceleration_structure); instruction.add_operand(ray_flags); instruction.add_operand(cull_mask); instruction.add_operand(ray_origin); instruction.add_operand(ray_tmin); instruction.add_operand(ray_dir); instruction.add_operand(ray_tmax); instruction } pub(super) fn ray_query_proceed(result_type_id: Word, id: Word, query: Word) -> Self { let mut instruction = Self::new(Op::RayQueryProceedKHR); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(query); instruction } pub(super) fn ray_query_generate_intersection(query: Word, hit: Word) -> Self { let mut instruction = Self::new(Op::RayQueryGenerateIntersectionKHR); instruction.add_operand(query); instruction.add_operand(hit); instruction } pub(super) fn ray_query_confirm_intersection(query: Word) -> Self { let mut instruction = Self::new(Op::RayQueryConfirmIntersectionKHR); instruction.add_operand(query); instruction } pub(super) fn ray_query_return_vertex_position( result_type_id: Word, id: Word, query: Word, intersection: Word, ) -> Self { let mut instruction = Self::new(Op::RayQueryGetIntersectionTriangleVertexPositionsKHR); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(query); instruction.add_operand(intersection); instruction } pub(super) fn ray_query_get_intersection( op: Op, result_type_id: Word, id: Word, query: Word, intersection: Word, ) -> Self { let mut instruction = Self::new(op); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(query); instruction.add_operand(intersection); instruction } pub(super) fn ray_query_get_t_min(result_type_id: Word, id: Word, query: Word) -> Self { let mut instruction = Self::new(Op::RayQueryGetRayTMinKHR); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(query); instruction } pub(super) fn ray_query_terminate(query: Word) -> Self { let mut instruction = Self::new(Op::RayQueryTerminateKHR); instruction.add_operand(query); instruction } // // Conversion Instructions // pub(super) fn unary(op: Op, result_type_id: Word, id: Word, value: Word) -> Self { let mut instruction = Self::new(op); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(value); instruction } // // Composite Instructions // pub(super) fn composite_construct( result_type_id: Word, id: Word, constituent_ids: &[Word], ) -> Self { let mut instruction = Self::new(Op::CompositeConstruct); instruction.set_type(result_type_id); instruction.set_result(id); for constituent_id in constituent_ids { instruction.add_operand(*constituent_id); } instruction } pub(super) fn composite_extract( result_type_id: Word, id: Word, composite_id: Word, indices: &[Word], ) -> Self { let mut instruction = Self::new(Op::CompositeExtract); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(composite_id); for index in indices { instruction.add_operand(*index); } instruction } pub(super) fn vector_extract_dynamic( result_type_id: Word, id: Word, vector_id: Word, index_id: Word, ) -> Self { let mut instruction = Self::new(Op::VectorExtractDynamic); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(vector_id); instruction.add_operand(index_id); instruction } pub(super) fn vector_shuffle( result_type_id: Word, id: Word, v1_id: Word, v2_id: Word, components: &[Word], ) -> Self { let mut instruction = Self::new(Op::VectorShuffle); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(v1_id); instruction.add_operand(v2_id); for &component in components { instruction.add_operand(component); } instruction } // // Arithmetic Instructions // pub(super) fn binary( op: Op, result_type_id: Word, id: Word, operand_1: Word, operand_2: Word, ) -> Self { let mut instruction = Self::new(op); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(operand_1); instruction.add_operand(operand_2); instruction } pub(super) fn ternary( op: Op, result_type_id: Word, id: Word, operand_1: Word, operand_2: Word, operand_3: Word, ) -> Self { let mut instruction = Self::new(op); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(operand_1); instruction.add_operand(operand_2); instruction.add_operand(operand_3); instruction } pub(super) fn quaternary( op: Op, result_type_id: Word, id: Word, operand_1: Word, operand_2: Word, operand_3: Word, operand_4: Word, ) -> Self { let mut instruction = Self::new(op); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(operand_1); instruction.add_operand(operand_2); instruction.add_operand(operand_3); instruction.add_operand(operand_4); instruction } pub(super) fn relational(op: Op, result_type_id: Word, id: Word, expr_id: Word) -> Self { let mut instruction = Self::new(op); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(expr_id); instruction } pub(super) fn atomic_binary( op: Op, result_type_id: Word, id: Word, pointer: Word, scope_id: Word, semantics_id: Word, value: Word, ) -> Self { let mut instruction = Self::new(op); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(pointer); instruction.add_operand(scope_id); instruction.add_operand(semantics_id); instruction.add_operand(value); instruction } // // Bit Instructions // // // Relational and Logical Instructions // // // Derivative Instructions // pub(super) fn derivative(op: Op, result_type_id: Word, id: Word, expr_id: Word) -> Self { let mut instruction = Self::new(op); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(expr_id); instruction } // // Control-Flow Instructions // pub(super) fn phi( result_type_id: Word, result_id: Word, var_parent_pairs: &[(Word, Word)], ) -> Self { let mut instruction = Self::new(Op::Phi); instruction.add_operand(result_type_id); instruction.add_operand(result_id); for &(variable, parent) in var_parent_pairs { instruction.add_operand(variable); instruction.add_operand(parent); } instruction } pub(super) fn selection_merge( merge_id: Word, selection_control: spirv::SelectionControl, ) -> Self { let mut instruction = Self::new(Op::SelectionMerge); instruction.add_operand(merge_id); instruction.add_operand(selection_control.bits()); instruction } pub(super) fn loop_merge( merge_id: Word, continuing_id: Word, selection_control: spirv::SelectionControl, ) -> Self { let mut instruction = Self::new(Op::LoopMerge); instruction.add_operand(merge_id); instruction.add_operand(continuing_id); instruction.add_operand(selection_control.bits()); instruction } pub(super) fn label(id: Word) -> Self { let mut instruction = Self::new(Op::Label); instruction.set_result(id); instruction } pub(super) fn branch(id: Word) -> Self { let mut instruction = Self::new(Op::Branch); instruction.add_operand(id); instruction } // TODO Branch Weights not implemented. pub(super) fn branch_conditional( condition_id: Word, true_label: Word, false_label: Word, ) -> Self { let mut instruction = Self::new(Op::BranchConditional); instruction.add_operand(condition_id); instruction.add_operand(true_label); instruction.add_operand(false_label); instruction } pub(super) fn switch(selector_id: Word, default_id: Word, cases: &[Case]) -> Self { let mut instruction = Self::new(Op::Switch); instruction.add_operand(selector_id); instruction.add_operand(default_id); for case in cases { instruction.add_operand(case.value); instruction.add_operand(case.label_id); } instruction } pub(super) fn select( result_type_id: Word, id: Word, condition_id: Word, accept_id: Word, reject_id: Word, ) -> Self { let mut instruction = Self::new(Op::Select); instruction.add_operand(result_type_id); instruction.add_operand(id); instruction.add_operand(condition_id); instruction.add_operand(accept_id); instruction.add_operand(reject_id); instruction } pub(super) const fn kill() -> Self { Self::new(Op::Kill) } pub(super) const fn return_void() -> Self { Self::new(Op::Return) } pub(super) fn return_value(value_id: Word) -> Self { let mut instruction = Self::new(Op::ReturnValue); instruction.add_operand(value_id); instruction } // // Atomic Instructions // // // Primitive Instructions // // Barriers pub(super) fn control_barrier( exec_scope_id: Word, mem_scope_id: Word, semantics_id: Word, ) -> Self { let mut instruction = Self::new(Op::ControlBarrier); instruction.add_operand(exec_scope_id); instruction.add_operand(mem_scope_id); instruction.add_operand(semantics_id); instruction } pub(super) fn memory_barrier(mem_scope_id: Word, semantics_id: Word) -> Self { let mut instruction = Self::new(Op::MemoryBarrier); instruction.add_operand(mem_scope_id); instruction.add_operand(semantics_id); instruction } // Group Instructions pub(super) fn group_non_uniform_ballot( result_type_id: Word, id: Word, exec_scope_id: Word, predicate: Word, ) -> Self { let mut instruction = Self::new(Op::GroupNonUniformBallot); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(exec_scope_id); instruction.add_operand(predicate); instruction } pub(super) fn group_non_uniform_broadcast_first( result_type_id: Word, id: Word, exec_scope_id: Word, value: Word, ) -> Self { let mut instruction = Self::new(Op::GroupNonUniformBroadcastFirst); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(exec_scope_id); instruction.add_operand(value); instruction } pub(super) fn group_non_uniform_gather( op: Op, result_type_id: Word, id: Word, exec_scope_id: Word, value: Word, index: Word, ) -> Self { let mut instruction = Self::new(op); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(exec_scope_id); instruction.add_operand(value); instruction.add_operand(index); instruction } pub(super) fn group_non_uniform_arithmetic( op: Op, result_type_id: Word, id: Word, exec_scope_id: Word, group_op: Option, value: Word, ) -> Self { let mut instruction = Self::new(op); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(exec_scope_id); if let Some(group_op) = group_op { instruction.add_operand(group_op as u32); } instruction.add_operand(value); instruction } pub(super) fn group_non_uniform_quad_swap( result_type_id: Word, id: Word, exec_scope_id: Word, value: Word, direction: Word, ) -> Self { let mut instruction = Self::new(Op::GroupNonUniformQuadSwap); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(exec_scope_id); instruction.add_operand(value); instruction.add_operand(direction); instruction } // Cooperative operations pub(super) fn coop_load( result_type_id: Word, id: Word, pointer_id: Word, layout_id: Word, stride_id: Word, ) -> Self { let mut instruction = Self::new(Op::CooperativeMatrixLoadKHR); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(pointer_id); instruction.add_operand(layout_id); instruction.add_operand(stride_id); instruction } pub(super) fn coop_store(id: Word, pointer_id: Word, layout_id: Word, stride_id: Word) -> Self { let mut instruction = Self::new(Op::CooperativeMatrixStoreKHR); instruction.add_operand(pointer_id); instruction.add_operand(id); instruction.add_operand(layout_id); instruction.add_operand(stride_id); instruction } pub(super) fn coop_mul_add(result_type_id: Word, id: Word, a: Word, b: Word, c: Word) -> Self { let mut instruction = Self::new(Op::CooperativeMatrixMulAddKHR); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(a); instruction.add_operand(b); instruction.add_operand(c); instruction } } impl From for spirv::ImageFormat { fn from(format: crate::StorageFormat) -> Self { use crate::StorageFormat as Sf; match format { Sf::R8Unorm => Self::R8, Sf::R8Snorm => Self::R8Snorm, Sf::R8Uint => Self::R8ui, Sf::R8Sint => Self::R8i, Sf::R16Uint => Self::R16ui, Sf::R16Sint => Self::R16i, Sf::R16Float => Self::R16f, Sf::Rg8Unorm => Self::Rg8, Sf::Rg8Snorm => Self::Rg8Snorm, Sf::Rg8Uint => Self::Rg8ui, Sf::Rg8Sint => Self::Rg8i, Sf::R32Uint => Self::R32ui, Sf::R32Sint => Self::R32i, Sf::R32Float => Self::R32f, Sf::Rg16Uint => Self::Rg16ui, Sf::Rg16Sint => Self::Rg16i, Sf::Rg16Float => Self::Rg16f, Sf::Rgba8Unorm => Self::Rgba8, Sf::Rgba8Snorm => Self::Rgba8Snorm, Sf::Rgba8Uint => Self::Rgba8ui, Sf::Rgba8Sint => Self::Rgba8i, Sf::Bgra8Unorm => Self::Unknown, Sf::Rgb10a2Uint => Self::Rgb10a2ui, Sf::Rgb10a2Unorm => Self::Rgb10A2, Sf::Rg11b10Ufloat => Self::R11fG11fB10f, Sf::R64Uint => Self::R64ui, Sf::Rg32Uint => Self::Rg32ui, Sf::Rg32Sint => Self::Rg32i, Sf::Rg32Float => Self::Rg32f, Sf::Rgba16Uint => Self::Rgba16ui, Sf::Rgba16Sint => Self::Rgba16i, Sf::Rgba16Float => Self::Rgba16f, Sf::Rgba32Uint => Self::Rgba32ui, Sf::Rgba32Sint => Self::Rgba32i, Sf::Rgba32Float => Self::Rgba32f, Sf::R16Unorm => Self::R16, Sf::R16Snorm => Self::R16Snorm, Sf::Rg16Unorm => Self::Rg16, Sf::Rg16Snorm => Self::Rg16Snorm, Sf::Rgba16Unorm => Self::Rgba16, Sf::Rgba16Snorm => Self::Rgba16Snorm, } } } impl From for spirv::Dim { fn from(dim: crate::ImageDimension) -> Self { use crate::ImageDimension as Id; match dim { Id::D1 => Self::Dim1D, Id::D2 => Self::Dim2D, Id::D3 => Self::Dim3D, Id::Cube => Self::DimCube, } } } impl From for spirv::CooperativeMatrixUse { fn from(role: crate::CooperativeRole) -> Self { match role { crate::CooperativeRole::A => Self::MatrixAKHR, crate::CooperativeRole::B => Self::MatrixBKHR, crate::CooperativeRole::C => Self::MatrixAccumulatorKHR, } } } naga-29.0.3/src/back/spv/layout.rs000064400000000000000000000161511046102023000150130ustar 00000000000000use alloc::{vec, vec::Vec}; use core::iter; use spirv::{Op, Word, MAGIC_NUMBER}; use super::{Instruction, LogicalLayout, PhysicalLayout}; #[cfg(test)] use alloc::format; // https://github.com/KhronosGroup/SPIRV-Headers/pull/195 const GENERATOR: Word = 28; impl PhysicalLayout { pub(super) const fn new(major_version: u8, minor_version: u8) -> Self { let version = ((major_version as u32) << 16) | ((minor_version as u32) << 8); PhysicalLayout { magic_number: MAGIC_NUMBER, version, generator: GENERATOR, bound: 0, instruction_schema: 0x0u32, } } pub(super) fn in_words(&self, sink: &mut impl Extend) { sink.extend(iter::once(self.magic_number)); sink.extend(iter::once(self.version)); sink.extend(iter::once(self.generator)); sink.extend(iter::once(self.bound)); sink.extend(iter::once(self.instruction_schema)); } /// Returns `(major, minor)`. pub(super) const fn lang_version(&self) -> (u8, u8) { let major = (self.version >> 16) as u8; let minor = (self.version >> 8) as u8; (major, minor) } } impl super::reclaimable::Reclaimable for PhysicalLayout { fn reclaim(self) -> Self { PhysicalLayout { magic_number: self.magic_number, version: self.version, generator: self.generator, instruction_schema: self.instruction_schema, bound: 0, } } } impl LogicalLayout { pub(super) fn in_words(&self, sink: &mut impl Extend) { sink.extend(self.capabilities.iter().cloned()); sink.extend(self.extensions.iter().cloned()); sink.extend(self.ext_inst_imports.iter().cloned()); sink.extend(self.memory_model.iter().cloned()); sink.extend(self.entry_points.iter().cloned()); sink.extend(self.execution_modes.iter().cloned()); sink.extend(self.debugs.iter().cloned()); sink.extend(self.annotations.iter().cloned()); sink.extend(self.declarations.iter().cloned()); sink.extend(self.function_declarations.iter().cloned()); sink.extend(self.function_definitions.iter().cloned()); } } impl super::reclaimable::Reclaimable for LogicalLayout { fn reclaim(self) -> Self { Self { capabilities: self.capabilities.reclaim(), extensions: self.extensions.reclaim(), ext_inst_imports: self.ext_inst_imports.reclaim(), memory_model: self.memory_model.reclaim(), entry_points: self.entry_points.reclaim(), execution_modes: self.execution_modes.reclaim(), debugs: self.debugs.reclaim(), annotations: self.annotations.reclaim(), declarations: self.declarations.reclaim(), function_declarations: self.function_declarations.reclaim(), function_definitions: self.function_definitions.reclaim(), } } } impl Instruction { pub(super) const fn new(op: Op) -> Self { Instruction { op, wc: 1, // Always start at 1 for the first word (OP + WC), type_id: None, result_id: None, operands: vec![], } } pub(super) fn set_type(&mut self, id: Word) { assert!(self.type_id.is_none(), "Type can only be set once"); self.type_id = Some(id); self.wc += 1; } pub(super) fn set_result(&mut self, id: Word) { assert!(self.result_id.is_none(), "Result can only be set once"); self.result_id = Some(id); self.wc += 1; } pub(super) fn add_operand(&mut self, operand: Word) { self.operands.push(operand); self.wc += 1; } pub(super) fn add_operands(&mut self, operands: Vec) { for operand in operands.into_iter() { self.add_operand(operand) } } pub(super) fn to_words(&self, sink: &mut impl Extend) { sink.extend(Some((self.wc << 16) | self.op as u32)); sink.extend(self.type_id); sink.extend(self.result_id); sink.extend(self.operands.iter().cloned()); } } impl Instruction { #[cfg(test)] fn validate(&self, words: &[Word]) { let mut inst_index = 0; let (wc, op) = ((words[inst_index] >> 16) as u16, words[inst_index] as u16); inst_index += 1; assert_eq!(wc, words.len() as u16); assert_eq!(op, self.op as u16); if let Some(type_id) = self.type_id { assert_eq!(words[inst_index], type_id); inst_index += 1; } if let Some(result_id) = self.result_id { assert_eq!(words[inst_index], result_id); inst_index += 1; } for (op_index, i) in (inst_index..wc as usize).enumerate() { assert_eq!(words[i], self.operands[op_index]); } } } #[test] fn test_physical_layout_in_words() { let bound = 5; // The least and most significant bytes of `version` must both be zero // according to the SPIR-V spec. let version = 0x0001_0200; let mut output = vec![]; let mut layout = PhysicalLayout::new(1, 2); layout.bound = bound; layout.in_words(&mut output); assert_eq!(&output, &[MAGIC_NUMBER, version, GENERATOR, bound, 0,]); } #[test] fn test_logical_layout_in_words() { let mut output = vec![]; let mut layout = LogicalLayout::default(); let layout_vectors = 11; let mut instructions = Vec::with_capacity(layout_vectors); let vector_names = &[ "Capabilities", "Extensions", "External Instruction Imports", "Memory Model", "Entry Points", "Execution Modes", "Debugs", "Annotations", "Declarations", "Function Declarations", "Function Definitions", ]; for (i, _) in vector_names.iter().enumerate().take(layout_vectors) { let mut dummy_instruction = Instruction::new(Op::Constant); dummy_instruction.set_type((i + 1) as u32); dummy_instruction.set_result((i + 2) as u32); dummy_instruction.add_operand((i + 3) as u32); dummy_instruction.add_operands(super::helpers::string_to_words( format!("This is the vector: {}", vector_names[i]).as_str(), )); instructions.push(dummy_instruction); } instructions[0].to_words(&mut layout.capabilities); instructions[1].to_words(&mut layout.extensions); instructions[2].to_words(&mut layout.ext_inst_imports); instructions[3].to_words(&mut layout.memory_model); instructions[4].to_words(&mut layout.entry_points); instructions[5].to_words(&mut layout.execution_modes); instructions[6].to_words(&mut layout.debugs); instructions[7].to_words(&mut layout.annotations); instructions[8].to_words(&mut layout.declarations); instructions[9].to_words(&mut layout.function_declarations); instructions[10].to_words(&mut layout.function_definitions); layout.in_words(&mut output); let mut index: usize = 0; for instruction in instructions { let wc = instruction.wc as usize; instruction.validate(&output[index..index + wc]); index += wc; } } naga-29.0.3/src/back/spv/mesh_shader.rs000064400000000000000000001207511046102023000157620ustar 00000000000000use alloc::vec::Vec; use spirv::Word; use crate::{ back::spv::{ helpers::BindingDecorations, writer::FunctionInterface, Block, EntryPointContext, Error, Instruction, WriterFlags, }, non_max_u32::NonMaxU32, Handle, }; #[derive(Clone)] pub struct MeshReturnMember { pub ty_id: u32, pub binding: crate::Binding, } struct PerOutputTypeMeshReturnInfo { max_length_constant: Word, array_type_id: Word, struct_members: Vec, // * Most builtins must be in the same block. // * All bindings must be in their own unique block. // * The primitive indices builtin family needs its own block. // * Cull primitive doesn't care about having its own block, but // some older validation layers didn't respect this. builtin_block: Option, bindings: Vec, } pub struct MeshReturnInfo { /// Id of the workgroup variable containing the data to be output out_variable_id: Word, /// All members of the output variable struct type out_members: Vec, /// Id of the input variable for local invocation id local_invocation_index_var_id: Word, /// Total workgroup size (product) workgroup_size: u32, /// Vertex-specific info vertex_info: PerOutputTypeMeshReturnInfo, /// Primitive-specific info primitive_info: PerOutputTypeMeshReturnInfo, /// Array variable for the primitive indices builtin primitive_indices: Option, } impl super::Writer { /// Sets up an output variable that will handle part of the mesh shader output pub(super) fn write_mesh_return_global_variable( &mut self, ty: u32, array_size_id: u32, ) -> Result { let array_ty = self.id_gen.next(); Instruction::type_array(array_ty, ty, array_size_id) .to_words(&mut self.logical_layout.declarations); let ptr_ty = self.get_pointer_type_id(array_ty, spirv::StorageClass::Output); let var_id = self.id_gen.next(); Instruction::variable(ptr_ty, var_id, spirv::StorageClass::Output, None) .to_words(&mut self.logical_layout.declarations); Ok(var_id) } /// This does various setup things to allow mesh shader entry points /// to be properly written, such as creating the output variables pub(super) fn write_entry_point_mesh_shader_info( &mut self, iface: &mut FunctionInterface, local_invocation_index_id: Option, ir_module: &crate::Module, ep_context: &mut EntryPointContext, ) -> Result<(), Error> { let Some(ref mesh_info) = iface.mesh_info else { return Ok(()); }; // Collect the members in the output structs let out_members: Vec = match &ir_module.types[ir_module.global_variables[mesh_info.output_variable].ty] { &crate::Type { inner: crate::TypeInner::Struct { ref members, .. }, .. } => members .iter() .map(|a| MeshReturnMember { ty_id: self.get_handle_type_id(a.ty), binding: a.binding.clone().unwrap(), }) .collect(), _ => unreachable!(), }; let vertex_array_type_id = out_members .iter() .find(|a| a.binding == crate::Binding::BuiltIn(crate::BuiltIn::Vertices)) .unwrap() .ty_id; let primitive_array_type_id = out_members .iter() .find(|a| a.binding == crate::Binding::BuiltIn(crate::BuiltIn::Primitives)) .unwrap() .ty_id; let vertex_members = match &ir_module.types[mesh_info.vertex_output_type] { &crate::Type { inner: crate::TypeInner::Struct { ref members, .. }, .. } => members .iter() .map(|a| MeshReturnMember { ty_id: self.get_handle_type_id(a.ty), binding: a.binding.clone().unwrap(), }) .collect(), _ => unreachable!(), }; let primitive_members = match &ir_module.types[mesh_info.primitive_output_type] { &crate::Type { inner: crate::TypeInner::Struct { ref members, .. }, .. } => members .iter() .map(|a| MeshReturnMember { ty_id: self.get_handle_type_id(a.ty), binding: a.binding.clone().unwrap(), }) .collect(), _ => unreachable!(), }; // In the final return, we do a giant memcpy, for which this is helpful let local_invocation_index_var_id = match local_invocation_index_id { Some(a) => a, None => { let u32_id = self.get_u32_type_id(); let var = self.id_gen.next(); Instruction::variable( self.get_pointer_type_id(u32_id, spirv::StorageClass::Input), var, spirv::StorageClass::Input, None, ) .to_words(&mut self.logical_layout.declarations); Instruction::decorate( var, spirv::Decoration::BuiltIn, &[spirv::BuiltIn::LocalInvocationIndex as u32], ) .to_words(&mut self.logical_layout.annotations); iface.varying_ids.push(var); var } }; // This is the information that is passed to the function writer // so that it can write the final return logic let mut mesh_return_info = MeshReturnInfo { out_variable_id: self.global_variables[mesh_info.output_variable].var_id, out_members, local_invocation_index_var_id, workgroup_size: self .get_constant_scalar(crate::Literal::U32(iface.workgroup_size.iter().product())), vertex_info: PerOutputTypeMeshReturnInfo { array_type_id: vertex_array_type_id, struct_members: vertex_members, max_length_constant: self .get_constant_scalar(crate::Literal::U32(mesh_info.max_vertices)), bindings: Vec::new(), builtin_block: None, }, primitive_info: PerOutputTypeMeshReturnInfo { array_type_id: primitive_array_type_id, struct_members: primitive_members, max_length_constant: self .get_constant_scalar(crate::Literal::U32(mesh_info.max_primitives)), bindings: Vec::new(), builtin_block: None, }, primitive_indices: None, }; let vert_array_size_id = self.get_constant_scalar(crate::Literal::U32(mesh_info.max_vertices)); let prim_array_size_id = self.get_constant_scalar(crate::Literal::U32(mesh_info.max_primitives)); // Create the actual output variables and types. // According to SPIR-V, // * All builtins must be in the same output `Block` (except builtins for different output types like vertex/primitive) // * Each member with `location` must be in its own `Block` decorated `struct` // * Some builtins like CullPrimitiveEXT don't care as much (older validation layers don't know this! Wonderful!) // * Some builtins like the indices ones need to be in their own output variable without a struct wrapper // Write vertex builtin block if mesh_return_info .vertex_info .struct_members .iter() .any(|a| matches!(a.binding, crate::Binding::BuiltIn(..))) { let builtin_block_ty_id = self.id_gen.next(); let mut ins = Instruction::type_struct(builtin_block_ty_id, &[]); let mut bi_index = 0; let mut decorations = Vec::new(); for member in &mesh_return_info.vertex_info.struct_members { if let crate::Binding::BuiltIn(_) = member.binding { ins.add_operand(member.ty_id); let binding = self.map_binding( ir_module, iface.stage, spirv::StorageClass::Output, // Unused except in fragment shaders with other conditions, so we can pass null Handle::new(NonMaxU32::new(0).unwrap()), &member.binding, )?; match binding { BindingDecorations::BuiltIn(bi, others) => { decorations.push(Instruction::member_decorate( builtin_block_ty_id, bi_index, spirv::Decoration::BuiltIn, &[bi as Word], )); for other in others { decorations.push(Instruction::member_decorate( builtin_block_ty_id, bi_index, other, &[], )); } } _ => unreachable!(), } bi_index += 1; } } ins.to_words(&mut self.logical_layout.declarations); decorations.push(Instruction::decorate( builtin_block_ty_id, spirv::Decoration::Block, &[], )); for dec in decorations { dec.to_words(&mut self.logical_layout.annotations); } let v = self.write_mesh_return_global_variable(builtin_block_ty_id, vert_array_size_id)?; iface.varying_ids.push(v); if self.flags.contains(WriterFlags::DEBUG) { self.debugs .push(Instruction::name(v, "naga_vertex_builtin_outputs")); } mesh_return_info.vertex_info.builtin_block = Some(v); } // Write primitive builtin block if mesh_return_info .primitive_info .struct_members .iter() .any(|a| { !matches!( a.binding, crate::Binding::BuiltIn( crate::BuiltIn::PointIndex | crate::BuiltIn::LineIndices | crate::BuiltIn::TriangleIndices ) | crate::Binding::Location { .. } ) }) { let builtin_block_ty_id = self.id_gen.next(); let mut ins = Instruction::type_struct(builtin_block_ty_id, &[]); let mut bi_index = 0; let mut decorations = Vec::new(); for member in &mesh_return_info.primitive_info.struct_members { if let crate::Binding::BuiltIn(bi) = member.binding { // These need to be in their own block, unlike other builtins if matches!( bi, crate::BuiltIn::PointIndex | crate::BuiltIn::LineIndices | crate::BuiltIn::TriangleIndices, ) { continue; } ins.add_operand(member.ty_id); let binding = self.map_binding( ir_module, iface.stage, spirv::StorageClass::Output, // Unused except in fragment shaders with other conditions, so we can pass null Handle::new(NonMaxU32::new(0).unwrap()), &member.binding, )?; match binding { BindingDecorations::BuiltIn(bi, others) => { decorations.push(Instruction::member_decorate( builtin_block_ty_id, bi_index, spirv::Decoration::BuiltIn, &[bi as Word], )); for other in others { decorations.push(Instruction::member_decorate( builtin_block_ty_id, bi_index, other, &[], )); } } _ => unreachable!(), } bi_index += 1; } } ins.to_words(&mut self.logical_layout.declarations); decorations.push(Instruction::decorate( builtin_block_ty_id, spirv::Decoration::Block, &[], )); for dec in decorations { dec.to_words(&mut self.logical_layout.annotations); } let v = self.write_mesh_return_global_variable(builtin_block_ty_id, prim_array_size_id)?; Instruction::decorate(v, spirv::Decoration::PerPrimitiveEXT, &[]) .to_words(&mut self.logical_layout.annotations); iface.varying_ids.push(v); if self.flags.contains(WriterFlags::DEBUG) { self.debugs .push(Instruction::name(v, "naga_primitive_builtin_outputs")); } mesh_return_info.primitive_info.builtin_block = Some(v); } // Write vertex binding output blocks (1 array per output struct member) for member in &mesh_return_info.vertex_info.struct_members { match member.binding { crate::Binding::Location { location, .. } => { // Create variable let v = self.write_mesh_return_global_variable(member.ty_id, vert_array_size_id)?; // Decorate the variable with Location Instruction::decorate(v, spirv::Decoration::Location, &[location]) .to_words(&mut self.logical_layout.annotations); iface.varying_ids.push(v); mesh_return_info.vertex_info.bindings.push(v); } crate::Binding::BuiltIn(_) => (), } } // Write primitive binding output blocks (1 array per output struct member) // Also write indices output block for member in &mesh_return_info.primitive_info.struct_members { match member.binding { crate::Binding::BuiltIn( crate::BuiltIn::PointIndex | crate::BuiltIn::LineIndices | crate::BuiltIn::TriangleIndices, ) => { // This is written here instead of as part of the builtin block let v = self.write_mesh_return_global_variable(member.ty_id, prim_array_size_id)?; // This shouldn't be marked as PerPrimitiveEXT Instruction::decorate( v, spirv::Decoration::BuiltIn, &[match member.binding.to_built_in().unwrap() { crate::BuiltIn::PointIndex => spirv::BuiltIn::PrimitivePointIndicesEXT, crate::BuiltIn::LineIndices => spirv::BuiltIn::PrimitiveLineIndicesEXT, crate::BuiltIn::TriangleIndices => { spirv::BuiltIn::PrimitiveTriangleIndicesEXT } _ => unreachable!(), } as Word], ) .to_words(&mut self.logical_layout.annotations); iface.varying_ids.push(v); if self.flags.contains(WriterFlags::DEBUG) { self.debugs .push(Instruction::name(v, "naga_primitive_indices_outputs")); } mesh_return_info.primitive_indices = Some(v); } crate::Binding::Location { location, .. } => { // Create variable let v = self.write_mesh_return_global_variable(member.ty_id, prim_array_size_id)?; // Decorate the variable with Location Instruction::decorate(v, spirv::Decoration::Location, &[location]) .to_words(&mut self.logical_layout.annotations); // Decorate it with PerPrimitiveEXT Instruction::decorate(v, spirv::Decoration::PerPrimitiveEXT, &[]) .to_words(&mut self.logical_layout.annotations); iface.varying_ids.push(v); mesh_return_info.primitive_info.bindings.push(v); } crate::Binding::BuiltIn(_) => (), } } // Store this where it can be read later during function write ep_context.mesh_state = Some(mesh_return_info); Ok(()) } pub(super) fn write_entry_point_task_return( &mut self, value_id: Word, body: &mut Vec, task_payload: Word, ) -> Result { // OpEmitMeshTasksEXT must be called right before exiting (after setting other // output variables if there are any) // Extract the vec3 into 3 u32's let values = [self.id_gen.next(), self.id_gen.next(), self.id_gen.next()]; for (i, &value) in values.iter().enumerate() { let instruction = Instruction::composite_extract( self.get_u32_type_id(), value, value_id, &[i as Word], ); body.push(instruction); } let mut instruction = Instruction::new(spirv::Op::EmitMeshTasksEXT); for id in values { instruction.add_operand(id); } // We have to include the task payload in our call instruction.add_operand(task_payload); Ok(instruction) } /// This writes the actual loop #[allow(clippy::too_many_arguments)] fn write_mesh_copy_loop( &mut self, body: &mut Vec, mut loop_body_block: Vec, loop_header: u32, loop_merge: u32, count_id: u32, index_var: u32, return_info: &MeshReturnInfo, ) { let u32_id = self.get_u32_type_id(); let condition_check = self.id_gen.next(); let loop_continue = self.id_gen.next(); let loop_body = self.id_gen.next(); // Loop header { body.push(Instruction::label(loop_header)); body.push(Instruction::loop_merge( loop_merge, loop_continue, spirv::SelectionControl::empty(), )); body.push(Instruction::branch(condition_check)); } // Condition check - check if i is less than num vertices to copy { body.push(Instruction::label(condition_check)); let val_i = self.id_gen.next(); body.push(Instruction::load(u32_id, val_i, index_var, None)); let cond = self.id_gen.next(); body.push(Instruction::binary( spirv::Op::ULessThan, self.get_bool_type_id(), cond, val_i, count_id, )); body.push(Instruction::branch_conditional(cond, loop_body, loop_merge)); } // Loop body { body.push(Instruction::label(loop_body)); body.append(&mut loop_body_block); body.push(Instruction::branch(loop_continue)); } // Loop continue - increment i { body.push(Instruction::label(loop_continue)); let prev_val_i = self.id_gen.next(); body.push(Instruction::load(u32_id, prev_val_i, index_var, None)); let new_val_i = self.id_gen.next(); body.push(Instruction::binary( spirv::Op::IAdd, u32_id, new_val_i, prev_val_i, return_info.workgroup_size, )); body.push(Instruction::store(index_var, new_val_i, None)); body.push(Instruction::branch(loop_header)); } } /// This generates the instructions used to copy all parts of a single output vertex/primitive /// to their individual output locations fn write_mesh_copy_body( &mut self, is_primitive: bool, return_info: &MeshReturnInfo, index_var: u32, vert_array_ptr: u32, prim_array_ptr: u32, ) -> Vec { let u32_type_id = self.get_u32_type_id(); let mut body = Vec::new(); // Current index to copy let val_i = self.id_gen.next(); body.push(Instruction::load(u32_type_id, val_i, index_var, None)); let info = if is_primitive { &return_info.primitive_info } else { &return_info.vertex_info }; let array_ptr = if is_primitive { prim_array_ptr } else { vert_array_ptr }; let mut builtin_index = 0; let mut binding_index = 0; // Write individual members of the vertex for (member_id, member) in info.struct_members.iter().enumerate() { let val_to_copy_ptr = self.id_gen.next(); body.push(Instruction::access_chain( self.get_pointer_type_id(member.ty_id, spirv::StorageClass::Workgroup), val_to_copy_ptr, array_ptr, &[ val_i, self.get_constant_scalar(crate::Literal::U32(member_id as u32)), ], )); let val_to_copy = self.id_gen.next(); body.push(Instruction::load( member.ty_id, val_to_copy, val_to_copy_ptr, None, )); let mut needs_y_flip = false; let ptr_to_copy_to = self.id_gen.next(); // Get a pointer to the struct member to copy match member.binding { crate::Binding::BuiltIn( crate::BuiltIn::PointIndex | crate::BuiltIn::LineIndices | crate::BuiltIn::TriangleIndices, ) => { body.push(Instruction::access_chain( self.get_pointer_type_id(member.ty_id, spirv::StorageClass::Output), ptr_to_copy_to, return_info.primitive_indices.unwrap(), &[val_i], )); } crate::Binding::BuiltIn(bi) => { body.push(Instruction::access_chain( self.get_pointer_type_id(member.ty_id, spirv::StorageClass::Output), ptr_to_copy_to, info.builtin_block.unwrap(), &[ val_i, self.get_constant_scalar(crate::Literal::U32(builtin_index)), ], )); needs_y_flip = matches!(bi, crate::BuiltIn::Position { .. }) && self.flags.contains(WriterFlags::ADJUST_COORDINATE_SPACE); builtin_index += 1; } crate::Binding::Location { .. } => { body.push(Instruction::access_chain( self.get_pointer_type_id(member.ty_id, spirv::StorageClass::Output), ptr_to_copy_to, info.bindings[binding_index], &[val_i], )); binding_index += 1; } } body.push(Instruction::store(ptr_to_copy_to, val_to_copy, None)); // Flip the vertex position y coordinate in some cases // Can't use epilogue flip because can't read from this storage class if needs_y_flip { let prev_y = self.id_gen.next(); body.push(Instruction::composite_extract( self.get_f32_type_id(), prev_y, val_to_copy, &[1], )); let new_y = self.id_gen.next(); body.push(Instruction::unary( spirv::Op::FNegate, self.get_f32_type_id(), new_y, prev_y, )); let new_ptr_to_copy_to = self.id_gen.next(); body.push(Instruction::access_chain( self.get_f32_pointer_type_id(spirv::StorageClass::Output), new_ptr_to_copy_to, ptr_to_copy_to, &[self.get_constant_scalar(crate::Literal::U32(1))], )); body.push(Instruction::store(new_ptr_to_copy_to, new_y, None)); } } body } /// Writes the return call for a mesh shader, which involves copying previously /// written vertices/primitives into the actual output location. pub(super) fn write_mesh_shader_return( &mut self, return_info: &MeshReturnInfo, block: &mut Block, loop_counter_vertices: u32, loop_counter_primitives: u32, local_invocation_index_id: Word, ) -> Result<(), Error> { let u32_id = self.get_u32_type_id(); // Load the actual vertex and primitive counts let mut load_u32_by_member_index = |members: &[MeshReturnMember], bi: crate::BuiltIn, max: u32| { let member_index = members .iter() .position(|a| a.binding == crate::Binding::BuiltIn(bi)) .unwrap() as u32; let ptr_id = self.id_gen.next(); block.body.push(Instruction::access_chain( self.get_pointer_type_id(u32_id, spirv::StorageClass::Workgroup), ptr_id, return_info.out_variable_id, &[self.get_constant_scalar(crate::Literal::U32(member_index))], )); let before_min_id = self.id_gen.next(); block .body .push(Instruction::load(u32_id, before_min_id, ptr_id, None)); // Clamp the values let id = self.id_gen.next(); block.body.push(Instruction::ext_inst_gl_op( self.gl450_ext_inst_id, spirv::GlslStd450Op::UMin, u32_id, id, &[before_min_id, max], )); id }; let vert_count_id = load_u32_by_member_index( &return_info.out_members, crate::BuiltIn::VertexCount, return_info.vertex_info.max_length_constant, ); let prim_count_id = load_u32_by_member_index( &return_info.out_members, crate::BuiltIn::PrimitiveCount, return_info.primitive_info.max_length_constant, ); // Get pointers to the arrays of data to extract let mut get_array_ptr = |bi: crate::BuiltIn, array_type_id: u32| { let id = self.id_gen.next(); block.body.push(Instruction::access_chain( self.get_pointer_type_id(array_type_id, spirv::StorageClass::Workgroup), id, return_info.out_variable_id, &[self.get_constant_scalar(crate::Literal::U32( return_info .out_members .iter() .position(|a| a.binding == crate::Binding::BuiltIn(bi)) .unwrap() as u32, ))], )); id }; let vert_array_ptr = get_array_ptr( crate::BuiltIn::Vertices, return_info.vertex_info.array_type_id, ); let prim_array_ptr = get_array_ptr( crate::BuiltIn::Primitives, return_info.primitive_info.array_type_id, ); // This must be called exactly once before any other mesh outputs are written { let mut ins = Instruction::new(spirv::Op::SetMeshOutputsEXT); ins.add_operand(vert_count_id); ins.add_operand(prim_count_id); block.body.push(ins); } // This is iterating over every returned vertex and splitting // it out into the multiple per-output arrays. let vertex_loop_header = self.id_gen.next(); let prim_loop_header = self.id_gen.next(); let in_between_loops = self.id_gen.next(); let func_end = self.id_gen.next(); block.body.push(Instruction::store( loop_counter_vertices, local_invocation_index_id, None, )); block.body.push(Instruction::branch(vertex_loop_header)); let vertex_copy_body = self.write_mesh_copy_body( false, return_info, loop_counter_vertices, vert_array_ptr, prim_array_ptr, ); // Write vertex copy loop self.write_mesh_copy_loop( &mut block.body, vertex_copy_body, vertex_loop_header, in_between_loops, vert_count_id, loop_counter_vertices, return_info, ); // In between loops, reset the initial index { block.body.push(Instruction::label(in_between_loops)); block.body.push(Instruction::store( loop_counter_primitives, local_invocation_index_id, None, )); block.body.push(Instruction::branch(prim_loop_header)); } let primitive_copy_body = self.write_mesh_copy_body( true, return_info, loop_counter_primitives, vert_array_ptr, prim_array_ptr, ); // Write primitive copy loop self.write_mesh_copy_loop( &mut block.body, primitive_copy_body, prim_loop_header, func_end, prim_count_id, loop_counter_primitives, return_info, ); block.body.push(Instruction::label(func_end)); Ok(()) } pub(super) fn write_mesh_shader_wrapper( &mut self, return_info: &MeshReturnInfo, inner_id: u32, ) -> Result { let out_id = self.id_gen.next(); let mut function = super::Function::default(); let lookup_function_type = super::LookupFunctionType { parameter_type_ids: alloc::vec![], return_type_id: self.void_type, }; let function_type = self.get_function_type(lookup_function_type); function.signature = Some(Instruction::function( self.void_type, out_id, spirv::FunctionControl::empty(), function_type, )); let u32_id = self.get_u32_type_id(); { let mut block = Block::new(self.id_gen.next()); // A general function variable that we guarantee to allow in the final return. It must be // declared at the top of the function. Currently it is used in the memcpy part to keep // track of the current index to copy. let loop_counter_vertices = self.id_gen.next(); let loop_counter_primitives = self.id_gen.next(); block.body.insert( 0, Instruction::variable( self.get_pointer_type_id(u32_id, spirv::StorageClass::Function), loop_counter_vertices, spirv::StorageClass::Function, None, ), ); block.body.insert( 1, Instruction::variable( self.get_pointer_type_id(u32_id, spirv::StorageClass::Function), loop_counter_primitives, spirv::StorageClass::Function, None, ), ); let local_invocation_index_id = self.id_gen.next(); block.body.push(Instruction::load( u32_id, local_invocation_index_id, return_info.local_invocation_index_var_id, None, )); block.body.push(Instruction::function_call( self.void_type, self.id_gen.next(), inner_id, &[], )); self.write_control_barrier(crate::Barrier::WORK_GROUP, &mut block.body); self.write_mesh_shader_return( return_info, &mut block, loop_counter_vertices, loop_counter_primitives, local_invocation_index_id, )?; function.consume(block, Instruction::return_void()); } function.to_words(&mut self.logical_layout.function_definitions); Ok(out_id) } pub(super) fn write_task_shader_wrapper( &mut self, task_payload: Word, inner_id: u32, ) -> Result { let out_id = self.id_gen.next(); let mut function = super::Function::default(); let lookup_function_type = super::LookupFunctionType { parameter_type_ids: alloc::vec![], return_type_id: self.void_type, }; let function_type = self.get_function_type(lookup_function_type); function.signature = Some(Instruction::function( self.void_type, out_id, spirv::FunctionControl::empty(), function_type, )); { let mut block = Block::new(self.id_gen.next()); let result = self.id_gen.next(); block.body.push(Instruction::function_call( self.get_vec3u_type_id(), result, inner_id, &[], )); self.write_control_barrier(crate::Barrier::WORK_GROUP, &mut block.body); let final_value = if let Some(task_limits) = self.task_dispatch_limits { let zero_u32 = self.get_constant_scalar(crate::Literal::U32(0)); let max_per_dim = self.get_constant_scalar(crate::Literal::U32( task_limits.max_mesh_workgroups_per_dim, )); let max_total = self.get_constant_scalar(crate::Literal::U32( task_limits.max_mesh_workgroups_total, )); let combined_struct_type = self.get_tuple_of_u32s_ty_id(); let values = [self.id_gen.next(), self.id_gen.next(), self.id_gen.next()]; for (i, value) in values.into_iter().enumerate() { block.body.push(Instruction::composite_extract( self.get_u32_type_id(), value, result, &[i as u32], )); } let prod_1 = self.id_gen.next(); let overflows = [self.id_gen.next(), self.id_gen.next()]; { let struct_out = self.id_gen.next(); block.body.push(Instruction::binary( spirv::Op::UMulExtended, combined_struct_type, struct_out, values[0], values[1], )); block.body.push(Instruction::composite_extract( self.get_u32_type_id(), prod_1, struct_out, &[0], )); block.body.push(Instruction::composite_extract( self.get_u32_type_id(), overflows[0], struct_out, &[1], )); } let prod_final = self.id_gen.next(); { let struct_out = self.id_gen.next(); block.body.push(Instruction::binary( spirv::Op::UMulExtended, combined_struct_type, struct_out, prod_1, values[2], )); block.body.push(Instruction::composite_extract( self.get_u32_type_id(), prod_final, struct_out, &[0], )); block.body.push(Instruction::composite_extract( self.get_u32_type_id(), overflows[1], struct_out, &[1], )); } let total_too_large = self.id_gen.next(); block.body.push(Instruction::binary( spirv::Op::UGreaterThan, self.get_bool_type_id(), total_too_large, prod_final, max_total, )); let too_large = [self.id_gen.next(), self.id_gen.next(), self.id_gen.next()]; for (i, value) in values.into_iter().enumerate() { block.body.push(Instruction::binary( spirv::Op::UGreaterThan, self.get_bool_type_id(), too_large[i], value, max_per_dim, )); } let overflow_happens = [self.id_gen.next(), self.id_gen.next()]; for (i, value) in overflows.into_iter().enumerate() { block.body.push(Instruction::binary( spirv::Op::INotEqual, self.get_bool_type_id(), overflow_happens[i], value, zero_u32, )); } let mut current_violates_limits = total_too_large; for is_too_large in too_large { let new = self.id_gen.next(); block.body.push(Instruction::binary( spirv::Op::LogicalOr, self.get_bool_type_id(), new, current_violates_limits, is_too_large, )); current_violates_limits = new; } for overflow_happens in overflow_happens { let new = self.id_gen.next(); block.body.push(Instruction::binary( spirv::Op::LogicalOr, self.get_bool_type_id(), new, current_violates_limits, overflow_happens, )); current_violates_limits = new; } let zero_vec3 = self.id_gen.next(); block.body.push(Instruction::composite_construct( self.get_vec3u_type_id(), zero_vec3, &[zero_u32, zero_u32, zero_u32], )); let final_result = self.id_gen.next(); block.body.push(Instruction::select( self.get_vec3u_type_id(), final_result, current_violates_limits, zero_vec3, result, )); final_result } else { result }; let ins = self.write_entry_point_task_return(final_value, &mut block.body, task_payload)?; function.consume(block, ins); } function.to_words(&mut self.logical_layout.function_definitions); Ok(out_id) } } naga-29.0.3/src/back/spv/mod.rs000064400000000000000000001240561046102023000142610ustar 00000000000000/*! Backend for [SPIR-V][spv] (Standard Portable Intermediate Representation). # Layout of values in `uniform` buffers WGSL's ["Internal Layout of Values"][ilov] rules specify the memory layout of each WGSL type. The memory layout is important for data stored in `uniform` and `storage` buffers, especially when exchanging data with CPU code. Both WGSL and Vulkan specify some conditions that a type's memory layout must satisfy in order to use that type in a `uniform` or `storage` buffer. For `storage` buffers, the WGSL and Vulkan restrictions are compatible, but for `uniform` buffers, WGSL allows some types that Vulkan does not, requiring adjustments when emitting SPIR-V for `uniform` buffers. ## Padding in two-row matrices SPIR-V provides detailed control over the layout of matrix types, and is capable of describing the WGSL memory layout. However, Vulkan imposes additional restrictions. Vulkan's ["extended layout"][extended-layout] (also known as std140) rules apply to types used in `uniform` buffers. Under these rules, matrices are defined in terms of arrays of their vector type, and arrays are defined to have an alignment equal to the alignment of their element type rounded up to a multiple of 16. This means that each column of the matrix has a minimum alignment of 16. WGSL, and consequently Naga IR, on the other hand specifies column alignment equal to the alignment of the vector type, without being rounded up to 16. To compensate for this, for any `struct` used as a `uniform` buffer which contains a two-row matrix, we declare an additional "std140 compatible" type in which each column of the matrix has been decomposed into the containing struct. For example, the following WGSL struct type: ```ignore struct Baz { m: mat3x2, } ``` is rendered as the SPIR-V struct type: ```ignore OpTypeStruct %v2float %v2float %v2float ``` This has the effect that struct indices in Naga IR for such types do not correspond to the struct indices used in SPIR-V. A mapping of struct indices for these types is maintained in [`Std140CompatTypeInfo`]. Additionally, any two-row matrices that are declared directly as uniform buffers without being wrapped in a struct are declared as a struct containing a vector member for each column. Any array of a two-row matrix in a uniform buffer is declared as an array of a struct containing a vector member for each column. Any struct or array within a uniform buffer which contains a member or whose base type requires a std140 compatible type declaration, itself requires a std140 compatible type declaration. Whenever a value of such a type is [`loaded`] we insert code to convert the loaded value from the std140 compatible type to the regular type. This occurs in `BlockContext::write_checked_load`, making use of the wrapper function defined by `Writer::write_wrapped_convert_from_std140_compat_type`. For matrices that have been decomposed as separate columns in the containing struct, we load each column separately then composite the matrix type in `BlockContext::maybe_write_load_uniform_matcx2_struct_member`. Whenever a column of a matrix that has been decomposed into its containing struct is [`accessed`] with a constant index we adjust the emitted access chain to access from the containing struct instead, in `BlockContext::write_access_chain`. Whenever a column of a uniform buffer two-row matrix is [`dynamically accessed`] we must first load the matrix type, converting it from its std140 compatible type as described above, then access the column using the wrapper function defined by `Writer::write_wrapped_matcx2_get_column`. This is handled by `BlockContext::maybe_write_uniform_matcx2_dynamic_access`. Note that this approach differs somewhat from the equivalent code in the HLSL backend. For HLSL all structs containing two-row matrices (or arrays of such) have their declarations modified, not just those used as uniform buffers. Two-row matrices and arrays of such only use modified type declarations when used as uniform buffers, or additionally when used as struct member in any context. This avoids the need to convert struct values when loading from uniform buffers, but when loading arrays and matrices from uniform buffers or from any struct the conversion is still required. In contrast, the approach used here always requires converting *any* affected type when loading from a uniform buffer, but consistently *only* when loading from a uniform buffer. As a result this also means we only have to handle loads and not stores, as uniform buffers are read-only. [spv]: https://www.khronos.org/registry/SPIR-V/ [ilov]: https://gpuweb.github.io/gpuweb/wgsl/#internal-value-layout [extended-layout]: https://docs.vulkan.org/spec/latest/chapters/interfaces.html#interfaces-resources-layout [`loaded`]: crate::Expression::Load [`accessed`]: crate::Expression::AccessIndex [`dynamically accessed`]: crate::Expression::Access */ mod block; mod f16_polyfill; mod helpers; mod image; mod index; mod instructions; mod layout; mod mesh_shader; mod ray; mod reclaimable; mod selection; mod subgroup; mod writer; pub use mesh_shader::{MeshReturnInfo, MeshReturnMember}; pub use spirv::{Capability, SourceLanguage}; use alloc::{string::String, vec::Vec}; use core::ops; use spirv::Word; use thiserror::Error; use crate::arena::{Handle, HandleVec}; use crate::back::TaskDispatchLimits; use crate::proc::{BoundsCheckPolicies, TypeResolution}; #[derive(Clone)] struct PhysicalLayout { magic_number: Word, version: Word, generator: Word, bound: Word, instruction_schema: Word, } #[derive(Default)] struct LogicalLayout { capabilities: Vec, extensions: Vec, ext_inst_imports: Vec, memory_model: Vec, entry_points: Vec, execution_modes: Vec, debugs: Vec, annotations: Vec, declarations: Vec, function_declarations: Vec, function_definitions: Vec, } #[derive(Clone)] struct Instruction { op: spirv::Op, wc: u32, type_id: Option, result_id: Option, operands: Vec, } const BITS_PER_BYTE: crate::Bytes = 8; #[derive(Clone, Debug, Error)] pub enum Error { #[error("The requested entry point couldn't be found")] EntryPointNotFound, #[error("target SPIRV-{0}.{1} is not supported")] UnsupportedVersion(u8, u8), #[error("using {0} requires at least one of the capabilities {1:?}, but none are available")] MissingCapabilities(&'static str, Vec), #[error("unimplemented {0}")] FeatureNotImplemented(&'static str), #[error("module is not validated properly: {0}")] Validation(&'static str), #[error("overrides should not be present at this stage")] Override, #[error(transparent)] ResolveArraySizeError(#[from] crate::proc::ResolveArraySizeError), #[error("module requires SPIRV-{0}.{1}, which isn't supported")] SpirvVersionTooLow(u8, u8), #[error("mapping of {0:?} is missing")] MissingBinding(crate::ResourceBinding), } #[derive(Default)] struct IdGenerator(Word); impl IdGenerator { const fn next(&mut self) -> Word { self.0 += 1; self.0 } } #[derive(Debug, Clone)] pub struct DebugInfo<'a> { pub source_code: &'a str, pub file_name: &'a str, pub language: SourceLanguage, } /// A SPIR-V block to which we are still adding instructions. /// /// A `Block` represents a SPIR-V block that does not yet have a termination /// instruction like `OpBranch` or `OpReturn`. /// /// The `OpLabel` that starts the block is implicit. It will be emitted based on /// `label_id` when we write the block to a `LogicalLayout`. /// /// To terminate a `Block`, pass the block and the termination instruction to /// `Function::consume`. This takes ownership of the `Block` and transforms it /// into a `TerminatedBlock`. struct Block { label_id: Word, body: Vec, } /// A SPIR-V block that ends with a termination instruction. struct TerminatedBlock { label_id: Word, body: Vec, } impl Block { const fn new(label_id: Word) -> Self { Block { label_id, body: Vec::new(), } } } struct LocalVariable { id: Word, instruction: Instruction, } struct ResultMember { id: Word, type_id: Word, built_in: Option, } struct EntryPointContext { argument_ids: Vec, results: Vec, task_payload_variable_id: Option, mesh_state: Option, } #[derive(Default)] struct Function { signature: Option, parameters: Vec, variables: crate::FastHashMap, LocalVariable>, /// Map from a local variable that is a ray query to its u32 tracker. ray_query_initialization_tracker_variables: crate::FastHashMap, LocalVariable>, /// Map from a local variable that is a ray query to its tracker for the t max. ray_query_t_max_tracker_variables: crate::FastHashMap, LocalVariable>, /// List of local variables used as a counters to ensure that all loops are bounded. force_loop_bounding_vars: Vec, /// A map from a Naga expression to the temporary SPIR-V variable we have /// spilled its value to, if any. /// /// Naga IR lets us apply [`Access`] expressions to expressions whose value /// is an array or matrix---not a pointer to such---but SPIR-V doesn't have /// instructions that can do the same. So when we encounter such code, we /// spill the expression's value to a generated temporary variable. That, we /// can obtain a pointer to, and then use an `OpAccessChain` instruction to /// do whatever series of [`Access`] and [`AccessIndex`] operations we need /// (with bounds checks). Finally, we generate an `OpLoad` to get the final /// value. /// /// [`Access`]: crate::Expression::Access /// [`AccessIndex`]: crate::Expression::AccessIndex spilled_composites: crate::FastIndexMap, LocalVariable>, /// A set of expressions that are either in [`spilled_composites`] or refer /// to some component/element of such. /// /// [`spilled_composites`]: Function::spilled_composites spilled_accesses: crate::arena::HandleSet, /// A map taking each expression to the number of [`Access`] and /// [`AccessIndex`] expressions that uses it as a base value. If an /// expression has no entry, its count is zero: it is never used as a /// [`Access`] or [`AccessIndex`] base. /// /// We use this, together with [`ExpressionInfo::ref_count`], to recognize /// the tips of chains of [`Access`] and [`AccessIndex`] expressions that /// access spilled values --- expressions in [`spilled_composites`]. We /// defer generating code for the chain until we reach its tip, so we can /// handle it with a single instruction. /// /// [`Access`]: crate::Expression::Access /// [`AccessIndex`]: crate::Expression::AccessIndex /// [`ExpressionInfo::ref_count`]: crate::valid::ExpressionInfo /// [`spilled_composites`]: Function::spilled_composites access_uses: crate::FastHashMap, usize>, blocks: Vec, entry_point_context: Option, } impl Function { fn consume(&mut self, mut block: Block, termination: Instruction) { block.body.push(termination); self.blocks.push(TerminatedBlock { label_id: block.label_id, body: block.body, }) } fn parameter_id(&self, index: u32) -> Word { match self.entry_point_context { Some(ref context) => context.argument_ids[index as usize], None => self.parameters[index as usize] .instruction .result_id .unwrap(), } } } /// Characteristics of a SPIR-V `OpTypeImage` type. /// /// SPIR-V requires non-composite types to be unique, including images. Since we /// use `LocalType` for this deduplication, it's essential that `LocalImageType` /// be equal whenever the corresponding `OpTypeImage`s would be. To reduce the /// likelihood of mistakes, we use fields that correspond exactly to the /// operands of an `OpTypeImage` instruction, using the actual SPIR-V types /// where practical. #[derive(Debug, PartialEq, Hash, Eq, Copy, Clone)] struct LocalImageType { sampled_type: crate::Scalar, dim: spirv::Dim, flags: ImageTypeFlags, image_format: spirv::ImageFormat, } bitflags::bitflags! { /// Flags corresponding to the boolean(-ish) parameters to OpTypeImage. #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] pub struct ImageTypeFlags: u8 { const DEPTH = 0x1; const ARRAYED = 0x2; const MULTISAMPLED = 0x4; const SAMPLED = 0x8; } } impl LocalImageType { /// Construct a `LocalImageType` from the fields of a `TypeInner::Image`. fn from_inner(dim: crate::ImageDimension, arrayed: bool, class: crate::ImageClass) -> Self { let make_flags = |multi: bool, other: ImageTypeFlags| -> ImageTypeFlags { let mut flags = other; flags.set(ImageTypeFlags::ARRAYED, arrayed); flags.set(ImageTypeFlags::MULTISAMPLED, multi); flags }; let dim = spirv::Dim::from(dim); match class { crate::ImageClass::Sampled { kind, multi } => LocalImageType { sampled_type: crate::Scalar { kind, width: 4 }, dim, flags: make_flags(multi, ImageTypeFlags::SAMPLED), image_format: spirv::ImageFormat::Unknown, }, crate::ImageClass::Depth { multi } => LocalImageType { sampled_type: crate::Scalar { kind: crate::ScalarKind::Float, width: 4, }, dim, flags: make_flags(multi, ImageTypeFlags::DEPTH | ImageTypeFlags::SAMPLED), image_format: spirv::ImageFormat::Unknown, }, crate::ImageClass::Storage { format, access: _ } => LocalImageType { sampled_type: format.into(), dim, flags: make_flags(false, ImageTypeFlags::empty()), image_format: format.into(), }, crate::ImageClass::External => unimplemented!(), } } } /// A numeric type, for use in [`LocalType`]. #[derive(Debug, PartialEq, Hash, Eq, Copy, Clone)] enum NumericType { Scalar(crate::Scalar), Vector { size: crate::VectorSize, scalar: crate::Scalar, }, Matrix { columns: crate::VectorSize, rows: crate::VectorSize, scalar: crate::Scalar, }, } impl NumericType { const fn from_inner(inner: &crate::TypeInner) -> Option { match *inner { crate::TypeInner::Scalar(scalar) | crate::TypeInner::Atomic(scalar) => { Some(NumericType::Scalar(scalar)) } crate::TypeInner::Vector { size, scalar } => Some(NumericType::Vector { size, scalar }), crate::TypeInner::Matrix { columns, rows, scalar, } => Some(NumericType::Matrix { columns, rows, scalar, }), _ => None, } } const fn scalar(self) -> crate::Scalar { match self { NumericType::Scalar(scalar) | NumericType::Vector { scalar, .. } | NumericType::Matrix { scalar, .. } => scalar, } } const fn with_scalar(self, scalar: crate::Scalar) -> Self { match self { NumericType::Scalar(_) => NumericType::Scalar(scalar), NumericType::Vector { size, .. } => NumericType::Vector { size, scalar }, NumericType::Matrix { columns, rows, .. } => NumericType::Matrix { columns, rows, scalar, }, } } } /// A cooperative type, for use in [`LocalType`]. #[derive(Debug, PartialEq, Hash, Eq, Copy, Clone)] enum CooperativeType { Matrix { columns: crate::CooperativeSize, rows: crate::CooperativeSize, scalar: crate::Scalar, role: crate::CooperativeRole, }, } impl CooperativeType { const fn from_inner(inner: &crate::TypeInner) -> Option { match *inner { crate::TypeInner::CooperativeMatrix { columns, rows, scalar, role, } => Some(Self::Matrix { columns, rows, scalar, role, }), _ => None, } } } /// A SPIR-V type constructed during code generation. /// /// This is the variant of [`LookupType`] used to represent types that might not /// be available in the arena. Variants are present here for one of two reasons: /// /// - They represent types synthesized during code generation, as explained /// in the documentation for [`LookupType`]. /// /// - They represent types for which SPIR-V forbids duplicate `OpType...` /// instructions, requiring deduplication. /// /// This is not a complete copy of [`TypeInner`]: for example, SPIR-V generation /// never synthesizes new struct types, so `LocalType` has nothing for that. /// /// Each `LocalType` variant should be handled identically to its analogous /// `TypeInner` variant. You can use the [`Writer::localtype_from_inner`] /// function to help with this, by converting everything possible to a /// `LocalType` before inspecting it. /// /// ## `LocalType` equality and SPIR-V `OpType` uniqueness /// /// The definition of `Eq` on `LocalType` is carefully chosen to help us follow /// certain SPIR-V rules. SPIR-V §2.8 requires some classes of `OpType...` /// instructions to be unique; for example, you can't have two `OpTypeInt 32 1` /// instructions in the same module. All 32-bit signed integers must use the /// same type id. /// /// All SPIR-V types that must be unique can be represented as a `LocalType`, /// and two `LocalType`s are always `Eq` if SPIR-V would require them to use the /// same `OpType...` instruction. This lets us avoid duplicates by recording the /// ids of the type instructions we've already generated in a hash table, /// [`Writer::lookup_type`], keyed by `LocalType`. /// /// As another example, [`LocalImageType`], stored in the `LocalType::Image` /// variant, is designed to help us deduplicate `OpTypeImage` instructions. See /// its documentation for details. /// /// SPIR-V does not require pointer types to be unique - but different /// SPIR-V ids are considered to be distinct pointer types. Since Naga /// uses structural type equality, we need to represent each Naga /// equivalence class with a single SPIR-V `OpTypePointer`. /// /// As it always must, the `Hash` implementation respects the `Eq` relation. /// /// [`TypeInner`]: crate::TypeInner #[derive(Debug, PartialEq, Hash, Eq, Copy, Clone)] enum LocalType { /// A numeric type. Numeric(NumericType), Cooperative(CooperativeType), Pointer { base: Word, class: spirv::StorageClass, }, Image(LocalImageType), SampledImage { image_type_id: Word, }, Sampler, BindingArray { base: Handle, size: u32, }, AccelerationStructure, RayQuery, } /// A type encountered during SPIR-V generation. /// /// In the process of writing SPIR-V, we need to synthesize various types for /// intermediate results and such: pointer types, vector/matrix component types, /// or even booleans, which usually appear in SPIR-V code even when they're not /// used by the module source. /// /// However, we can't use `crate::Type` or `crate::TypeInner` for these, as the /// type arena may not contain what we need (it only contains types used /// directly by other parts of the IR), and the IR module is immutable, so we /// can't add anything to it. /// /// So for local use in the SPIR-V writer, we use this type, which holds either /// a handle into the arena, or a [`LocalType`] containing something synthesized /// locally. /// /// This is very similar to the [`proc::TypeResolution`] enum, with `LocalType` /// playing the role of `TypeInner`. However, `LocalType` also has other /// properties needed for SPIR-V generation; see the description of /// [`LocalType`] for details. /// /// [`proc::TypeResolution`]: crate::proc::TypeResolution #[derive(Debug, PartialEq, Hash, Eq, Copy, Clone)] enum LookupType { Handle(Handle), Local(LocalType), } impl From for LookupType { fn from(local: LocalType) -> Self { Self::Local(local) } } #[derive(Debug, PartialEq, Clone, Hash, Eq)] struct LookupFunctionType { parameter_type_ids: Vec, return_type_id: Word, } #[derive(Debug, PartialEq, Clone, Hash, Eq)] enum LookupRayQueryFunction { Initialize, Proceed, GenerateIntersection, ConfirmIntersection, GetVertexPositions { committed: bool }, GetIntersection { committed: bool }, Terminate, } #[derive(Debug)] enum Dimension { Scalar, Vector, Matrix, CooperativeMatrix, } /// Key used to look up an operation which we have wrapped in a helper /// function, which should be called instead of directly emitting code /// for the expression. See [`Writer::wrapped_functions`]. #[derive(Debug, Eq, PartialEq, Hash)] enum WrappedFunction { BinaryOp { op: crate::BinaryOperator, left_type_id: Word, right_type_id: Word, }, ConvertFromStd140CompatType { r#type: Handle, }, MatCx2GetColumn { r#type: Handle, }, } /// A map from evaluated [`Expression`](crate::Expression)s to their SPIR-V ids. /// /// When we emit code to evaluate a given `Expression`, we record the /// SPIR-V id of its value here, under its `Handle` index. /// /// A `CachedExpressions` value can be indexed by a `Handle` value. /// /// [emit]: index.html#expression-evaluation-time-and-scope #[derive(Default)] struct CachedExpressions { ids: HandleVec, } impl CachedExpressions { fn reset(&mut self, length: usize) { self.ids.clear(); self.ids.resize(length, 0); } } impl ops::Index> for CachedExpressions { type Output = Word; fn index(&self, h: Handle) -> &Word { let id = &self.ids[h]; if *id == 0 { unreachable!("Expression {:?} is not cached!", h); } id } } impl ops::IndexMut> for CachedExpressions { fn index_mut(&mut self, h: Handle) -> &mut Word { let id = &mut self.ids[h]; if *id != 0 { unreachable!("Expression {:?} is already cached!", h); } id } } impl reclaimable::Reclaimable for CachedExpressions { fn reclaim(self) -> Self { CachedExpressions { ids: self.ids.reclaim(), } } } #[derive(Eq, Hash, PartialEq)] enum CachedConstant { Literal(crate::proc::HashableLiteral), Composite { ty: LookupType, constituent_ids: Vec, }, ZeroValue(Word), } /// The SPIR-V representation of a [`crate::GlobalVariable`]. /// /// In the Vulkan spec 1.3.296, the section [Descriptor Set Interface][dsi] says: /// /// > Variables identified with the `Uniform` storage class are used to access /// > transparent buffer backed resources. Such variables *must* be: /// > /// > - typed as `OpTypeStruct`, or an array of this type, /// > /// > - identified with a `Block` or `BufferBlock` decoration, and /// > /// > - laid out explicitly using the `Offset`, `ArrayStride`, and `MatrixStride` /// > decorations as specified in "Offset and Stride Assignment". /// /// This is followed by identical language for the `StorageBuffer`, /// except that a `BufferBlock` decoration is not allowed. /// /// When we encounter a global variable in the [`Storage`] or [`Uniform`] /// address spaces whose type is not already [`Struct`], this backend implicitly /// wraps the global variable in a struct: we generate a SPIR-V global variable /// holding an `OpTypeStruct` with a single member, whose type is what the Naga /// global's type would suggest, decorated as required above. /// /// The [`helpers::global_needs_wrapper`] function determines whether a given /// [`crate::GlobalVariable`] needs to be wrapped. /// /// [dsi]: https://registry.khronos.org/vulkan/specs/1.3-extensions/html/vkspec.html#interfaces-resources-descset /// [`Storage`]: crate::AddressSpace::Storage /// [`Uniform`]: crate::AddressSpace::Uniform /// [`Struct`]: crate::TypeInner::Struct #[derive(Clone)] struct GlobalVariable { /// The SPIR-V id of the `OpVariable` that declares the global. /// /// If this global has been implicitly wrapped in an `OpTypeStruct`, this id /// refers to the wrapper, not the original Naga value it contains. If you /// need the Naga value, use [`access_id`] instead of this field. /// /// If this global is not implicitly wrapped, this is the same as /// [`access_id`]. /// /// This is used to compute the `access_id` pointer in function prologues, /// and used for `ArrayLength` expressions, which need to pass the wrapper /// struct. /// /// [`access_id`]: GlobalVariable::access_id var_id: Word, /// The loaded value of a `AddressSpace::Handle` global variable. /// /// If the current function uses this global variable, this is the id of an /// `OpLoad` instruction in the function's prologue that loads its value. /// (This value is assigned as we write the prologue code of each function.) /// It is then used for all operations on the global, such as `OpImageSample`. handle_id: Word, /// The SPIR-V id of a pointer to this variable's Naga IR value. /// /// If the current function uses this global variable, and it has been /// implicitly wrapped in an `OpTypeStruct`, this is the id of an /// `OpAccessChain` instruction in the function's prologue that refers to /// the wrapped value inside the struct. (This value is assigned as we write /// the prologue code of each function.) If you need the wrapper struct /// itself, use [`var_id`] instead of this field. /// /// If this global is not implicitly wrapped, this is the same as /// [`var_id`]. /// /// [`var_id`]: GlobalVariable::var_id access_id: Word, } impl GlobalVariable { const fn dummy() -> Self { Self { var_id: 0, handle_id: 0, access_id: 0, } } const fn new(id: Word) -> Self { Self { var_id: id, handle_id: 0, access_id: 0, } } /// Prepare `self` for use within a single function. const fn reset_for_function(&mut self) { self.handle_id = 0; self.access_id = 0; } } struct FunctionArgument { /// Actual instruction of the argument. instruction: Instruction, handle_id: Word, } /// Tracks the expressions for which the backend emits the following instructions: /// - OpConstantTrue /// - OpConstantFalse /// - OpConstant /// - OpConstantComposite /// - OpConstantNull struct ExpressionConstnessTracker { inner: crate::arena::HandleSet, } impl ExpressionConstnessTracker { fn from_arena(arena: &crate::Arena) -> Self { let mut inner = crate::arena::HandleSet::for_arena(arena); for (handle, expr) in arena.iter() { let insert = match *expr { crate::Expression::Literal(_) | crate::Expression::ZeroValue(_) | crate::Expression::Constant(_) => true, crate::Expression::Compose { ref components, .. } => { components.iter().all(|&h| inner.contains(h)) } crate::Expression::Splat { value, .. } => inner.contains(value), _ => false, }; if insert { inner.insert(handle); } } Self { inner } } fn is_const(&self, value: Handle) -> bool { self.inner.contains(value) } } /// General information needed to emit SPIR-V for Naga statements. struct BlockContext<'w> { /// The writer handling the module to which this code belongs. writer: &'w mut Writer, /// The [`Module`](crate::Module) for which we're generating code. ir_module: &'w crate::Module, /// The [`Function`](crate::Function) for which we're generating code. ir_function: &'w crate::Function, /// Information module validation produced about /// [`ir_function`](BlockContext::ir_function). fun_info: &'w crate::valid::FunctionInfo, /// The [`spv::Function`](Function) to which we are contributing SPIR-V instructions. function: &'w mut Function, /// SPIR-V ids for expressions we've evaluated. cached: CachedExpressions, /// The `Writer`'s temporary vector, for convenience. temp_list: Vec, /// Tracks the constness of `Expression`s residing in `self.ir_function.expressions` expression_constness: ExpressionConstnessTracker, force_loop_bounding: bool, /// Hash from an expression whose type is a ray query / pointer to a ray query to its tracker. /// Note: this is sparse, so can't be a handle vec ray_query_tracker_expr: crate::FastHashMap, RayQueryTrackers>, } #[derive(Clone, Copy)] struct RayQueryTrackers { // Initialization tracker initialized_tracker: Word, // Tracks the t max from ray query initialize. // Unlike HLSL, spir-v's equivalent getter for the current committed t has UB (instead of just // returning t_max) if there was no previous hit (though in some places it treats the behaviour as // defined), therefore we must track the tmax inputted into ray query initialize. t_max_tracker: Word, } impl BlockContext<'_> { const fn gen_id(&mut self) -> Word { self.writer.id_gen.next() } fn get_type_id(&mut self, lookup_type: LookupType) -> Word { self.writer.get_type_id(lookup_type) } fn get_handle_type_id(&mut self, handle: Handle) -> Word { self.writer.get_handle_type_id(handle) } fn get_expression_type_id(&mut self, tr: &TypeResolution) -> Word { self.writer.get_expression_type_id(tr) } fn get_index_constant(&mut self, index: Word) -> Word { self.writer.get_constant_scalar(crate::Literal::U32(index)) } fn get_scope_constant(&mut self, scope: Word) -> Word { self.writer .get_constant_scalar(crate::Literal::I32(scope as _)) } fn get_pointer_type_id(&mut self, base: Word, class: spirv::StorageClass) -> Word { self.writer.get_pointer_type_id(base, class) } fn get_numeric_type_id(&mut self, numeric: NumericType) -> Word { self.writer.get_numeric_type_id(numeric) } } /// Information about a type for which we have declared a std140 layout /// compatible variant, because the type is used in a uniform but does not /// adhere to std140 requirements. The uniform will be declared using the /// type `type_id`, and the result of any `Load` will be immediately converted /// to the base type. This is used for matrices with 2 rows, as well as any /// arrays or structs containing such matrices. pub struct Std140CompatTypeInfo { /// ID of the std140 compatible type declaration. type_id: Word, /// For structs, a mapping of Naga IR struct member indices to the indices /// used in the generated SPIR-V. For non-struct types this will be empty. member_indices: Vec, } pub struct Writer { physical_layout: PhysicalLayout, logical_layout: LogicalLayout, id_gen: IdGenerator, /// The set of capabilities modules are permitted to use. /// /// This is initialized from `Options::capabilities`. capabilities_available: Option>, /// The set of capabilities used by this module. /// /// If `capabilities_available` is `Some`, then this is always a subset of /// that. capabilities_used: crate::FastIndexSet, /// The set of spirv extensions used. extensions_used: crate::FastIndexSet<&'static str>, debug_strings: Vec, debugs: Vec, annotations: Vec, flags: WriterFlags, bounds_check_policies: BoundsCheckPolicies, zero_initialize_workgroup_memory: ZeroInitializeWorkgroupMemoryMode, force_loop_bounding: bool, use_storage_input_output_16: bool, void_type: Word, tuple_of_u32s_ty_id: Option, //TODO: convert most of these into vectors, addressable by handle indices lookup_type: crate::FastHashMap, lookup_function: crate::FastHashMap, Word>, lookup_function_type: crate::FastHashMap, /// Operations which have been wrapped in a helper function. The value is /// the ID of the function, which should be called instead of emitting code /// for the operation directly. wrapped_functions: crate::FastHashMap, /// Indexed by const-expression handle indexes constant_ids: HandleVec, cached_constants: crate::FastHashMap, global_variables: HandleVec, std140_compat_uniform_types: crate::FastHashMap, Std140CompatTypeInfo>, fake_missing_bindings: bool, binding_map: BindingMap, // Cached expressions are only meaningful within a BlockContext, but we // retain the table here between functions to save heap allocations. saved_cached: CachedExpressions, gl450_ext_inst_id: Word, // Just a temporary list of SPIR-V ids temp_list: Vec, ray_query_functions: crate::FastHashMap, /// F16 I/O polyfill manager for handling `f16` input/output variables /// when `StorageInputOutput16` capability is not available. io_f16_polyfills: f16_polyfill::F16IoPolyfill, /// Non semantic debug printf extension `OpExtInstImport` debug_printf: Option, pub(crate) ray_query_initialization_tracking: bool, /// Limits to the mesh shader dispatch group a task workgroup can dispatch. /// /// Metal for example limits to 1024 workgroups per task shader dispatch. Dispatching more is /// undefined behavior, so this would validate that to dispatch zero workgroups. task_dispatch_limits: Option, /// If true, naga may generate checks that the primitive indices are valid in the output. /// /// Currently this validation is unimplemented. mesh_shader_primitive_indices_clamp: bool, } bitflags::bitflags! { #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub struct WriterFlags: u32 { /// Include debug labels for everything. const DEBUG = 0x1; /// Flip Y coordinate of [`BuiltIn::Position`] output. /// /// [`BuiltIn::Position`]: crate::BuiltIn::Position const ADJUST_COORDINATE_SPACE = 0x2; /// Emit [`OpName`][op] for input/output locations. /// /// Contrary to spec, some drivers treat it as semantic, not allowing /// any conflicts. /// /// [op]: https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpName const LABEL_VARYINGS = 0x4; /// Emit [`PointSize`] output builtin to vertex shaders, which is /// required for drawing with `PointList` topology. /// /// [`PointSize`]: crate::BuiltIn::PointSize const FORCE_POINT_SIZE = 0x8; /// Clamp [`BuiltIn::FragDepth`] output between 0 and 1. /// /// [`BuiltIn::FragDepth`]: crate::BuiltIn::FragDepth const CLAMP_FRAG_DEPTH = 0x10; /// Instead of silently failing if the arguments to generate a ray query are /// invalid, uses debug printf extension to print to the command line /// /// Note: VK_KHR_shader_non_semantic_info must be enabled. This will have no /// effect if `options.ray_query_initialization_tracking` is set to false. const PRINT_ON_RAY_QUERY_INITIALIZATION_FAIL = 0x20; } } #[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] pub struct BindingInfo { pub descriptor_set: u32, pub binding: u32, /// If the binding is an unsized binding array, this overrides the size. pub binding_array_size: Option, } // Using `BTreeMap` instead of `HashMap` so that we can hash itself. pub type BindingMap = alloc::collections::BTreeMap; #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum ZeroInitializeWorkgroupMemoryMode { /// Via `VK_KHR_zero_initialize_workgroup_memory` or Vulkan 1.3 Native, /// Via assignments + barrier Polyfill, None, } #[derive(Debug, Clone)] pub struct Options<'a> { /// (Major, Minor) target version of the SPIR-V. pub lang_version: (u8, u8), /// Configuration flags for the writer. pub flags: WriterFlags, /// Don't panic on missing bindings. Instead use fake values for `Binding` /// and `DescriptorSet` decorations. This may result in invalid SPIR-V. pub fake_missing_bindings: bool, /// Map of resources to information about the binding. pub binding_map: BindingMap, /// If given, the set of capabilities modules are allowed to use. Code that /// requires capabilities beyond these is rejected with an error. /// /// If this is `None`, all capabilities are permitted. pub capabilities: Option>, /// How should generate code handle array, vector, matrix, or image texel /// indices that are out of range? pub bounds_check_policies: BoundsCheckPolicies, /// Dictates the way workgroup variables should be zero initialized pub zero_initialize_workgroup_memory: ZeroInitializeWorkgroupMemoryMode, /// If set, loops will have code injected into them, forcing the compiler /// to think the number of iterations is bounded. pub force_loop_bounding: bool, /// if set, ray queries will get a variable to track their state to prevent /// misuse. pub ray_query_initialization_tracking: bool, /// Whether to use the `StorageInputOutput16` capability for `f16` shader I/O. /// When false, `f16` I/O is polyfilled using `f32` types with conversions. pub use_storage_input_output_16: bool, pub debug_info: Option>, pub task_dispatch_limits: Option, pub mesh_shader_primitive_indices_clamp: bool, } impl Default for Options<'_> { fn default() -> Self { let mut flags = WriterFlags::ADJUST_COORDINATE_SPACE | WriterFlags::LABEL_VARYINGS | WriterFlags::CLAMP_FRAG_DEPTH; if cfg!(debug_assertions) { flags |= WriterFlags::DEBUG; } Options { lang_version: (1, 0), flags, fake_missing_bindings: true, binding_map: BindingMap::default(), capabilities: None, bounds_check_policies: BoundsCheckPolicies::default(), zero_initialize_workgroup_memory: ZeroInitializeWorkgroupMemoryMode::Polyfill, force_loop_bounding: true, ray_query_initialization_tracking: true, use_storage_input_output_16: true, debug_info: None, task_dispatch_limits: None, mesh_shader_primitive_indices_clamp: true, } } } // A subset of options meant to be changed per pipeline. #[derive(Debug, Clone)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] pub struct PipelineOptions { /// The stage of the entry point. pub shader_stage: crate::ShaderStage, /// The name of the entry point. /// /// If no entry point that matches is found while creating a [`Writer`], a error will be thrown. pub entry_point: String, } pub fn write_vec( module: &crate::Module, info: &crate::valid::ModuleInfo, options: &Options, pipeline_options: Option<&PipelineOptions>, ) -> Result, Error> { let mut words: Vec = Vec::new(); let mut w = Writer::new(options)?; w.write( module, info, pipeline_options, &options.debug_info, &mut words, )?; Ok(words) } pub fn supported_capabilities() -> crate::valid::Capabilities { use crate::valid::Capabilities as Caps; Caps::IMMEDIATES | Caps::FLOAT64 | Caps::PRIMITIVE_INDEX | Caps::TEXTURE_AND_SAMPLER_BINDING_ARRAY | Caps::BUFFER_BINDING_ARRAY | Caps::STORAGE_TEXTURE_BINDING_ARRAY | Caps::STORAGE_BUFFER_BINDING_ARRAY | Caps::ACCELERATION_STRUCTURE_BINDING_ARRAY | Caps::CLIP_DISTANCE // No cull distance | Caps::STORAGE_TEXTURE_16BIT_NORM_FORMATS | Caps::MULTIVIEW | Caps::EARLY_DEPTH_TEST | Caps::MULTISAMPLED_SHADING | Caps::RAY_QUERY | Caps::DUAL_SOURCE_BLENDING | Caps::CUBE_ARRAY_TEXTURES | Caps::SHADER_INT64 | Caps::SUBGROUP | Caps::SUBGROUP_BARRIER | Caps::SUBGROUP_VERTEX_STAGE | Caps::SHADER_INT64_ATOMIC_MIN_MAX | Caps::SHADER_INT64_ATOMIC_ALL_OPS | Caps::SHADER_FLOAT32_ATOMIC | Caps::TEXTURE_ATOMIC | Caps::TEXTURE_INT64_ATOMIC | Caps::RAY_HIT_VERTEX_POSITION | Caps::SHADER_FLOAT16 // No TEXTURE_EXTERNAL | Caps::SHADER_FLOAT16_IN_FLOAT32 | Caps::SHADER_BARYCENTRICS | Caps::MESH_SHADER | Caps::MESH_SHADER_POINT_TOPOLOGY | Caps::TEXTURE_AND_SAMPLER_BINDING_ARRAY_NON_UNIFORM_INDEXING // No BUFFER_BINDING_ARRAY_NON_UNIFORM_INDEXING | Caps::STORAGE_TEXTURE_BINDING_ARRAY_NON_UNIFORM_INDEXING | Caps::STORAGE_BUFFER_BINDING_ARRAY_NON_UNIFORM_INDEXING | Caps::COOPERATIVE_MATRIX | Caps::PER_VERTEX // No RAY_TRACING_PIPELINE | Caps::DRAW_INDEX | Caps::MEMORY_DECORATION_COHERENT | Caps::MEMORY_DECORATION_VOLATILE } naga-29.0.3/src/back/spv/ray/mod.rs000064400000000000000000000002421046102023000150420ustar 00000000000000/*! Module for code shared between ray queries and ray tracing pipeline code. Ray tracing pipelines are not yet implemented, so this is empty. */ pub mod query; naga-29.0.3/src/back/spv/ray/query.rs000064400000000000000000002107671046102023000154470ustar 00000000000000/*! Generating SPIR-V for ray query operations. */ use alloc::{vec, vec::Vec}; use super::super::{ Block, BlockContext, Function, FunctionArgument, Instruction, LocalType, LookupFunctionType, LookupRayQueryFunction, NumericType, Writer, WriterFlags, }; use crate::{arena::Handle, back::RayQueryPoint}; /// helper function to check if a particular flag is set in a u32. fn write_ray_flags_contains_flags( writer: &mut Writer, block: &mut Block, id: spirv::Word, flag: u32, ) -> spirv::Word { let bit_id = writer.get_constant_scalar(crate::Literal::U32(flag)); let zero_id = writer.get_constant_scalar(crate::Literal::U32(0)); let u32_type_id = writer.get_u32_type_id(); let bool_ty = writer.get_bool_type_id(); let and_id = writer.id_gen.next(); block.body.push(Instruction::binary( spirv::Op::BitwiseAnd, u32_type_id, and_id, id, bit_id, )); let eq_id = writer.id_gen.next(); block.body.push(Instruction::binary( spirv::Op::INotEqual, bool_ty, eq_id, and_id, zero_id, )); eq_id } impl Writer { /// writes a logical and of two scalar booleans fn write_logical_and( &mut self, block: &mut Block, one: spirv::Word, two: spirv::Word, ) -> spirv::Word { let id = self.id_gen.next(); let bool_id = self.get_bool_type_id(); block.body.push(Instruction::binary( spirv::Op::LogicalAnd, bool_id, id, one, two, )); id } fn write_reduce_and(&mut self, block: &mut Block, mut bools: Vec) -> spirv::Word { // The combined `and`ed together of all of the bools up to this point. let mut current_combined = bools.pop().unwrap(); for boolean in bools { current_combined = self.write_logical_and(block, current_combined, boolean) } current_combined } // returns the id of the function, the function, and ids for its arguments. fn write_function_signature( &mut self, arg_types: &[spirv::Word], return_ty: spirv::Word, ) -> (spirv::Word, Function, Vec) { let func_ty = self.get_function_type(LookupFunctionType { parameter_type_ids: Vec::from(arg_types), return_type_id: return_ty, }); let mut function = Function::default(); let func_id = self.id_gen.next(); function.signature = Some(Instruction::function( return_ty, func_id, spirv::FunctionControl::empty(), func_ty, )); let mut arg_ids = Vec::with_capacity(arg_types.len()); for (idx, &arg_ty) in arg_types.iter().enumerate() { let id = self.id_gen.next(); let instruction = Instruction::function_parameter(arg_ty, id); function.parameters.push(FunctionArgument { instruction, handle_id: idx as u32, }); arg_ids.push(id); } (func_id, function, arg_ids) } pub(in super::super) fn write_ray_query_get_intersection_function( &mut self, is_committed: bool, ir_module: &crate::Module, ) -> spirv::Word { if let Some(&word) = self.ray_query_functions .get(&LookupRayQueryFunction::GetIntersection { committed: is_committed, }) { return word; } let ray_intersection = ir_module.special_types.ray_intersection.unwrap(); let intersection_type_id = self.get_handle_type_id(ray_intersection); let intersection_pointer_type_id = self.get_pointer_type_id(intersection_type_id, spirv::StorageClass::Function); let flag_type_id = self.get_u32_type_id(); let flag_pointer_type_id = self.get_pointer_type_id(flag_type_id, spirv::StorageClass::Function); let transform_type_id = self.get_numeric_type_id(NumericType::Matrix { columns: crate::VectorSize::Quad, rows: crate::VectorSize::Tri, scalar: crate::Scalar::F32, }); let transform_pointer_type_id = self.get_pointer_type_id(transform_type_id, spirv::StorageClass::Function); let barycentrics_type_id = self.get_numeric_type_id(NumericType::Vector { size: crate::VectorSize::Bi, scalar: crate::Scalar::F32, }); let barycentrics_pointer_type_id = self.get_pointer_type_id(barycentrics_type_id, spirv::StorageClass::Function); let bool_type_id = self.get_bool_type_id(); let bool_pointer_type_id = self.get_pointer_type_id(bool_type_id, spirv::StorageClass::Function); let scalar_type_id = self.get_f32_type_id(); let float_pointer_type_id = self.get_f32_pointer_type_id(spirv::StorageClass::Function); let argument_type_id = self.get_ray_query_pointer_id(); let (func_id, mut function, arg_ids) = self.write_function_signature( &[argument_type_id, flag_pointer_type_id], intersection_type_id, ); let query_id = arg_ids[0]; let intersection_tracker_id = arg_ids[1]; let label_id = self.id_gen.next(); let mut block = Block::new(label_id); let blank_intersection = self.get_constant_null(intersection_type_id); let blank_intersection_id = self.id_gen.next(); // This must be before everything else in the function. block.body.push(Instruction::variable( intersection_pointer_type_id, blank_intersection_id, spirv::StorageClass::Function, Some(blank_intersection), )); let intersection_id = self.get_constant_scalar(crate::Literal::U32(if is_committed { spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR } else { spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR } as _)); let loaded_ray_query_tracker_id = self.id_gen.next(); block.body.push(Instruction::load( flag_type_id, loaded_ray_query_tracker_id, intersection_tracker_id, None, )); let proceeded_id = write_ray_flags_contains_flags( self, &mut block, loaded_ray_query_tracker_id, RayQueryPoint::PROCEED.bits(), ); let finished_proceed_id = write_ray_flags_contains_flags( self, &mut block, loaded_ray_query_tracker_id, RayQueryPoint::FINISHED_TRAVERSAL.bits(), ); let proceed_finished_correct_id = if is_committed { finished_proceed_id } else { let not_finished_id = self.id_gen.next(); block.body.push(Instruction::unary( spirv::Op::LogicalNot, bool_type_id, not_finished_id, finished_proceed_id, )); not_finished_id }; let is_valid_id = self.write_logical_and(&mut block, proceed_finished_correct_id, proceeded_id); let valid_id = self.id_gen.next(); let mut valid_block = Block::new(valid_id); let final_label_id = self.id_gen.next(); let mut final_block = Block::new(final_label_id); block.body.push(Instruction::selection_merge( final_label_id, spirv::SelectionControl::NONE, )); function.consume( block, Instruction::branch_conditional(is_valid_id, valid_id, final_label_id), ); let raw_kind_id = self.id_gen.next(); valid_block .body .push(Instruction::ray_query_get_intersection( spirv::Op::RayQueryGetIntersectionTypeKHR, flag_type_id, raw_kind_id, query_id, intersection_id, )); let kind_id = if is_committed { // Nothing to do: the IR value matches `spirv::RayQueryCommittedIntersectionType` raw_kind_id } else { // Remap from the candidate kind to IR let condition_id = self.id_gen.next(); let committed_triangle_kind_id = self.get_constant_scalar(crate::Literal::U32( spirv::RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionTriangleKHR as _, )); valid_block.body.push(Instruction::binary( spirv::Op::IEqual, self.get_bool_type_id(), condition_id, raw_kind_id, committed_triangle_kind_id, )); let kind_id = self.id_gen.next(); valid_block.body.push(Instruction::select( flag_type_id, kind_id, condition_id, self.get_constant_scalar(crate::Literal::U32( crate::RayQueryIntersection::Triangle as _, )), self.get_constant_scalar(crate::Literal::U32( crate::RayQueryIntersection::Aabb as _, )), )); kind_id }; let idx_id = self.get_index_constant(0); let access_idx = self.id_gen.next(); valid_block.body.push(Instruction::access_chain( flag_pointer_type_id, access_idx, blank_intersection_id, &[idx_id], )); valid_block .body .push(Instruction::store(access_idx, kind_id, None)); let not_none_comp_id = self.id_gen.next(); let none_id = self.get_constant_scalar(crate::Literal::U32(crate::RayQueryIntersection::None as _)); valid_block.body.push(Instruction::binary( spirv::Op::INotEqual, self.get_bool_type_id(), not_none_comp_id, kind_id, none_id, )); let not_none_label_id = self.id_gen.next(); let mut not_none_block = Block::new(not_none_label_id); let outer_merge_label_id = self.id_gen.next(); let outer_merge_block = Block::new(outer_merge_label_id); valid_block.body.push(Instruction::selection_merge( outer_merge_label_id, spirv::SelectionControl::NONE, )); function.consume( valid_block, Instruction::branch_conditional( not_none_comp_id, not_none_label_id, outer_merge_label_id, ), ); let instance_custom_index_id = self.id_gen.next(); not_none_block .body .push(Instruction::ray_query_get_intersection( spirv::Op::RayQueryGetIntersectionInstanceCustomIndexKHR, flag_type_id, instance_custom_index_id, query_id, intersection_id, )); let instance_id = self.id_gen.next(); not_none_block .body .push(Instruction::ray_query_get_intersection( spirv::Op::RayQueryGetIntersectionInstanceIdKHR, flag_type_id, instance_id, query_id, intersection_id, )); let sbt_record_offset_id = self.id_gen.next(); not_none_block .body .push(Instruction::ray_query_get_intersection( spirv::Op::RayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR, flag_type_id, sbt_record_offset_id, query_id, intersection_id, )); let geometry_index_id = self.id_gen.next(); not_none_block .body .push(Instruction::ray_query_get_intersection( spirv::Op::RayQueryGetIntersectionGeometryIndexKHR, flag_type_id, geometry_index_id, query_id, intersection_id, )); let primitive_index_id = self.id_gen.next(); not_none_block .body .push(Instruction::ray_query_get_intersection( spirv::Op::RayQueryGetIntersectionPrimitiveIndexKHR, flag_type_id, primitive_index_id, query_id, intersection_id, )); //Note: there is also `OpRayQueryGetIntersectionCandidateAABBOpaqueKHR`, // but it's not a property of an intersection. let object_to_world_id = self.id_gen.next(); not_none_block .body .push(Instruction::ray_query_get_intersection( spirv::Op::RayQueryGetIntersectionObjectToWorldKHR, transform_type_id, object_to_world_id, query_id, intersection_id, )); let world_to_object_id = self.id_gen.next(); not_none_block .body .push(Instruction::ray_query_get_intersection( spirv::Op::RayQueryGetIntersectionWorldToObjectKHR, transform_type_id, world_to_object_id, query_id, intersection_id, )); // instance custom index let idx_id = self.get_index_constant(2); let access_idx = self.id_gen.next(); not_none_block.body.push(Instruction::access_chain( flag_pointer_type_id, access_idx, blank_intersection_id, &[idx_id], )); not_none_block.body.push(Instruction::store( access_idx, instance_custom_index_id, None, )); // instance let idx_id = self.get_index_constant(3); let access_idx = self.id_gen.next(); not_none_block.body.push(Instruction::access_chain( flag_pointer_type_id, access_idx, blank_intersection_id, &[idx_id], )); not_none_block .body .push(Instruction::store(access_idx, instance_id, None)); let idx_id = self.get_index_constant(4); let access_idx = self.id_gen.next(); not_none_block.body.push(Instruction::access_chain( flag_pointer_type_id, access_idx, blank_intersection_id, &[idx_id], )); not_none_block .body .push(Instruction::store(access_idx, sbt_record_offset_id, None)); let idx_id = self.get_index_constant(5); let access_idx = self.id_gen.next(); not_none_block.body.push(Instruction::access_chain( flag_pointer_type_id, access_idx, blank_intersection_id, &[idx_id], )); not_none_block .body .push(Instruction::store(access_idx, geometry_index_id, None)); let idx_id = self.get_index_constant(6); let access_idx = self.id_gen.next(); not_none_block.body.push(Instruction::access_chain( flag_pointer_type_id, access_idx, blank_intersection_id, &[idx_id], )); not_none_block .body .push(Instruction::store(access_idx, primitive_index_id, None)); let idx_id = self.get_index_constant(9); let access_idx = self.id_gen.next(); not_none_block.body.push(Instruction::access_chain( transform_pointer_type_id, access_idx, blank_intersection_id, &[idx_id], )); not_none_block .body .push(Instruction::store(access_idx, object_to_world_id, None)); let idx_id = self.get_index_constant(10); let access_idx = self.id_gen.next(); not_none_block.body.push(Instruction::access_chain( transform_pointer_type_id, access_idx, blank_intersection_id, &[idx_id], )); not_none_block .body .push(Instruction::store(access_idx, world_to_object_id, None)); let tri_comp_id = self.id_gen.next(); let tri_id = self.get_constant_scalar(crate::Literal::U32( crate::RayQueryIntersection::Triangle as _, )); not_none_block.body.push(Instruction::binary( spirv::Op::IEqual, self.get_bool_type_id(), tri_comp_id, kind_id, tri_id, )); let tri_label_id = self.id_gen.next(); let mut tri_block = Block::new(tri_label_id); let merge_label_id = self.id_gen.next(); let merge_block = Block::new(merge_label_id); // t { let block = if is_committed { &mut not_none_block } else { &mut tri_block }; let t_id = self.id_gen.next(); block.body.push(Instruction::ray_query_get_intersection( spirv::Op::RayQueryGetIntersectionTKHR, scalar_type_id, t_id, query_id, intersection_id, )); let idx_id = self.get_index_constant(1); let access_idx = self.id_gen.next(); block.body.push(Instruction::access_chain( float_pointer_type_id, access_idx, blank_intersection_id, &[idx_id], )); block.body.push(Instruction::store(access_idx, t_id, None)); } not_none_block.body.push(Instruction::selection_merge( merge_label_id, spirv::SelectionControl::NONE, )); function.consume( not_none_block, Instruction::branch_conditional(not_none_comp_id, tri_label_id, merge_label_id), ); let barycentrics_id = self.id_gen.next(); tri_block.body.push(Instruction::ray_query_get_intersection( spirv::Op::RayQueryGetIntersectionBarycentricsKHR, barycentrics_type_id, barycentrics_id, query_id, intersection_id, )); let front_face_id = self.id_gen.next(); tri_block.body.push(Instruction::ray_query_get_intersection( spirv::Op::RayQueryGetIntersectionFrontFaceKHR, bool_type_id, front_face_id, query_id, intersection_id, )); let idx_id = self.get_index_constant(7); let access_idx = self.id_gen.next(); tri_block.body.push(Instruction::access_chain( barycentrics_pointer_type_id, access_idx, blank_intersection_id, &[idx_id], )); tri_block .body .push(Instruction::store(access_idx, barycentrics_id, None)); let idx_id = self.get_index_constant(8); let access_idx = self.id_gen.next(); tri_block.body.push(Instruction::access_chain( bool_pointer_type_id, access_idx, blank_intersection_id, &[idx_id], )); tri_block .body .push(Instruction::store(access_idx, front_face_id, None)); function.consume(tri_block, Instruction::branch(merge_label_id)); function.consume(merge_block, Instruction::branch(outer_merge_label_id)); function.consume(outer_merge_block, Instruction::branch(final_label_id)); let loaded_blank_intersection_id = self.id_gen.next(); final_block.body.push(Instruction::load( intersection_type_id, loaded_blank_intersection_id, blank_intersection_id, None, )); function.consume( final_block, Instruction::return_value(loaded_blank_intersection_id), ); function.to_words(&mut self.logical_layout.function_definitions); self.ray_query_functions.insert( LookupRayQueryFunction::GetIntersection { committed: is_committed, }, func_id, ); func_id } fn write_ray_query_initialize(&mut self, ir_module: &crate::Module) -> spirv::Word { if let Some(&word) = self .ray_query_functions .get(&LookupRayQueryFunction::Initialize) { return word; } let ray_query_type_id = self.get_ray_query_pointer_id(); let acceleration_structure_type_id = self.get_localtype_id(LocalType::AccelerationStructure); let ray_desc_type_id = self.get_handle_type_id( ir_module .special_types .ray_desc .expect("ray desc should be set if ray queries are being initialized"), ); let u32_ty = self.get_u32_type_id(); let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function); let f32_type_id = self.get_f32_type_id(); let f32_ptr_ty = self.get_pointer_type_id(f32_type_id, spirv::StorageClass::Function); let bool_type_id = self.get_bool_type_id(); let bool_vec3_type_id = self.get_vec3_bool_type_id(); let (func_id, mut function, arg_ids) = self.write_function_signature( &[ ray_query_type_id, acceleration_structure_type_id, ray_desc_type_id, u32_ptr_ty, f32_ptr_ty, ], self.void_type, ); let query_id = arg_ids[0]; let acceleration_structure_id = arg_ids[1]; let desc_id = arg_ids[2]; let init_tracker_id = arg_ids[3]; let t_max_tracker_id = arg_ids[4]; let label_id = self.id_gen.next(); let mut block = Block::new(label_id); let flag_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32)); //Note: composite extract indices and types must match `generate_ray_desc_type` let ray_flags_id = self.id_gen.next(); block.body.push(Instruction::composite_extract( flag_type_id, ray_flags_id, desc_id, &[0], )); let cull_mask_id = self.id_gen.next(); block.body.push(Instruction::composite_extract( flag_type_id, cull_mask_id, desc_id, &[1], )); let tmin_id = self.id_gen.next(); block.body.push(Instruction::composite_extract( f32_type_id, tmin_id, desc_id, &[2], )); let tmax_id = self.id_gen.next(); block.body.push(Instruction::composite_extract( f32_type_id, tmax_id, desc_id, &[3], )); block .body .push(Instruction::store(t_max_tracker_id, tmax_id, None)); let vector_type_id = self.get_numeric_type_id(NumericType::Vector { size: crate::VectorSize::Tri, scalar: crate::Scalar::F32, }); let ray_origin_id = self.id_gen.next(); block.body.push(Instruction::composite_extract( vector_type_id, ray_origin_id, desc_id, &[4], )); let ray_dir_id = self.id_gen.next(); block.body.push(Instruction::composite_extract( vector_type_id, ray_dir_id, desc_id, &[5], )); let valid_id = self.ray_query_initialization_tracking.then(||{ let tmin_le_tmax_id = self.id_gen.next(); // Check both that tmin is less than or equal to tmax (https://docs.vulkan.org/spec/latest/appendices/spirvenv.html#VUID-RuntimeSpirv-OpRayQueryInitializeKHR-06350) // and implicitly that neither tmin or tmax are NaN (https://docs.vulkan.org/spec/latest/appendices/spirvenv.html#VUID-RuntimeSpirv-OpRayQueryInitializeKHR-06351) // because this checks if tmin and tmax are ordered too (i.e: not NaN). block.body.push(Instruction::binary( spirv::Op::FOrdLessThanEqual, bool_type_id, tmin_le_tmax_id, tmin_id, tmax_id, )); // Check that tmin is greater than or equal to 0 (and // therefore also tmax is too because it is greater than // or equal to tmin) (https://docs.vulkan.org/spec/latest/appendices/spirvenv.html#VUID-RuntimeSpirv-OpRayQueryInitializeKHR-06349). let tmin_ge_zero_id = self.id_gen.next(); let zero_id = self.get_constant_scalar(crate::Literal::F32(0.0)); block.body.push(Instruction::binary( spirv::Op::FOrdGreaterThanEqual, bool_type_id, tmin_ge_zero_id, tmin_id, zero_id, )); // Check that ray origin is finite (https://docs.vulkan.org/spec/latest/appendices/spirvenv.html#VUID-RuntimeSpirv-OpRayQueryInitializeKHR-06348) let ray_origin_infinite_id = self.id_gen.next(); block.body.push(Instruction::unary( spirv::Op::IsInf, bool_vec3_type_id, ray_origin_infinite_id, ray_origin_id, )); let any_ray_origin_infinite_id = self.id_gen.next(); block.body.push(Instruction::unary( spirv::Op::Any, bool_type_id, any_ray_origin_infinite_id, ray_origin_infinite_id, )); let ray_origin_nan_id = self.id_gen.next(); block.body.push(Instruction::unary( spirv::Op::IsNan, bool_vec3_type_id, ray_origin_nan_id, ray_origin_id, )); let any_ray_origin_nan_id = self.id_gen.next(); block.body.push(Instruction::unary( spirv::Op::Any, bool_type_id, any_ray_origin_nan_id, ray_origin_nan_id, )); let ray_origin_not_finite_id = self.id_gen.next(); block.body.push(Instruction::binary( spirv::Op::LogicalOr, bool_type_id, ray_origin_not_finite_id, any_ray_origin_nan_id, any_ray_origin_infinite_id, )); let all_ray_origin_finite_id = self.id_gen.next(); block.body.push(Instruction::unary( spirv::Op::LogicalNot, bool_type_id, all_ray_origin_finite_id, ray_origin_not_finite_id, )); // Check that ray direction is finite (https://docs.vulkan.org/spec/latest/appendices/spirvenv.html#VUID-RuntimeSpirv-OpRayQueryInitializeKHR-06348) let ray_dir_infinite_id = self.id_gen.next(); block.body.push(Instruction::unary( spirv::Op::IsInf, bool_vec3_type_id, ray_dir_infinite_id, ray_dir_id, )); let any_ray_dir_infinite_id = self.id_gen.next(); block.body.push(Instruction::unary( spirv::Op::Any, bool_type_id, any_ray_dir_infinite_id, ray_dir_infinite_id, )); let ray_dir_nan_id = self.id_gen.next(); block.body.push(Instruction::unary( spirv::Op::IsNan, bool_vec3_type_id, ray_dir_nan_id, ray_dir_id, )); let any_ray_dir_nan_id = self.id_gen.next(); block.body.push(Instruction::unary( spirv::Op::Any, bool_type_id, any_ray_dir_nan_id, ray_dir_nan_id, )); let ray_dir_not_finite_id = self.id_gen.next(); block.body.push(Instruction::binary( spirv::Op::LogicalOr, bool_type_id, ray_dir_not_finite_id, any_ray_dir_nan_id, any_ray_dir_infinite_id, )); let all_ray_dir_finite_id = self.id_gen.next(); block.body.push(Instruction::unary( spirv::Op::LogicalNot, bool_type_id, all_ray_dir_finite_id, ray_dir_not_finite_id, )); /// Writes spirv to check that less than two booleans are true /// /// For each boolean: removes it, `and`s it with all others (i.e for all possible combinations of two booleans in the list checks to see if both are true). /// Then `or`s all of these checks together. This produces whether two or more booleans are true. fn write_less_than_2_true( writer: &mut Writer, block: &mut Block, mut bools: Vec, ) -> spirv::Word { assert!(bools.len() > 1, "Must have multiple booleans!"); let bool_ty = writer.get_bool_type_id(); let mut each_two_true = Vec::new(); while let Some(last_bool) = bools.pop() { for &bool in &bools { let both_true_id = writer.write_logical_and( block, last_bool, bool, ); each_two_true.push(both_true_id); } } let mut all_or_id = each_two_true.pop().expect("since this must have multiple booleans, there must be at least one thing in `each_two_true`"); for two_true in each_two_true { let new_all_or_id = writer.id_gen.next(); block.body.push(Instruction::binary( spirv::Op::LogicalOr, bool_ty, new_all_or_id, all_or_id, two_true, )); all_or_id = new_all_or_id; } let less_than_two_id = writer.id_gen.next(); block.body.push(Instruction::unary( spirv::Op::LogicalNot, bool_ty, less_than_two_id, all_or_id, )); less_than_two_id } // Check that at most one of skip triangles and skip AABBs is // present (https://docs.vulkan.org/spec/latest/appendices/spirvenv.html#VUID-RuntimeSpirv-OpRayQueryInitializeKHR-06889) let contains_skip_triangles = write_ray_flags_contains_flags( self, &mut block, ray_flags_id, crate::RayFlag::SKIP_TRIANGLES.bits(), ); let contains_skip_aabbs = write_ray_flags_contains_flags( self, &mut block, ray_flags_id, crate::RayFlag::SKIP_AABBS.bits(), ); let not_contain_skip_triangles_aabbs = write_less_than_2_true( self, &mut block, vec![contains_skip_triangles, contains_skip_aabbs], ); // Check that at most one of skip triangles (taken from above check), // cull back facing, and cull front face is present (https://docs.vulkan.org/spec/latest/appendices/spirvenv.html#VUID-RuntimeSpirv-OpRayQueryInitializeKHR-06890) let contains_cull_back = write_ray_flags_contains_flags( self, &mut block, ray_flags_id, crate::RayFlag::CULL_BACK_FACING.bits(), ); let contains_cull_front = write_ray_flags_contains_flags( self, &mut block, ray_flags_id, crate::RayFlag::CULL_FRONT_FACING.bits(), ); let not_contain_skip_triangles_cull = write_less_than_2_true( self, &mut block, vec![ contains_skip_triangles, contains_cull_back, contains_cull_front, ], ); // Check that at most one of force opaque, force not opaque, cull opaque, // and cull not opaque are present (https://docs.vulkan.org/spec/latest/appendices/spirvenv.html#VUID-RuntimeSpirv-OpRayQueryInitializeKHR-06891) let contains_opaque = write_ray_flags_contains_flags( self, &mut block, ray_flags_id, crate::RayFlag::FORCE_OPAQUE.bits(), ); let contains_no_opaque = write_ray_flags_contains_flags( self, &mut block, ray_flags_id, crate::RayFlag::FORCE_NO_OPAQUE.bits(), ); let contains_cull_opaque = write_ray_flags_contains_flags( self, &mut block, ray_flags_id, crate::RayFlag::CULL_OPAQUE.bits(), ); let contains_cull_no_opaque = write_ray_flags_contains_flags( self, &mut block, ray_flags_id, crate::RayFlag::CULL_NO_OPAQUE.bits(), ); let not_contain_multiple_opaque = write_less_than_2_true( self, &mut block, vec![ contains_opaque, contains_no_opaque, contains_cull_opaque, contains_cull_no_opaque, ], ); // Combine all checks into a single flag saying whether the call is valid or not. self.write_reduce_and( &mut block, vec![ tmin_le_tmax_id, tmin_ge_zero_id, all_ray_origin_finite_id, all_ray_dir_finite_id, not_contain_skip_triangles_aabbs, not_contain_skip_triangles_cull, not_contain_multiple_opaque, ], ) }); let merge_label_id = self.id_gen.next(); let merge_block = Block::new(merge_label_id); // NOTE: this block will be unreachable if initialization tracking is disabled. let invalid_label_id = self.id_gen.next(); let mut invalid_block = Block::new(invalid_label_id); let valid_label_id = self.id_gen.next(); let mut valid_block = Block::new(valid_label_id); match valid_id { Some(all_valid_id) => { block.body.push(Instruction::selection_merge( merge_label_id, spirv::SelectionControl::NONE, )); function.consume( block, Instruction::branch_conditional(all_valid_id, valid_label_id, invalid_label_id), ); } None => { function.consume(block, Instruction::branch(valid_label_id)); } } valid_block.body.push(Instruction::ray_query_initialize( query_id, acceleration_structure_id, ray_flags_id, cull_mask_id, ray_origin_id, tmin_id, ray_dir_id, tmax_id, )); let const_initialized = self.get_constant_scalar(crate::Literal::U32(RayQueryPoint::INITIALIZED.bits())); valid_block .body .push(Instruction::store(init_tracker_id, const_initialized, None)); function.consume(valid_block, Instruction::branch(merge_label_id)); if self .flags .contains(WriterFlags::PRINT_ON_RAY_QUERY_INITIALIZATION_FAIL) { self.write_debug_printf( &mut invalid_block, "Naga ignored invalid arguments to rayQueryInitialize with flags: %u t_min: %f t_max: %f origin: %v4f dir: %v4f", &[ ray_flags_id, tmin_id, tmax_id, ray_origin_id, ray_dir_id, ], ); } function.consume(invalid_block, Instruction::branch(merge_label_id)); function.consume(merge_block, Instruction::return_void()); function.to_words(&mut self.logical_layout.function_definitions); self.ray_query_functions .insert(LookupRayQueryFunction::Initialize, func_id); func_id } fn write_ray_query_proceed(&mut self) -> spirv::Word { if let Some(&word) = self .ray_query_functions .get(&LookupRayQueryFunction::Proceed) { return word; } let ray_query_type_id = self.get_ray_query_pointer_id(); let u32_ty = self.get_u32_type_id(); let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function); let bool_type_id = self.get_bool_type_id(); let bool_ptr_ty = self.get_pointer_type_id(bool_type_id, spirv::StorageClass::Function); let (func_id, mut function, arg_ids) = self.write_function_signature(&[ray_query_type_id, u32_ptr_ty], bool_type_id); let query_id = arg_ids[0]; let init_tracker_id = arg_ids[1]; let block_id = self.id_gen.next(); let mut block = Block::new(block_id); // TODO: perhaps this could be replaced with an OpPhi? let proceeded_id = self.id_gen.next(); let const_false = self.get_constant_scalar(crate::Literal::Bool(false)); block.body.push(Instruction::variable( bool_ptr_ty, proceeded_id, spirv::StorageClass::Function, Some(const_false), )); let initialized_tracker_id = self.id_gen.next(); block.body.push(Instruction::load( u32_ty, initialized_tracker_id, init_tracker_id, None, )); let merge_id = self.id_gen.next(); let mut merge_block = Block::new(merge_id); let valid_block_id = self.id_gen.next(); let mut valid_block = Block::new(valid_block_id); let instruction = if self.ray_query_initialization_tracking { let is_initialized = write_ray_flags_contains_flags( self, &mut block, initialized_tracker_id, RayQueryPoint::INITIALIZED.bits(), ); block.body.push(Instruction::selection_merge( merge_id, spirv::SelectionControl::NONE, )); Instruction::branch_conditional(is_initialized, valid_block_id, merge_id) } else { Instruction::branch(valid_block_id) }; function.consume(block, instruction); let has_proceeded = self.id_gen.next(); valid_block.body.push(Instruction::ray_query_proceed( bool_type_id, has_proceeded, query_id, )); valid_block .body .push(Instruction::store(proceeded_id, has_proceeded, None)); let add_flag_finished = self.get_constant_scalar(crate::Literal::U32( (RayQueryPoint::PROCEED | RayQueryPoint::FINISHED_TRAVERSAL).bits(), )); let add_flag_continuing = self.get_constant_scalar(crate::Literal::U32(RayQueryPoint::PROCEED.bits())); let add_flags_id = self.id_gen.next(); valid_block.body.push(Instruction::select( u32_ty, add_flags_id, has_proceeded, add_flag_continuing, add_flag_finished, )); let final_flags = self.id_gen.next(); valid_block.body.push(Instruction::binary( spirv::Op::BitwiseOr, u32_ty, final_flags, initialized_tracker_id, add_flags_id, )); valid_block .body .push(Instruction::store(init_tracker_id, final_flags, None)); function.consume(valid_block, Instruction::branch(merge_id)); let loaded_proceeded_id = self.id_gen.next(); merge_block.body.push(Instruction::load( bool_type_id, loaded_proceeded_id, proceeded_id, None, )); function.consume(merge_block, Instruction::return_value(loaded_proceeded_id)); function.to_words(&mut self.logical_layout.function_definitions); self.ray_query_functions .insert(LookupRayQueryFunction::Proceed, func_id); func_id } fn write_ray_query_generate_intersection(&mut self) -> spirv::Word { if let Some(&word) = self .ray_query_functions .get(&LookupRayQueryFunction::GenerateIntersection) { return word; } let ray_query_type_id = self.get_ray_query_pointer_id(); let u32_ty = self.get_u32_type_id(); let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function); let f32_type_id = self.get_f32_type_id(); let f32_ptr_type_id = self.get_pointer_type_id(f32_type_id, spirv::StorageClass::Function); let bool_type_id = self.get_bool_type_id(); let (func_id, mut function, arg_ids) = self.write_function_signature( &[ray_query_type_id, u32_ptr_ty, f32_type_id, f32_ptr_type_id], self.void_type, ); let query_id = arg_ids[0]; let init_tracker_id = arg_ids[1]; let depth_id = arg_ids[2]; let t_max_tracker_id = arg_ids[3]; let block_id = self.id_gen.next(); let mut block = Block::new(block_id); let current_t = self.id_gen.next(); block.body.push(Instruction::variable( f32_ptr_type_id, current_t, spirv::StorageClass::Function, None, )); let current_t = self.id_gen.next(); block.body.push(Instruction::variable( f32_ptr_type_id, current_t, spirv::StorageClass::Function, None, )); let valid_id = self.id_gen.next(); let mut valid_block = Block::new(valid_id); let final_label_id = self.id_gen.next(); let final_block = Block::new(final_label_id); let instruction = if self.ray_query_initialization_tracking { let initialized_tracker_id = self.id_gen.next(); block.body.push(Instruction::load( u32_ty, initialized_tracker_id, init_tracker_id, None, )); let proceeded_id = write_ray_flags_contains_flags( self, &mut block, initialized_tracker_id, RayQueryPoint::PROCEED.bits(), ); let finished_proceed_id = write_ray_flags_contains_flags( self, &mut block, initialized_tracker_id, RayQueryPoint::FINISHED_TRAVERSAL.bits(), ); // Can't find anything to suggest double calling this function is invalid. let not_finished_id = self.id_gen.next(); block.body.push(Instruction::unary( spirv::Op::LogicalNot, bool_type_id, not_finished_id, finished_proceed_id, )); let is_valid_id = self.write_logical_and(&mut block, not_finished_id, proceeded_id); block.body.push(Instruction::selection_merge( final_label_id, spirv::SelectionControl::NONE, )); Instruction::branch_conditional(is_valid_id, valid_id, final_label_id) } else { Instruction::branch(valid_id) }; function.consume(block, instruction); let intersection_id = self.get_constant_scalar(crate::Literal::U32( spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR as _, )); let committed_intersection_id = self.get_constant_scalar(crate::Literal::U32( spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR as _, )); let raw_kind_id = self.id_gen.next(); valid_block .body .push(Instruction::ray_query_get_intersection( spirv::Op::RayQueryGetIntersectionTypeKHR, u32_ty, raw_kind_id, query_id, intersection_id, )); let candidate_aabb_id = self.get_constant_scalar(crate::Literal::U32( spirv::RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionAABBKHR as _, )); let intersection_aabb_id = self.id_gen.next(); valid_block.body.push(Instruction::binary( spirv::Op::IEqual, bool_type_id, intersection_aabb_id, raw_kind_id, candidate_aabb_id, )); // Check that the provided t value is between t min and the current committed // t value, (https://docs.vulkan.org/spec/latest/appendices/spirvenv.html#VUID-RuntimeSpirv-OpRayQueryGenerateIntersectionKHR-06353) // Get the tmin let t_min_id = self.id_gen.next(); valid_block.body.push(Instruction::ray_query_get_t_min( f32_type_id, t_min_id, query_id, )); // Get the current committed t, or tmax if no hit. // Basically emulate HLSL's (easier) version // Pseudo-code: // ````wgsl // // start of function // var current_t:f32; // ... // let committed_type_id = RayQueryGetIntersectionTypeKHR(query_id); // if committed_type_id == Committed_None { // current_t = load(t_max_tracker); // } else { // current_t = RayQueryGetIntersectionTKHR(query_id); // } // ... // ```` let committed_type_id = self.id_gen.next(); valid_block .body .push(Instruction::ray_query_get_intersection( spirv::Op::RayQueryGetIntersectionTypeKHR, u32_ty, committed_type_id, query_id, committed_intersection_id, )); let no_committed = self.id_gen.next(); valid_block.body.push(Instruction::binary( spirv::Op::IEqual, bool_type_id, no_committed, committed_type_id, self.get_constant_scalar(crate::Literal::U32( spirv::RayQueryCommittedIntersectionType::RayQueryCommittedIntersectionNoneKHR as _, )), )); let next_valid_block_id = self.id_gen.next(); let no_committed_block_id = self.id_gen.next(); let mut no_committed_block = Block::new(no_committed_block_id); let committed_block_id = self.id_gen.next(); let mut committed_block = Block::new(committed_block_id); valid_block.body.push(Instruction::selection_merge( next_valid_block_id, spirv::SelectionControl::NONE, )); function.consume( valid_block, Instruction::branch_conditional( no_committed, no_committed_block_id, committed_block_id, ), ); // Assign t_max to current_t let t_max_id = self.id_gen.next(); no_committed_block.body.push(Instruction::load( f32_type_id, t_max_id, t_max_tracker_id, None, )); no_committed_block .body .push(Instruction::store(current_t, t_max_id, None)); function.consume(no_committed_block, Instruction::branch(next_valid_block_id)); // Assign t_current to current_t let latest_t_id = self.id_gen.next(); committed_block .body .push(Instruction::ray_query_get_intersection( spirv::Op::RayQueryGetIntersectionTKHR, f32_type_id, latest_t_id, query_id, intersection_id, )); committed_block .body .push(Instruction::store(current_t, latest_t_id, None)); function.consume(committed_block, Instruction::branch(next_valid_block_id)); let mut valid_block = Block::new(next_valid_block_id); let t_ge_t_min = self.id_gen.next(); valid_block.body.push(Instruction::binary( spirv::Op::FOrdGreaterThanEqual, bool_type_id, t_ge_t_min, depth_id, t_min_id, )); let t_current = self.id_gen.next(); valid_block .body .push(Instruction::load(f32_type_id, t_current, current_t, None)); let t_le_t_current = self.id_gen.next(); valid_block.body.push(Instruction::binary( spirv::Op::FOrdLessThanEqual, bool_type_id, t_le_t_current, depth_id, t_current, )); let t_in_range = self.id_gen.next(); valid_block.body.push(Instruction::binary( spirv::Op::LogicalAnd, bool_type_id, t_in_range, t_ge_t_min, t_le_t_current, )); let call_valid_id = self.id_gen.next(); valid_block.body.push(Instruction::binary( spirv::Op::LogicalAnd, bool_type_id, call_valid_id, t_in_range, intersection_aabb_id, )); let generate_label_id = self.id_gen.next(); let mut generate_block = Block::new(generate_label_id); let merge_label_id = self.id_gen.next(); let merge_block = Block::new(merge_label_id); valid_block.body.push(Instruction::selection_merge( merge_label_id, spirv::SelectionControl::NONE, )); function.consume( valid_block, Instruction::branch_conditional(call_valid_id, generate_label_id, merge_label_id), ); generate_block .body .push(Instruction::ray_query_generate_intersection( query_id, depth_id, )); function.consume(generate_block, Instruction::branch(merge_label_id)); function.consume(merge_block, Instruction::branch(final_label_id)); function.consume(final_block, Instruction::return_void()); function.to_words(&mut self.logical_layout.function_definitions); self.ray_query_functions .insert(LookupRayQueryFunction::GenerateIntersection, func_id); func_id } fn write_ray_query_confirm_intersection(&mut self) -> spirv::Word { if let Some(&word) = self .ray_query_functions .get(&LookupRayQueryFunction::ConfirmIntersection) { return word; } let ray_query_type_id = self.get_ray_query_pointer_id(); let u32_ty = self.get_u32_type_id(); let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function); let bool_type_id = self.get_bool_type_id(); let (func_id, mut function, arg_ids) = self.write_function_signature(&[ray_query_type_id, u32_ptr_ty], self.void_type); let query_id = arg_ids[0]; let init_tracker_id = arg_ids[1]; let block_id = self.id_gen.next(); let mut block = Block::new(block_id); let valid_id = self.id_gen.next(); let mut valid_block = Block::new(valid_id); let final_label_id = self.id_gen.next(); let final_block = Block::new(final_label_id); let instruction = if self.ray_query_initialization_tracking { let initialized_tracker_id = self.id_gen.next(); block.body.push(Instruction::load( u32_ty, initialized_tracker_id, init_tracker_id, None, )); let proceeded_id = write_ray_flags_contains_flags( self, &mut block, initialized_tracker_id, RayQueryPoint::PROCEED.bits(), ); let finished_proceed_id = write_ray_flags_contains_flags( self, &mut block, initialized_tracker_id, RayQueryPoint::FINISHED_TRAVERSAL.bits(), ); // Although it seems strange to call this twice, I (Vecvec) can't find anything to suggest double calling this function is invalid. let not_finished_id = self.id_gen.next(); block.body.push(Instruction::unary( spirv::Op::LogicalNot, bool_type_id, not_finished_id, finished_proceed_id, )); let is_valid_id = self.write_logical_and(&mut block, not_finished_id, proceeded_id); block.body.push(Instruction::selection_merge( final_label_id, spirv::SelectionControl::NONE, )); Instruction::branch_conditional(is_valid_id, valid_id, final_label_id) } else { Instruction::branch(valid_id) }; function.consume(block, instruction); let intersection_id = self.get_constant_scalar(crate::Literal::U32( spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR as _, )); let raw_kind_id = self.id_gen.next(); valid_block .body .push(Instruction::ray_query_get_intersection( spirv::Op::RayQueryGetIntersectionTypeKHR, u32_ty, raw_kind_id, query_id, intersection_id, )); let candidate_tri_id = self.get_constant_scalar(crate::Literal::U32( spirv::RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionTriangleKHR as _, )); let intersection_tri_id = self.id_gen.next(); valid_block.body.push(Instruction::binary( spirv::Op::IEqual, bool_type_id, intersection_tri_id, raw_kind_id, candidate_tri_id, )); let generate_label_id = self.id_gen.next(); let mut generate_block = Block::new(generate_label_id); let merge_label_id = self.id_gen.next(); let merge_block = Block::new(merge_label_id); valid_block.body.push(Instruction::selection_merge( merge_label_id, spirv::SelectionControl::NONE, )); function.consume( valid_block, Instruction::branch_conditional(intersection_tri_id, generate_label_id, merge_label_id), ); generate_block .body .push(Instruction::ray_query_confirm_intersection(query_id)); function.consume(generate_block, Instruction::branch(merge_label_id)); function.consume(merge_block, Instruction::branch(final_label_id)); function.consume(final_block, Instruction::return_void()); self.ray_query_functions .insert(LookupRayQueryFunction::ConfirmIntersection, func_id); function.to_words(&mut self.logical_layout.function_definitions); func_id } fn write_ray_query_get_vertex_positions( &mut self, is_committed: bool, ir_module: &crate::Module, ) -> spirv::Word { if let Some(&word) = self.ray_query_functions .get(&LookupRayQueryFunction::GetVertexPositions { committed: is_committed, }) { return word; } let (committed_ty, committed_tri_ty) = if is_committed { ( spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR as u32, spirv::RayQueryCommittedIntersectionType::RayQueryCommittedIntersectionTriangleKHR as u32, ) } else { ( spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR as u32, spirv::RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionTriangleKHR as u32, ) }; let ray_query_type_id = self.get_ray_query_pointer_id(); let u32_ty = self.get_u32_type_id(); let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function); let rq_get_vertex_positions_ty_id = self.get_handle_type_id( *ir_module .special_types .ray_vertex_return .as_ref() .expect("must be generated when reading in get vertex position"), ); let ptr_return_ty = self.get_pointer_type_id(rq_get_vertex_positions_ty_id, spirv::StorageClass::Function); let bool_type_id = self.get_bool_type_id(); let (func_id, mut function, arg_ids) = self.write_function_signature( &[ray_query_type_id, u32_ptr_ty], rq_get_vertex_positions_ty_id, ); let query_id = arg_ids[0]; let init_tracker_id = arg_ids[1]; let block_id = self.id_gen.next(); let mut block = Block::new(block_id); let return_id = self.id_gen.next(); block.body.push(Instruction::variable( ptr_return_ty, return_id, spirv::StorageClass::Function, Some(self.get_constant_null(rq_get_vertex_positions_ty_id)), )); let valid_id = self.id_gen.next(); let mut valid_block = Block::new(valid_id); let final_label_id = self.id_gen.next(); let mut final_block = Block::new(final_label_id); let instruction = if self.ray_query_initialization_tracking { let initialized_tracker_id = self.id_gen.next(); block.body.push(Instruction::load( u32_ty, initialized_tracker_id, init_tracker_id, None, )); let proceeded_id = write_ray_flags_contains_flags( self, &mut block, initialized_tracker_id, RayQueryPoint::PROCEED.bits(), ); let finished_proceed_id = write_ray_flags_contains_flags( self, &mut block, initialized_tracker_id, RayQueryPoint::FINISHED_TRAVERSAL.bits(), ); let correct_finish_id = if is_committed { finished_proceed_id } else { let not_finished_id = self.id_gen.next(); block.body.push(Instruction::unary( spirv::Op::LogicalNot, bool_type_id, not_finished_id, finished_proceed_id, )); not_finished_id }; let is_valid_id = self.write_logical_and(&mut block, correct_finish_id, proceeded_id); block.body.push(Instruction::selection_merge( final_label_id, spirv::SelectionControl::NONE, )); Instruction::branch_conditional(is_valid_id, valid_id, final_label_id) } else { Instruction::branch(valid_id) }; function.consume(block, instruction); let intersection_id = self.get_constant_scalar(crate::Literal::U32(committed_ty)); let raw_kind_id = self.id_gen.next(); valid_block .body .push(Instruction::ray_query_get_intersection( spirv::Op::RayQueryGetIntersectionTypeKHR, u32_ty, raw_kind_id, query_id, intersection_id, )); let candidate_tri_id = self.get_constant_scalar(crate::Literal::U32(committed_tri_ty)); let intersection_tri_id = self.id_gen.next(); valid_block.body.push(Instruction::binary( spirv::Op::IEqual, bool_type_id, intersection_tri_id, raw_kind_id, candidate_tri_id, )); let generate_label_id = self.id_gen.next(); let mut vertex_return_block = Block::new(generate_label_id); let merge_label_id = self.id_gen.next(); let merge_block = Block::new(merge_label_id); valid_block.body.push(Instruction::selection_merge( merge_label_id, spirv::SelectionControl::NONE, )); function.consume( valid_block, Instruction::branch_conditional(intersection_tri_id, generate_label_id, merge_label_id), ); let vertices_id = self.id_gen.next(); vertex_return_block .body .push(Instruction::ray_query_return_vertex_position( rq_get_vertex_positions_ty_id, vertices_id, query_id, intersection_id, )); vertex_return_block .body .push(Instruction::store(return_id, vertices_id, None)); function.consume(vertex_return_block, Instruction::branch(merge_label_id)); function.consume(merge_block, Instruction::branch(final_label_id)); let loaded_pos_id = self.id_gen.next(); final_block.body.push(Instruction::load( rq_get_vertex_positions_ty_id, loaded_pos_id, return_id, None, )); function.consume(final_block, Instruction::return_value(loaded_pos_id)); self.ray_query_functions.insert( LookupRayQueryFunction::GetVertexPositions { committed: is_committed, }, func_id, ); function.to_words(&mut self.logical_layout.function_definitions); func_id } fn write_ray_query_terminate(&mut self) -> spirv::Word { if let Some(&word) = self .ray_query_functions .get(&LookupRayQueryFunction::Terminate) { return word; } let ray_query_type_id = self.get_ray_query_pointer_id(); let u32_ty = self.get_u32_type_id(); let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function); let bool_type_id = self.get_bool_type_id(); let (func_id, mut function, arg_ids) = self.write_function_signature(&[ray_query_type_id, u32_ptr_ty], self.void_type); let query_id = arg_ids[0]; let init_tracker_id = arg_ids[1]; let block_id = self.id_gen.next(); let mut block = Block::new(block_id); let initialized_tracker_id = self.id_gen.next(); block.body.push(Instruction::load( u32_ty, initialized_tracker_id, init_tracker_id, None, )); let merge_id = self.id_gen.next(); let merge_block = Block::new(merge_id); let valid_block_id = self.id_gen.next(); let mut valid_block = Block::new(valid_block_id); let instruction = if self.ray_query_initialization_tracking { let has_proceeded = write_ray_flags_contains_flags( self, &mut block, initialized_tracker_id, RayQueryPoint::PROCEED.bits(), ); let finished_proceed_id = write_ray_flags_contains_flags( self, &mut block, initialized_tracker_id, RayQueryPoint::FINISHED_TRAVERSAL.bits(), ); let not_finished_id = self.id_gen.next(); block.body.push(Instruction::unary( spirv::Op::LogicalNot, bool_type_id, not_finished_id, finished_proceed_id, )); let valid_call = self.write_logical_and(&mut block, not_finished_id, has_proceeded); block.body.push(Instruction::selection_merge( merge_id, spirv::SelectionControl::NONE, )); Instruction::branch_conditional(valid_call, valid_block_id, merge_id) } else { Instruction::branch(valid_block_id) }; function.consume(block, instruction); valid_block .body .push(Instruction::ray_query_terminate(query_id)); function.consume(valid_block, Instruction::branch(merge_id)); function.consume(merge_block, Instruction::return_void()); function.to_words(&mut self.logical_layout.function_definitions); self.ray_query_functions .insert(LookupRayQueryFunction::Proceed, func_id); func_id } } impl BlockContext<'_> { pub(in super::super) fn write_ray_query_function( &mut self, query: Handle, function: &crate::RayQueryFunction, block: &mut Block, ) { let query_id = self.cached[query]; let tracker_ids = *self .ray_query_tracker_expr .get(&query) .expect("not a cached ray query"); match *function { crate::RayQueryFunction::Initialize { acceleration_structure, descriptor, } => { let desc_id = self.cached[descriptor]; let acc_struct_id = self.get_handle_id(acceleration_structure); let func = self.writer.write_ray_query_initialize(self.ir_module); let func_id = self.gen_id(); block.body.push(Instruction::function_call( self.writer.void_type, func_id, func, &[ query_id, acc_struct_id, desc_id, tracker_ids.initialized_tracker, tracker_ids.t_max_tracker, ], )); } crate::RayQueryFunction::Proceed { result } => { let id = self.gen_id(); self.cached[result] = id; let bool_ty = self.writer.get_bool_type_id(); let func_id = self.writer.write_ray_query_proceed(); block.body.push(Instruction::function_call( bool_ty, id, func_id, &[query_id, tracker_ids.initialized_tracker], )); } crate::RayQueryFunction::GenerateIntersection { hit_t } => { let hit_id = self.cached[hit_t]; let func_id = self.writer.write_ray_query_generate_intersection(); let func_call_id = self.gen_id(); block.body.push(Instruction::function_call( self.writer.void_type, func_call_id, func_id, &[ query_id, tracker_ids.initialized_tracker, hit_id, tracker_ids.t_max_tracker, ], )); } crate::RayQueryFunction::ConfirmIntersection => { let func_id = self.writer.write_ray_query_confirm_intersection(); let func_call_id = self.gen_id(); block.body.push(Instruction::function_call( self.writer.void_type, func_call_id, func_id, &[query_id, tracker_ids.initialized_tracker], )); } crate::RayQueryFunction::Terminate => { let id = self.gen_id(); let func_id = self.writer.write_ray_query_terminate(); block.body.push(Instruction::function_call( self.writer.void_type, id, func_id, &[query_id, tracker_ids.initialized_tracker], )); } } } pub(in super::super) fn write_ray_query_return_vertex_position( &mut self, query: Handle, block: &mut Block, is_committed: bool, ) -> spirv::Word { let fn_id = self .writer .write_ray_query_get_vertex_positions(is_committed, self.ir_module); let query_id = self.cached[query]; let tracker_id = *self .ray_query_tracker_expr .get(&query) .expect("not a cached ray query"); let rq_get_vertex_positions_ty_id = self.get_handle_type_id( *self .ir_module .special_types .ray_vertex_return .as_ref() .expect("must be generated when reading in get vertex position"), ); let func_call_id = self.gen_id(); block.body.push(Instruction::function_call( rq_get_vertex_positions_ty_id, func_call_id, fn_id, &[query_id, tracker_id.initialized_tracker], )); func_call_id } } naga-29.0.3/src/back/spv/reclaimable.rs000064400000000000000000000045541046102023000157420ustar 00000000000000/*! Reusing collections' previous allocations. */ use alloc::vec::Vec; /// A value that can be reset to its initial state, retaining its current allocations. /// /// Naga attempts to lower the cost of SPIR-V generation by allowing clients to /// reuse the same `Writer` for multiple Module translations. Reusing a `Writer` /// means that the `Vec`s, `HashMap`s, and other heap-allocated structures the /// `Writer` uses internally begin the translation with heap-allocated buffers /// ready to use. /// /// But this approach introduces the risk of `Writer` state leaking from one /// module to the next. When a developer adds fields to `Writer` or its internal /// types, they must remember to reset their contents between modules. /// /// One trick to ensure that every field has been accounted for is to use Rust's /// struct literal syntax to construct a new, reset value. If a developer adds a /// field, but neglects to update the reset code, the compiler will complain /// that a field is missing from the literal. This trait's `recycle` method /// takes `self` by value, and returns `Self` by value, encouraging the use of /// struct literal expressions in its implementation. pub trait Reclaimable { /// Clear `self`, retaining its current memory allocations. /// /// Shrink the buffer if it's currently much larger than was actually used. /// This prevents a module with exceptionally large allocations from causing /// the `Writer` to retain more memory than it needs indefinitely. fn reclaim(self) -> Self; } // Stock values for various collections. impl Reclaimable for Vec { fn reclaim(mut self) -> Self { self.clear(); self } } impl Reclaimable for hashbrown::HashMap { fn reclaim(mut self) -> Self { self.clear(); self } } impl Reclaimable for hashbrown::HashSet { fn reclaim(mut self) -> Self { self.clear(); self } } impl Reclaimable for indexmap::IndexSet { fn reclaim(mut self) -> Self { self.clear(); self } } impl Reclaimable for alloc::collections::BTreeMap { fn reclaim(mut self) -> Self { self.clear(); self } } impl Reclaimable for crate::arena::HandleVec { fn reclaim(mut self) -> Self { self.clear(); self } } naga-29.0.3/src/back/spv/selection.rs000064400000000000000000000233641046102023000154670ustar 00000000000000/*! Generate SPIR-V conditional structures. Builders for `if` structures with `and`s. The types in this module track the information needed to emit SPIR-V code for complex conditional structures, like those whose conditions involve short-circuiting 'and' and 'or' structures. These track labels and can emit `OpPhi` instructions to merge values produced along different paths. This currently only supports exactly the forms Naga uses, so it doesn't support `or` or `else`, and only supports zero or one merged values. Naga needs to emit code roughly like this: ```ignore value = DEFAULT; if COND1 && COND2 { value = THEN_VALUE; } // use value ``` Assuming `ctx` and `block` are a mutable references to a [`BlockContext`] and the current [`Block`], and `merge_type` is the SPIR-V type for the merged value `value`, we can build SPIR-V for the code above like so: ```ignore let cond = Selection::start(block, merge_type); // ... compute `cond1` ... cond.if_true(ctx, cond1, DEFAULT); // ... compute `cond2` ... cond.if_true(ctx, cond2, DEFAULT); // ... compute THEN_VALUE let merged_value = cond.finish(ctx, THEN_VALUE); ``` After this, `merged_value` is either `DEFAULT` or `THEN_VALUE`, depending on the path by which the merged block was reached. This takes care of writing all branch instructions, including an `OpSelectionMerge` annotation in the header block; starting new blocks and assigning them labels; and emitting the `OpPhi` that gathers together the right sources for the merged values, for every path through the selection construct. When there is no merged value to produce, you can pass `()` for `merge_type` and the merge values. In this case no `OpPhi` instructions are produced, and the `finish` method returns `()`. To enforce proper nesting, a `Selection` takes ownership of the `&mut Block` pointer for the duration of its lifetime. To obtain the block for generating code in the selection's body, call the `Selection::block` method. */ use alloc::{vec, vec::Vec}; use spirv::Word; use super::{Block, BlockContext, Instruction}; /// A private struct recording what we know about the selection construct so far. pub(super) struct Selection<'b, M: MergeTuple> { /// The block pointer we're emitting code into. block: &'b mut Block, /// The label of the selection construct's merge block, or `None` if we /// haven't yet written the `OpSelectionMerge` merge instruction. merge_label: Option, /// A set of `(VALUES, PARENT)` pairs, used to build `OpPhi` instructions in /// the merge block. Each `PARENT` is the label of a predecessor block of /// the merge block. The corresponding `VALUES` holds the ids of the values /// that `PARENT` contributes to the merged values. /// /// We emit all branches to the merge block, so we know all its /// predecessors. And we refuse to emit a branch unless we're given the /// values the branching block contributes to the merge, so we always have /// everything we need to emit the correct phis, by construction. values: Vec<(M, Word)>, /// The types of the values in each element of `values`. merge_types: M, } impl<'b, M: MergeTuple> Selection<'b, M> { /// Start a new selection construct. /// /// The `block` argument indicates the selection's header block. /// /// The `merge_types` argument should be a `Word` or tuple of `Word`s, each /// value being the SPIR-V result type id of an `OpPhi` instruction that /// will be written to the selection's merge block when this selection's /// [`finish`] method is called. This argument may also be `()`, for /// selections that produce no values. /// /// (This function writes no code to `block` itself; it simply constructs a /// fresh `Selection`.) /// /// [`finish`]: Selection::finish pub(super) const fn start(block: &'b mut Block, merge_types: M) -> Self { Selection { block, merge_label: None, values: vec![], merge_types, } } pub(super) const fn block(&mut self) -> &mut Block { self.block } /// Branch to a successor block if `cond` is true, otherwise merge. /// /// If `cond` is false, branch to the merge block, using `values` as the /// merged values. Otherwise, proceed to a new block. /// /// The `values` argument must be the same shape as the `merge_types` /// argument passed to `Selection::start`. pub(super) fn if_true(&mut self, ctx: &mut BlockContext, cond: Word, values: M) { self.values.push((values, self.block.label_id)); let merge_label = self.make_merge_label(ctx); let next_label = ctx.gen_id(); ctx.function.consume( core::mem::replace(self.block, Block::new(next_label)), Instruction::branch_conditional(cond, next_label, merge_label), ); } /// Emit an unconditional branch to the merge block, and compute merged /// values. /// /// Use `final_values` as the merged values contributed by the current /// block, and transition to the merge block, emitting `OpPhi` instructions /// to produce the merged values. This must be the same shape as the /// `merge_types` argument passed to [`Selection::start`]. /// /// Return the SPIR-V ids of the merged values. This value has the same /// shape as the `merge_types` argument passed to `Selection::start`. pub(super) fn finish(self, ctx: &mut BlockContext, final_values: M) -> M { match self { Selection { merge_label: None, .. } => { // We didn't actually emit any branches, so `self.values` must // be empty, and `final_values` are the only sources we have for // the merged values. Easy peasy. final_values } Selection { block, merge_label: Some(merge_label), mut values, merge_types, } => { // Emit the final branch and transition to the merge block. values.push((final_values, block.label_id)); ctx.function.consume( core::mem::replace(block, Block::new(merge_label)), Instruction::branch(merge_label), ); // Now that we're in the merge block, build the phi instructions. merge_types.write_phis(ctx, block, &values) } } } /// Return the id of the merge block, writing a merge instruction if needed. fn make_merge_label(&mut self, ctx: &mut BlockContext) -> Word { match self.merge_label { None => { let merge_label = ctx.gen_id(); self.block.body.push(Instruction::selection_merge( merge_label, spirv::SelectionControl::NONE, )); self.merge_label = Some(merge_label); merge_label } Some(merge_label) => merge_label, } } } /// A trait to help `Selection` manage any number of merged values. /// /// Some selection constructs, like a `ReadZeroSkipWrite` bounds check on a /// [`Load`] expression, produce a single merged value. Others produce no merged /// value, like a bounds check on a [`Store`] statement. /// /// To let `Selection` work nicely with both cases, we let the merge type /// argument passed to [`Selection::start`] be any type that implements this /// `MergeTuple` trait. `MergeTuple` is then implemented for `()`, `Word`, /// `(Word, Word)`, and so on. /// /// A `MergeTuple` type can represent either a bunch of SPIR-V types or values; /// the `merge_types` argument to `Selection::start` are type ids, whereas the /// `values` arguments to the [`if_true`] and [`finish`] methods are value ids. /// The set of merged value returned by `finish` is a tuple of value ids. /// /// In fact, since Naga only uses zero- and single-valued selection constructs /// at present, we only implement `MergeTuple` for `()` and `Word`. But if you /// add more cases, feel free to add more implementations. Once const generics /// are available, we could have a single implementation of `MergeTuple` for all /// lengths of arrays, and be done with it. /// /// [`Load`]: crate::Expression::Load /// [`Store`]: crate::Statement::Store /// [`if_true`]: Selection::if_true /// [`finish`]: Selection::finish pub(super) trait MergeTuple: Sized { /// Write OpPhi instructions for the given set of predecessors. /// /// The `predecessors` vector should be a vector of `(LABEL, VALUES)` pairs, /// where each `VALUES` holds the values contributed by the branch from /// `LABEL`, which should be one of the current block's predecessors. fn write_phis( self, ctx: &mut BlockContext, block: &mut Block, predecessors: &[(Self, Word)], ) -> Self; } /// Selections that produce a single merged value. /// /// For example, `ImageLoad` with `BoundsCheckPolicy::ReadZeroSkipWrite` either /// returns a texel value or zeros. impl MergeTuple for Word { fn write_phis( self, ctx: &mut BlockContext, block: &mut Block, predecessors: &[(Word, Word)], ) -> Word { let merged_value = ctx.gen_id(); block .body .push(Instruction::phi(self, merged_value, predecessors)); merged_value } } /// Selections that produce no merged values. /// /// For example, `ImageStore` under `BoundsCheckPolicy::ReadZeroSkipWrite` /// either does the store or skips it, but in neither case does it produce a /// value. impl MergeTuple for () { /// No phis need to be generated. fn write_phis(self, _: &mut BlockContext, _: &mut Block, _: &[((), Word)]) {} } naga-29.0.3/src/back/spv/subgroup.rs000064400000000000000000000217731046102023000153520ustar 00000000000000use super::{Block, BlockContext, Error, Instruction, NumericType}; use crate::{arena::Handle, TypeInner}; impl BlockContext<'_> { pub(super) fn write_subgroup_ballot( &mut self, predicate: &Option>, result: Handle, block: &mut Block, ) -> Result<(), Error> { self.writer.require_any( "GroupNonUniformBallot", &[spirv::Capability::GroupNonUniformBallot], )?; let vec4_u32_type_id = self.get_numeric_type_id(NumericType::Vector { size: crate::VectorSize::Quad, scalar: crate::Scalar::U32, }); let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); let predicate = if let Some(predicate) = *predicate { self.cached[predicate] } else { self.writer.get_constant_scalar(crate::Literal::Bool(true)) }; let id = self.gen_id(); block.body.push(Instruction::group_non_uniform_ballot( vec4_u32_type_id, id, exec_scope_id, predicate, )); self.cached[result] = id; Ok(()) } pub(super) fn write_subgroup_operation( &mut self, op: &crate::SubgroupOperation, collective_op: &crate::CollectiveOperation, argument: Handle, result: Handle, block: &mut Block, ) -> Result<(), Error> { use crate::SubgroupOperation as sg; match *op { sg::All | sg::Any => { self.writer.require_any( "GroupNonUniformVote", &[spirv::Capability::GroupNonUniformVote], )?; } _ => { self.writer.require_any( "GroupNonUniformArithmetic", &[spirv::Capability::GroupNonUniformArithmetic], )?; } } let id = self.gen_id(); let result_ty = &self.fun_info[result].ty; let result_type_id = self.get_expression_type_id(result_ty); let result_ty_inner = result_ty.inner_with(&self.ir_module.types); let (is_scalar, scalar) = match *result_ty_inner { TypeInner::Scalar(kind) => (true, kind), TypeInner::Vector { scalar: kind, .. } => (false, kind), _ => unimplemented!(), }; use crate::ScalarKind as sk; let spirv_op = match (scalar.kind, *op) { (sk::Bool, sg::All) if is_scalar => spirv::Op::GroupNonUniformAll, (sk::Bool, sg::Any) if is_scalar => spirv::Op::GroupNonUniformAny, (_, sg::All | sg::Any) => unimplemented!(), (sk::Sint | sk::Uint, sg::Add) => spirv::Op::GroupNonUniformIAdd, (sk::Float, sg::Add) => spirv::Op::GroupNonUniformFAdd, (sk::Sint | sk::Uint, sg::Mul) => spirv::Op::GroupNonUniformIMul, (sk::Float, sg::Mul) => spirv::Op::GroupNonUniformFMul, (sk::Sint, sg::Max) => spirv::Op::GroupNonUniformSMax, (sk::Uint, sg::Max) => spirv::Op::GroupNonUniformUMax, (sk::Float, sg::Max) => spirv::Op::GroupNonUniformFMax, (sk::Sint, sg::Min) => spirv::Op::GroupNonUniformSMin, (sk::Uint, sg::Min) => spirv::Op::GroupNonUniformUMin, (sk::Float, sg::Min) => spirv::Op::GroupNonUniformFMin, (_, sg::Add | sg::Mul | sg::Min | sg::Max) => unimplemented!(), (sk::Sint | sk::Uint, sg::And) => spirv::Op::GroupNonUniformBitwiseAnd, (sk::Sint | sk::Uint, sg::Or) => spirv::Op::GroupNonUniformBitwiseOr, (sk::Sint | sk::Uint, sg::Xor) => spirv::Op::GroupNonUniformBitwiseXor, (sk::Bool, sg::And) => spirv::Op::GroupNonUniformLogicalAnd, (sk::Bool, sg::Or) => spirv::Op::GroupNonUniformLogicalOr, (sk::Bool, sg::Xor) => spirv::Op::GroupNonUniformLogicalXor, (_, sg::And | sg::Or | sg::Xor) => unimplemented!(), }; let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); use crate::CollectiveOperation as c; let group_op = match *op { sg::All | sg::Any => None, _ => Some(match *collective_op { c::Reduce => spirv::GroupOperation::Reduce, c::InclusiveScan => spirv::GroupOperation::InclusiveScan, c::ExclusiveScan => spirv::GroupOperation::ExclusiveScan, }), }; let arg_id = self.cached[argument]; block.body.push(Instruction::group_non_uniform_arithmetic( spirv_op, result_type_id, id, exec_scope_id, group_op, arg_id, )); self.cached[result] = id; Ok(()) } pub(super) fn write_subgroup_gather( &mut self, mode: &crate::GatherMode, argument: Handle, result: Handle, block: &mut Block, ) -> Result<(), Error> { match *mode { crate::GatherMode::BroadcastFirst => { self.writer.require_any( "GroupNonUniformBallot", &[spirv::Capability::GroupNonUniformBallot], )?; } crate::GatherMode::Shuffle(_) | crate::GatherMode::ShuffleXor(_) | crate::GatherMode::Broadcast(_) => { self.writer.require_any( "GroupNonUniformShuffle", &[spirv::Capability::GroupNonUniformShuffle], )?; } crate::GatherMode::ShuffleDown(_) | crate::GatherMode::ShuffleUp(_) => { self.writer.require_any( "GroupNonUniformShuffleRelative", &[spirv::Capability::GroupNonUniformShuffleRelative], )?; } crate::GatherMode::QuadBroadcast(_) | crate::GatherMode::QuadSwap(_) => { self.writer.require_any( "GroupNonUniformQuad", &[spirv::Capability::GroupNonUniformQuad], )?; } } let id = self.gen_id(); let result_ty = &self.fun_info[result].ty; let result_type_id = self.get_expression_type_id(result_ty); let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); let arg_id = self.cached[argument]; match *mode { crate::GatherMode::BroadcastFirst => { block .body .push(Instruction::group_non_uniform_broadcast_first( result_type_id, id, exec_scope_id, arg_id, )); } crate::GatherMode::Broadcast(index) | crate::GatherMode::Shuffle(index) | crate::GatherMode::ShuffleDown(index) | crate::GatherMode::ShuffleUp(index) | crate::GatherMode::ShuffleXor(index) | crate::GatherMode::QuadBroadcast(index) => { let index_id = self.cached[index]; let op = match *mode { crate::GatherMode::BroadcastFirst => unreachable!(), // Use shuffle to emit broadcast to allow the index to // be dynamically uniform on Vulkan 1.1. The argument to // OpGroupNonUniformBroadcast must be a constant pre SPIR-V // 1.5 (vulkan 1.2) crate::GatherMode::Broadcast(_) => spirv::Op::GroupNonUniformShuffle, crate::GatherMode::Shuffle(_) => spirv::Op::GroupNonUniformShuffle, crate::GatherMode::ShuffleDown(_) => spirv::Op::GroupNonUniformShuffleDown, crate::GatherMode::ShuffleUp(_) => spirv::Op::GroupNonUniformShuffleUp, crate::GatherMode::ShuffleXor(_) => spirv::Op::GroupNonUniformShuffleXor, crate::GatherMode::QuadBroadcast(_) => spirv::Op::GroupNonUniformQuadBroadcast, crate::GatherMode::QuadSwap(_) => unreachable!(), }; block.body.push(Instruction::group_non_uniform_gather( op, result_type_id, id, exec_scope_id, arg_id, index_id, )); } crate::GatherMode::QuadSwap(direction) => { let direction = self.get_index_constant(match direction { crate::Direction::X => 0, crate::Direction::Y => 1, crate::Direction::Diagonal => 2, }); block.body.push(Instruction::group_non_uniform_quad_swap( result_type_id, id, exec_scope_id, arg_id, direction, )); } } self.cached[result] = id; Ok(()) } } naga-29.0.3/src/back/spv/writer.rs000064400000000000000000004771751046102023000150330ustar 00000000000000use alloc::{format, string::String, vec, vec::Vec}; use arrayvec::ArrayVec; use hashbrown::hash_map::Entry; use spirv::Word; use super::{ block::DebugInfoInner, helpers::{contains_builtin, global_needs_wrapper, map_storage_class}, Block, BlockContext, CachedConstant, CachedExpressions, CooperativeType, DebugInfo, EntryPointContext, Error, Function, FunctionArgument, GlobalVariable, IdGenerator, Instruction, LocalImageType, LocalType, LocalVariable, LogicalLayout, LookupFunctionType, LookupType, NumericType, Options, PhysicalLayout, PipelineOptions, ResultMember, Writer, WriterFlags, BITS_PER_BYTE, }; use crate::{ arena::{Handle, HandleVec, UniqueArena}, back::spv::{ helpers::{is_uniform_matcx2_struct_member_access, BindingDecorations}, BindingInfo, Std140CompatTypeInfo, WrappedFunction, }, common::ForDebugWithTypes as _, proc::{Alignment, TypeResolution}, valid::{FunctionInfo, ModuleInfo}, }; pub struct FunctionInterface<'a> { pub varying_ids: &'a mut Vec, pub stage: crate::ShaderStage, pub task_payload: Option>, pub mesh_info: Option, pub workgroup_size: [u32; 3], } impl Function { pub(super) fn to_words(&self, sink: &mut impl Extend) { self.signature.as_ref().unwrap().to_words(sink); for argument in self.parameters.iter() { argument.instruction.to_words(sink); } for (index, block) in self.blocks.iter().enumerate() { Instruction::label(block.label_id).to_words(sink); if index == 0 { for local_var in self.variables.values() { local_var.instruction.to_words(sink); } for local_var in self.ray_query_initialization_tracker_variables.values() { local_var.instruction.to_words(sink); } for local_var in self.ray_query_t_max_tracker_variables.values() { local_var.instruction.to_words(sink); } for local_var in self.force_loop_bounding_vars.iter() { local_var.instruction.to_words(sink); } for internal_var in self.spilled_composites.values() { internal_var.instruction.to_words(sink); } } for instruction in block.body.iter() { instruction.to_words(sink); } } Instruction::function_end().to_words(sink); } } impl Writer { pub fn new(options: &Options) -> Result { let (major, minor) = options.lang_version; if major != 1 { return Err(Error::UnsupportedVersion(major, minor)); } let mut capabilities_used = crate::FastIndexSet::default(); capabilities_used.insert(spirv::Capability::Shader); let mut id_gen = IdGenerator::default(); let gl450_ext_inst_id = id_gen.next(); let void_type = id_gen.next(); Ok(Writer { physical_layout: PhysicalLayout::new(major, minor), logical_layout: LogicalLayout::default(), id_gen, capabilities_available: options.capabilities.clone(), capabilities_used, extensions_used: crate::FastIndexSet::default(), debug_strings: vec![], debugs: vec![], annotations: vec![], flags: options.flags, bounds_check_policies: options.bounds_check_policies, zero_initialize_workgroup_memory: options.zero_initialize_workgroup_memory, force_loop_bounding: options.force_loop_bounding, ray_query_initialization_tracking: options.ray_query_initialization_tracking, use_storage_input_output_16: options.use_storage_input_output_16, void_type, tuple_of_u32s_ty_id: None, lookup_type: crate::FastHashMap::default(), lookup_function: crate::FastHashMap::default(), lookup_function_type: crate::FastHashMap::default(), wrapped_functions: crate::FastHashMap::default(), constant_ids: HandleVec::new(), cached_constants: crate::FastHashMap::default(), global_variables: HandleVec::new(), std140_compat_uniform_types: crate::FastHashMap::default(), fake_missing_bindings: options.fake_missing_bindings, binding_map: options.binding_map.clone(), saved_cached: CachedExpressions::default(), gl450_ext_inst_id, temp_list: Vec::new(), ray_query_functions: crate::FastHashMap::default(), io_f16_polyfills: super::f16_polyfill::F16IoPolyfill::new( options.use_storage_input_output_16, ), debug_printf: None, task_dispatch_limits: options.task_dispatch_limits, mesh_shader_primitive_indices_clamp: options.mesh_shader_primitive_indices_clamp, }) } pub fn set_options(&mut self, options: &Options) -> Result<(), Error> { let (major, minor) = options.lang_version; if major != 1 { return Err(Error::UnsupportedVersion(major, minor)); } self.physical_layout = PhysicalLayout::new(major, minor); self.capabilities_available = options.capabilities.clone(); self.flags = options.flags; self.bounds_check_policies = options.bounds_check_policies; self.zero_initialize_workgroup_memory = options.zero_initialize_workgroup_memory; self.force_loop_bounding = options.force_loop_bounding; self.use_storage_input_output_16 = options.use_storage_input_output_16; self.binding_map = options.binding_map.clone(); self.io_f16_polyfills = super::f16_polyfill::F16IoPolyfill::new(options.use_storage_input_output_16); self.task_dispatch_limits = options.task_dispatch_limits; self.mesh_shader_primitive_indices_clamp = options.mesh_shader_primitive_indices_clamp; Ok(()) } /// Returns `(major, minor)` of the SPIR-V language version. pub const fn lang_version(&self) -> (u8, u8) { self.physical_layout.lang_version() } /// Reset `Writer` to its initial state, retaining any allocations. /// /// Why not just implement `Reclaimable` for `Writer`? By design, /// `Reclaimable::reclaim` requires ownership of the value, not just /// `&mut`; see the trait documentation. But we need to use this method /// from functions like `Writer::write`, which only have `&mut Writer`. /// Workarounds include unsafe code (`core::ptr::read`, then `write`, ugh) /// or something like a `Default` impl that returns an oddly-initialized /// `Writer`, which is worse. fn reset(&mut self) { use super::reclaimable::Reclaimable; use core::mem::take; let mut id_gen = IdGenerator::default(); let gl450_ext_inst_id = id_gen.next(); let void_type = id_gen.next(); // Every field of the old writer that is not determined by the `Options` // passed to `Writer::new` should be reset somehow. let fresh = Writer { // Copied from the old Writer: flags: self.flags, bounds_check_policies: self.bounds_check_policies, zero_initialize_workgroup_memory: self.zero_initialize_workgroup_memory, force_loop_bounding: self.force_loop_bounding, ray_query_initialization_tracking: self.ray_query_initialization_tracking, use_storage_input_output_16: self.use_storage_input_output_16, capabilities_available: take(&mut self.capabilities_available), fake_missing_bindings: self.fake_missing_bindings, binding_map: take(&mut self.binding_map), task_dispatch_limits: self.task_dispatch_limits, mesh_shader_primitive_indices_clamp: self.mesh_shader_primitive_indices_clamp, // Initialized afresh: id_gen, void_type, tuple_of_u32s_ty_id: None, gl450_ext_inst_id, // Reclaimed: capabilities_used: take(&mut self.capabilities_used).reclaim(), extensions_used: take(&mut self.extensions_used).reclaim(), physical_layout: self.physical_layout.clone().reclaim(), logical_layout: take(&mut self.logical_layout).reclaim(), debug_strings: take(&mut self.debug_strings).reclaim(), debugs: take(&mut self.debugs).reclaim(), annotations: take(&mut self.annotations).reclaim(), lookup_type: take(&mut self.lookup_type).reclaim(), lookup_function: take(&mut self.lookup_function).reclaim(), lookup_function_type: take(&mut self.lookup_function_type).reclaim(), wrapped_functions: take(&mut self.wrapped_functions).reclaim(), constant_ids: take(&mut self.constant_ids).reclaim(), cached_constants: take(&mut self.cached_constants).reclaim(), global_variables: take(&mut self.global_variables).reclaim(), std140_compat_uniform_types: take(&mut self.std140_compat_uniform_types).reclaim(), saved_cached: take(&mut self.saved_cached).reclaim(), temp_list: take(&mut self.temp_list).reclaim(), ray_query_functions: take(&mut self.ray_query_functions).reclaim(), io_f16_polyfills: take(&mut self.io_f16_polyfills).reclaim(), debug_printf: None, }; *self = fresh; self.capabilities_used.insert(spirv::Capability::Shader); } /// Indicate that the code requires any one of the listed capabilities. /// /// If nothing in `capabilities` appears in the available capabilities /// specified in the [`Options`] from which this `Writer` was created, /// return an error. The `what` string is used in the error message to /// explain what provoked the requirement. (If no available capabilities were /// given, assume everything is available.) /// /// The first acceptable capability will be added to this `Writer`'s /// [`capabilities_used`] table, and an `OpCapability` emitted for it in the /// result. For this reason, more specific capabilities should be listed /// before more general. /// /// [`capabilities_used`]: Writer::capabilities_used pub(super) fn require_any( &mut self, what: &'static str, capabilities: &[spirv::Capability], ) -> Result<(), Error> { match *capabilities { [] => Ok(()), [first, ..] => { // Find the first acceptable capability, or return an error if // there is none. let selected = match self.capabilities_available { None => first, Some(ref available) => { match capabilities .iter() // need explicit type for hashbrown::HashSet::contains fn call to keep rustc happy .find(|cap| available.contains::(cap)) { Some(&cap) => cap, None => { return Err(Error::MissingCapabilities(what, capabilities.to_vec())) } } } }; self.capabilities_used.insert(selected); Ok(()) } } } /// Indicate that the code requires all of the listed capabilities. /// /// If all entries of `capabilities` appear in the available capabilities /// specified in the [`Options`] from which this `Writer` was created /// (including the case where [`Options::capabilities`] is `None`), add /// them all to this `Writer`'s [`capabilities_used`] table, and return /// `Ok(())`. If at least one of the listed capabilities is not available, /// do not add anything to the `capabilities_used` table, and return the /// first unavailable requested capability, wrapped in `Err()`. /// /// This method is does not return an [`enum@Error`] in case of failure /// because it may be used in cases where the caller can recover (e.g., /// with a polyfill) if the requested capabilities are not available. In /// this case, it would be unnecessary work to find *all* the unavailable /// requested capabilities, and to allocate a `Vec` for them, just so we /// could return an [`Error::MissingCapabilities`]). /// /// [`capabilities_used`]: Writer::capabilities_used pub(super) fn require_all( &mut self, capabilities: &[spirv::Capability], ) -> Result<(), spirv::Capability> { if let Some(ref available) = self.capabilities_available { for requested in capabilities { if !available.contains(requested) { return Err(*requested); } } } for requested in capabilities { self.capabilities_used.insert(*requested); } Ok(()) } /// Indicate that the code uses the given extension. pub(super) fn use_extension(&mut self, extension: &'static str) { self.extensions_used.insert(extension); } pub(super) fn get_type_id(&mut self, lookup_ty: LookupType) -> Word { match self.lookup_type.entry(lookup_ty) { Entry::Occupied(e) => *e.get(), Entry::Vacant(e) => { let local = match lookup_ty { LookupType::Handle(_handle) => unreachable!("Handles are populated at start"), LookupType::Local(local) => local, }; let id = self.id_gen.next(); e.insert(id); self.write_type_declaration_local(id, local); id } } } pub(super) fn get_handle_type_id(&mut self, handle: Handle) -> Word { self.get_type_id(LookupType::Handle(handle)) } pub(super) fn get_expression_lookup_type(&mut self, tr: &TypeResolution) -> LookupType { match *tr { TypeResolution::Handle(ty_handle) => LookupType::Handle(ty_handle), TypeResolution::Value(ref inner) => { let inner_local_type = self.localtype_from_inner(inner).unwrap(); LookupType::Local(inner_local_type) } } } pub(super) fn get_expression_type_id(&mut self, tr: &TypeResolution) -> Word { let lookup_ty = self.get_expression_lookup_type(tr); self.get_type_id(lookup_ty) } pub(super) fn get_localtype_id(&mut self, local: LocalType) -> Word { self.get_type_id(LookupType::Local(local)) } pub(super) fn get_pointer_type_id(&mut self, base: Word, class: spirv::StorageClass) -> Word { self.get_type_id(LookupType::Local(LocalType::Pointer { base, class })) } pub(super) fn get_handle_pointer_type_id( &mut self, base: Handle, class: spirv::StorageClass, ) -> Word { let base_id = self.get_handle_type_id(base); self.get_pointer_type_id(base_id, class) } pub(super) fn get_ray_query_pointer_id(&mut self) -> Word { let rq_id = self.get_type_id(LookupType::Local(LocalType::RayQuery)); self.get_pointer_type_id(rq_id, spirv::StorageClass::Function) } /// Return a SPIR-V type for a pointer to `resolution`. /// /// The given `resolution` must be one that we can represent /// either as a `LocalType::Pointer` or `LocalType::LocalPointer`. pub(super) fn get_resolution_pointer_id( &mut self, resolution: &TypeResolution, class: spirv::StorageClass, ) -> Word { let resolution_type_id = self.get_expression_type_id(resolution); self.get_pointer_type_id(resolution_type_id, class) } pub(super) fn get_numeric_type_id(&mut self, numeric: NumericType) -> Word { self.get_type_id(LocalType::Numeric(numeric).into()) } pub(super) fn get_u32_type_id(&mut self) -> Word { self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32)) } pub(super) fn get_f32_type_id(&mut self) -> Word { self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::F32)) } pub(super) fn get_vec2u_type_id(&mut self) -> Word { self.get_numeric_type_id(NumericType::Vector { size: crate::VectorSize::Bi, scalar: crate::Scalar::U32, }) } pub(super) fn get_vec2f_type_id(&mut self) -> Word { self.get_numeric_type_id(NumericType::Vector { size: crate::VectorSize::Bi, scalar: crate::Scalar::F32, }) } pub(super) fn get_vec3u_type_id(&mut self) -> Word { self.get_numeric_type_id(NumericType::Vector { size: crate::VectorSize::Tri, scalar: crate::Scalar::U32, }) } pub(super) fn get_f32_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word { let f32_id = self.get_f32_type_id(); self.get_pointer_type_id(f32_id, class) } pub(super) fn get_vec2u_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word { let vec2u_id = self.get_numeric_type_id(NumericType::Vector { size: crate::VectorSize::Bi, scalar: crate::Scalar::U32, }); self.get_pointer_type_id(vec2u_id, class) } pub(super) fn get_bool_type_id(&mut self) -> Word { self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::BOOL)) } pub(super) fn get_vec2_bool_type_id(&mut self) -> Word { self.get_numeric_type_id(NumericType::Vector { size: crate::VectorSize::Bi, scalar: crate::Scalar::BOOL, }) } pub(super) fn get_vec3_bool_type_id(&mut self) -> Word { self.get_numeric_type_id(NumericType::Vector { size: crate::VectorSize::Tri, scalar: crate::Scalar::BOOL, }) } /// Used for "mulhi" to get the upper bits of multiplication. /// /// More specifically, `OpUMulExtended` multiplies 2 numbers and returns the lower and upper bits of the result /// as a user-defined struct type with 2 u32s. This defines that struct. pub(super) fn get_tuple_of_u32s_ty_id(&mut self) -> Word { if let Some(val) = self.tuple_of_u32s_ty_id { val } else { let id = self.id_gen.next(); let u32_id = self.get_u32_type_id(); let ins = Instruction::type_struct(id, &[u32_id, u32_id]); ins.to_words(&mut self.logical_layout.declarations); self.tuple_of_u32s_ty_id = Some(id); id } } pub(super) fn decorate(&mut self, id: Word, decoration: spirv::Decoration, operands: &[Word]) { self.annotations .push(Instruction::decorate(id, decoration, operands)); } /// Return `inner` as a `LocalType`, if that's possible. /// /// If `inner` can be represented as a `LocalType`, return /// `Some(local_type)`. /// /// Otherwise, return `None`. In this case, the type must always be looked /// up using a `LookupType::Handle`. fn localtype_from_inner(&mut self, inner: &crate::TypeInner) -> Option { Some(match *inner { crate::TypeInner::Scalar(_) | crate::TypeInner::Atomic(_) | crate::TypeInner::Vector { .. } | crate::TypeInner::Matrix { .. } => { // We expect `NumericType::from_inner` to handle all // these cases, so unwrap. LocalType::Numeric(NumericType::from_inner(inner).unwrap()) } crate::TypeInner::CooperativeMatrix { .. } => { LocalType::Cooperative(CooperativeType::from_inner(inner).unwrap()) } crate::TypeInner::Pointer { base, space } => { let base_type_id = self.get_handle_type_id(base); LocalType::Pointer { base: base_type_id, class: map_storage_class(space), } } crate::TypeInner::ValuePointer { size, scalar, space, } => { let base_numeric_type = match size { Some(size) => NumericType::Vector { size, scalar }, None => NumericType::Scalar(scalar), }; LocalType::Pointer { base: self.get_numeric_type_id(base_numeric_type), class: map_storage_class(space), } } crate::TypeInner::Image { dim, arrayed, class, } => LocalType::Image(LocalImageType::from_inner(dim, arrayed, class)), crate::TypeInner::Sampler { comparison: _ } => LocalType::Sampler, crate::TypeInner::AccelerationStructure { .. } => LocalType::AccelerationStructure, crate::TypeInner::RayQuery { .. } => LocalType::RayQuery, crate::TypeInner::Array { .. } | crate::TypeInner::Struct { .. } | crate::TypeInner::BindingArray { .. } => return None, }) } /// Resolve the [`BindingInfo`] for a [`crate::ResourceBinding`] from the /// provided [`Writer::binding_map`]. /// /// If the specified resource is not present in the binding map this will /// return an error, unless [`Writer::fake_missing_bindings`] is set. fn resolve_resource_binding( &self, res_binding: &crate::ResourceBinding, ) -> Result { match self.binding_map.get(res_binding) { Some(target) => Ok(*target), None if self.fake_missing_bindings => Ok(BindingInfo { descriptor_set: res_binding.group, binding: res_binding.binding, binding_array_size: None, }), None => Err(Error::MissingBinding(*res_binding)), } } /// Emits code for any wrapper functions required by the expressions in ir_function. /// The IDs of any emitted functions will be stored in [`Self::wrapped_functions`]. fn write_wrapped_functions( &mut self, ir_function: &crate::Function, info: &FunctionInfo, ir_module: &crate::Module, ) -> Result<(), Error> { log::trace!("Generating wrapped functions for {:?}", ir_function.name); for (expr_handle, expr) in ir_function.expressions.iter() { match *expr { crate::Expression::Binary { op, left, right } => { let expr_ty_inner = info[expr_handle].ty.inner_with(&ir_module.types); if let Some(expr_ty) = NumericType::from_inner(expr_ty_inner) { match (op, expr_ty.scalar().kind) { // Division and modulo are undefined behaviour when the // dividend is the minimum representable value and the divisor // is negative one, or when the divisor is zero. These wrapped // functions override the divisor to one in these cases, // matching the WGSL spec. ( crate::BinaryOperator::Divide | crate::BinaryOperator::Modulo, crate::ScalarKind::Sint | crate::ScalarKind::Uint, ) => { self.write_wrapped_binary_op( op, expr_ty, &info[left].ty, &info[right].ty, )?; } _ => {} } } } crate::Expression::Load { pointer } => { if let crate::TypeInner::Pointer { base: pointer_type, space: crate::AddressSpace::Uniform, } = *info[pointer].ty.inner_with(&ir_module.types) { if self.std140_compat_uniform_types.contains_key(&pointer_type) { // Loading a std140 compat type requires the wrapper function // to convert to the regular type. self.write_wrapped_convert_from_std140_compat_type( ir_module, pointer_type, )?; } } } crate::Expression::Access { base, .. } => { if let crate::TypeInner::Pointer { base: base_type, space: crate::AddressSpace::Uniform, } = *info[base].ty.inner_with(&ir_module.types) { // Dynamic accesses of a two-row matrix's columns require a // wrapper function. if let crate::TypeInner::Matrix { rows: crate::VectorSize::Bi, .. } = ir_module.types[base_type].inner { self.write_wrapped_matcx2_get_column(ir_module, base_type)?; // If the matrix is *not* directly a member of a struct, then // we additionally require a wrapper function to convert from // the std140 compat type to the regular type. if !is_uniform_matcx2_struct_member_access( ir_function, info, ir_module, base, ) { self.write_wrapped_convert_from_std140_compat_type( ir_module, base_type, )?; } } } } _ => {} } } Ok(()) } /// Write a SPIR-V function that performs the operator `op` with Naga IR semantics. /// /// Define a function that performs an integer division or modulo operation, /// except that using a divisor of zero or causing signed overflow with a /// divisor of -1 returns the numerator unchanged, rather than exhibiting /// undefined behavior. /// /// Store the generated function's id in the [`wrapped_functions`] table. /// /// The operator `op` must be either [`Divide`] or [`Modulo`]. /// /// # Panics /// /// The `return_type`, `left_type` or `right_type` arguments must all be /// integer scalars or vectors. If not, this function panics. /// /// [`wrapped_functions`]: Writer::wrapped_functions /// [`Divide`]: crate::BinaryOperator::Divide /// [`Modulo`]: crate::BinaryOperator::Modulo fn write_wrapped_binary_op( &mut self, op: crate::BinaryOperator, return_type: NumericType, left_type: &TypeResolution, right_type: &TypeResolution, ) -> Result<(), Error> { let return_type_id = self.get_localtype_id(LocalType::Numeric(return_type)); let left_type_id = self.get_expression_type_id(left_type); let right_type_id = self.get_expression_type_id(right_type); // Check if we've already emitted this function. let wrapped = WrappedFunction::BinaryOp { op, left_type_id, right_type_id, }; let function_id = match self.wrapped_functions.entry(wrapped) { Entry::Occupied(_) => return Ok(()), Entry::Vacant(e) => *e.insert(self.id_gen.next()), }; let scalar = return_type.scalar(); if self.flags.contains(WriterFlags::DEBUG) { let function_name = match op { crate::BinaryOperator::Divide => "naga_div", crate::BinaryOperator::Modulo => "naga_mod", _ => unreachable!(), }; self.debugs .push(Instruction::name(function_id, function_name)); } let mut function = Function::default(); let function_type_id = self.get_function_type(LookupFunctionType { parameter_type_ids: vec![left_type_id, right_type_id], return_type_id, }); function.signature = Some(Instruction::function( return_type_id, function_id, spirv::FunctionControl::empty(), function_type_id, )); let lhs_id = self.id_gen.next(); let rhs_id = self.id_gen.next(); if self.flags.contains(WriterFlags::DEBUG) { self.debugs.push(Instruction::name(lhs_id, "lhs")); self.debugs.push(Instruction::name(rhs_id, "rhs")); } let left_par = Instruction::function_parameter(left_type_id, lhs_id); let right_par = Instruction::function_parameter(right_type_id, rhs_id); for instruction in [left_par, right_par] { function.parameters.push(FunctionArgument { instruction, handle_id: 0, }); } let label_id = self.id_gen.next(); let mut block = Block::new(label_id); let bool_type = return_type.with_scalar(crate::Scalar::BOOL); let bool_type_id = self.get_numeric_type_id(bool_type); let maybe_splat_const = |writer: &mut Self, const_id| match return_type { NumericType::Scalar(_) => const_id, NumericType::Vector { size, .. } => { let constituent_ids = [const_id; crate::VectorSize::MAX]; writer.get_constant_composite( LookupType::Local(LocalType::Numeric(return_type)), &constituent_ids[..size as usize], ) } NumericType::Matrix { .. } => unreachable!(), }; let const_zero_id = self.get_constant_scalar_with(0, scalar)?; let composite_zero_id = maybe_splat_const(self, const_zero_id); let rhs_eq_zero_id = self.id_gen.next(); block.body.push(Instruction::binary( spirv::Op::IEqual, bool_type_id, rhs_eq_zero_id, rhs_id, composite_zero_id, )); let divisor_selector_id = match scalar.kind { crate::ScalarKind::Sint => { let (const_min_id, const_neg_one_id) = match scalar.width { 4 => Ok(( self.get_constant_scalar(crate::Literal::I32(i32::MIN)), self.get_constant_scalar(crate::Literal::I32(-1i32)), )), 8 => Ok(( self.get_constant_scalar(crate::Literal::I64(i64::MIN)), self.get_constant_scalar(crate::Literal::I64(-1i64)), )), _ => Err(Error::Validation("Unexpected scalar width")), }?; let composite_min_id = maybe_splat_const(self, const_min_id); let composite_neg_one_id = maybe_splat_const(self, const_neg_one_id); let lhs_eq_int_min_id = self.id_gen.next(); block.body.push(Instruction::binary( spirv::Op::IEqual, bool_type_id, lhs_eq_int_min_id, lhs_id, composite_min_id, )); let rhs_eq_neg_one_id = self.id_gen.next(); block.body.push(Instruction::binary( spirv::Op::IEqual, bool_type_id, rhs_eq_neg_one_id, rhs_id, composite_neg_one_id, )); let lhs_eq_int_min_and_rhs_eq_neg_one_id = self.id_gen.next(); block.body.push(Instruction::binary( spirv::Op::LogicalAnd, bool_type_id, lhs_eq_int_min_and_rhs_eq_neg_one_id, lhs_eq_int_min_id, rhs_eq_neg_one_id, )); let rhs_eq_zero_or_lhs_eq_int_min_and_rhs_eq_neg_one_id = self.id_gen.next(); block.body.push(Instruction::binary( spirv::Op::LogicalOr, bool_type_id, rhs_eq_zero_or_lhs_eq_int_min_and_rhs_eq_neg_one_id, rhs_eq_zero_id, lhs_eq_int_min_and_rhs_eq_neg_one_id, )); rhs_eq_zero_or_lhs_eq_int_min_and_rhs_eq_neg_one_id } crate::ScalarKind::Uint => rhs_eq_zero_id, _ => unreachable!(), }; let const_one_id = self.get_constant_scalar_with(1, scalar)?; let composite_one_id = maybe_splat_const(self, const_one_id); let divisor_id = self.id_gen.next(); block.body.push(Instruction::select( right_type_id, divisor_id, divisor_selector_id, composite_one_id, rhs_id, )); let op = match (op, scalar.kind) { (crate::BinaryOperator::Divide, crate::ScalarKind::Sint) => spirv::Op::SDiv, (crate::BinaryOperator::Divide, crate::ScalarKind::Uint) => spirv::Op::UDiv, (crate::BinaryOperator::Modulo, crate::ScalarKind::Sint) => spirv::Op::SRem, (crate::BinaryOperator::Modulo, crate::ScalarKind::Uint) => spirv::Op::UMod, _ => unreachable!(), }; let return_id = self.id_gen.next(); block.body.push(Instruction::binary( op, return_type_id, return_id, lhs_id, divisor_id, )); function.consume(block, Instruction::return_value(return_id)); function.to_words(&mut self.logical_layout.function_definitions); Ok(()) } /// Writes a wrapper function to convert from a std140 compat type to its /// corresponding regular type. /// /// See [`Self::write_std140_compat_type_declaration`] for more details. fn write_wrapped_convert_from_std140_compat_type( &mut self, ir_module: &crate::Module, r#type: Handle, ) -> Result<(), Error> { // Check if we've already emitted this function. let wrapped = WrappedFunction::ConvertFromStd140CompatType { r#type }; let function_id = match self.wrapped_functions.entry(wrapped) { Entry::Occupied(_) => return Ok(()), Entry::Vacant(e) => *e.insert(self.id_gen.next()), }; if self.flags.contains(WriterFlags::DEBUG) { self.debugs.push(Instruction::name( function_id, &format!("{:?}_from_std140", r#type.for_debug(&ir_module.types)), )); } let param_type_id = self.std140_compat_uniform_types[&r#type].type_id; let return_type_id = self.get_handle_type_id(r#type); let mut function = Function::default(); let function_type_id = self.get_function_type(LookupFunctionType { parameter_type_ids: vec![param_type_id], return_type_id, }); function.signature = Some(Instruction::function( return_type_id, function_id, spirv::FunctionControl::empty(), function_type_id, )); let param_id = self.id_gen.next(); function.parameters.push(FunctionArgument { instruction: Instruction::function_parameter(param_type_id, param_id), handle_id: 0, }); let label_id = self.id_gen.next(); let mut block = Block::new(label_id); let result_id = match ir_module.types[r#type].inner { // Param is struct containing a vector member for each of the // matrix's columns. Extract each column from the struct then // composite into a matrix. crate::TypeInner::Matrix { columns, rows: rows @ crate::VectorSize::Bi, scalar, } => { let column_type_id = self.get_numeric_type_id(NumericType::Vector { size: rows, scalar }); let mut column_ids: ArrayVec = ArrayVec::new(); for column in 0..columns as u32 { let column_id = self.id_gen.next(); block.body.push(Instruction::composite_extract( column_type_id, column_id, param_id, &[column], )); column_ids.push(column_id); } let result_id = self.id_gen.next(); block.body.push(Instruction::composite_construct( return_type_id, result_id, &column_ids, )); result_id } // Param is an array where the base type is the std140 compatible // type corresponding to `base`. Iterate through each element and // call its conversion function, then composite into a new array. crate::TypeInner::Array { base, size, .. } => { // Ensure the conversion function for the array's base type is // declared. self.write_wrapped_convert_from_std140_compat_type(ir_module, base)?; let element_type_id = self.get_handle_type_id(base); let std140_element_type_id = self.std140_compat_uniform_types[&base].type_id; let element_conversion_function_id = self.wrapped_functions [&WrappedFunction::ConvertFromStd140CompatType { r#type: base }]; let mut element_ids = Vec::new(); let size = match size.resolve(ir_module.to_ctx())? { crate::proc::IndexableLength::Known(size) => size, crate::proc::IndexableLength::Dynamic => { return Err(Error::Validation( "Uniform buffers cannot contain dynamic arrays", )) } }; for i in 0..size { let std140_element_id = self.id_gen.next(); block.body.push(Instruction::composite_extract( std140_element_type_id, std140_element_id, param_id, &[i], )); let element_id = self.id_gen.next(); block.body.push(Instruction::function_call( element_type_id, element_id, element_conversion_function_id, &[std140_element_id], )); element_ids.push(element_id); } let result_id = self.id_gen.next(); block.body.push(Instruction::composite_construct( return_type_id, result_id, &element_ids, )); result_id } // Param is a struct where each two-row matrix member has been // decomposed in to separate vector members for each column. // Other members use their std140 compatible type if one exists, or // else their regular type. Iterate through each member, converting // or composing any matrices if required, then finally compose into // the struct. crate::TypeInner::Struct { ref members, .. } => { let mut member_ids = Vec::new(); let mut next_index = 0; for member in members { let member_id = self.id_gen.next(); let member_type_id = self.get_handle_type_id(member.ty); match ir_module.types[member.ty].inner { crate::TypeInner::Matrix { columns, rows: rows @ crate::VectorSize::Bi, scalar, } => { let mut column_ids: ArrayVec = ArrayVec::new(); let column_type_id = self .get_numeric_type_id(NumericType::Vector { size: rows, scalar }); for _ in 0..columns as u32 { let column_id = self.id_gen.next(); block.body.push(Instruction::composite_extract( column_type_id, column_id, param_id, &[next_index], )); column_ids.push(column_id); next_index += 1; } block.body.push(Instruction::composite_construct( member_type_id, member_id, &column_ids, )); } _ => { // Ensure the conversion function for the member's // type is declared. self.write_wrapped_convert_from_std140_compat_type( ir_module, member.ty, )?; match self.std140_compat_uniform_types.get(&member.ty) { Some(std140_type_info) => { let std140_member_id = self.id_gen.next(); block.body.push(Instruction::composite_extract( std140_type_info.type_id, std140_member_id, param_id, &[next_index], )); let function_id = self.wrapped_functions [&WrappedFunction::ConvertFromStd140CompatType { r#type: member.ty, }]; block.body.push(Instruction::function_call( member_type_id, member_id, function_id, &[std140_member_id], )); next_index += 1; } None => { let member_id = self.id_gen.next(); block.body.push(Instruction::composite_extract( member_type_id, member_id, param_id, &[next_index], )); next_index += 1; } } } } member_ids.push(member_id); } let result_id = self.id_gen.next(); block.body.push(Instruction::composite_construct( return_type_id, result_id, &member_ids, )); result_id } _ => unreachable!(), }; function.consume(block, Instruction::return_value(result_id)); function.to_words(&mut self.logical_layout.function_definitions); Ok(()) } /// Writes a wrapper function to get an `OpTypeVector` column from an /// `OpTypeMatrix` with a dynamic index. /// /// This is used when accessing a column of a [`TypeInner::Matrix`] through /// a [`Uniform`] address space pointer. In such cases, the matrix will have /// been declared in SPIR-V using an alternative type where each column is a /// member of a containing struct. SPIR-V is unable to dynamically access /// struct members, so instead we load the matrix then call this function to /// access a column from the loaded value. /// /// [`TypeInner::Matrix`]: crate::TypeInner::Matrix /// [`Uniform`]: crate::AddressSpace::Uniform fn write_wrapped_matcx2_get_column( &mut self, ir_module: &crate::Module, r#type: Handle, ) -> Result<(), Error> { let wrapped = WrappedFunction::MatCx2GetColumn { r#type }; let function_id = match self.wrapped_functions.entry(wrapped) { Entry::Occupied(_) => return Ok(()), Entry::Vacant(e) => *e.insert(self.id_gen.next()), }; if self.flags.contains(WriterFlags::DEBUG) { self.debugs.push(Instruction::name( function_id, &format!("{:?}_get_column", r#type.for_debug(&ir_module.types)), )); } let crate::TypeInner::Matrix { columns, rows: rows @ crate::VectorSize::Bi, scalar, } = ir_module.types[r#type].inner else { unreachable!(); }; let mut function = Function::default(); let matrix_type_id = self.get_handle_type_id(r#type); let column_index_type_id = self.get_u32_type_id(); let column_type_id = self.get_numeric_type_id(NumericType::Vector { size: rows, scalar }); let matrix_param_id = self.id_gen.next(); let column_index_param_id = self.id_gen.next(); function.parameters.push(FunctionArgument { instruction: Instruction::function_parameter(matrix_type_id, matrix_param_id), handle_id: 0, }); function.parameters.push(FunctionArgument { instruction: Instruction::function_parameter( column_index_type_id, column_index_param_id, ), handle_id: 0, }); let function_type_id = self.get_function_type(LookupFunctionType { parameter_type_ids: vec![matrix_type_id, column_index_type_id], return_type_id: column_type_id, }); function.signature = Some(Instruction::function( column_type_id, function_id, spirv::FunctionControl::empty(), function_type_id, )); let label_id = self.id_gen.next(); let mut block = Block::new(label_id); // Create a switch case for each column in the matrix, where each case // extracts its column from the matrix. Finally we use OpPhi to return // the correct column. let merge_id = self.id_gen.next(); block.body.push(Instruction::selection_merge( merge_id, spirv::SelectionControl::NONE, )); let cases = (0..columns as u32) .map(|i| super::instructions::Case { value: i, label_id: self.id_gen.next(), }) .collect::>(); // Which label we branch to in the default (column index out-of-bounds) // case depends on our bounds check policy. let default_id = match self.bounds_check_policies.index { // For `Restrict`, treat the same as the final column. crate::proc::BoundsCheckPolicy::Restrict => cases.last().unwrap().label_id, // For `ReadZeroSkipWrite`, branch directly to the merge block. This // will be handled in the `OpPhi` below to produce a zero value. crate::proc::BoundsCheckPolicy::ReadZeroSkipWrite => merge_id, // For `Unchecked` we create a new block containing an // `OpUnreachable`. crate::proc::BoundsCheckPolicy::Unchecked => self.id_gen.next(), }; function.consume( block, Instruction::switch(column_index_param_id, default_id, &cases), ); // Emit a block for each case, and produce a list of variable and parent // block IDs that will be used in an `OpPhi` below to select the right // value. let mut var_parent_pairs = cases .into_iter() .map(|case| { let mut block = Block::new(case.label_id); let column_id = self.id_gen.next(); block.body.push(Instruction::composite_extract( column_type_id, column_id, matrix_param_id, &[case.value], )); function.consume(block, Instruction::branch(merge_id)); (column_id, case.label_id) }) // Need capacity for up to 4 columns plus possibly a default case. .collect::>(); // Emit a block or append the variable and parent `OpPhi` pair for the // column index out-of-bounds case, if required. match self.bounds_check_policies.index { // Don't need to do anything for `Restrict` as we have branched from // the final column case's block. crate::proc::BoundsCheckPolicy::Restrict => {} // For `ReadZeroSkipWrite` we have branched directly from the block // containing the `OpSwitch`. The `OpPhi` should produce a zero // value. crate::proc::BoundsCheckPolicy::ReadZeroSkipWrite => { var_parent_pairs.push((self.get_constant_null(column_type_id), label_id)); } // For `Unchecked` create a new block containing `OpUnreachable`. // This does not need to be handled by the `OpPhi`. crate::proc::BoundsCheckPolicy::Unchecked => { function.consume( Block::new(default_id), Instruction::new(spirv::Op::Unreachable), ); } } let mut block = Block::new(merge_id); let result_id = self.id_gen.next(); block.body.push(Instruction::phi( column_type_id, result_id, &var_parent_pairs, )); function.consume(block, Instruction::return_value(result_id)); function.to_words(&mut self.logical_layout.function_definitions); Ok(()) } fn write_function( &mut self, ir_function: &crate::Function, info: &FunctionInfo, ir_module: &crate::Module, mut interface: Option, debug_info: &Option, ) -> Result { self.write_wrapped_functions(ir_function, info, ir_module)?; log::trace!("Generating code for {:?}", ir_function.name); let mut function = Function::default(); let prelude_id = self.id_gen.next(); let mut prelude = Block::new(prelude_id); let mut ep_context = EntryPointContext { argument_ids: Vec::new(), results: Vec::new(), task_payload_variable_id: if let Some(ref i) = interface { i.task_payload.map(|a| self.global_variables[a].var_id) } else { None }, mesh_state: None, }; let mut parameter_type_ids = Vec::with_capacity(ir_function.arguments.len()); let mut local_invocation_index_var_id = None; let mut local_invocation_index_id = None; for argument in ir_function.arguments.iter() { let class = spirv::StorageClass::Input; let handle_ty = ir_module.types[argument.ty].inner.is_handle(); let argument_type_id = if handle_ty { self.get_handle_pointer_type_id(argument.ty, spirv::StorageClass::UniformConstant) } else { self.get_handle_type_id(argument.ty) }; if let Some(ref mut iface) = interface { let id = if let Some(ref binding) = argument.binding { let name = argument.name.as_deref(); let varying_id = self.write_varying( ir_module, iface.stage, class, name, argument.ty, binding, )?; iface.varying_ids.push(varying_id); let id = self.load_io_with_f16_polyfill( &mut prelude.body, varying_id, argument_type_id, ); if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationIndex) { local_invocation_index_id = Some(id); local_invocation_index_var_id = Some(varying_id); } id } else if let crate::TypeInner::Struct { ref members, .. } = ir_module.types[argument.ty].inner { let struct_id = self.id_gen.next(); let mut constituent_ids = Vec::with_capacity(members.len()); for member in members { let type_id = self.get_handle_type_id(member.ty); let name = member.name.as_deref(); let binding = member.binding.as_ref().unwrap(); let varying_id = self.write_varying( ir_module, iface.stage, class, name, member.ty, binding, )?; iface.varying_ids.push(varying_id); let id = self.load_io_with_f16_polyfill(&mut prelude.body, varying_id, type_id); constituent_ids.push(id); if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationIndex) { local_invocation_index_id = Some(id); local_invocation_index_var_id = Some(varying_id); } } prelude.body.push(Instruction::composite_construct( argument_type_id, struct_id, &constituent_ids, )); struct_id } else { unreachable!("Missing argument binding on an entry point"); }; ep_context.argument_ids.push(id); } else { let argument_id = self.id_gen.next(); let instruction = Instruction::function_parameter(argument_type_id, argument_id); if self.flags.contains(WriterFlags::DEBUG) { if let Some(ref name) = argument.name { self.debugs.push(Instruction::name(argument_id, name)); } } function.parameters.push(FunctionArgument { instruction, handle_id: if handle_ty { let id = self.id_gen.next(); prelude.body.push(Instruction::load( self.get_handle_type_id(argument.ty), id, argument_id, None, )); id } else { 0 }, }); parameter_type_ids.push(argument_type_id); }; } let return_type_id = match ir_function.result { Some(ref result) => { if let Some(ref mut iface) = interface { let mut has_point_size = false; let class = spirv::StorageClass::Output; if let Some(ref binding) = result.binding { has_point_size |= *binding == crate::Binding::BuiltIn(crate::BuiltIn::PointSize); let type_id = self.get_handle_type_id(result.ty); let varying_id = if *binding == crate::Binding::BuiltIn(crate::BuiltIn::MeshTaskSize) { 0 } else { let varying_id = self.write_varying( ir_module, iface.stage, class, None, result.ty, binding, )?; iface.varying_ids.push(varying_id); varying_id }; ep_context.results.push(ResultMember { id: varying_id, type_id, built_in: binding.to_built_in(), }); } else if let crate::TypeInner::Struct { ref members, .. } = ir_module.types[result.ty].inner { for member in members { let type_id = self.get_handle_type_id(member.ty); let name = member.name.as_deref(); let binding = member.binding.as_ref().unwrap(); has_point_size |= *binding == crate::Binding::BuiltIn(crate::BuiltIn::PointSize); // This isn't an actual builtin in SPIR-V. It can only appear as the // output of a task shader and the output is used when writing the // entry point return, in which case the id is ignored anyway. let varying_id = if *binding == crate::Binding::BuiltIn(crate::BuiltIn::MeshTaskSize) { 0 } else { let varying_id = self.write_varying( ir_module, iface.stage, class, name, member.ty, binding, )?; iface.varying_ids.push(varying_id); varying_id }; ep_context.results.push(ResultMember { id: varying_id, type_id, built_in: binding.to_built_in(), }); } } else { unreachable!("Missing result binding on an entry point"); } if self.flags.contains(WriterFlags::FORCE_POINT_SIZE) && iface.stage == crate::ShaderStage::Vertex && !has_point_size { // add point size artificially let varying_id = self.id_gen.next(); let pointer_type_id = self.get_f32_pointer_type_id(class); Instruction::variable(pointer_type_id, varying_id, class, None) .to_words(&mut self.logical_layout.declarations); self.decorate( varying_id, spirv::Decoration::BuiltIn, &[spirv::BuiltIn::PointSize as u32], ); iface.varying_ids.push(varying_id); let default_value_id = self.get_constant_scalar(crate::Literal::F32(1.0)); prelude .body .push(Instruction::store(varying_id, default_value_id, None)); } if iface.stage == crate::ShaderStage::Task { self.get_vec3u_type_id() } else { self.void_type } } else { self.get_handle_type_id(result.ty) } } None => self.void_type, }; if let Some(ref mut iface) = interface { if let Some(task_payload) = iface.task_payload { iface .varying_ids .push(self.global_variables[task_payload].var_id); } self.write_entry_point_mesh_shader_info( iface, local_invocation_index_var_id, ir_module, &mut ep_context, )?; } let lookup_function_type = LookupFunctionType { parameter_type_ids, return_type_id, }; let function_id = self.id_gen.next(); if self.flags.contains(WriterFlags::DEBUG) { if let Some(ref name) = ir_function.name { self.debugs.push(Instruction::name(function_id, name)); } } let function_type = self.get_function_type(lookup_function_type); function.signature = Some(Instruction::function( return_type_id, function_id, spirv::FunctionControl::empty(), function_type, )); if interface.is_some() { function.entry_point_context = Some(ep_context); } // fill up the `GlobalVariable::access_id` for gv in self.global_variables.iter_mut() { gv.reset_for_function(); } for (handle, var) in ir_module.global_variables.iter() { if info[handle].is_empty() { continue; } let mut gv = self.global_variables[handle].clone(); if let Some(ref mut iface) = interface { // Have to include global variables in the interface if self.physical_layout.version >= 0x10400 && iface.task_payload != Some(handle) { iface.varying_ids.push(gv.var_id); } } match ir_module.types[var.ty].inner { // Any that are binding arrays we skip as we cannot load the array, we must load the result after indexing. crate::TypeInner::BindingArray { .. } => { gv.access_id = gv.var_id; } _ => { // Handle globals are pre-emitted and should be loaded automatically. if var.space == crate::AddressSpace::Handle { let var_type_id = self.get_handle_type_id(var.ty); let id = self.id_gen.next(); prelude .body .push(Instruction::load(var_type_id, id, gv.var_id, None)); gv.access_id = gv.var_id; gv.handle_id = id; } else if global_needs_wrapper(ir_module, var) { let class = map_storage_class(var.space); let pointer_type_id = match self.std140_compat_uniform_types.get(&var.ty) { Some(std140_type_info) if var.space == crate::AddressSpace::Uniform => { self.get_pointer_type_id(std140_type_info.type_id, class) } _ => self.get_handle_pointer_type_id(var.ty, class), }; let index_id = self.get_index_constant(0); let id = self.id_gen.next(); prelude.body.push(Instruction::access_chain( pointer_type_id, id, gv.var_id, &[index_id], )); gv.access_id = id; } else { // by default, the variable ID is accessed as is gv.access_id = gv.var_id; }; } } // work around borrow checking in the presence of `self.xxx()` calls self.global_variables[handle] = gv; } // Create a `BlockContext` for generating SPIR-V for the function's // body. let mut context = BlockContext { ir_module, ir_function, fun_info: info, function: &mut function, // Re-use the cached expression table from prior functions. cached: core::mem::take(&mut self.saved_cached), // Steal the Writer's temp list for a bit. temp_list: core::mem::take(&mut self.temp_list), force_loop_bounding: self.force_loop_bounding, writer: self, expression_constness: super::ExpressionConstnessTracker::from_arena( &ir_function.expressions, ), ray_query_tracker_expr: crate::FastHashMap::default(), }; // fill up the pre-emitted and const expressions context.cached.reset(ir_function.expressions.len()); for (handle, expr) in ir_function.expressions.iter() { if (expr.needs_pre_emit() && !matches!(*expr, crate::Expression::LocalVariable(_))) || context.expression_constness.is_const(handle) { context.cache_expression_value(handle, &mut prelude)?; } } for (handle, variable) in ir_function.local_variables.iter() { let id = context.gen_id(); if context.writer.flags.contains(WriterFlags::DEBUG) { if let Some(ref name) = variable.name { context.writer.debugs.push(Instruction::name(id, name)); } } let init_word = variable.init.map(|constant| context.cached[constant]); let pointer_type_id = context .writer .get_handle_pointer_type_id(variable.ty, spirv::StorageClass::Function); let instruction = Instruction::variable( pointer_type_id, id, spirv::StorageClass::Function, init_word.or_else(|| match ir_module.types[variable.ty].inner { crate::TypeInner::RayQuery { .. } => None, _ => { let type_id = context.get_handle_type_id(variable.ty); Some(context.writer.write_constant_null(type_id)) } }), ); context .function .variables .insert(handle, LocalVariable { id, instruction }); if let crate::TypeInner::RayQuery { .. } = ir_module.types[variable.ty].inner { // Don't refactor this into a struct: Although spirv itself allows opaque types in structs, // the vulkan environment for spirv does not. Putting ray queries into structs can cause // confusing bugs. let u32_type_id = context.writer.get_u32_type_id(); let ptr_u32_type_id = context .writer .get_pointer_type_id(u32_type_id, spirv::StorageClass::Function); let tracker_id = context.gen_id(); let tracker_init_id = context.writer.get_constant_scalar(crate::Literal::U32( crate::back::RayQueryPoint::empty().bits(), )); let tracker_instruction = Instruction::variable( ptr_u32_type_id, tracker_id, spirv::StorageClass::Function, Some(tracker_init_id), ); context .function .ray_query_initialization_tracker_variables .insert( handle, LocalVariable { id: tracker_id, instruction: tracker_instruction, }, ); let f32_type_id = context.writer.get_f32_type_id(); let ptr_f32_type_id = context .writer .get_pointer_type_id(f32_type_id, spirv::StorageClass::Function); let t_max_tracker_id = context.gen_id(); let t_max_tracker_init_id = context.writer.get_constant_scalar(crate::Literal::F32(0.0)); let t_max_tracker_instruction = Instruction::variable( ptr_f32_type_id, t_max_tracker_id, spirv::StorageClass::Function, Some(t_max_tracker_init_id), ); context.function.ray_query_t_max_tracker_variables.insert( handle, LocalVariable { id: t_max_tracker_id, instruction: t_max_tracker_instruction, }, ); } } for (handle, expr) in ir_function.expressions.iter() { match *expr { crate::Expression::LocalVariable(_) => { // Cache the `OpVariable` instruction we generated above as // the value of this expression. context.cache_expression_value(handle, &mut prelude)?; } crate::Expression::Access { base, .. } | crate::Expression::AccessIndex { base, .. } => { // Count references to `base` by `Access` and `AccessIndex` // instructions. See `access_uses` for details. *context.function.access_uses.entry(base).or_insert(0) += 1; } _ => {} } } let next_id = context.gen_id(); context .function .consume(prelude, Instruction::branch(next_id)); let workgroup_vars_init_exit_block_id = match (context.writer.zero_initialize_workgroup_memory, interface) { ( super::ZeroInitializeWorkgroupMemoryMode::Polyfill, Some( ref mut interface @ FunctionInterface { stage: crate::ShaderStage::Compute | crate::ShaderStage::Mesh | crate::ShaderStage::Task, .. }, ), ) => context.writer.generate_workgroup_vars_init_block( next_id, ir_module, info, local_invocation_index_id, interface, context.function, ), _ => None, }; let main_id = if let Some(exit_id) = workgroup_vars_init_exit_block_id { exit_id } else { next_id }; context.write_function_body(main_id, debug_info.as_ref())?; // Consume the `BlockContext`, ending its borrows and letting the // `Writer` steal back its cached expression table and temp_list. let BlockContext { cached, temp_list, .. } = context; self.saved_cached = cached; self.temp_list = temp_list; function.to_words(&mut self.logical_layout.function_definitions); if let Some(EntryPointContext { mesh_state: Some(ref mesh_state), .. }) = function.entry_point_context { self.write_mesh_shader_wrapper(mesh_state, function_id) } else if let Some(EntryPointContext { task_payload_variable_id: Some(tp), .. }) = function.entry_point_context { self.write_task_shader_wrapper(tp, function_id) } else { Ok(function_id) } } fn write_execution_mode( &mut self, function_id: Word, mode: spirv::ExecutionMode, ) -> Result<(), Error> { //self.check(mode.required_capabilities())?; Instruction::execution_mode(function_id, mode, &[]) .to_words(&mut self.logical_layout.execution_modes); Ok(()) } // TODO Move to instructions module fn write_entry_point( &mut self, entry_point: &crate::EntryPoint, info: &FunctionInfo, ir_module: &crate::Module, debug_info: &Option, ) -> Result { let mut interface_ids = Vec::new(); let function_id = self.write_function( &entry_point.function, info, ir_module, Some(FunctionInterface { varying_ids: &mut interface_ids, stage: entry_point.stage, task_payload: entry_point.task_payload, mesh_info: entry_point.mesh_info.clone(), workgroup_size: entry_point.workgroup_size, }), debug_info, )?; let exec_model = match entry_point.stage { crate::ShaderStage::Vertex => spirv::ExecutionModel::Vertex, crate::ShaderStage::Fragment => { self.write_execution_mode(function_id, spirv::ExecutionMode::OriginUpperLeft)?; match entry_point.early_depth_test { Some(crate::EarlyDepthTest::Force) => { self.write_execution_mode( function_id, spirv::ExecutionMode::EarlyFragmentTests, )?; } Some(crate::EarlyDepthTest::Allow { conservative }) => { // TODO: Consider emitting EarlyAndLateFragmentTestsAMD here, if available. // https://github.khronos.org/SPIRV-Registry/extensions/AMD/SPV_AMD_shader_early_and_late_fragment_tests.html // This permits early depth tests even if the shader writes to a storage // binding match conservative { crate::ConservativeDepth::GreaterEqual => self.write_execution_mode( function_id, spirv::ExecutionMode::DepthGreater, )?, crate::ConservativeDepth::LessEqual => self.write_execution_mode( function_id, spirv::ExecutionMode::DepthLess, )?, crate::ConservativeDepth::Unchanged => self.write_execution_mode( function_id, spirv::ExecutionMode::DepthUnchanged, )?, } } None => {} } if let Some(ref result) = entry_point.function.result { if contains_builtin( result.binding.as_ref(), result.ty, &ir_module.types, crate::BuiltIn::FragDepth, ) { self.write_execution_mode( function_id, spirv::ExecutionMode::DepthReplacing, )?; } } spirv::ExecutionModel::Fragment } crate::ShaderStage::Compute => { let execution_mode = spirv::ExecutionMode::LocalSize; Instruction::execution_mode( function_id, execution_mode, &entry_point.workgroup_size, ) .to_words(&mut self.logical_layout.execution_modes); spirv::ExecutionModel::GLCompute } crate::ShaderStage::Task => { let execution_mode = spirv::ExecutionMode::LocalSize; Instruction::execution_mode( function_id, execution_mode, &entry_point.workgroup_size, ) .to_words(&mut self.logical_layout.execution_modes); spirv::ExecutionModel::TaskEXT } crate::ShaderStage::Mesh => { let execution_mode = spirv::ExecutionMode::LocalSize; Instruction::execution_mode( function_id, execution_mode, &entry_point.workgroup_size, ) .to_words(&mut self.logical_layout.execution_modes); let mesh_info = entry_point.mesh_info.as_ref().unwrap(); Instruction::execution_mode( function_id, match mesh_info.topology { crate::MeshOutputTopology::Points => spirv::ExecutionMode::OutputPoints, crate::MeshOutputTopology::Lines => spirv::ExecutionMode::OutputLinesEXT, crate::MeshOutputTopology::Triangles => { spirv::ExecutionMode::OutputTrianglesEXT } }, &[], ) .to_words(&mut self.logical_layout.execution_modes); Instruction::execution_mode( function_id, spirv::ExecutionMode::OutputVertices, core::slice::from_ref(&mesh_info.max_vertices), ) .to_words(&mut self.logical_layout.execution_modes); Instruction::execution_mode( function_id, spirv::ExecutionMode::OutputPrimitivesEXT, core::slice::from_ref(&mesh_info.max_primitives), ) .to_words(&mut self.logical_layout.execution_modes); spirv::ExecutionModel::MeshEXT } crate::ShaderStage::RayGeneration | crate::ShaderStage::AnyHit | crate::ShaderStage::ClosestHit | crate::ShaderStage::Miss => unreachable!(), }; //self.check(exec_model.required_capabilities())?; Ok(Instruction::entry_point( exec_model, function_id, &entry_point.name, interface_ids.as_slice(), )) } fn make_scalar(&mut self, id: Word, scalar: crate::Scalar) -> Instruction { use crate::ScalarKind as Sk; let bits = (scalar.width * BITS_PER_BYTE) as u32; match scalar.kind { Sk::Sint | Sk::Uint => { let signedness = if scalar.kind == Sk::Sint { super::instructions::Signedness::Signed } else { super::instructions::Signedness::Unsigned }; let cap = match bits { 8 => Some(spirv::Capability::Int8), 16 => Some(spirv::Capability::Int16), 64 => Some(spirv::Capability::Int64), _ => None, }; if let Some(cap) = cap { self.capabilities_used.insert(cap); } Instruction::type_int(id, bits, signedness) } Sk::Float => { if bits == 64 { self.capabilities_used.insert(spirv::Capability::Float64); } if bits == 16 { self.capabilities_used.insert(spirv::Capability::Float16); self.capabilities_used .insert(spirv::Capability::StorageBuffer16BitAccess); self.capabilities_used .insert(spirv::Capability::UniformAndStorageBuffer16BitAccess); if self.use_storage_input_output_16 { self.capabilities_used .insert(spirv::Capability::StorageInputOutput16); } } Instruction::type_float(id, bits) } Sk::Bool => Instruction::type_bool(id), Sk::AbstractInt | Sk::AbstractFloat => { unreachable!("abstract types should never reach the backend"); } } } fn request_type_capabilities(&mut self, inner: &crate::TypeInner) -> Result<(), Error> { match *inner { crate::TypeInner::Image { dim, arrayed, class, } => { let sampled = match class { crate::ImageClass::Sampled { .. } => true, crate::ImageClass::Depth { .. } => true, crate::ImageClass::Storage { format, .. } => { self.request_image_format_capabilities(format.into())?; false } crate::ImageClass::External => unimplemented!(), }; match dim { crate::ImageDimension::D1 => { if sampled { self.require_any("sampled 1D images", &[spirv::Capability::Sampled1D])?; } else { self.require_any("1D storage images", &[spirv::Capability::Image1D])?; } } crate::ImageDimension::Cube if arrayed => { if sampled { self.require_any( "sampled cube array images", &[spirv::Capability::SampledCubeArray], )?; } else { self.require_any( "cube array storage images", &[spirv::Capability::ImageCubeArray], )?; } } _ => {} } } crate::TypeInner::AccelerationStructure { .. } => { self.require_any("Acceleration Structure", &[spirv::Capability::RayQueryKHR])?; } crate::TypeInner::RayQuery { .. } => { self.require_any("Ray Query", &[spirv::Capability::RayQueryKHR])?; } crate::TypeInner::Atomic(crate::Scalar { width: 8, kind: _ }) => { self.require_any("64 bit integer atomics", &[spirv::Capability::Int64Atomics])?; } crate::TypeInner::Atomic(crate::Scalar { width: 4, kind: crate::ScalarKind::Float, }) => { self.require_any( "32 bit floating-point atomics", &[spirv::Capability::AtomicFloat32AddEXT], )?; self.use_extension("SPV_EXT_shader_atomic_float_add"); } // 16 bit floating-point support requires Float16 capability crate::TypeInner::Matrix { scalar: crate::Scalar::F16, .. } | crate::TypeInner::Vector { scalar: crate::Scalar::F16, .. } | crate::TypeInner::Scalar(crate::Scalar::F16) => { self.require_any("16 bit floating-point", &[spirv::Capability::Float16])?; self.use_extension("SPV_KHR_16bit_storage"); } // Cooperative types and ops crate::TypeInner::CooperativeMatrix { .. } => { self.require_any( "cooperative matrix", &[spirv::Capability::CooperativeMatrixKHR], )?; self.require_any("memory model", &[spirv::Capability::VulkanMemoryModel])?; self.use_extension("SPV_KHR_cooperative_matrix"); self.use_extension("SPV_KHR_vulkan_memory_model"); } _ => {} } Ok(()) } fn write_numeric_type_declaration_local(&mut self, id: Word, numeric: NumericType) { let instruction = match numeric { NumericType::Scalar(scalar) => self.make_scalar(id, scalar), NumericType::Vector { size, scalar } => { let scalar_id = self.get_numeric_type_id(NumericType::Scalar(scalar)); Instruction::type_vector(id, scalar_id, size) } NumericType::Matrix { columns, rows, scalar, } => { let column_id = self.get_numeric_type_id(NumericType::Vector { size: rows, scalar }); Instruction::type_matrix(id, column_id, columns) } }; instruction.to_words(&mut self.logical_layout.declarations); } fn write_cooperative_type_declaration_local(&mut self, id: Word, coop: CooperativeType) { let instruction = match coop { CooperativeType::Matrix { columns, rows, scalar, role, } => { let scalar_id = self.get_localtype_id(LocalType::Numeric(NumericType::Scalar(scalar))); let scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); let columns_id = self.get_index_constant(columns as u32); let rows_id = self.get_index_constant(rows as u32); let role_id = self.get_index_constant(spirv::CooperativeMatrixUse::from(role) as u32); Instruction::type_coop_matrix(id, scalar_id, scope_id, rows_id, columns_id, role_id) } }; instruction.to_words(&mut self.logical_layout.declarations); } fn write_type_declaration_local(&mut self, id: Word, local_ty: LocalType) { let instruction = match local_ty { LocalType::Numeric(numeric) => { self.write_numeric_type_declaration_local(id, numeric); return; } LocalType::Cooperative(coop) => { self.write_cooperative_type_declaration_local(id, coop); return; } LocalType::Pointer { base, class } => Instruction::type_pointer(id, class, base), LocalType::Image(image) => { let local_type = LocalType::Numeric(NumericType::Scalar(image.sampled_type)); let type_id = self.get_localtype_id(local_type); Instruction::type_image(id, type_id, image.dim, image.flags, image.image_format) } LocalType::Sampler => Instruction::type_sampler(id), LocalType::SampledImage { image_type_id } => { Instruction::type_sampled_image(id, image_type_id) } LocalType::BindingArray { base, size } => { let inner_ty = self.get_handle_type_id(base); let scalar_id = self.get_constant_scalar(crate::Literal::U32(size)); Instruction::type_array(id, inner_ty, scalar_id) } LocalType::AccelerationStructure => Instruction::type_acceleration_structure(id), LocalType::RayQuery => Instruction::type_ray_query(id), }; instruction.to_words(&mut self.logical_layout.declarations); } fn write_type_declaration_arena( &mut self, module: &crate::Module, handle: Handle, ) -> Result { let ty = &module.types[handle]; // If it's a type that needs SPIR-V capabilities, request them now. // This needs to happen regardless of the LocalType lookup succeeding, // because some types which map to the same LocalType have different // capability requirements. See https://github.com/gfx-rs/wgpu/issues/5569 self.request_type_capabilities(&ty.inner)?; let id = if let Some(local) = self.localtype_from_inner(&ty.inner) { // This type can be represented as a `LocalType`, so check if we've // already written an instruction for it. If not, do so now, with // `write_type_declaration_local`. match self.lookup_type.entry(LookupType::Local(local)) { // We already have an id for this `LocalType`. Entry::Occupied(e) => *e.get(), // It's a type we haven't seen before. Entry::Vacant(e) => { let id = self.id_gen.next(); e.insert(id); self.write_type_declaration_local(id, local); id } } } else { use spirv::Decoration; let id = self.id_gen.next(); let instruction = match ty.inner { crate::TypeInner::Array { base, size, stride } => { self.decorate(id, Decoration::ArrayStride, &[stride]); let type_id = self.get_handle_type_id(base); match size.resolve(module.to_ctx())? { crate::proc::IndexableLength::Known(length) => { let length_id = self.get_index_constant(length); Instruction::type_array(id, type_id, length_id) } crate::proc::IndexableLength::Dynamic => { Instruction::type_runtime_array(id, type_id) } } } crate::TypeInner::BindingArray { base, size } => { let type_id = self.get_handle_type_id(base); match size.resolve(module.to_ctx())? { crate::proc::IndexableLength::Known(length) => { let length_id = self.get_index_constant(length); Instruction::type_array(id, type_id, length_id) } crate::proc::IndexableLength::Dynamic => { Instruction::type_runtime_array(id, type_id) } } } crate::TypeInner::Struct { ref members, span: _, } => { let mut has_runtime_array = false; let mut member_ids = Vec::with_capacity(members.len()); for (index, member) in members.iter().enumerate() { let member_ty = &module.types[member.ty]; match member_ty.inner { crate::TypeInner::Array { base: _, size: crate::ArraySize::Dynamic, stride: _, } => { has_runtime_array = true; } _ => (), } self.decorate_struct_member(id, index, member, &module.types)?; let member_id = self.get_handle_type_id(member.ty); member_ids.push(member_id); } if has_runtime_array { self.decorate(id, Decoration::Block, &[]); } Instruction::type_struct(id, member_ids.as_slice()) } // These all have TypeLocal representations, so they should have been // handled by `write_type_declaration_local` above. crate::TypeInner::Scalar(_) | crate::TypeInner::Atomic(_) | crate::TypeInner::Vector { .. } | crate::TypeInner::Matrix { .. } | crate::TypeInner::CooperativeMatrix { .. } | crate::TypeInner::Pointer { .. } | crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Image { .. } | crate::TypeInner::Sampler { .. } | crate::TypeInner::AccelerationStructure { .. } | crate::TypeInner::RayQuery { .. } => unreachable!(), }; instruction.to_words(&mut self.logical_layout.declarations); id }; // Add this handle as a new alias for that type. self.lookup_type.insert(LookupType::Handle(handle), id); if self.flags.contains(WriterFlags::DEBUG) { if let Some(ref name) = ty.name { self.debugs.push(Instruction::name(id, name)); } } Ok(id) } /// Writes a std140 layout compatible type declaration for a type. Returns /// the ID of the declared type, or None if no declaration is required. /// /// This should be called for any type for which there exists a /// [`GlobalVariable`] in the [`Uniform`] address space. If the type already /// adheres to std140 layout rules it will return without declaring any /// types. If the type contains another type which requires a std140 /// compatible type declaration, it will recursively call itself. /// /// When `handle` refers to a [`TypeInner::Matrix`] with 2 rows, the /// declared type will be an `OpTypeStruct` containing an `OpVector` for /// each of the matrix's columns. /// /// When `handle` refers to a [`TypeInner::Array`] whose base type is a /// matrix with 2 rows, this will declare an `OpTypeArray` whose element /// type is the matrix's corresponding std140 compatible type. /// /// When `handle` refers to a [`TypeInner::Struct`] and any of its members /// require a std140 compatible type declaration, this will declare a new /// struct with the following rules: /// * Struct or array members will be declared with their std140 compatible /// type declaration, if one is required. /// * Two-row matrix members will have each of their columns hoisted /// directly into the struct as 2-component vector members. /// * All other members will be declared with their normal type. /// /// Note that this means the Naga IR index of a struct member may not match /// the index in the generated SPIR-V. The mapping can be obtained via /// `Std140TypeInfo::member_indices`. /// /// [`GlobalVariable`]: crate::GlobalVariable /// [`Uniform`]: crate::AddressSpace::Uniform /// [`TypeInner::Matrix`]: crate::TypeInner::Matrix /// [`TypeInner::Array`]: crate::TypeInner::Array /// [`TypeInner::Struct`]: crate::TypeInner::Struct fn write_std140_compat_type_declaration( &mut self, module: &crate::Module, handle: Handle, ) -> Result, Error> { if let Some(std140_type_info) = self.std140_compat_uniform_types.get(&handle) { return Ok(Some(std140_type_info.type_id)); } let type_inner = &module.types[handle].inner; let std140_type_id = match *type_inner { crate::TypeInner::Matrix { columns, rows: rows @ crate::VectorSize::Bi, scalar, } => { let std140_type_id = self.id_gen.next(); let mut member_type_ids: ArrayVec = ArrayVec::new(); let column_type_id = self.get_numeric_type_id(NumericType::Vector { size: rows, scalar }); for column in 0..columns as u32 { member_type_ids.push(column_type_id); self.annotations.push(Instruction::member_decorate( std140_type_id, column, spirv::Decoration::Offset, &[column * rows as u32 * scalar.width as u32], )); if self.flags.contains(WriterFlags::DEBUG) { self.debugs.push(Instruction::member_name( std140_type_id, column, &format!("col{column}"), )); } } Instruction::type_struct(std140_type_id, &member_type_ids) .to_words(&mut self.logical_layout.declarations); self.std140_compat_uniform_types.insert( handle, Std140CompatTypeInfo { type_id: std140_type_id, member_indices: Vec::new(), }, ); Some(std140_type_id) } crate::TypeInner::Array { base, size, stride } => { match self.write_std140_compat_type_declaration(module, base)? { Some(std140_base_type_id) => { let std140_type_id = self.id_gen.next(); self.decorate(std140_type_id, spirv::Decoration::ArrayStride, &[stride]); let instruction = match size.resolve(module.to_ctx())? { crate::proc::IndexableLength::Known(length) => { let length_id = self.get_index_constant(length); Instruction::type_array( std140_type_id, std140_base_type_id, length_id, ) } crate::proc::IndexableLength::Dynamic => { unreachable!() } }; instruction.to_words(&mut self.logical_layout.declarations); self.std140_compat_uniform_types.insert( handle, Std140CompatTypeInfo { type_id: std140_type_id, member_indices: Vec::new(), }, ); Some(std140_type_id) } None => None, } } crate::TypeInner::Struct { ref members, .. } => { let mut needs_std140_type = false; for member in members { match module.types[member.ty].inner { // We don't need to write a std140 type for the matrix itself as // it will be decomposed into the parent struct. As a result, the // struct does need a std140 type, however. crate::TypeInner::Matrix { rows: crate::VectorSize::Bi, .. } => needs_std140_type = true, // If an array member needs a std140 type, because it is an array // (of an array, etc) of `matCx2`s, then the struct also needs // a std140 type which uses the std140 type for this member. crate::TypeInner::Array { .. } if self .write_std140_compat_type_declaration(module, member.ty)? .is_some() => { needs_std140_type = true; } _ => {} } } if needs_std140_type { let std140_type_id = self.id_gen.next(); let mut member_ids = Vec::new(); let mut member_indices = Vec::new(); let mut next_index = 0; for member in members { member_indices.push(next_index); match module.types[member.ty].inner { crate::TypeInner::Matrix { columns, rows: rows @ crate::VectorSize::Bi, scalar, } => { let vector_type_id = self.get_numeric_type_id(NumericType::Vector { size: rows, scalar, }); for column in 0..columns as u32 { self.annotations.push(Instruction::member_decorate( std140_type_id, next_index, spirv::Decoration::Offset, &[member.offset + column * rows as u32 * scalar.width as u32], )); if self.flags.contains(WriterFlags::DEBUG) { if let Some(ref name) = member.name { self.debugs.push(Instruction::member_name( std140_type_id, next_index, &format!("{name}_col{column}"), )); } } member_ids.push(vector_type_id); next_index += 1; } } _ => { let member_id = match self.std140_compat_uniform_types.get(&member.ty) { Some(std140_member_type_info) => { self.annotations.push(Instruction::member_decorate( std140_type_id, next_index, spirv::Decoration::Offset, &[member.offset], )); if self.flags.contains(WriterFlags::DEBUG) { if let Some(ref name) = member.name { self.debugs.push(Instruction::member_name( std140_type_id, next_index, name, )); } } std140_member_type_info.type_id } None => { self.decorate_struct_member( std140_type_id, next_index as usize, member, &module.types, )?; self.get_handle_type_id(member.ty) } }; member_ids.push(member_id); next_index += 1; } } } Instruction::type_struct(std140_type_id, &member_ids) .to_words(&mut self.logical_layout.declarations); self.std140_compat_uniform_types.insert( handle, Std140CompatTypeInfo { type_id: std140_type_id, member_indices, }, ); Some(std140_type_id) } else { None } } _ => None, }; if let Some(std140_type_id) = std140_type_id { if self.flags.contains(WriterFlags::DEBUG) { let name = format!("std140_{:?}", handle.for_debug(&module.types)); self.debugs.push(Instruction::name(std140_type_id, &name)); } } Ok(std140_type_id) } fn request_image_format_capabilities( &mut self, format: spirv::ImageFormat, ) -> Result<(), Error> { use spirv::ImageFormat as If; match format { If::Rg32f | If::Rg16f | If::R11fG11fB10f | If::R16f | If::Rgba16 | If::Rgb10A2 | If::Rg16 | If::Rg8 | If::R16 | If::R8 | If::Rgba16Snorm | If::Rg16Snorm | If::Rg8Snorm | If::R16Snorm | If::R8Snorm | If::Rg32i | If::Rg16i | If::Rg8i | If::R16i | If::R8i | If::Rgb10a2ui | If::Rg32ui | If::Rg16ui | If::Rg8ui | If::R16ui | If::R8ui => self.require_any( "storage image format", &[spirv::Capability::StorageImageExtendedFormats], ), If::R64ui | If::R64i => { self.use_extension("SPV_EXT_shader_image_int64"); self.require_any( "64-bit integer storage image format", &[spirv::Capability::Int64ImageEXT], ) } If::Unknown | If::Rgba32f | If::Rgba16f | If::R32f | If::Rgba8 | If::Rgba8Snorm | If::Rgba32i | If::Rgba16i | If::Rgba8i | If::R32i | If::Rgba32ui | If::Rgba16ui | If::Rgba8ui | If::R32ui => Ok(()), } } pub(super) fn get_index_constant(&mut self, index: Word) -> Word { self.get_constant_scalar(crate::Literal::U32(index)) } pub(super) fn get_constant_scalar_with( &mut self, value: u8, scalar: crate::Scalar, ) -> Result { Ok( self.get_constant_scalar(crate::Literal::new(value, scalar).ok_or( Error::Validation("Unexpected kind and/or width for Literal"), )?), ) } pub(super) fn get_constant_scalar(&mut self, value: crate::Literal) -> Word { let scalar = CachedConstant::Literal(value.into()); if let Some(&id) = self.cached_constants.get(&scalar) { return id; } let id = self.id_gen.next(); self.write_constant_scalar(id, &value, None); self.cached_constants.insert(scalar, id); id } fn write_constant_scalar( &mut self, id: Word, value: &crate::Literal, debug_name: Option<&String>, ) { if self.flags.contains(WriterFlags::DEBUG) { if let Some(name) = debug_name { self.debugs.push(Instruction::name(id, name)); } } let type_id = self.get_numeric_type_id(NumericType::Scalar(value.scalar())); let instruction = match *value { crate::Literal::F64(value) => { let bits = value.to_bits(); Instruction::constant_64bit(type_id, id, bits as u32, (bits >> 32) as u32) } crate::Literal::F32(value) => Instruction::constant_32bit(type_id, id, value.to_bits()), crate::Literal::F16(value) => { let low = value.to_bits(); Instruction::constant_16bit(type_id, id, low as u32) } crate::Literal::U32(value) => Instruction::constant_32bit(type_id, id, value), crate::Literal::I32(value) => Instruction::constant_32bit(type_id, id, value as u32), crate::Literal::U64(value) => { Instruction::constant_64bit(type_id, id, value as u32, (value >> 32) as u32) } crate::Literal::I64(value) => { Instruction::constant_64bit(type_id, id, value as u32, (value >> 32) as u32) } crate::Literal::Bool(true) => Instruction::constant_true(type_id, id), crate::Literal::Bool(false) => Instruction::constant_false(type_id, id), crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => { unreachable!("Abstract types should not appear in IR presented to backends"); } }; instruction.to_words(&mut self.logical_layout.declarations); } pub(super) fn get_constant_composite( &mut self, ty: LookupType, constituent_ids: &[Word], ) -> Word { let composite = CachedConstant::Composite { ty, constituent_ids: constituent_ids.to_vec(), }; if let Some(&id) = self.cached_constants.get(&composite) { return id; } let id = self.id_gen.next(); self.write_constant_composite(id, ty, constituent_ids, None); self.cached_constants.insert(composite, id); id } fn write_constant_composite( &mut self, id: Word, ty: LookupType, constituent_ids: &[Word], debug_name: Option<&String>, ) { if self.flags.contains(WriterFlags::DEBUG) { if let Some(name) = debug_name { self.debugs.push(Instruction::name(id, name)); } } let type_id = self.get_type_id(ty); Instruction::constant_composite(type_id, id, constituent_ids) .to_words(&mut self.logical_layout.declarations); } pub(super) fn get_constant_null(&mut self, type_id: Word) -> Word { let null = CachedConstant::ZeroValue(type_id); if let Some(&id) = self.cached_constants.get(&null) { return id; } let id = self.write_constant_null(type_id); self.cached_constants.insert(null, id); id } pub(super) fn write_constant_null(&mut self, type_id: Word) -> Word { let null_id = self.id_gen.next(); Instruction::constant_null(type_id, null_id) .to_words(&mut self.logical_layout.declarations); null_id } fn write_constant_expr( &mut self, handle: Handle, ir_module: &crate::Module, mod_info: &ModuleInfo, ) -> Result { let id = match ir_module.global_expressions[handle] { crate::Expression::Literal(literal) => self.get_constant_scalar(literal), crate::Expression::Constant(constant) => { let constant = &ir_module.constants[constant]; self.constant_ids[constant.init] } crate::Expression::ZeroValue(ty) => { let type_id = self.get_handle_type_id(ty); self.get_constant_null(type_id) } crate::Expression::Compose { ty, ref components } => { let component_ids: Vec<_> = crate::proc::flatten_compose( ty, components, &ir_module.global_expressions, &ir_module.types, ) .map(|component| self.constant_ids[component]) .collect(); self.get_constant_composite(LookupType::Handle(ty), component_ids.as_slice()) } crate::Expression::Splat { size, value } => { let value_id = self.constant_ids[value]; let component_ids = &[value_id; 4][..size as usize]; let ty = self.get_expression_lookup_type(&mod_info[handle]); self.get_constant_composite(ty, component_ids) } _ => { return Err(Error::Override); } }; self.constant_ids[handle] = id; Ok(id) } pub(super) fn write_control_barrier( &mut self, flags: crate::Barrier, body: &mut Vec, ) { let memory_scope = if flags.contains(crate::Barrier::STORAGE) { spirv::Scope::Device } else if flags.contains(crate::Barrier::SUB_GROUP) { spirv::Scope::Subgroup } else { spirv::Scope::Workgroup }; let mut semantics = spirv::MemorySemantics::ACQUIRE_RELEASE; semantics.set( spirv::MemorySemantics::UNIFORM_MEMORY, flags.contains(crate::Barrier::STORAGE), ); semantics.set( spirv::MemorySemantics::WORKGROUP_MEMORY, flags.contains(crate::Barrier::WORK_GROUP), ); semantics.set( spirv::MemorySemantics::SUBGROUP_MEMORY, flags.contains(crate::Barrier::SUB_GROUP), ); semantics.set( spirv::MemorySemantics::IMAGE_MEMORY, flags.contains(crate::Barrier::TEXTURE), ); let exec_scope_id = if flags.contains(crate::Barrier::SUB_GROUP) { self.get_index_constant(spirv::Scope::Subgroup as u32) } else { self.get_index_constant(spirv::Scope::Workgroup as u32) }; let mem_scope_id = self.get_index_constant(memory_scope as u32); let semantics_id = self.get_index_constant(semantics.bits()); body.push(Instruction::control_barrier( exec_scope_id, mem_scope_id, semantics_id, )); } pub(super) fn write_memory_barrier(&mut self, flags: crate::Barrier, block: &mut Block) { let mut semantics = spirv::MemorySemantics::ACQUIRE_RELEASE; semantics.set( spirv::MemorySemantics::UNIFORM_MEMORY, flags.contains(crate::Barrier::STORAGE), ); semantics.set( spirv::MemorySemantics::WORKGROUP_MEMORY, flags.contains(crate::Barrier::WORK_GROUP), ); semantics.set( spirv::MemorySemantics::SUBGROUP_MEMORY, flags.contains(crate::Barrier::SUB_GROUP), ); semantics.set( spirv::MemorySemantics::IMAGE_MEMORY, flags.contains(crate::Barrier::TEXTURE), ); let mem_scope_id = if flags.contains(crate::Barrier::STORAGE) { self.get_index_constant(spirv::Scope::Device as u32) } else if flags.contains(crate::Barrier::SUB_GROUP) { self.get_index_constant(spirv::Scope::Subgroup as u32) } else { self.get_index_constant(spirv::Scope::Workgroup as u32) }; let semantics_id = self.get_index_constant(semantics.bits()); block .body .push(Instruction::memory_barrier(mem_scope_id, semantics_id)); } fn generate_workgroup_vars_init_block( &mut self, entry_id: Word, ir_module: &crate::Module, info: &FunctionInfo, local_invocation_index: Option, interface: &mut FunctionInterface, function: &mut Function, ) -> Option { let body = ir_module .global_variables .iter() .filter(|&(handle, var)| { let task_exception = (var.space == crate::AddressSpace::TaskPayload) && interface.stage == crate::ShaderStage::Task; !info[handle].is_empty() && (var.space == crate::AddressSpace::WorkGroup || task_exception) }) .map(|(handle, var)| { // It's safe to use `var_id` here, not `access_id`, because only // variables in the `Uniform` and `StorageBuffer` address spaces // get wrapped, and we're initializing `WorkGroup` variables. let var_id = self.global_variables[handle].var_id; let var_type_id = self.get_handle_type_id(var.ty); let init_word = self.get_constant_null(var_type_id); Instruction::store(var_id, init_word, None) }) .collect::>(); if body.is_empty() { return None; } let mut pre_if_block = Block::new(entry_id); let local_invocation_index = if let Some(local_invocation_index) = local_invocation_index { local_invocation_index } else { let varying_id = self.id_gen.next(); let class = spirv::StorageClass::Input; let u32_ty_id = self.get_u32_type_id(); let pointer_type_id = self.get_pointer_type_id(u32_ty_id, class); Instruction::variable(pointer_type_id, varying_id, class, None) .to_words(&mut self.logical_layout.declarations); self.decorate( varying_id, spirv::Decoration::BuiltIn, &[spirv::BuiltIn::LocalInvocationIndex as u32], ); interface.varying_ids.push(varying_id); let id = self.id_gen.next(); pre_if_block .body .push(Instruction::load(u32_ty_id, id, varying_id, None)); id }; let zero_id = self.get_constant_scalar(crate::Literal::U32(0)); let eq_id = self.id_gen.next(); pre_if_block.body.push(Instruction::binary( spirv::Op::IEqual, self.get_bool_type_id(), eq_id, local_invocation_index, zero_id, )); let merge_id = self.id_gen.next(); pre_if_block.body.push(Instruction::selection_merge( merge_id, spirv::SelectionControl::NONE, )); let accept_id = self.id_gen.next(); function.consume( pre_if_block, Instruction::branch_conditional(eq_id, accept_id, merge_id), ); let accept_block = Block { label_id: accept_id, body, }; function.consume(accept_block, Instruction::branch(merge_id)); let mut post_if_block = Block::new(merge_id); self.write_control_barrier(crate::Barrier::WORK_GROUP, &mut post_if_block.body); let next_id = self.id_gen.next(); function.consume(post_if_block, Instruction::branch(next_id)); Some(next_id) } /// Generate an `OpVariable` for one value in an [`EntryPoint`]'s IO interface. /// /// The [`Binding`]s of the arguments and result of an [`EntryPoint`]'s /// [`Function`] describe a SPIR-V shader interface. In SPIR-V, the /// interface is represented by global variables in the `Input` and `Output` /// storage classes, with decorations indicating which builtin or location /// each variable corresponds to. /// /// This function emits a single global `OpVariable` for a single value from /// the interface, and adds appropriate decorations to indicate which /// builtin or location it represents, how it should be interpolated, and so /// on. The `class` argument gives the variable's SPIR-V storage class, /// which should be either [`Input`] or [`Output`]. /// /// [`Binding`]: crate::Binding /// [`Function`]: crate::Function /// [`EntryPoint`]: crate::EntryPoint /// [`Input`]: spirv::StorageClass::Input /// [`Output`]: spirv::StorageClass::Output fn write_varying( &mut self, ir_module: &crate::Module, stage: crate::ShaderStage, class: spirv::StorageClass, debug_name: Option<&str>, ty: Handle, binding: &crate::Binding, ) -> Result { let id = self.id_gen.next(); let ty_inner = &ir_module.types[ty].inner; let needs_polyfill = self.needs_f16_polyfill(ty_inner); let pointer_type_id = if needs_polyfill { let f32_value_local = super::f16_polyfill::F16IoPolyfill::create_polyfill_type(ty_inner) .expect("needs_polyfill returned true but create_polyfill_type returned None"); let f32_type_id = self.get_localtype_id(f32_value_local); let ptr_id = self.get_pointer_type_id(f32_type_id, class); self.io_f16_polyfills.register_io_var(id, f32_type_id); ptr_id } else { self.get_handle_pointer_type_id(ty, class) }; Instruction::variable(pointer_type_id, id, class, None) .to_words(&mut self.logical_layout.declarations); if self .flags .contains(WriterFlags::DEBUG | WriterFlags::LABEL_VARYINGS) { if let Some(name) = debug_name { self.debugs.push(Instruction::name(id, name)); } } let binding = self.map_binding(ir_module, stage, class, ty, binding)?; self.write_binding(id, binding); Ok(id) } pub fn write_binding(&mut self, id: Word, binding: BindingDecorations) { match binding { BindingDecorations::None => (), BindingDecorations::BuiltIn(bi, others) => { self.decorate(id, spirv::Decoration::BuiltIn, &[bi as u32]); for other in others { self.decorate(id, other, &[]); } } BindingDecorations::Location { location, others, blend_src, } => { self.decorate(id, spirv::Decoration::Location, &[location]); for other in others { self.decorate(id, other, &[]); } if let Some(blend_src) = blend_src { self.decorate(id, spirv::Decoration::Index, &[blend_src]); } } } } pub fn write_binding_struct_member( &mut self, struct_id: Word, member_idx: Word, binding_info: BindingDecorations, ) { match binding_info { BindingDecorations::None => (), BindingDecorations::BuiltIn(bi, others) => { self.annotations.push(Instruction::member_decorate( struct_id, member_idx, spirv::Decoration::BuiltIn, &[bi as Word], )); for other in others { self.annotations.push(Instruction::member_decorate( struct_id, member_idx, other, &[], )); } } BindingDecorations::Location { location, others, blend_src, } => { self.annotations.push(Instruction::member_decorate( struct_id, member_idx, spirv::Decoration::Location, &[location], )); for other in others { self.annotations.push(Instruction::member_decorate( struct_id, member_idx, other, &[], )); } if let Some(blend_src) = blend_src { self.annotations.push(Instruction::member_decorate( struct_id, member_idx, spirv::Decoration::Index, &[blend_src], )); } } } } pub fn map_binding( &mut self, ir_module: &crate::Module, stage: crate::ShaderStage, class: spirv::StorageClass, ty: Handle, binding: &crate::Binding, ) -> Result { use spirv::BuiltIn; use spirv::Decoration; match *binding { crate::Binding::Location { location, interpolation, sampling, blend_src, per_primitive, } => { let mut others = ArrayVec::new(); let no_decorations = // VUID-StandaloneSpirv-Flat-06202 // > The Flat, NoPerspective, Sample, and Centroid decorations // > must not be used on variables with the Input storage class in a vertex shader (class == spirv::StorageClass::Input && stage == crate::ShaderStage::Vertex) || // VUID-StandaloneSpirv-Flat-06201 // > The Flat, NoPerspective, Sample, and Centroid decorations // > must not be used on variables with the Output storage class in a fragment shader (class == spirv::StorageClass::Output && stage == crate::ShaderStage::Fragment); if !no_decorations { match interpolation { // Perspective-correct interpolation is the default in SPIR-V. None | Some(crate::Interpolation::Perspective) => (), Some(crate::Interpolation::Flat) => { others.push(Decoration::Flat); } Some(crate::Interpolation::Linear) => { others.push(Decoration::NoPerspective); } Some(crate::Interpolation::PerVertex) => { others.push(Decoration::PerVertexKHR); self.require_any( "`per_vertex` interpolation", &[spirv::Capability::FragmentBarycentricKHR], )?; self.use_extension("SPV_KHR_fragment_shader_barycentric"); } } match sampling { // Center sampling is the default in SPIR-V. None | Some( crate::Sampling::Center | crate::Sampling::First | crate::Sampling::Either, ) => (), Some(crate::Sampling::Centroid) => { others.push(Decoration::Centroid); } Some(crate::Sampling::Sample) => { self.require_any( "per-sample interpolation", &[spirv::Capability::SampleRateShading], )?; others.push(Decoration::Sample); } } } if per_primitive && stage == crate::ShaderStage::Fragment { others.push(Decoration::PerPrimitiveEXT); } Ok(BindingDecorations::Location { location, others, blend_src, }) } crate::Binding::BuiltIn(built_in) => { use crate::BuiltIn as Bi; let mut others = ArrayVec::new(); let built_in = match built_in { Bi::Position { invariant } => { if invariant { others.push(Decoration::Invariant); } if class == spirv::StorageClass::Output { BuiltIn::Position } else { BuiltIn::FragCoord } } Bi::ViewIndex => { self.require_any("`view_index` built-in", &[spirv::Capability::MultiView])?; BuiltIn::ViewIndex } // vertex Bi::BaseInstance => BuiltIn::BaseInstance, Bi::BaseVertex => BuiltIn::BaseVertex, Bi::ClipDistance => { self.require_any( "`clip_distance` built-in", &[spirv::Capability::ClipDistance], )?; BuiltIn::ClipDistance } Bi::CullDistance => { self.require_any( "`cull_distance` built-in", &[spirv::Capability::CullDistance], )?; BuiltIn::CullDistance } Bi::InstanceIndex => BuiltIn::InstanceIndex, Bi::PointSize => BuiltIn::PointSize, Bi::VertexIndex => BuiltIn::VertexIndex, Bi::DrawIndex => { self.use_extension("SPV_KHR_shader_draw_parameters"); self.require_any( "`draw_index built-in", &[spirv::Capability::DrawParameters], )?; BuiltIn::DrawIndex } // fragment Bi::FragDepth => BuiltIn::FragDepth, Bi::PointCoord => BuiltIn::PointCoord, Bi::FrontFacing => BuiltIn::FrontFacing, Bi::PrimitiveIndex => { // Geometry shader capability is required for primitive index self.require_any( "`primitive_index` built-in", &[spirv::Capability::Geometry], )?; if stage == crate::ShaderStage::Mesh { others.push(Decoration::PerPrimitiveEXT); } BuiltIn::PrimitiveId } Bi::Barycentric { perspective } => { self.require_any( "`barycentric` built-in", &[spirv::Capability::FragmentBarycentricKHR], )?; self.use_extension("SPV_KHR_fragment_shader_barycentric"); if perspective { BuiltIn::BaryCoordKHR } else { BuiltIn::BaryCoordNoPerspKHR } } Bi::SampleIndex => { self.require_any( "`sample_index` built-in", &[spirv::Capability::SampleRateShading], )?; BuiltIn::SampleId } Bi::SampleMask => BuiltIn::SampleMask, // compute Bi::GlobalInvocationId => BuiltIn::GlobalInvocationId, Bi::LocalInvocationId => BuiltIn::LocalInvocationId, Bi::LocalInvocationIndex => BuiltIn::LocalInvocationIndex, Bi::WorkGroupId => BuiltIn::WorkgroupId, Bi::WorkGroupSize => BuiltIn::WorkgroupSize, Bi::NumWorkGroups => BuiltIn::NumWorkgroups, // Subgroup Bi::NumSubgroups => { self.require_any( "`num_subgroups` built-in", &[spirv::Capability::GroupNonUniform], )?; BuiltIn::NumSubgroups } Bi::SubgroupId => { self.require_any( "`subgroup_id` built-in", &[spirv::Capability::GroupNonUniform], )?; BuiltIn::SubgroupId } Bi::SubgroupSize => { self.require_any( "`subgroup_size` built-in", &[ spirv::Capability::GroupNonUniform, spirv::Capability::SubgroupBallotKHR, ], )?; BuiltIn::SubgroupSize } Bi::SubgroupInvocationId => { self.require_any( "`subgroup_invocation_id` built-in", &[ spirv::Capability::GroupNonUniform, spirv::Capability::SubgroupBallotKHR, ], )?; BuiltIn::SubgroupLocalInvocationId } Bi::CullPrimitive => { others.push(Decoration::PerPrimitiveEXT); BuiltIn::CullPrimitiveEXT } Bi::PointIndex => BuiltIn::PrimitivePointIndicesEXT, Bi::LineIndices => BuiltIn::PrimitiveLineIndicesEXT, Bi::TriangleIndices => BuiltIn::PrimitiveTriangleIndicesEXT, // No decoration, this EmitMeshTasksEXT is called at function return Bi::MeshTaskSize => return Ok(BindingDecorations::None), // These aren't normal builtins and don't occur in function output Bi::VertexCount | Bi::Vertices | Bi::PrimitiveCount | Bi::Primitives => { unreachable!() } Bi::RayInvocationId | Bi::NumRayInvocations | Bi::InstanceCustomData | Bi::GeometryIndex | Bi::WorldRayOrigin | Bi::WorldRayDirection | Bi::ObjectRayOrigin | Bi::ObjectRayDirection | Bi::RayTmin | Bi::RayTCurrentMax | Bi::ObjectToWorld | Bi::WorldToObject | Bi::HitKind => unreachable!(), }; use crate::ScalarKind as Sk; // Per the Vulkan spec, `VUID-StandaloneSpirv-Flat-04744`: // // > Any variable with integer or double-precision floating- // > point type and with Input storage class in a fragment // > shader, must be decorated Flat if class == spirv::StorageClass::Input && stage == crate::ShaderStage::Fragment { let is_flat = match ir_module.types[ty].inner { crate::TypeInner::Scalar(scalar) | crate::TypeInner::Vector { scalar, .. } => match scalar.kind { Sk::Uint | Sk::Sint | Sk::Bool => true, Sk::Float => false, Sk::AbstractInt | Sk::AbstractFloat => { return Err(Error::Validation( "Abstract types should not appear in IR presented to backends", )) } }, _ => false, }; if is_flat { others.push(Decoration::Flat); } } Ok(BindingDecorations::BuiltIn(built_in, others)) } } } /// Load an IO variable, converting from `f32` to `f16` if polyfill is active. /// Returns the id of the loaded value matching `target_type_id`. pub(super) fn load_io_with_f16_polyfill( &mut self, body: &mut Vec, varying_id: Word, target_type_id: Word, ) -> Word { let tmp = self.id_gen.next(); if let Some(f32_ty) = self.io_f16_polyfills.get_f32_io_type(varying_id) { body.push(Instruction::load(f32_ty, tmp, varying_id, None)); let converted = self.id_gen.next(); super::f16_polyfill::F16IoPolyfill::emit_f32_to_f16_conversion( tmp, target_type_id, converted, body, ); converted } else { body.push(Instruction::load(target_type_id, tmp, varying_id, None)); tmp } } /// Store an IO variable, converting from `f16` to `f32` if polyfill is active. pub(super) fn store_io_with_f16_polyfill( &mut self, body: &mut Vec, varying_id: Word, value_id: Word, ) { if let Some(f32_ty) = self.io_f16_polyfills.get_f32_io_type(varying_id) { let converted = self.id_gen.next(); super::f16_polyfill::F16IoPolyfill::emit_f16_to_f32_conversion( value_id, f32_ty, converted, body, ); body.push(Instruction::store(varying_id, converted, None)); } else { body.push(Instruction::store(varying_id, value_id, None)); } } fn write_global_variable( &mut self, ir_module: &crate::Module, global_variable: &crate::GlobalVariable, ) -> Result { use spirv::Decoration; let id = self.id_gen.next(); let class = map_storage_class(global_variable.space); //self.check(class.required_capabilities())?; if global_variable .memory_decorations .contains(crate::MemoryDecorations::COHERENT) { self.decorate(id, Decoration::Coherent, &[]); } if global_variable .memory_decorations .contains(crate::MemoryDecorations::VOLATILE) { self.decorate(id, Decoration::Volatile, &[]); } if self.flags.contains(WriterFlags::DEBUG) { if let Some(ref name) = global_variable.name { self.debugs.push(Instruction::name(id, name)); } } let storage_access = match global_variable.space { crate::AddressSpace::Storage { access } => Some(access), _ => match ir_module.types[global_variable.ty].inner { crate::TypeInner::Image { class: crate::ImageClass::Storage { access, .. }, .. } => Some(access), _ => None, }, }; if let Some(storage_access) = storage_access { if !storage_access.contains(crate::StorageAccess::LOAD) { self.decorate(id, Decoration::NonReadable, &[]); } if !storage_access.contains(crate::StorageAccess::STORE) { self.decorate(id, Decoration::NonWritable, &[]); } } // Note: we should be able to substitute `binding_array`, // but there is still code that tries to register the pre-substituted type, // and it is failing on 0. let mut substitute_inner_type_lookup = None; if let Some(ref res_binding) = global_variable.binding { let bind_target = self.resolve_resource_binding(res_binding)?; self.decorate(id, Decoration::DescriptorSet, &[bind_target.descriptor_set]); self.decorate(id, Decoration::Binding, &[bind_target.binding]); if let Some(remapped_binding_array_size) = bind_target.binding_array_size { if let crate::TypeInner::BindingArray { base, .. } = ir_module.types[global_variable.ty].inner { let binding_array_type_id = self.get_type_id(LookupType::Local(LocalType::BindingArray { base, size: remapped_binding_array_size, })); substitute_inner_type_lookup = Some(LookupType::Local(LocalType::Pointer { base: binding_array_type_id, class, })); } } }; let init_word = global_variable .init .map(|constant| self.constant_ids[constant]); let inner_type_id = self.get_type_id( substitute_inner_type_lookup.unwrap_or(LookupType::Handle(global_variable.ty)), ); // generate the wrapping structure if needed let pointer_type_id = if global_needs_wrapper(ir_module, global_variable) { let wrapper_type_id = self.id_gen.next(); self.decorate(wrapper_type_id, Decoration::Block, &[]); match self.std140_compat_uniform_types.get(&global_variable.ty) { Some(std140_type_info) if global_variable.space == crate::AddressSpace::Uniform => { self.annotations.push(Instruction::member_decorate( wrapper_type_id, 0, Decoration::Offset, &[0], )); Instruction::type_struct(wrapper_type_id, &[std140_type_info.type_id]) .to_words(&mut self.logical_layout.declarations); } _ => { let member = crate::StructMember { name: None, ty: global_variable.ty, binding: None, offset: 0, }; self.decorate_struct_member(wrapper_type_id, 0, &member, &ir_module.types)?; Instruction::type_struct(wrapper_type_id, &[inner_type_id]) .to_words(&mut self.logical_layout.declarations); } } let pointer_type_id = self.id_gen.next(); Instruction::type_pointer(pointer_type_id, class, wrapper_type_id) .to_words(&mut self.logical_layout.declarations); pointer_type_id } else { // This is a global variable in the Storage address space. The only // way it could have `global_needs_wrapper() == false` is if it has // a runtime-sized or binding array. // Runtime-sized arrays were decorated when iterating through struct content. // Now binding arrays require Block decorating. if let crate::AddressSpace::Storage { .. } = global_variable.space { match ir_module.types[global_variable.ty].inner { crate::TypeInner::BindingArray { base, .. } => { let ty = &ir_module.types[base]; let mut should_decorate = true; // Check if the type has a runtime array. // A normal runtime array gets validated out, // so only structs can be with runtime arrays if let crate::TypeInner::Struct { ref members, .. } = ty.inner { // only the last member in a struct can be dynamically sized if let Some(last_member) = members.last() { if let &crate::TypeInner::Array { size: crate::ArraySize::Dynamic, .. } = &ir_module.types[last_member.ty].inner { should_decorate = false; } } } if should_decorate { let decorated_id = self.get_handle_type_id(base); self.decorate(decorated_id, Decoration::Block, &[]); } } _ => (), }; } if substitute_inner_type_lookup.is_some() { inner_type_id } else { self.get_handle_pointer_type_id(global_variable.ty, class) } }; let init_word = match (global_variable.space, self.zero_initialize_workgroup_memory) { (crate::AddressSpace::Private, _) | (crate::AddressSpace::WorkGroup, super::ZeroInitializeWorkgroupMemoryMode::Native) => { init_word.or_else(|| Some(self.get_constant_null(inner_type_id))) } _ => init_word, }; Instruction::variable(pointer_type_id, id, class, init_word) .to_words(&mut self.logical_layout.declarations); Ok(id) } /// Write the necessary decorations for a struct member. /// /// Emit decorations for the `index`'th member of the struct type /// designated by `struct_id`, described by `member`. fn decorate_struct_member( &mut self, struct_id: Word, index: usize, member: &crate::StructMember, arena: &UniqueArena, ) -> Result<(), Error> { use spirv::Decoration; self.annotations.push(Instruction::member_decorate( struct_id, index as u32, Decoration::Offset, &[member.offset], )); if self.flags.contains(WriterFlags::DEBUG) { if let Some(ref name) = member.name { self.debugs .push(Instruction::member_name(struct_id, index as u32, name)); } } // Matrices and (potentially nested) arrays of matrices both require decorations, // so "see through" any arrays to determine if they're needed. let mut member_array_subty_inner = &arena[member.ty].inner; while let crate::TypeInner::Array { base, .. } = *member_array_subty_inner { member_array_subty_inner = &arena[base].inner; } if let crate::TypeInner::Matrix { columns: _, rows, scalar, } = *member_array_subty_inner { let byte_stride = Alignment::from(rows) * scalar.width as u32; self.annotations.push(Instruction::member_decorate( struct_id, index as u32, Decoration::ColMajor, &[], )); self.annotations.push(Instruction::member_decorate( struct_id, index as u32, Decoration::MatrixStride, &[byte_stride], )); } Ok(()) } pub(super) fn get_function_type(&mut self, lookup_function_type: LookupFunctionType) -> Word { match self .lookup_function_type .entry(lookup_function_type.clone()) { Entry::Occupied(e) => *e.get(), Entry::Vacant(_) => { let id = self.id_gen.next(); let instruction = Instruction::type_function( id, lookup_function_type.return_type_id, &lookup_function_type.parameter_type_ids, ); instruction.to_words(&mut self.logical_layout.declarations); self.lookup_function_type.insert(lookup_function_type, id); id } } } const fn write_physical_layout(&mut self) { self.physical_layout.bound = self.id_gen.0 + 1; } fn write_logical_layout( &mut self, ir_module: &crate::Module, mod_info: &ModuleInfo, ep_index: Option, debug_info: &Option, ) -> Result<(), Error> { fn has_view_index_check( ir_module: &crate::Module, binding: Option<&crate::Binding>, ty: Handle, ) -> bool { match ir_module.types[ty].inner { crate::TypeInner::Struct { ref members, .. } => members.iter().any(|member| { has_view_index_check(ir_module, member.binding.as_ref(), member.ty) }), _ => binding == Some(&crate::Binding::BuiltIn(crate::BuiltIn::ViewIndex)), } } let has_storage_buffers = ir_module .global_variables .iter() .any(|(_, var)| match var.space { crate::AddressSpace::Storage { .. } => true, _ => false, }); let has_view_index = ir_module .entry_points .iter() .flat_map(|entry| entry.function.arguments.iter()) .any(|arg| has_view_index_check(ir_module, arg.binding.as_ref(), arg.ty)); let mut has_ray_query = ir_module.special_types.ray_desc.is_some() | ir_module.special_types.ray_intersection.is_some(); let has_vertex_return = ir_module.special_types.ray_vertex_return.is_some(); for (_, &crate::Type { ref inner, .. }) in ir_module.types.iter() { // spirv does not know whether these have vertex return - that is done by us if let &crate::TypeInner::AccelerationStructure { .. } | &crate::TypeInner::RayQuery { .. } = inner { has_ray_query = true } } if self.physical_layout.version < 0x10300 && has_storage_buffers { // enable the storage buffer class on < SPV-1.3 Instruction::extension("SPV_KHR_storage_buffer_storage_class") .to_words(&mut self.logical_layout.extensions); } if has_view_index { Instruction::extension("SPV_KHR_multiview") .to_words(&mut self.logical_layout.extensions) } if has_ray_query { Instruction::extension("SPV_KHR_ray_query") .to_words(&mut self.logical_layout.extensions) } if has_vertex_return { Instruction::extension("SPV_KHR_ray_tracing_position_fetch") .to_words(&mut self.logical_layout.extensions); } if ir_module.uses_mesh_shaders() { self.use_extension("SPV_EXT_mesh_shader"); self.require_any("Mesh Shaders", &[spirv::Capability::MeshShadingEXT])?; let lang_version = self.lang_version(); if lang_version.0 <= 1 && lang_version.1 < 4 { return Err(Error::SpirvVersionTooLow(1, 4)); } } Instruction::type_void(self.void_type).to_words(&mut self.logical_layout.declarations); Instruction::ext_inst_import(self.gl450_ext_inst_id, "GLSL.std.450") .to_words(&mut self.logical_layout.ext_inst_imports); let mut debug_info_inner = None; if self.flags.contains(WriterFlags::DEBUG) { if let Some(debug_info) = debug_info.as_ref() { let source_file_id = self.id_gen.next(); self.debugs .push(Instruction::string(debug_info.file_name, source_file_id)); debug_info_inner = Some(DebugInfoInner { source_code: debug_info.source_code, source_file_id, }); self.debugs.append(&mut Instruction::source_auto_continued( debug_info.language, 0, &debug_info_inner, )); } } // write all types for (handle, _) in ir_module.types.iter() { self.write_type_declaration_arena(ir_module, handle)?; } // write std140 layout compatible types required by uniforms for (_, var) in ir_module.global_variables.iter() { if var.space == crate::AddressSpace::Uniform { self.write_std140_compat_type_declaration(ir_module, var.ty)?; } } // write all const-expressions as constants self.constant_ids .resize(ir_module.global_expressions.len(), 0); for (handle, _) in ir_module.global_expressions.iter() { self.write_constant_expr(handle, ir_module, mod_info)?; } debug_assert!(self.constant_ids.iter().all(|&id| id != 0)); // write the name of constants on their respective const-expression initializer if self.flags.contains(WriterFlags::DEBUG) { for (_, constant) in ir_module.constants.iter() { if let Some(ref name) = constant.name { let id = self.constant_ids[constant.init]; self.debugs.push(Instruction::name(id, name)); } } } // write all global variables for (handle, var) in ir_module.global_variables.iter() { // If a single entry point was specified, only write `OpVariable` instructions // for the globals it actually uses. Emit dummies for the others, // to preserve the indices in `global_variables`. let gvar = match ep_index { Some(index) if mod_info.get_entry_point(index)[handle].is_empty() => { GlobalVariable::dummy() } _ => { let id = self.write_global_variable(ir_module, var)?; GlobalVariable::new(id) } }; self.global_variables.insert(handle, gvar); } // write all functions for (handle, ir_function) in ir_module.functions.iter() { let info = &mod_info[handle]; if let Some(index) = ep_index { let ep_info = mod_info.get_entry_point(index); // If this function uses globals that we omitted from the SPIR-V // because the entry point and its callees didn't use them, // then we must skip it. if !ep_info.dominates_global_use(info) { log::debug!("Skip function {:?}", ir_function.name); continue; } // Skip functions that that are not compatible with this entry point's stage. // // When validation is enabled, it rejects modules whose entry points try to call // incompatible functions, so if we got this far, then any functions incompatible // with our selected entry point must not be used. // // When validation is disabled, `fun_info.available_stages` is always just // `ShaderStages::all()`, so this will write all functions in the module, and // the downstream GLSL compiler will catch any problems. if !info.available_stages.contains(ep_info.available_stages) { continue; } } let id = self.write_function(ir_function, info, ir_module, None, &debug_info_inner)?; self.lookup_function.insert(handle, id); } // write all or one entry points for (index, ir_ep) in ir_module.entry_points.iter().enumerate() { if ep_index.is_some() && ep_index != Some(index) { continue; } let info = mod_info.get_entry_point(index); let ep_instruction = self.write_entry_point(ir_ep, info, ir_module, &debug_info_inner)?; ep_instruction.to_words(&mut self.logical_layout.entry_points); } for capability in self.capabilities_used.iter() { Instruction::capability(*capability).to_words(&mut self.logical_layout.capabilities); } for extension in self.extensions_used.iter() { Instruction::extension(extension).to_words(&mut self.logical_layout.extensions); } if ir_module.entry_points.is_empty() { // SPIR-V doesn't like modules without entry points Instruction::capability(spirv::Capability::Linkage) .to_words(&mut self.logical_layout.capabilities); } let addressing_model = spirv::AddressingModel::Logical; let memory_model = if self .capabilities_used .contains(&spirv::Capability::VulkanMemoryModel) { spirv::MemoryModel::Vulkan } else { spirv::MemoryModel::GLSL450 }; //self.check(addressing_model.required_capabilities())?; //self.check(memory_model.required_capabilities())?; Instruction::memory_model(addressing_model, memory_model) .to_words(&mut self.logical_layout.memory_model); for debug_string in self.debug_strings.iter() { debug_string.to_words(&mut self.logical_layout.debugs); } if self.flags.contains(WriterFlags::DEBUG) { for debug in self.debugs.iter() { debug.to_words(&mut self.logical_layout.debugs); } } for annotation in self.annotations.iter() { annotation.to_words(&mut self.logical_layout.annotations); } Ok(()) } pub fn write( &mut self, ir_module: &crate::Module, info: &ModuleInfo, pipeline_options: Option<&PipelineOptions>, debug_info: &Option, words: &mut Vec, ) -> Result<(), Error> { self.reset(); // Try to find the entry point and corresponding index let ep_index = match pipeline_options { Some(po) => { let index = ir_module .entry_points .iter() .position(|ep| po.shader_stage == ep.stage && po.entry_point == ep.name) .ok_or(Error::EntryPointNotFound)?; Some(index) } None => None, }; self.write_logical_layout(ir_module, info, ep_index, debug_info)?; self.write_physical_layout(); self.physical_layout.in_words(words); self.logical_layout.in_words(words); Ok(()) } /// Return the set of capabilities the last module written used. pub const fn get_capabilities_used(&self) -> &crate::FastIndexSet { &self.capabilities_used } pub fn decorate_non_uniform_binding_array_access(&mut self, id: Word) -> Result<(), Error> { self.require_any("NonUniformEXT", &[spirv::Capability::ShaderNonUniform])?; self.use_extension("SPV_EXT_descriptor_indexing"); self.decorate(id, spirv::Decoration::NonUniform, &[]); Ok(()) } pub(super) fn needs_f16_polyfill(&self, ty_inner: &crate::TypeInner) -> bool { self.io_f16_polyfills.needs_polyfill(ty_inner) } pub(super) fn write_debug_printf( &mut self, block: &mut Block, string: &str, format_params: &[Word], ) { if self.debug_printf.is_none() { self.use_extension("SPV_KHR_non_semantic_info"); let import_id = self.id_gen.next(); Instruction::ext_inst_import(import_id, "NonSemantic.DebugPrintf") .to_words(&mut self.logical_layout.ext_inst_imports); self.debug_printf = Some(import_id) } let import_id = self.debug_printf.unwrap(); let string_id = self.id_gen.next(); self.debug_strings .push(Instruction::string(string, string_id)); let mut operands = Vec::with_capacity(1 + format_params.len()); operands.push(string_id); operands.extend(format_params.iter()); let print_id = self.id_gen.next(); block.body.push(Instruction::ext_inst( import_id, 1, self.void_type, print_id, &operands, )); } } #[test] fn test_write_physical_layout() { let mut writer = Writer::new(&Options::default()).unwrap(); assert_eq!(writer.physical_layout.bound, 0); writer.write_physical_layout(); assert_eq!(writer.physical_layout.bound, 3); } naga-29.0.3/src/back/wgsl/mod.rs000064400000000000000000000050051046102023000144150ustar 00000000000000/*! Backend for [WGSL][wgsl] (WebGPU Shading Language). [wgsl]: https://gpuweb.github.io/gpuweb/wgsl.html */ mod polyfill; mod writer; use alloc::format; use alloc::string::String; use thiserror::Error; pub use writer::{Writer, WriterFlags}; use crate::common::wgsl; #[derive(Error, Debug)] pub enum Error { #[error(transparent)] FmtError(#[from] core::fmt::Error), #[error("{0}")] Custom(String), #[error("{0}")] Unimplemented(String), // TODO: Error used only during development #[error("Unsupported relational function: {0:?}")] UnsupportedRelationalFunction(crate::RelationalFunction), #[error("Unsupported {kind}: {value}")] Unsupported { /// What kind of unsupported thing this is: interpolation, builtin, etc. kind: &'static str, /// The debug form of the Naga IR value that this backend can't express. value: String, }, } impl Error { /// Produce an [`Unsupported`] error for `value`. /// /// [`Unsupported`]: Error::Unsupported fn unsupported(kind: &'static str, value: T) -> Error { Error::Unsupported { kind, value: format!("{value:?}"), } } } trait ToWgslIfImplemented { fn to_wgsl_if_implemented(self) -> Result<&'static str, Error>; } impl ToWgslIfImplemented for T where T: wgsl::TryToWgsl + core::fmt::Debug + Copy, { fn to_wgsl_if_implemented(self) -> Result<&'static str, Error> { self.try_to_wgsl() .ok_or_else(|| Error::unsupported(T::DESCRIPTION, self)) } } pub fn write_string( module: &crate::Module, info: &crate::valid::ModuleInfo, flags: WriterFlags, ) -> Result { let mut w = Writer::new(String::new(), flags); w.write(module, info)?; let output = w.finish(); Ok(output) } impl crate::AtomicFunction { const fn to_wgsl(self) -> &'static str { match self { Self::Add => "Add", Self::Subtract => "Sub", Self::And => "And", Self::InclusiveOr => "Or", Self::ExclusiveOr => "Xor", Self::Min => "Min", Self::Max => "Max", Self::Exchange { compare: None } => "Exchange", Self::Exchange { .. } => "CompareExchangeWeak", } } } pub const fn supported_capabilities() -> crate::valid::Capabilities { // WGSL regurgitation supports almost everything, though browser webgpu can't parse most of these. use crate::valid::Capabilities as Caps; Caps::all() } naga-29.0.3/src/back/wgsl/polyfill/inverse/inverse_2x2_f16.wgsl000064400000000000000000000004221046102023000223130ustar 00000000000000fn _naga_inverse_2x2_f16(m: mat2x2) -> mat2x2 { var adj: mat2x2; adj[0][0] = m[1][1]; adj[0][1] = -m[0][1]; adj[1][0] = -m[1][0]; adj[1][1] = m[0][0]; let det: f16 = m[0][0] * m[1][1] - m[1][0] * m[0][1]; return adj * (1 / det); }naga-29.0.3/src/back/wgsl/polyfill/inverse/inverse_2x2_f32.wgsl000064400000000000000000000004221046102023000223110ustar 00000000000000fn _naga_inverse_2x2_f32(m: mat2x2) -> mat2x2 { var adj: mat2x2; adj[0][0] = m[1][1]; adj[0][1] = -m[0][1]; adj[1][0] = -m[1][0]; adj[1][1] = m[0][0]; let det: f32 = m[0][0] * m[1][1] - m[1][0] * m[0][1]; return adj * (1 / det); }naga-29.0.3/src/back/wgsl/polyfill/inverse/inverse_3x3_f16.wgsl000064400000000000000000000015031046102023000223160ustar 00000000000000fn _naga_inverse_3x3_f16(m: mat3x3) -> mat3x3 { var adj: mat3x3; adj[0][0] = (m[1][1] * m[2][2] - m[2][1] * m[1][2]); adj[1][0] = - (m[1][0] * m[2][2] - m[2][0] * m[1][2]); adj[2][0] = (m[1][0] * m[2][1] - m[2][0] * m[1][1]); adj[0][1] = - (m[0][1] * m[2][2] - m[2][1] * m[0][2]); adj[1][1] = (m[0][0] * m[2][2] - m[2][0] * m[0][2]); adj[2][1] = - (m[0][0] * m[2][1] - m[2][0] * m[0][1]); adj[0][2] = (m[0][1] * m[1][2] - m[1][1] * m[0][2]); adj[1][2] = - (m[0][0] * m[1][2] - m[1][0] * m[0][2]); adj[2][2] = (m[0][0] * m[1][1] - m[1][0] * m[0][1]); let det: f16 = (m[0][0] * (m[1][1] * m[2][2] - m[1][2] * m[2][1]) - m[0][1] * (m[1][0] * m[2][2] - m[1][2] * m[2][0]) + m[0][2] * (m[1][0] * m[2][1] - m[1][1] * m[2][0])); return adj * (1 / det); }naga-29.0.3/src/back/wgsl/polyfill/inverse/inverse_3x3_f32.wgsl000064400000000000000000000015031046102023000223140ustar 00000000000000fn _naga_inverse_3x3_f32(m: mat3x3) -> mat3x3 { var adj: mat3x3; adj[0][0] = (m[1][1] * m[2][2] - m[2][1] * m[1][2]); adj[1][0] = - (m[1][0] * m[2][2] - m[2][0] * m[1][2]); adj[2][0] = (m[1][0] * m[2][1] - m[2][0] * m[1][1]); adj[0][1] = - (m[0][1] * m[2][2] - m[2][1] * m[0][2]); adj[1][1] = (m[0][0] * m[2][2] - m[2][0] * m[0][2]); adj[2][1] = - (m[0][0] * m[2][1] - m[2][0] * m[0][1]); adj[0][2] = (m[0][1] * m[1][2] - m[1][1] * m[0][2]); adj[1][2] = - (m[0][0] * m[1][2] - m[1][0] * m[0][2]); adj[2][2] = (m[0][0] * m[1][1] - m[1][0] * m[0][1]); let det: f32 = (m[0][0] * (m[1][1] * m[2][2] - m[1][2] * m[2][1]) - m[0][1] * (m[1][0] * m[2][2] - m[1][2] * m[2][0]) + m[0][2] * (m[1][0] * m[2][1] - m[1][1] * m[2][0])); return adj * (1 / det); }naga-29.0.3/src/back/wgsl/polyfill/inverse/inverse_4x4_f16.wgsl000064400000000000000000000056161046102023000223310ustar 00000000000000fn _naga_inverse_4x4_f16(m: mat4x4) -> mat4x4 { let sub_factor00: f16 = m[2][2] * m[3][3] - m[3][2] * m[2][3]; let sub_factor01: f16 = m[2][1] * m[3][3] - m[3][1] * m[2][3]; let sub_factor02: f16 = m[2][1] * m[3][2] - m[3][1] * m[2][2]; let sub_factor03: f16 = m[2][0] * m[3][3] - m[3][0] * m[2][3]; let sub_factor04: f16 = m[2][0] * m[3][2] - m[3][0] * m[2][2]; let sub_factor05: f16 = m[2][0] * m[3][1] - m[3][0] * m[2][1]; let sub_factor06: f16 = m[1][2] * m[3][3] - m[3][2] * m[1][3]; let sub_factor07: f16 = m[1][1] * m[3][3] - m[3][1] * m[1][3]; let sub_factor08: f16 = m[1][1] * m[3][2] - m[3][1] * m[1][2]; let sub_factor09: f16 = m[1][0] * m[3][3] - m[3][0] * m[1][3]; let sub_factor10: f16 = m[1][0] * m[3][2] - m[3][0] * m[1][2]; let sub_factor11: f16 = m[1][1] * m[3][3] - m[3][1] * m[1][3]; let sub_factor12: f16 = m[1][0] * m[3][1] - m[3][0] * m[1][1]; let sub_factor13: f16 = m[1][2] * m[2][3] - m[2][2] * m[1][3]; let sub_factor14: f16 = m[1][1] * m[2][3] - m[2][1] * m[1][3]; let sub_factor15: f16 = m[1][1] * m[2][2] - m[2][1] * m[1][2]; let sub_factor16: f16 = m[1][0] * m[2][3] - m[2][0] * m[1][3]; let sub_factor17: f16 = m[1][0] * m[2][2] - m[2][0] * m[1][2]; let sub_factor18: f16 = m[1][0] * m[2][1] - m[2][0] * m[1][1]; var adj: mat4x4; adj[0][0] = (m[1][1] * sub_factor00 - m[1][2] * sub_factor01 + m[1][3] * sub_factor02); adj[1][0] = - (m[1][0] * sub_factor00 - m[1][2] * sub_factor03 + m[1][3] * sub_factor04); adj[2][0] = (m[1][0] * sub_factor01 - m[1][1] * sub_factor03 + m[1][3] * sub_factor05); adj[3][0] = - (m[1][0] * sub_factor02 - m[1][1] * sub_factor04 + m[1][2] * sub_factor05); adj[0][1] = - (m[0][1] * sub_factor00 - m[0][2] * sub_factor01 + m[0][3] * sub_factor02); adj[1][1] = (m[0][0] * sub_factor00 - m[0][2] * sub_factor03 + m[0][3] * sub_factor04); adj[2][1] = - (m[0][0] * sub_factor01 - m[0][1] * sub_factor03 + m[0][3] * sub_factor05); adj[3][1] = (m[0][0] * sub_factor02 - m[0][1] * sub_factor04 + m[0][2] * sub_factor05); adj[0][2] = (m[0][1] * sub_factor06 - m[0][2] * sub_factor07 + m[0][3] * sub_factor08); adj[1][2] = - (m[0][0] * sub_factor06 - m[0][2] * sub_factor09 + m[0][3] * sub_factor10); adj[2][2] = (m[0][0] * sub_factor11 - m[0][1] * sub_factor09 + m[0][3] * sub_factor12); adj[3][2] = - (m[0][0] * sub_factor08 - m[0][1] * sub_factor10 + m[0][2] * sub_factor12); adj[0][3] = - (m[0][1] * sub_factor13 - m[0][2] * sub_factor14 + m[0][3] * sub_factor15); adj[1][3] = (m[0][0] * sub_factor13 - m[0][2] * sub_factor16 + m[0][3] * sub_factor17); adj[2][3] = - (m[0][0] * sub_factor14 - m[0][1] * sub_factor16 + m[0][3] * sub_factor18); adj[3][3] = (m[0][0] * sub_factor15 - m[0][1] * sub_factor17 + m[0][2] * sub_factor18); let det = (m[0][0] * adj[0][0] + m[0][1] * adj[1][0] + m[0][2] * adj[2][0] + m[0][3] * adj[3][0]); return adj * (1 / det); }naga-29.0.3/src/back/wgsl/polyfill/inverse/inverse_4x4_f32.wgsl000064400000000000000000000056161046102023000223270ustar 00000000000000fn _naga_inverse_4x4_f32(m: mat4x4) -> mat4x4 { let sub_factor00: f32 = m[2][2] * m[3][3] - m[3][2] * m[2][3]; let sub_factor01: f32 = m[2][1] * m[3][3] - m[3][1] * m[2][3]; let sub_factor02: f32 = m[2][1] * m[3][2] - m[3][1] * m[2][2]; let sub_factor03: f32 = m[2][0] * m[3][3] - m[3][0] * m[2][3]; let sub_factor04: f32 = m[2][0] * m[3][2] - m[3][0] * m[2][2]; let sub_factor05: f32 = m[2][0] * m[3][1] - m[3][0] * m[2][1]; let sub_factor06: f32 = m[1][2] * m[3][3] - m[3][2] * m[1][3]; let sub_factor07: f32 = m[1][1] * m[3][3] - m[3][1] * m[1][3]; let sub_factor08: f32 = m[1][1] * m[3][2] - m[3][1] * m[1][2]; let sub_factor09: f32 = m[1][0] * m[3][3] - m[3][0] * m[1][3]; let sub_factor10: f32 = m[1][0] * m[3][2] - m[3][0] * m[1][2]; let sub_factor11: f32 = m[1][1] * m[3][3] - m[3][1] * m[1][3]; let sub_factor12: f32 = m[1][0] * m[3][1] - m[3][0] * m[1][1]; let sub_factor13: f32 = m[1][2] * m[2][3] - m[2][2] * m[1][3]; let sub_factor14: f32 = m[1][1] * m[2][3] - m[2][1] * m[1][3]; let sub_factor15: f32 = m[1][1] * m[2][2] - m[2][1] * m[1][2]; let sub_factor16: f32 = m[1][0] * m[2][3] - m[2][0] * m[1][3]; let sub_factor17: f32 = m[1][0] * m[2][2] - m[2][0] * m[1][2]; let sub_factor18: f32 = m[1][0] * m[2][1] - m[2][0] * m[1][1]; var adj: mat4x4; adj[0][0] = (m[1][1] * sub_factor00 - m[1][2] * sub_factor01 + m[1][3] * sub_factor02); adj[1][0] = - (m[1][0] * sub_factor00 - m[1][2] * sub_factor03 + m[1][3] * sub_factor04); adj[2][0] = (m[1][0] * sub_factor01 - m[1][1] * sub_factor03 + m[1][3] * sub_factor05); adj[3][0] = - (m[1][0] * sub_factor02 - m[1][1] * sub_factor04 + m[1][2] * sub_factor05); adj[0][1] = - (m[0][1] * sub_factor00 - m[0][2] * sub_factor01 + m[0][3] * sub_factor02); adj[1][1] = (m[0][0] * sub_factor00 - m[0][2] * sub_factor03 + m[0][3] * sub_factor04); adj[2][1] = - (m[0][0] * sub_factor01 - m[0][1] * sub_factor03 + m[0][3] * sub_factor05); adj[3][1] = (m[0][0] * sub_factor02 - m[0][1] * sub_factor04 + m[0][2] * sub_factor05); adj[0][2] = (m[0][1] * sub_factor06 - m[0][2] * sub_factor07 + m[0][3] * sub_factor08); adj[1][2] = - (m[0][0] * sub_factor06 - m[0][2] * sub_factor09 + m[0][3] * sub_factor10); adj[2][2] = (m[0][0] * sub_factor11 - m[0][1] * sub_factor09 + m[0][3] * sub_factor12); adj[3][2] = - (m[0][0] * sub_factor08 - m[0][1] * sub_factor10 + m[0][2] * sub_factor12); adj[0][3] = - (m[0][1] * sub_factor13 - m[0][2] * sub_factor14 + m[0][3] * sub_factor15); adj[1][3] = (m[0][0] * sub_factor13 - m[0][2] * sub_factor16 + m[0][3] * sub_factor17); adj[2][3] = - (m[0][0] * sub_factor14 - m[0][1] * sub_factor16 + m[0][3] * sub_factor18); adj[3][3] = (m[0][0] * sub_factor15 - m[0][1] * sub_factor17 + m[0][2] * sub_factor18); let det = (m[0][0] * adj[0][0] + m[0][1] * adj[1][0] + m[0][2] * adj[2][0] + m[0][3] * adj[3][0]); return adj * (1 / det); }naga-29.0.3/src/back/wgsl/polyfill/mod.rs000064400000000000000000000044201046102023000162470ustar 00000000000000use crate::{ScalarKind, TypeInner, VectorSize}; #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub struct InversePolyfill { pub fun_name: &'static str, pub source: &'static str, } impl InversePolyfill { pub fn find_overload(ty: &TypeInner) -> Option { let &TypeInner::Matrix { columns, rows, scalar, } = ty else { return None; }; if columns != rows || scalar.kind != ScalarKind::Float { return None; }; Self::polyfill_overload(columns, scalar.width) } const fn polyfill_overload( dimension: VectorSize, width: crate::Bytes, ) -> Option { const INVERSE_2X2_F32: &str = include_str!("inverse/inverse_2x2_f32.wgsl"); const INVERSE_3X3_F32: &str = include_str!("inverse/inverse_3x3_f32.wgsl"); const INVERSE_4X4_F32: &str = include_str!("inverse/inverse_4x4_f32.wgsl"); const INVERSE_2X2_F16: &str = include_str!("inverse/inverse_2x2_f16.wgsl"); const INVERSE_3X3_F16: &str = include_str!("inverse/inverse_3x3_f16.wgsl"); const INVERSE_4X4_F16: &str = include_str!("inverse/inverse_4x4_f16.wgsl"); match (dimension, width) { (VectorSize::Bi, 4) => Some(InversePolyfill { fun_name: "_naga_inverse_2x2_f32", source: INVERSE_2X2_F32, }), (VectorSize::Tri, 4) => Some(InversePolyfill { fun_name: "_naga_inverse_3x3_f32", source: INVERSE_3X3_F32, }), (VectorSize::Quad, 4) => Some(InversePolyfill { fun_name: "_naga_inverse_4x4_f32", source: INVERSE_4X4_F32, }), (VectorSize::Bi, 2) => Some(InversePolyfill { fun_name: "_naga_inverse_2x2_f16", source: INVERSE_2X2_F16, }), (VectorSize::Tri, 2) => Some(InversePolyfill { fun_name: "_naga_inverse_3x3_f16", source: INVERSE_3X3_F16, }), (VectorSize::Quad, 2) => Some(InversePolyfill { fun_name: "_naga_inverse_4x4_f16", source: INVERSE_4X4_F16, }), _ => None, } } } naga-29.0.3/src/back/wgsl/writer.rs000064400000000000000000002456411046102023000151660ustar 00000000000000use alloc::{ format, string::{String, ToString}, vec, vec::Vec, }; use core::fmt::Write; use super::Error; use super::ToWgslIfImplemented as _; use crate::{back::wgsl::polyfill::InversePolyfill, common::wgsl::TypeContext}; use crate::{ back::{self, Baked}, common::{ self, wgsl::{address_space_str, ToWgsl, TryToWgsl}, }, proc::{self, NameKey}, valid, Handle, Module, ShaderStage, TypeInner, }; /// Shorthand result used internally by the backend type BackendResult = Result<(), Error>; /// WGSL [attribute](https://gpuweb.github.io/gpuweb/wgsl/#attributes) enum Attribute { Binding(u32), BuiltIn(crate::BuiltIn), Group(u32), Invariant, Interpolate(Option, Option), Location(u32), BlendSrc(u32), Stage(ShaderStage), WorkGroupSize([u32; 3]), MeshStage(String), TaskPayload(String), PerPrimitive, IncomingRayPayload(String), } /// The WGSL form that `write_expr_with_indirection` should use to render a Naga /// expression. /// /// Sometimes a Naga `Expression` alone doesn't provide enough information to /// choose the right rendering for it in WGSL. For example, one natural WGSL /// rendering of a Naga `LocalVariable(x)` expression might be `&x`, since /// `LocalVariable` produces a pointer to the local variable's storage. But when /// rendering a `Store` statement, the `pointer` operand must be the left hand /// side of a WGSL assignment, so the proper rendering is `x`. /// /// The caller of `write_expr_with_indirection` must provide an `Expected` value /// to indicate how ambiguous expressions should be rendered. #[derive(Clone, Copy, Debug)] enum Indirection { /// Render pointer-construction expressions as WGSL `ptr`-typed expressions. /// /// This is the right choice for most cases. Whenever a Naga pointer /// expression is not the `pointer` operand of a `Load` or `Store`, it /// must be a WGSL pointer expression. Ordinary, /// Render pointer-construction expressions as WGSL reference-typed /// expressions. /// /// For example, this is the right choice for the `pointer` operand when /// rendering a `Store` statement as a WGSL assignment. Reference, } bitflags::bitflags! { #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub struct WriterFlags: u32 { /// Always annotate the type information instead of inferring. const EXPLICIT_TYPES = 0x1; } } pub struct Writer { out: W, flags: WriterFlags, names: crate::FastHashMap, namer: proc::Namer, named_expressions: crate::NamedExpressions, required_polyfills: crate::FastIndexSet, } impl Writer { pub fn new(out: W, flags: WriterFlags) -> Self { Writer { out, flags, names: crate::FastHashMap::default(), namer: proc::Namer::default(), named_expressions: crate::NamedExpressions::default(), required_polyfills: crate::FastIndexSet::default(), } } fn reset(&mut self, module: &Module) { self.names.clear(); self.namer.reset( module, &crate::keywords::wgsl::RESERVED_SET, &crate::keywords::wgsl::BUILTIN_IDENTIFIER_SET, // an identifier must not start with two underscore proc::CaseInsensitiveKeywordSet::empty(), &["__", "_naga"], &mut self.names, ); self.named_expressions.clear(); self.required_polyfills.clear(); } /// Determine if `ty` is the Naga IR presentation of a WGSL builtin type. /// /// Return true if `ty` refers to the Naga IR form of a WGSL builtin type /// like `__atomic_compare_exchange_result`. /// /// Even though the module may use the type, the WGSL backend should avoid /// emitting a definition for it, since it is [predeclared] in WGSL. /// /// This also covers types like [`NagaExternalTextureParams`], which other /// backends use to lower WGSL constructs like external textures to their /// implementations. WGSL can express these directly, so the types need not /// be emitted. /// /// [predeclared]: https://www.w3.org/TR/WGSL/#predeclared /// [`NagaExternalTextureParams`]: crate::ir::SpecialTypes::external_texture_params fn is_builtin_wgsl_struct(&self, module: &Module, ty: Handle) -> bool { module .special_types .predeclared_types .values() .any(|t| *t == ty) || Some(ty) == module.special_types.external_texture_params || Some(ty) == module.special_types.external_texture_transfer_function } pub fn write(&mut self, module: &Module, info: &valid::ModuleInfo) -> BackendResult { self.reset(module); // Write all `enable` declarations self.write_enable_declarations(module)?; // Write all structs for (handle, ty) in module.types.iter() { if let TypeInner::Struct { ref members, .. } = ty.inner { { if !self.is_builtin_wgsl_struct(module, handle) { self.write_struct(module, handle, members)?; writeln!(self.out)?; } } } } // Write all named constants let mut constants = module .constants .iter() .filter(|&(_, c)| c.name.is_some()) .peekable(); while let Some((handle, _)) = constants.next() { self.write_global_constant(module, handle)?; // Add extra newline for readability on last iteration if constants.peek().is_none() { writeln!(self.out)?; } } // Write all overrides let mut overrides = module.overrides.iter().peekable(); while let Some((handle, _)) = overrides.next() { self.write_override(module, handle)?; // Add extra newline for readability on last iteration if overrides.peek().is_none() { writeln!(self.out)?; } } // Write all globals for (ty, global) in module.global_variables.iter() { self.write_global(module, global, ty)?; } if !module.global_variables.is_empty() { // Add extra newline for readability writeln!(self.out)?; } // Write all regular functions for (handle, function) in module.functions.iter() { let fun_info = &info[handle]; let func_ctx = back::FunctionCtx { ty: back::FunctionType::Function(handle), info: fun_info, expressions: &function.expressions, named_expressions: &function.named_expressions, }; // Write the function self.write_function(module, function, &func_ctx)?; writeln!(self.out)?; } // Write all entry points for (index, ep) in module.entry_points.iter().enumerate() { let attributes = match ep.stage { ShaderStage::Vertex | ShaderStage::Fragment => vec![Attribute::Stage(ep.stage)], ShaderStage::Compute => vec![ Attribute::Stage(ShaderStage::Compute), Attribute::WorkGroupSize(ep.workgroup_size), ], ShaderStage::Mesh => { let mesh_output_name = module.global_variables [ep.mesh_info.as_ref().unwrap().output_variable] .name .clone() .unwrap(); let mut mesh_attrs = vec![ Attribute::MeshStage(mesh_output_name), Attribute::WorkGroupSize(ep.workgroup_size), ]; if let Some(task_payload) = ep.task_payload { let payload_name = module.global_variables[task_payload].name.clone().unwrap(); mesh_attrs.push(Attribute::TaskPayload(payload_name)); } mesh_attrs } ShaderStage::Task => { let payload_name = module.global_variables[ep.task_payload.unwrap()] .name .clone() .unwrap(); vec![ Attribute::Stage(ShaderStage::Task), Attribute::TaskPayload(payload_name), Attribute::WorkGroupSize(ep.workgroup_size), ] } ShaderStage::RayGeneration => vec![Attribute::Stage(ShaderStage::RayGeneration)], ShaderStage::AnyHit | ShaderStage::ClosestHit | ShaderStage::Miss => { let payload_name = module.global_variables[ep.incoming_ray_payload.unwrap()] .name .clone() .unwrap(); vec![ Attribute::Stage(ep.stage), Attribute::IncomingRayPayload(payload_name), ] } }; self.write_attributes(&attributes)?; // Add a newline after attribute writeln!(self.out)?; let func_ctx = back::FunctionCtx { ty: back::FunctionType::EntryPoint(index as u16), info: info.get_entry_point(index), expressions: &ep.function.expressions, named_expressions: &ep.function.named_expressions, }; self.write_function(module, &ep.function, &func_ctx)?; if index < module.entry_points.len() - 1 { writeln!(self.out)?; } } // Write any polyfills that were required. for polyfill in &self.required_polyfills { writeln!(self.out)?; write!(self.out, "{}", polyfill.source)?; writeln!(self.out)?; } Ok(()) } /// Helper method which writes all the `enable` declarations /// needed for a module. fn write_enable_declarations(&mut self, module: &Module) -> BackendResult { #[derive(Default)] struct RequiredEnabled { f16: bool, dual_source_blending: bool, clip_distances: bool, mesh_shaders: bool, primitive_index: bool, cooperative_matrix: bool, draw_index: bool, ray_tracing_pipeline: bool, } let mut needed = RequiredEnabled { mesh_shaders: module.uses_mesh_shaders(), ..Default::default() }; let check_binding = |binding: &crate::Binding, needed: &mut RequiredEnabled| match *binding { crate::Binding::Location { blend_src: Some(_), .. } => { needed.dual_source_blending = true; } crate::Binding::BuiltIn(crate::BuiltIn::ClipDistance) => { needed.clip_distances = true; } crate::Binding::BuiltIn(crate::BuiltIn::PrimitiveIndex) => { needed.primitive_index = true; } crate::Binding::Location { per_primitive: true, .. } => { needed.mesh_shaders = true; } crate::Binding::BuiltIn(crate::BuiltIn::DrawIndex) => needed.draw_index = true, crate::Binding::BuiltIn( crate::BuiltIn::RayInvocationId | crate::BuiltIn::NumRayInvocations | crate::BuiltIn::InstanceCustomData | crate::BuiltIn::GeometryIndex | crate::BuiltIn::WorldRayOrigin | crate::BuiltIn::WorldRayDirection | crate::BuiltIn::ObjectRayOrigin | crate::BuiltIn::ObjectRayDirection | crate::BuiltIn::RayTmin | crate::BuiltIn::RayTCurrentMax | crate::BuiltIn::ObjectToWorld | crate::BuiltIn::WorldToObject, ) => { needed.ray_tracing_pipeline = true; } _ => {} }; // Determine which `enable` declarations are needed for (_, ty) in module.types.iter() { match ty.inner { TypeInner::Scalar(scalar) | TypeInner::Vector { scalar, .. } | TypeInner::Matrix { scalar, .. } => { needed.f16 |= scalar == crate::Scalar::F16; } TypeInner::Struct { ref members, .. } => { for binding in members.iter().filter_map(|m| m.binding.as_ref()) { check_binding(binding, &mut needed); } } TypeInner::CooperativeMatrix { .. } => { needed.cooperative_matrix = true; } TypeInner::AccelerationStructure { .. } => { needed.ray_tracing_pipeline = true; } _ => {} } } for ep in &module.entry_points { if let Some(res) = ep.function.result.as_ref().and_then(|a| a.binding.as_ref()) { check_binding(res, &mut needed); } for arg in ep .function .arguments .iter() .filter_map(|a| a.binding.as_ref()) { check_binding(arg, &mut needed); } } if module.global_variables.iter().any(|gv| { gv.1.space == crate::AddressSpace::IncomingRayPayload || gv.1.space == crate::AddressSpace::RayPayload }) { needed.ray_tracing_pipeline = true; } if module.entry_points.iter().any(|ep| { matches!( ep.stage, ShaderStage::RayGeneration | ShaderStage::AnyHit | ShaderStage::ClosestHit | ShaderStage::Miss ) }) { needed.ray_tracing_pipeline = true; } if module.global_variables.iter().any(|gv| { gv.1.space == crate::AddressSpace::IncomingRayPayload || gv.1.space == crate::AddressSpace::RayPayload }) { needed.ray_tracing_pipeline = true; } if module.entry_points.iter().any(|ep| { matches!( ep.stage, ShaderStage::RayGeneration | ShaderStage::AnyHit | ShaderStage::ClosestHit | ShaderStage::Miss ) }) { needed.ray_tracing_pipeline = true; } // Write required declarations let mut any_written = false; if needed.f16 { writeln!(self.out, "enable f16;")?; any_written = true; } if needed.dual_source_blending { writeln!(self.out, "enable dual_source_blending;")?; any_written = true; } if needed.clip_distances { writeln!(self.out, "enable clip_distances;")?; any_written = true; } if module.uses_mesh_shaders() { writeln!(self.out, "enable wgpu_mesh_shader;")?; any_written = true; } if needed.draw_index { writeln!(self.out, "enable draw_index;")?; any_written = true; } if needed.primitive_index { writeln!(self.out, "enable primitive_index;")?; any_written = true; } if needed.cooperative_matrix { writeln!(self.out, "enable wgpu_cooperative_matrix;")?; any_written = true; } if needed.ray_tracing_pipeline { writeln!(self.out, "enable wgpu_ray_tracing_pipeline;")?; any_written = true; } if any_written { // Empty line for readability writeln!(self.out)?; } Ok(()) } /// Helper method used to write /// [functions](https://gpuweb.github.io/gpuweb/wgsl/#functions) /// /// # Notes /// Ends in a newline fn write_function( &mut self, module: &Module, func: &crate::Function, func_ctx: &back::FunctionCtx<'_>, ) -> BackendResult { let func_name = match func_ctx.ty { back::FunctionType::EntryPoint(index) => &self.names[&NameKey::EntryPoint(index)], back::FunctionType::Function(handle) => &self.names[&NameKey::Function(handle)], }; // Write function name write!(self.out, "fn {func_name}(")?; // Write function arguments for (index, arg) in func.arguments.iter().enumerate() { // Write argument attribute if a binding is present if let Some(ref binding) = arg.binding { self.write_attributes(&map_binding_to_attribute(binding))?; } // Write argument name let argument_name = &self.names[&func_ctx.argument_key(index as u32)]; write!(self.out, "{argument_name}: ")?; // Write argument type self.write_type(module, arg.ty)?; if index < func.arguments.len() - 1 { // Add a separator between args write!(self.out, ", ")?; } } write!(self.out, ")")?; // Write function return type if let Some(ref result) = func.result { write!(self.out, " -> ")?; if let Some(ref binding) = result.binding { self.write_attributes(&map_binding_to_attribute(binding))?; } self.write_type(module, result.ty)?; } write!(self.out, " {{")?; writeln!(self.out)?; // Write function local variables for (handle, local) in func.local_variables.iter() { // Write indentation (only for readability) write!(self.out, "{}", back::INDENT)?; // Write the local name // The leading space is important write!(self.out, "var {}: ", self.names[&func_ctx.name_key(handle)])?; // Write the local type self.write_type(module, local.ty)?; // Write the local initializer if needed if let Some(init) = local.init { // Put the equal signal only if there's a initializer // The leading and trailing spaces aren't needed but help with readability write!(self.out, " = ")?; // Write the constant // `write_constant` adds no trailing or leading space/newline self.write_expr(module, init, func_ctx)?; } // Finish the local with `;` and add a newline (only for readability) writeln!(self.out, ";")? } if !func.local_variables.is_empty() { writeln!(self.out)?; } // Write the function body (statement list) for sta in func.body.iter() { // The indentation should always be 1 when writing the function body self.write_stmt(module, sta, func_ctx, back::Level(1))?; } writeln!(self.out, "}}")?; self.named_expressions.clear(); Ok(()) } /// Helper method to write a attribute fn write_attributes(&mut self, attributes: &[Attribute]) -> BackendResult { for attribute in attributes { match *attribute { Attribute::Location(id) => write!(self.out, "@location({id}) ")?, Attribute::BlendSrc(blend_src) => write!(self.out, "@blend_src({blend_src}) ")?, Attribute::BuiltIn(builtin_attrib) => { let builtin = builtin_attrib.to_wgsl_if_implemented()?; write!(self.out, "@builtin({builtin}) ")?; } Attribute::Stage(shader_stage) => { let stage_str = match shader_stage { ShaderStage::Vertex => "vertex", ShaderStage::Fragment => "fragment", ShaderStage::Compute => "compute", ShaderStage::Task => "task", //Handled by another variant in the Attribute enum, so this code should never be hit. ShaderStage::Mesh => unreachable!(), ShaderStage::RayGeneration => "ray_generation", ShaderStage::AnyHit => "any_hit", ShaderStage::ClosestHit => "closest_hit", ShaderStage::Miss => "miss", }; write!(self.out, "@{stage_str} ")?; } Attribute::WorkGroupSize(size) => { write!( self.out, "@workgroup_size({}, {}, {}) ", size[0], size[1], size[2] )?; } Attribute::Binding(id) => write!(self.out, "@binding({id}) ")?, Attribute::Group(id) => write!(self.out, "@group({id}) ")?, Attribute::Invariant => write!(self.out, "@invariant ")?, Attribute::Interpolate(interpolation, sampling) => { if sampling.is_some() && sampling != Some(crate::Sampling::Center) { let interpolation = interpolation .unwrap_or(crate::Interpolation::Perspective) .to_wgsl(); let sampling = sampling.unwrap_or(crate::Sampling::Center).to_wgsl(); write!(self.out, "@interpolate({interpolation}, {sampling}) ")?; } else if interpolation.is_some() && interpolation != Some(crate::Interpolation::Perspective) { let interpolation = interpolation .unwrap_or(crate::Interpolation::Perspective) .to_wgsl(); write!(self.out, "@interpolate({interpolation}) ")?; } } Attribute::MeshStage(ref name) => { write!(self.out, "@mesh({name}) ")?; } Attribute::TaskPayload(ref payload_name) => { write!(self.out, "@payload({payload_name}) ")?; } Attribute::PerPrimitive => write!(self.out, "@per_primitive ")?, Attribute::IncomingRayPayload(ref payload_name) => { write!(self.out, "@incoming_payload({payload_name}) ")?; } }; } Ok(()) } /// Helper method used to write structs /// Write the full declaration of a struct type. /// /// Write out a definition of the struct type referred to by /// `handle` in `module`. The output will be an instance of the /// `struct_decl` production in the WGSL grammar. /// /// Use `members` as the list of `handle`'s members. (This /// function is usually called after matching a `TypeInner`, so /// the callers already have the members at hand.) fn write_struct( &mut self, module: &Module, handle: Handle, members: &[crate::StructMember], ) -> BackendResult { write!(self.out, "struct {}", self.names[&NameKey::Type(handle)])?; write!(self.out, " {{")?; writeln!(self.out)?; for (index, member) in members.iter().enumerate() { // The indentation is only for readability write!(self.out, "{}", back::INDENT)?; if let Some(ref binding) = member.binding { self.write_attributes(&map_binding_to_attribute(binding))?; } // Write struct member name and type let member_name = &self.names[&NameKey::StructMember(handle, index as u32)]; write!(self.out, "{member_name}: ")?; self.write_type(module, member.ty)?; write!(self.out, ",")?; writeln!(self.out)?; } writeln!(self.out, "}}")?; Ok(()) } fn write_type(&mut self, module: &Module, ty: Handle) -> BackendResult { // This actually can't be factored out into a nice constructor method, // because the borrow checker needs to be able to see that the borrows // of `self.names` and `self.out` are disjoint. let type_context = WriterTypeContext { module, names: &self.names, }; type_context.write_type(ty, &mut self.out)?; Ok(()) } fn write_type_resolution( &mut self, module: &Module, resolution: &proc::TypeResolution, ) -> BackendResult { // This actually can't be factored out into a nice constructor method, // because the borrow checker needs to be able to see that the borrows // of `self.names` and `self.out` are disjoint. let type_context = WriterTypeContext { module, names: &self.names, }; type_context.write_type_resolution(resolution, &mut self.out)?; Ok(()) } /// Helper method used to write statements /// /// # Notes /// Always adds a newline fn write_stmt( &mut self, module: &Module, stmt: &crate::Statement, func_ctx: &back::FunctionCtx<'_>, level: back::Level, ) -> BackendResult { use crate::{Expression, Statement}; match *stmt { Statement::Emit(ref range) => { for handle in range.clone() { let info = &func_ctx.info[handle]; let expr_name = if let Some(name) = func_ctx.named_expressions.get(&handle) { // Front end provides names for all variables at the start of writing. // But we write them to step by step. We need to recache them // Otherwise, we could accidentally write variable name instead of full expression. // Also, we use sanitized names! It defense backend from generating variable with name from reserved keywords. Some(self.namer.call(name)) } else { let expr = &func_ctx.expressions[handle]; let min_ref_count = expr.bake_ref_count(); // Forcefully creating baking expressions in some cases to help with readability let required_baking_expr = match *expr { Expression::ImageLoad { .. } | Expression::ImageQuery { .. } | Expression::ImageSample { .. } => true, _ => false, }; if min_ref_count <= info.ref_count || required_baking_expr { Some(Baked(handle).to_string()) } else { None } }; if let Some(name) = expr_name { write!(self.out, "{level}")?; self.start_named_expr(module, handle, func_ctx, &name)?; self.write_expr(module, handle, func_ctx)?; self.named_expressions.insert(handle, name); writeln!(self.out, ";")?; } } } // TODO: copy-paste from glsl-out Statement::If { condition, ref accept, ref reject, } => { write!(self.out, "{level}")?; write!(self.out, "if ")?; self.write_expr(module, condition, func_ctx)?; writeln!(self.out, " {{")?; let l2 = level.next(); for sta in accept { // Increase indentation to help with readability self.write_stmt(module, sta, func_ctx, l2)?; } // If there are no statements in the reject block we skip writing it // This is only for readability if !reject.is_empty() { writeln!(self.out, "{level}}} else {{")?; for sta in reject { // Increase indentation to help with readability self.write_stmt(module, sta, func_ctx, l2)?; } } writeln!(self.out, "{level}}}")? } Statement::Return { value } => { write!(self.out, "{level}")?; write!(self.out, "return")?; if let Some(return_value) = value { // The leading space is important write!(self.out, " ")?; self.write_expr(module, return_value, func_ctx)?; } writeln!(self.out, ";")?; } // TODO: copy-paste from glsl-out Statement::Kill => { write!(self.out, "{level}")?; writeln!(self.out, "discard;")? } Statement::Store { pointer, value } => { write!(self.out, "{level}")?; let is_atomic_pointer = func_ctx .resolve_type(pointer, &module.types) .is_atomic_pointer(&module.types); if is_atomic_pointer { write!(self.out, "atomicStore(")?; self.write_expr(module, pointer, func_ctx)?; write!(self.out, ", ")?; self.write_expr(module, value, func_ctx)?; write!(self.out, ")")?; } else { self.write_expr_with_indirection( module, pointer, func_ctx, Indirection::Reference, )?; write!(self.out, " = ")?; self.write_expr(module, value, func_ctx)?; } writeln!(self.out, ";")? } Statement::Call { function, ref arguments, result, } => { write!(self.out, "{level}")?; if let Some(expr) = result { let name = Baked(expr).to_string(); self.start_named_expr(module, expr, func_ctx, &name)?; self.named_expressions.insert(expr, name); } let func_name = &self.names[&NameKey::Function(function)]; write!(self.out, "{func_name}(")?; for (index, &argument) in arguments.iter().enumerate() { if index != 0 { write!(self.out, ", ")?; } self.write_expr(module, argument, func_ctx)?; } writeln!(self.out, ");")? } Statement::Atomic { pointer, ref fun, value, result, } => { write!(self.out, "{level}")?; if let Some(result) = result { let res_name = Baked(result).to_string(); self.start_named_expr(module, result, func_ctx, &res_name)?; self.named_expressions.insert(result, res_name); } let fun_str = fun.to_wgsl(); write!(self.out, "atomic{fun_str}(")?; self.write_expr(module, pointer, func_ctx)?; if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun { write!(self.out, ", ")?; self.write_expr(module, cmp, func_ctx)?; } write!(self.out, ", ")?; self.write_expr(module, value, func_ctx)?; writeln!(self.out, ");")? } Statement::ImageAtomic { image, coordinate, array_index, ref fun, value, } => { write!(self.out, "{level}")?; let fun_str = fun.to_wgsl(); write!(self.out, "textureAtomic{fun_str}(")?; self.write_expr(module, image, func_ctx)?; write!(self.out, ", ")?; self.write_expr(module, coordinate, func_ctx)?; if let Some(array_index_expr) = array_index { write!(self.out, ", ")?; self.write_expr(module, array_index_expr, func_ctx)?; } write!(self.out, ", ")?; self.write_expr(module, value, func_ctx)?; writeln!(self.out, ");")?; } Statement::WorkGroupUniformLoad { pointer, result } => { write!(self.out, "{level}")?; // TODO: Obey named expressions here. let res_name = Baked(result).to_string(); self.start_named_expr(module, result, func_ctx, &res_name)?; self.named_expressions.insert(result, res_name); write!(self.out, "workgroupUniformLoad(")?; self.write_expr(module, pointer, func_ctx)?; writeln!(self.out, ");")?; } Statement::ImageStore { image, coordinate, array_index, value, } => { write!(self.out, "{level}")?; write!(self.out, "textureStore(")?; self.write_expr(module, image, func_ctx)?; write!(self.out, ", ")?; self.write_expr(module, coordinate, func_ctx)?; if let Some(array_index_expr) = array_index { write!(self.out, ", ")?; self.write_expr(module, array_index_expr, func_ctx)?; } write!(self.out, ", ")?; self.write_expr(module, value, func_ctx)?; writeln!(self.out, ");")?; } // TODO: copy-paste from glsl-out Statement::Block(ref block) => { write!(self.out, "{level}")?; writeln!(self.out, "{{")?; for sta in block.iter() { // Increase the indentation to help with readability self.write_stmt(module, sta, func_ctx, level.next())? } writeln!(self.out, "{level}}}")? } Statement::Switch { selector, ref cases, } => { // Start the switch write!(self.out, "{level}")?; write!(self.out, "switch ")?; self.write_expr(module, selector, func_ctx)?; writeln!(self.out, " {{")?; let l2 = level.next(); let mut new_case = true; for case in cases { if case.fall_through && !case.body.is_empty() { // TODO: we could do the same workaround as we did for the HLSL backend return Err(Error::Unimplemented( "fall-through switch case block".into(), )); } match case.value { crate::SwitchValue::I32(value) => { if new_case { write!(self.out, "{l2}case ")?; } write!(self.out, "{value}")?; } crate::SwitchValue::U32(value) => { if new_case { write!(self.out, "{l2}case ")?; } write!(self.out, "{value}u")?; } crate::SwitchValue::Default => { if new_case { if case.fall_through { write!(self.out, "{l2}case ")?; } else { write!(self.out, "{l2}")?; } } write!(self.out, "default")?; } } new_case = !case.fall_through; if case.fall_through { write!(self.out, ", ")?; } else { writeln!(self.out, ": {{")?; } for sta in case.body.iter() { self.write_stmt(module, sta, func_ctx, l2.next())?; } if !case.fall_through { writeln!(self.out, "{l2}}}")?; } } writeln!(self.out, "{level}}}")? } Statement::Loop { ref body, ref continuing, break_if, } => { write!(self.out, "{level}")?; writeln!(self.out, "loop {{")?; let l2 = level.next(); for sta in body.iter() { self.write_stmt(module, sta, func_ctx, l2)?; } // The continuing is optional so we don't need to write it if // it is empty, but the `break if` counts as a continuing statement // so even if `continuing` is empty we must generate it if a // `break if` exists if !continuing.is_empty() || break_if.is_some() { writeln!(self.out, "{l2}continuing {{")?; for sta in continuing.iter() { self.write_stmt(module, sta, func_ctx, l2.next())?; } // The `break if` is always the last // statement of the `continuing` block if let Some(condition) = break_if { // The trailing space is important write!(self.out, "{}break if ", l2.next())?; self.write_expr(module, condition, func_ctx)?; // Close the `break if` statement writeln!(self.out, ";")?; } writeln!(self.out, "{l2}}}")?; } writeln!(self.out, "{level}}}")? } Statement::Break => { writeln!(self.out, "{level}break;")?; } Statement::Continue => { writeln!(self.out, "{level}continue;")?; } Statement::ControlBarrier(barrier) | Statement::MemoryBarrier(barrier) => { if barrier.contains(crate::Barrier::STORAGE) { writeln!(self.out, "{level}storageBarrier();")?; } if barrier.contains(crate::Barrier::WORK_GROUP) { writeln!(self.out, "{level}workgroupBarrier();")?; } if barrier.contains(crate::Barrier::SUB_GROUP) { writeln!(self.out, "{level}subgroupBarrier();")?; } if barrier.contains(crate::Barrier::TEXTURE) { writeln!(self.out, "{level}textureBarrier();")?; } } Statement::RayQuery { .. } => unreachable!(), Statement::SubgroupBallot { result, predicate } => { write!(self.out, "{level}")?; let res_name = Baked(result).to_string(); self.start_named_expr(module, result, func_ctx, &res_name)?; self.named_expressions.insert(result, res_name); write!(self.out, "subgroupBallot(")?; if let Some(predicate) = predicate { self.write_expr(module, predicate, func_ctx)?; } writeln!(self.out, ");")?; } Statement::SubgroupCollectiveOperation { op, collective_op, argument, result, } => { write!(self.out, "{level}")?; let res_name = Baked(result).to_string(); self.start_named_expr(module, result, func_ctx, &res_name)?; self.named_expressions.insert(result, res_name); match (collective_op, op) { (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => { write!(self.out, "subgroupAll(")? } (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => { write!(self.out, "subgroupAny(")? } (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => { write!(self.out, "subgroupAdd(")? } (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => { write!(self.out, "subgroupMul(")? } (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => { write!(self.out, "subgroupMax(")? } (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => { write!(self.out, "subgroupMin(")? } (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => { write!(self.out, "subgroupAnd(")? } (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => { write!(self.out, "subgroupOr(")? } (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => { write!(self.out, "subgroupXor(")? } (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => { write!(self.out, "subgroupExclusiveAdd(")? } (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => { write!(self.out, "subgroupExclusiveMul(")? } (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => { write!(self.out, "subgroupInclusiveAdd(")? } (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => { write!(self.out, "subgroupInclusiveMul(")? } _ => unimplemented!(), } self.write_expr(module, argument, func_ctx)?; writeln!(self.out, ");")?; } Statement::SubgroupGather { mode, argument, result, } => { write!(self.out, "{level}")?; let res_name = Baked(result).to_string(); self.start_named_expr(module, result, func_ctx, &res_name)?; self.named_expressions.insert(result, res_name); match mode { crate::GatherMode::BroadcastFirst => { write!(self.out, "subgroupBroadcastFirst(")?; } crate::GatherMode::Broadcast(_) => { write!(self.out, "subgroupBroadcast(")?; } crate::GatherMode::Shuffle(_) => { write!(self.out, "subgroupShuffle(")?; } crate::GatherMode::ShuffleDown(_) => { write!(self.out, "subgroupShuffleDown(")?; } crate::GatherMode::ShuffleUp(_) => { write!(self.out, "subgroupShuffleUp(")?; } crate::GatherMode::ShuffleXor(_) => { write!(self.out, "subgroupShuffleXor(")?; } crate::GatherMode::QuadBroadcast(_) => { write!(self.out, "quadBroadcast(")?; } crate::GatherMode::QuadSwap(direction) => match direction { crate::Direction::X => { write!(self.out, "quadSwapX(")?; } crate::Direction::Y => { write!(self.out, "quadSwapY(")?; } crate::Direction::Diagonal => { write!(self.out, "quadSwapDiagonal(")?; } }, } self.write_expr(module, argument, func_ctx)?; match mode { crate::GatherMode::BroadcastFirst => {} crate::GatherMode::Broadcast(index) | crate::GatherMode::Shuffle(index) | crate::GatherMode::ShuffleDown(index) | crate::GatherMode::ShuffleUp(index) | crate::GatherMode::ShuffleXor(index) | crate::GatherMode::QuadBroadcast(index) => { write!(self.out, ", ")?; self.write_expr(module, index, func_ctx)?; } crate::GatherMode::QuadSwap(_) => {} } writeln!(self.out, ");")?; } Statement::CooperativeStore { target, ref data } => { let suffix = if data.row_major { "T" } else { "" }; write!(self.out, "{level}coopStore{suffix}(")?; self.write_expr(module, target, func_ctx)?; write!(self.out, ", ")?; self.write_expr(module, data.pointer, func_ctx)?; write!(self.out, ", ")?; self.write_expr(module, data.stride, func_ctx)?; writeln!(self.out, ");")? } Statement::RayPipelineFunction(fun) => match fun { crate::RayPipelineFunction::TraceRay { acceleration_structure, descriptor, payload, } => { write!(self.out, "{level}traceRay(")?; self.write_expr(module, acceleration_structure, func_ctx)?; write!(self.out, ", ")?; self.write_expr(module, descriptor, func_ctx)?; write!(self.out, ", ")?; self.write_expr(module, payload, func_ctx)?; writeln!(self.out, ");")? } }, } Ok(()) } /// Return the sort of indirection that `expr`'s plain form evaluates to. /// /// An expression's 'plain form' is the most general rendition of that /// expression into WGSL, lacking `&` or `*` operators: /// /// - The plain form of `LocalVariable(x)` is simply `x`, which is a reference /// to the local variable's storage. /// /// - The plain form of `GlobalVariable(g)` is simply `g`, which is usually a /// reference to the global variable's storage. However, globals in the /// `Handle` address space are immutable, and `GlobalVariable` expressions for /// those produce the value directly, not a pointer to it. Such /// `GlobalVariable` expressions are `Ordinary`. /// /// - `Access` and `AccessIndex` are `Reference` when their `base` operand is a /// pointer. If they are applied directly to a composite value, they are /// `Ordinary`. /// /// Note that `FunctionArgument` expressions are never `Reference`, even when /// the argument's type is `Pointer`. `FunctionArgument` always evaluates to the /// argument's value directly, so any pointer it produces is merely the value /// passed by the caller. fn plain_form_indirection( &self, expr: Handle, module: &Module, func_ctx: &back::FunctionCtx<'_>, ) -> Indirection { use crate::Expression as Ex; // Named expressions are `let` expressions, which apply the Load Rule, // so if their type is a Naga pointer, then that must be a WGSL pointer // as well. if self.named_expressions.contains_key(&expr) { return Indirection::Ordinary; } match func_ctx.expressions[expr] { Ex::LocalVariable(_) => Indirection::Reference, Ex::GlobalVariable(handle) => { let global = &module.global_variables[handle]; match global.space { crate::AddressSpace::Handle => Indirection::Ordinary, _ => Indirection::Reference, } } Ex::Access { base, .. } | Ex::AccessIndex { base, .. } => { let base_ty = func_ctx.resolve_type(base, &module.types); match *base_ty { TypeInner::Pointer { .. } | TypeInner::ValuePointer { .. } => { Indirection::Reference } _ => Indirection::Ordinary, } } _ => Indirection::Ordinary, } } fn start_named_expr( &mut self, module: &Module, handle: Handle, func_ctx: &back::FunctionCtx, name: &str, ) -> BackendResult { // Write variable name write!(self.out, "let {name}")?; if self.flags.contains(WriterFlags::EXPLICIT_TYPES) { write!(self.out, ": ")?; // Write variable type self.write_type_resolution(module, &func_ctx.info[handle].ty)?; } write!(self.out, " = ")?; Ok(()) } /// Write the ordinary WGSL form of `expr`. /// /// See `write_expr_with_indirection` for details. fn write_expr( &mut self, module: &Module, expr: Handle, func_ctx: &back::FunctionCtx<'_>, ) -> BackendResult { self.write_expr_with_indirection(module, expr, func_ctx, Indirection::Ordinary) } /// Write `expr` as a WGSL expression with the requested indirection. /// /// In terms of the WGSL grammar, the resulting expression is a /// `singular_expression`. It may be parenthesized. This makes it suitable /// for use as the operand of a unary or binary operator without worrying /// about precedence. /// /// This does not produce newlines or indentation. /// /// The `requested` argument indicates (roughly) whether Naga /// `Pointer`-valued expressions represent WGSL references or pointers. See /// `Indirection` for details. fn write_expr_with_indirection( &mut self, module: &Module, expr: Handle, func_ctx: &back::FunctionCtx<'_>, requested: Indirection, ) -> BackendResult { // If the plain form of the expression is not what we need, emit the // operator necessary to correct that. let plain = self.plain_form_indirection(expr, module, func_ctx); log::trace!( "expression {:?}={:?} is {:?}, expected {:?}", expr, func_ctx.expressions[expr], plain, requested, ); match (requested, plain) { (Indirection::Ordinary, Indirection::Reference) => { write!(self.out, "(&")?; self.write_expr_plain_form(module, expr, func_ctx, plain)?; write!(self.out, ")")?; } (Indirection::Reference, Indirection::Ordinary) => { write!(self.out, "(*")?; self.write_expr_plain_form(module, expr, func_ctx, plain)?; write!(self.out, ")")?; } (_, _) => self.write_expr_plain_form(module, expr, func_ctx, plain)?, } Ok(()) } fn write_const_expression( &mut self, module: &Module, expr: Handle, arena: &crate::Arena, ) -> BackendResult { self.write_possibly_const_expression(module, expr, arena, |writer, expr| { writer.write_const_expression(module, expr, arena) }) } fn write_possibly_const_expression( &mut self, module: &Module, expr: Handle, expressions: &crate::Arena, write_expression: E, ) -> BackendResult where E: Fn(&mut Self, Handle) -> BackendResult, { use crate::Expression; match expressions[expr] { Expression::Literal(literal) => match literal { crate::Literal::F16(value) => write!(self.out, "{value}h")?, crate::Literal::F32(value) => write!(self.out, "{value}f")?, crate::Literal::U32(value) => write!(self.out, "{value}u")?, crate::Literal::I32(value) => { // `-2147483648i` is not valid WGSL. The most negative `i32` // value can only be expressed in WGSL using AbstractInt and // a unary negation operator. if value == i32::MIN { write!(self.out, "i32({value})")?; } else { write!(self.out, "{value}i")?; } } crate::Literal::Bool(value) => write!(self.out, "{value}")?, crate::Literal::F64(value) => write!(self.out, "{value:?}lf")?, crate::Literal::I64(value) => { // `-9223372036854775808li` is not valid WGSL. Nor can we simply use the // AbstractInt trick above, as AbstractInt also cannot represent // `9223372036854775808`. Instead construct the second most negative // AbstractInt, subtract one from it, then cast to i64. if value == i64::MIN { write!(self.out, "i64({} - 1)", value + 1)?; } else { write!(self.out, "{value}li")?; } } crate::Literal::U64(value) => write!(self.out, "{value:?}lu")?, crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => { return Err(Error::Custom( "Abstract types should not appear in IR presented to backends".into(), )); } }, Expression::Constant(handle) => { let constant = &module.constants[handle]; if constant.name.is_some() { write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?; } else { self.write_const_expression(module, constant.init, &module.global_expressions)?; } } Expression::ZeroValue(ty) => { self.write_type(module, ty)?; write!(self.out, "()")?; } Expression::Compose { ty, ref components } => { self.write_type(module, ty)?; write!(self.out, "(")?; for (index, component) in components.iter().enumerate() { if index != 0 { write!(self.out, ", ")?; } write_expression(self, *component)?; } write!(self.out, ")")? } Expression::Splat { size, value } => { let size = common::vector_size_str(size); write!(self.out, "vec{size}(")?; write_expression(self, value)?; write!(self.out, ")")?; } Expression::Override(handle) => { write!(self.out, "{}", self.names[&NameKey::Override(handle)])?; } _ => unreachable!(), } Ok(()) } /// Write the 'plain form' of `expr`. /// /// An expression's 'plain form' is the most general rendition of that /// expression into WGSL, lacking `&` or `*` operators. The plain forms of /// `LocalVariable(x)` and `GlobalVariable(g)` are simply `x` and `g`. Such /// Naga expressions represent both WGSL pointers and references; it's the /// caller's responsibility to distinguish those cases appropriately. fn write_expr_plain_form( &mut self, module: &Module, expr: Handle, func_ctx: &back::FunctionCtx<'_>, indirection: Indirection, ) -> BackendResult { use crate::Expression; if let Some(name) = self.named_expressions.get(&expr) { write!(self.out, "{name}")?; return Ok(()); } let expression = &func_ctx.expressions[expr]; // Write the plain WGSL form of a Naga expression. // // The plain form of `LocalVariable` and `GlobalVariable` expressions is // simply the variable name; `*` and `&` operators are never emitted. // // The plain form of `Access` and `AccessIndex` expressions are WGSL // `postfix_expression` forms for member/component access and // subscripting. match *expression { Expression::Literal(_) | Expression::Constant(_) | Expression::ZeroValue(_) | Expression::Compose { .. } | Expression::Splat { .. } => { self.write_possibly_const_expression( module, expr, func_ctx.expressions, |writer, expr| writer.write_expr(module, expr, func_ctx), )?; } Expression::Override(handle) => { write!(self.out, "{}", self.names[&NameKey::Override(handle)])?; } Expression::FunctionArgument(pos) => { let name_key = func_ctx.argument_key(pos); let name = &self.names[&name_key]; write!(self.out, "{name}")?; } Expression::Binary { op, left, right } => { write!(self.out, "(")?; self.write_expr(module, left, func_ctx)?; write!(self.out, " {} ", back::binary_operation_str(op))?; self.write_expr(module, right, func_ctx)?; write!(self.out, ")")?; } Expression::Access { base, index } => { self.write_expr_with_indirection(module, base, func_ctx, indirection)?; write!(self.out, "[")?; self.write_expr(module, index, func_ctx)?; write!(self.out, "]")? } Expression::AccessIndex { base, index } => { let base_ty_res = &func_ctx.info[base].ty; let mut resolved = base_ty_res.inner_with(&module.types); self.write_expr_with_indirection(module, base, func_ctx, indirection)?; let base_ty_handle = match *resolved { TypeInner::Pointer { base, space: _ } => { resolved = &module.types[base].inner; Some(base) } _ => base_ty_res.handle(), }; match *resolved { TypeInner::Vector { .. } => { // Write vector access as a swizzle write!(self.out, ".{}", back::COMPONENTS[index as usize])? } TypeInner::Matrix { .. } | TypeInner::Array { .. } | TypeInner::BindingArray { .. } | TypeInner::ValuePointer { .. } => write!(self.out, "[{index}]")?, TypeInner::Struct { .. } => { // This will never panic in case the type is a `Struct`, this is not true // for other types so we can only check while inside this match arm let ty = base_ty_handle.unwrap(); write!( self.out, ".{}", &self.names[&NameKey::StructMember(ty, index)] )? } ref other => return Err(Error::Custom(format!("Cannot index {other:?}"))), } } Expression::ImageSample { image, sampler, gather: None, coordinate, array_index, offset, level, depth_ref, clamp_to_edge, } => { use crate::SampleLevel as Sl; let suffix_cmp = match depth_ref { Some(_) => "Compare", None => "", }; let suffix_level = match level { Sl::Auto => "", Sl::Zero if clamp_to_edge => "BaseClampToEdge", Sl::Zero | Sl::Exact(_) => "Level", Sl::Bias(_) => "Bias", Sl::Gradient { .. } => "Grad", }; write!(self.out, "textureSample{suffix_cmp}{suffix_level}(")?; self.write_expr(module, image, func_ctx)?; write!(self.out, ", ")?; self.write_expr(module, sampler, func_ctx)?; write!(self.out, ", ")?; self.write_expr(module, coordinate, func_ctx)?; if let Some(array_index) = array_index { write!(self.out, ", ")?; self.write_expr(module, array_index, func_ctx)?; } if let Some(depth_ref) = depth_ref { write!(self.out, ", ")?; self.write_expr(module, depth_ref, func_ctx)?; } match level { Sl::Auto => {} Sl::Zero => { // Level 0 is implied for depth comparison and BaseClampToEdge if depth_ref.is_none() && !clamp_to_edge { write!(self.out, ", 0.0")?; } } Sl::Exact(expr) => { write!(self.out, ", ")?; self.write_expr(module, expr, func_ctx)?; } Sl::Bias(expr) => { write!(self.out, ", ")?; self.write_expr(module, expr, func_ctx)?; } Sl::Gradient { x, y } => { write!(self.out, ", ")?; self.write_expr(module, x, func_ctx)?; write!(self.out, ", ")?; self.write_expr(module, y, func_ctx)?; } } if let Some(offset) = offset { write!(self.out, ", ")?; self.write_const_expression(module, offset, func_ctx.expressions)?; } write!(self.out, ")")?; } Expression::ImageSample { image, sampler, gather: Some(component), coordinate, array_index, offset, level: _, depth_ref, clamp_to_edge: _, } => { let suffix_cmp = match depth_ref { Some(_) => "Compare", None => "", }; write!(self.out, "textureGather{suffix_cmp}(")?; match *func_ctx.resolve_type(image, &module.types) { TypeInner::Image { class: crate::ImageClass::Depth { multi: _ }, .. } => {} _ => { write!(self.out, "{}, ", component as u8)?; } } self.write_expr(module, image, func_ctx)?; write!(self.out, ", ")?; self.write_expr(module, sampler, func_ctx)?; write!(self.out, ", ")?; self.write_expr(module, coordinate, func_ctx)?; if let Some(array_index) = array_index { write!(self.out, ", ")?; self.write_expr(module, array_index, func_ctx)?; } if let Some(depth_ref) = depth_ref { write!(self.out, ", ")?; self.write_expr(module, depth_ref, func_ctx)?; } if let Some(offset) = offset { write!(self.out, ", ")?; self.write_const_expression(module, offset, func_ctx.expressions)?; } write!(self.out, ")")?; } Expression::ImageQuery { image, query } => { use crate::ImageQuery as Iq; let texture_function = match query { Iq::Size { .. } => "textureDimensions", Iq::NumLevels => "textureNumLevels", Iq::NumLayers => "textureNumLayers", Iq::NumSamples => "textureNumSamples", }; write!(self.out, "{texture_function}(")?; self.write_expr(module, image, func_ctx)?; if let Iq::Size { level: Some(level) } = query { write!(self.out, ", ")?; self.write_expr(module, level, func_ctx)?; }; write!(self.out, ")")?; } Expression::ImageLoad { image, coordinate, array_index, sample, level, } => { write!(self.out, "textureLoad(")?; self.write_expr(module, image, func_ctx)?; write!(self.out, ", ")?; self.write_expr(module, coordinate, func_ctx)?; if let Some(array_index) = array_index { write!(self.out, ", ")?; self.write_expr(module, array_index, func_ctx)?; } if let Some(index) = sample.or(level) { write!(self.out, ", ")?; self.write_expr(module, index, func_ctx)?; } write!(self.out, ")")?; } Expression::GlobalVariable(handle) => { let name = &self.names[&NameKey::GlobalVariable(handle)]; write!(self.out, "{name}")?; } Expression::As { expr, kind, convert, } => { let inner = func_ctx.resolve_type(expr, &module.types); match *inner { TypeInner::Matrix { columns, rows, scalar, } => { let scalar = crate::Scalar { kind, width: convert.unwrap_or(scalar.width), }; let scalar_kind_str = scalar.to_wgsl_if_implemented()?; write!( self.out, "mat{}x{}<{}>", common::vector_size_str(columns), common::vector_size_str(rows), scalar_kind_str )?; } TypeInner::Vector { size, scalar: crate::Scalar { width, .. }, } => { let scalar = crate::Scalar { kind, width: convert.unwrap_or(width), }; let vector_size_str = common::vector_size_str(size); let scalar_kind_str = scalar.to_wgsl_if_implemented()?; if convert.is_some() { write!(self.out, "vec{vector_size_str}<{scalar_kind_str}>")?; } else { write!(self.out, "bitcast>")?; } } TypeInner::Scalar(crate::Scalar { width, .. }) => { let scalar = crate::Scalar { kind, width: convert.unwrap_or(width), }; let scalar_kind_str = scalar.to_wgsl_if_implemented()?; if convert.is_some() { write!(self.out, "{scalar_kind_str}")? } else { write!(self.out, "bitcast<{scalar_kind_str}>")? } } _ => { return Err(Error::Unimplemented(format!( "write_expr expression::as {inner:?}" ))); } }; write!(self.out, "(")?; self.write_expr(module, expr, func_ctx)?; write!(self.out, ")")?; } Expression::Load { pointer } => { let is_atomic_pointer = func_ctx .resolve_type(pointer, &module.types) .is_atomic_pointer(&module.types); if is_atomic_pointer { write!(self.out, "atomicLoad(")?; self.write_expr(module, pointer, func_ctx)?; write!(self.out, ")")?; } else { self.write_expr_with_indirection( module, pointer, func_ctx, Indirection::Reference, )?; } } Expression::LocalVariable(handle) => { write!(self.out, "{}", self.names[&func_ctx.name_key(handle)])? } Expression::ArrayLength(expr) => { write!(self.out, "arrayLength(")?; self.write_expr(module, expr, func_ctx)?; write!(self.out, ")")?; } Expression::Math { fun, arg, arg1, arg2, arg3, } => { use crate::MathFunction as Mf; enum Function { Regular(&'static str), InversePolyfill(InversePolyfill), } let function = match fun.try_to_wgsl() { Some(name) => Function::Regular(name), None => match fun { Mf::Inverse => { let ty = func_ctx.resolve_type(arg, &module.types); let Some(overload) = InversePolyfill::find_overload(ty) else { return Err(Error::unsupported("math function", fun)); }; Function::InversePolyfill(overload) } _ => return Err(Error::unsupported("math function", fun)), }, }; match function { Function::Regular(fun_name) => { write!(self.out, "{fun_name}(")?; self.write_expr(module, arg, func_ctx)?; for arg in IntoIterator::into_iter([arg1, arg2, arg3]).flatten() { write!(self.out, ", ")?; self.write_expr(module, arg, func_ctx)?; } write!(self.out, ")")? } Function::InversePolyfill(inverse) => { write!(self.out, "{}(", inverse.fun_name)?; self.write_expr(module, arg, func_ctx)?; write!(self.out, ")")?; self.required_polyfills.insert(inverse); } } } Expression::Swizzle { size, vector, pattern, } => { self.write_expr(module, vector, func_ctx)?; write!(self.out, ".")?; for &sc in pattern[..size as usize].iter() { self.out.write_char(back::COMPONENTS[sc as usize])?; } } Expression::Unary { op, expr } => { let unary = match op { crate::UnaryOperator::Negate => "-", crate::UnaryOperator::LogicalNot => "!", crate::UnaryOperator::BitwiseNot => "~", }; write!(self.out, "{unary}(")?; self.write_expr(module, expr, func_ctx)?; write!(self.out, ")")? } Expression::Select { condition, accept, reject, } => { write!(self.out, "select(")?; self.write_expr(module, reject, func_ctx)?; write!(self.out, ", ")?; self.write_expr(module, accept, func_ctx)?; write!(self.out, ", ")?; self.write_expr(module, condition, func_ctx)?; write!(self.out, ")")? } Expression::Derivative { axis, ctrl, expr } => { use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl}; let op = match (axis, ctrl) { (Axis::X, Ctrl::Coarse) => "dpdxCoarse", (Axis::X, Ctrl::Fine) => "dpdxFine", (Axis::X, Ctrl::None) => "dpdx", (Axis::Y, Ctrl::Coarse) => "dpdyCoarse", (Axis::Y, Ctrl::Fine) => "dpdyFine", (Axis::Y, Ctrl::None) => "dpdy", (Axis::Width, Ctrl::Coarse) => "fwidthCoarse", (Axis::Width, Ctrl::Fine) => "fwidthFine", (Axis::Width, Ctrl::None) => "fwidth", }; write!(self.out, "{op}(")?; self.write_expr(module, expr, func_ctx)?; write!(self.out, ")")? } Expression::Relational { fun, argument } => { use crate::RelationalFunction as Rf; let fun_name = match fun { Rf::All => "all", Rf::Any => "any", _ => return Err(Error::UnsupportedRelationalFunction(fun)), }; write!(self.out, "{fun_name}(")?; self.write_expr(module, argument, func_ctx)?; write!(self.out, ")")? } // Not supported yet Expression::RayQueryGetIntersection { .. } | Expression::RayQueryVertexPositions { .. } => unreachable!(), // Nothing to do here, since call expression already cached Expression::CallResult(_) | Expression::AtomicResult { .. } | Expression::RayQueryProceedResult | Expression::SubgroupBallotResult | Expression::SubgroupOperationResult { .. } | Expression::WorkGroupUniformLoadResult { .. } => {} Expression::CooperativeLoad { columns, rows, role, ref data, } => { let suffix = if data.row_major { "T" } else { "" }; let scalar = func_ctx.info[data.pointer] .ty .inner_with(&module.types) .pointer_base_type() .unwrap() .inner_with(&module.types) .scalar() .unwrap(); write!( self.out, "coopLoad{suffix}>(", columns as u32, rows as u32, scalar.try_to_wgsl().unwrap(), role, )?; self.write_expr(module, data.pointer, func_ctx)?; write!(self.out, ", ")?; self.write_expr(module, data.stride, func_ctx)?; write!(self.out, ")")?; } Expression::CooperativeMultiplyAdd { a, b, c } => { write!(self.out, "coopMultiplyAdd(")?; self.write_expr(module, a, func_ctx)?; write!(self.out, ", ")?; self.write_expr(module, b, func_ctx)?; write!(self.out, ", ")?; self.write_expr(module, c, func_ctx)?; write!(self.out, ")")?; } } Ok(()) } /// Helper method used to write global variables /// # Notes /// Always adds a newline fn write_global( &mut self, module: &Module, global: &crate::GlobalVariable, handle: Handle, ) -> BackendResult { // Write group and binding attributes if present if let Some(ref binding) = global.binding { self.write_attributes(&[ Attribute::Group(binding.group), Attribute::Binding(binding.binding), ])?; writeln!(self.out)?; } if global .memory_decorations .contains(crate::MemoryDecorations::COHERENT) { write!(self.out, "@coherent ")?; } if global .memory_decorations .contains(crate::MemoryDecorations::VOLATILE) { write!(self.out, "@volatile ")?; } // First write global name and address space if supported write!(self.out, "var")?; let (address, maybe_access) = address_space_str(global.space); if let Some(space) = address { write!(self.out, "<{space}")?; if let Some(access) = maybe_access { write!(self.out, ", {access}")?; } write!(self.out, ">")?; } write!( self.out, " {}: ", &self.names[&NameKey::GlobalVariable(handle)] )?; // Write global type self.write_type(module, global.ty)?; // Write initializer if let Some(init) = global.init { write!(self.out, " = ")?; self.write_const_expression(module, init, &module.global_expressions)?; } // End with semicolon writeln!(self.out, ";")?; Ok(()) } /// Helper method used to write global constants /// /// # Notes /// Ends in a newline fn write_global_constant( &mut self, module: &Module, handle: Handle, ) -> BackendResult { let name = &self.names[&NameKey::Constant(handle)]; // First write only constant name write!(self.out, "const {name}: ")?; self.write_type(module, module.constants[handle].ty)?; write!(self.out, " = ")?; let init = module.constants[handle].init; self.write_const_expression(module, init, &module.global_expressions)?; writeln!(self.out, ";")?; Ok(()) } /// Helper method used to write overrides /// /// # Notes /// Ends in a newline fn write_override( &mut self, module: &Module, handle: Handle, ) -> BackendResult { let override_ = &module.overrides[handle]; let name = &self.names[&NameKey::Override(handle)]; // Write @id attribute if present if let Some(id) = override_.id { write!(self.out, "@id({id}) ")?; } // Write override declaration write!(self.out, "override {name}: ")?; self.write_type(module, override_.ty)?; // Write initializer if present if let Some(init) = override_.init { write!(self.out, " = ")?; self.write_const_expression(module, init, &module.global_expressions)?; } writeln!(self.out, ";")?; Ok(()) } // See https://github.com/rust-lang/rust-clippy/issues/4979. pub fn finish(self) -> W { self.out } } struct WriterTypeContext<'m> { module: &'m Module, names: &'m crate::FastHashMap, } impl TypeContext for WriterTypeContext<'_> { fn lookup_type(&self, handle: Handle) -> &crate::Type { &self.module.types[handle] } fn type_name(&self, handle: Handle) -> &str { self.names[&NameKey::Type(handle)].as_str() } fn write_unnamed_struct(&self, _: &TypeInner, _: &mut W) -> core::fmt::Result { unreachable!("the WGSL back end should always provide type handles"); } fn write_override( &self, handle: Handle, out: &mut W, ) -> core::fmt::Result { write!(out, "{}", self.names[&NameKey::Override(handle)]) } fn write_non_wgsl_inner(&self, _: &TypeInner, _: &mut W) -> core::fmt::Result { unreachable!("backends should only be passed validated modules"); } fn write_non_wgsl_scalar(&self, _: crate::Scalar, _: &mut W) -> core::fmt::Result { unreachable!("backends should only be passed validated modules"); } } fn map_binding_to_attribute(binding: &crate::Binding) -> Vec { match *binding { crate::Binding::BuiltIn(built_in) => { if let crate::BuiltIn::Position { invariant: true } = built_in { vec![Attribute::BuiltIn(built_in), Attribute::Invariant] } else { vec![Attribute::BuiltIn(built_in)] } } crate::Binding::Location { location, interpolation, sampling, blend_src: None, per_primitive, } => { let mut attrs = vec![ Attribute::Location(location), Attribute::Interpolate(interpolation, sampling), ]; if per_primitive { attrs.push(Attribute::PerPrimitive); } attrs } crate::Binding::Location { location, interpolation, sampling, blend_src: Some(blend_src), per_primitive, } => { let mut attrs = vec![ Attribute::Location(location), Attribute::BlendSrc(blend_src), Attribute::Interpolate(interpolation, sampling), ]; if per_primitive { attrs.push(Attribute::PerPrimitive); } attrs } } } naga-29.0.3/src/common/diagnostic_debug.rs000064400000000000000000000076211046102023000165520ustar 00000000000000//! Displaying Naga IR terms in debugging output. #[cfg(any(feature = "wgsl-in", feature = "wgsl-out"))] use crate::common::wgsl::TypeContext; use crate::proc::TypeResolution; use crate::{Handle, Scalar, Type, TypeInner, UniqueArena}; use core::fmt; /// A wrapper for displaying Naga IR terms in debugging output. /// /// This is like [`DiagnosticDisplay`], but requires weaker context /// and produces correspondingly lower-fidelity output. For example, /// this cannot show the override names for override-sized array /// lengths. /// /// [`DiagnosticDisplay`]: super::DiagnosticDisplay pub struct DiagnosticDebug(pub T); impl fmt::Debug for DiagnosticDebug<(Handle, &UniqueArena)> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let (handle, ctx) = self.0; #[cfg(any(feature = "wgsl-in", feature = "wgsl-out"))] ctx.write_type(handle, f)?; #[cfg(not(any(feature = "wgsl-in", feature = "wgsl-out")))] { let _ = ctx; write!(f, "{handle:?}")?; } Ok(()) } } impl fmt::Debug for DiagnosticDebug<(&TypeInner, &UniqueArena)> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let (inner, ctx) = self.0; #[cfg(any(feature = "wgsl-in", feature = "wgsl-out"))] ctx.write_type_inner(inner, f)?; #[cfg(not(any(feature = "wgsl-in", feature = "wgsl-out")))] { let _ = ctx; write!(f, "{inner:?}")?; } Ok(()) } } impl fmt::Debug for DiagnosticDebug<(&TypeResolution, &UniqueArena)> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let (resolution, ctx) = self.0; #[cfg(any(feature = "wgsl-in", feature = "wgsl-out"))] ctx.write_type_resolution(resolution, f)?; #[cfg(not(any(feature = "wgsl-in", feature = "wgsl-out")))] { let _ = ctx; write!(f, "{resolution:?}")?; } Ok(()) } } impl fmt::Debug for DiagnosticDebug { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let scalar = self.0; #[cfg(any(feature = "wgsl-in", feature = "wgsl-out"))] f.write_str(&crate::common::wgsl::TryToWgsl::to_wgsl_for_diagnostics( scalar, ))?; #[cfg(not(any(feature = "wgsl-in", feature = "wgsl-out")))] write!(f, "{scalar:?}")?; Ok(()) } } pub trait ForDebug: Sized { /// Format this type using [`core::fmt::Debug`]. /// /// Return a value that implements the [`core::fmt::Debug`] trait /// by displaying `self` in a language-appropriate way. For /// example: /// /// # use naga::common::ForDebug; /// # let scalar: naga::Scalar = naga::Scalar::F32; /// log::debug!("My scalar: {:?}", scalar.for_debug()); fn for_debug(self) -> DiagnosticDebug { DiagnosticDebug(self) } } impl ForDebug for Scalar {} pub trait ForDebugWithTypes: Sized { /// Format this type using [`core::fmt::Debug`]. /// /// Given an arena to look up type handles in, return a value that /// implements the [`core::fmt::Debug`] trait by displaying `self` /// in a language-appropriate way. For example: /// /// # use naga::{Span, Type, TypeInner, Scalar, UniqueArena}; /// # use naga::common::ForDebugWithTypes; /// # let mut types = UniqueArena::::default(); /// # let inner = TypeInner::Scalar(Scalar::F32); /// # let span = Span::UNDEFINED; /// # let handle = types.insert(Type { name: None, inner }, span); /// log::debug!("My type: {:?}", handle.for_debug(&types)); fn for_debug(self, types: &UniqueArena) -> DiagnosticDebug<(Self, &UniqueArena)> { DiagnosticDebug((self, types)) } } impl ForDebugWithTypes for Handle {} impl ForDebugWithTypes for &TypeInner {} impl ForDebugWithTypes for &TypeResolution {} naga-29.0.3/src/common/diagnostic_display.rs000064400000000000000000000100751046102023000171260ustar 00000000000000//! Displaying Naga IR terms in diagnostic output. use crate::proc::{GlobalCtx, Rule, TypeResolution}; use crate::{Handle, Scalar, Type}; #[cfg(any(feature = "wgsl-in", feature = "wgsl-out"))] use crate::common::wgsl::TypeContext; use core::fmt; /// A wrapper for displaying Naga IR terms in diagnostic output. /// /// For some Naga IR type `T`, `DiagnosticDisplay` implements /// [`core::fmt::Display`] in a way that displays values of type `T` /// appropriately for diagnostic messages presented to human readers. /// /// For example, the implementation of [`Display`] for /// `DiagnosticDisplay` formats the type represented by the /// given [`Scalar`] appropriately for users. /// /// Some types like `Handle` require contextual information like /// a type arena to be displayed. In such cases, we implement [`Display`] /// for a type like `DiagnosticDisplay<(Handle, GlobalCtx)>`, where /// the [`GlobalCtx`] type provides the necessary context. /// /// Do not implement this type for [`TypeInner`], as that does not /// have enough information to display struct types correctly. /// /// If you only need debugging output, [`DiagnosticDebug`] uses /// easier-to-obtain context types but still does a good enough job /// for logging or debugging. /// /// [`Display`]: core::fmt::Display /// [`GlobalCtx`]: crate::proc::GlobalCtx /// [`TypeInner`]: crate::ir::TypeInner /// [`DiagnosticDebug`]: super::DiagnosticDebug /// /// ## Language-sensitive diagnostics /// /// Diagnostic output ought to depend on the source language from /// which the IR was produced: diagnostics resulting from processing /// GLSL code should use GLSL type syntax, for example. That means /// that `DiagnosticDisplay` ought to include some indication of which /// notation to use. /// /// For the moment, only WGSL output is implemented, so /// `DiagnosticDisplay` lacks any support for this (#7268). However, /// the plan is that all language-independent code in Naga should use /// `DiagnosticDisplay` wherever appropriate, such that when its /// definition is expanded to include some indication of the right /// source language to use, any use site that does not supply this /// indication will provoke a compile-time error. pub struct DiagnosticDisplay(pub T); impl fmt::Display for DiagnosticDisplay<(&TypeResolution, GlobalCtx<'_>)> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let (resolution, ctx) = self.0; #[cfg(any(feature = "wgsl-in", feature = "wgsl-out"))] ctx.write_type_resolution(resolution, f)?; #[cfg(not(any(feature = "wgsl-in", feature = "wgsl-out")))] { let _ = ctx; write!(f, "{resolution:?}")?; } Ok(()) } } impl fmt::Display for DiagnosticDisplay<(Handle, GlobalCtx<'_>)> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let (handle, ref ctx) = self.0; #[cfg(any(feature = "wgsl-in", feature = "wgsl-out"))] ctx.write_type(handle, f)?; #[cfg(not(any(feature = "wgsl-in", feature = "wgsl-out")))] { let _ = ctx; write!(f, "{handle:?}")?; } Ok(()) } } impl fmt::Display for DiagnosticDisplay<(&str, &Rule, GlobalCtx<'_>)> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let (name, rule, ref ctx) = self.0; #[cfg(any(feature = "wgsl-in", feature = "wgsl-out"))] ctx.write_type_rule(name, rule, f)?; #[cfg(not(any(feature = "wgsl-in", feature = "wgsl-out")))] { let _ = ctx; write!(f, "{name}({:?}) -> {:?}", rule.arguments, rule.conclusion)?; } Ok(()) } } impl fmt::Display for DiagnosticDisplay { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let scalar = self.0; #[cfg(any(feature = "wgsl-in", feature = "wgsl-out"))] f.write_str(&crate::common::wgsl::TryToWgsl::to_wgsl_for_diagnostics( scalar, ))?; #[cfg(not(any(feature = "wgsl-in", feature = "wgsl-out")))] write!(f, "{scalar:?}")?; Ok(()) } } naga-29.0.3/src/common/mod.rs000064400000000000000000000005451046102023000140350ustar 00000000000000//! Code common to the front and backends for specific languages. mod diagnostic_debug; mod diagnostic_display; pub mod predeclared; pub mod wgsl; pub use diagnostic_debug::{DiagnosticDebug, ForDebug, ForDebugWithTypes}; pub use diagnostic_display::DiagnosticDisplay; // Re-exported here for backwards compatibility pub use super::proc::vector_size_str; naga-29.0.3/src/common/predeclared.rs000064400000000000000000000017521046102023000155310ustar 00000000000000//! Generating names for predeclared types. use crate::ir; use alloc::format; use alloc::string::String; impl ir::PredeclaredType { pub fn struct_name(&self) -> String { use crate::PredeclaredType as Pt; match *self { Pt::AtomicCompareExchangeWeakResult(scalar) => { format!( "__atomic_compare_exchange_result<{:?},{}>", scalar.kind, scalar.width, ) } Pt::ModfResult { size, scalar } => frexp_mod_name("modf", size, scalar), Pt::FrexpResult { size, scalar } => frexp_mod_name("frexp", size, scalar), } } } fn frexp_mod_name(function: &str, size: Option, scalar: ir::Scalar) -> String { let bits = 8 * scalar.width; match size { Some(size) => { let size = size as u8; format!("__{function}_result_vec{size}_f{bits}") } None => format!("__{function}_result_f{bits}"), } } naga-29.0.3/src/common/wgsl/diagnostics.rs000064400000000000000000000043641046102023000165440ustar 00000000000000//! WGSL diagnostic filters and severities. use core::fmt::{self, Display, Formatter}; use crate::diagnostic_filter::{ FilterableTriggeringRule, Severity, StandardFilterableTriggeringRule, }; impl Severity { const ERROR: &'static str = "error"; const WARNING: &'static str = "warning"; const INFO: &'static str = "info"; const OFF: &'static str = "off"; /// Convert from a sentinel word in WGSL into its associated [`Severity`], if possible. pub fn from_wgsl_ident(s: &str) -> Option { Some(match s { Self::ERROR => Self::Error, Self::WARNING => Self::Warning, Self::INFO => Self::Info, Self::OFF => Self::Off, _ => return None, }) } } pub struct DisplayFilterableTriggeringRule<'a>(&'a FilterableTriggeringRule); impl Display for DisplayFilterableTriggeringRule<'_> { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { let &Self(inner) = self; match *inner { FilterableTriggeringRule::Standard(rule) => write!(f, "{}", rule.to_wgsl_ident()), FilterableTriggeringRule::Unknown(ref rule) => write!(f, "{rule}"), FilterableTriggeringRule::User(ref rules) => { let &[ref seg1, ref seg2] = rules.as_ref(); write!(f, "{seg1}.{seg2}") } } } } impl FilterableTriggeringRule { /// [`Display`] this rule's identifiers in WGSL. pub const fn display_wgsl_ident(&self) -> impl Display + '_ { DisplayFilterableTriggeringRule(self) } } impl StandardFilterableTriggeringRule { const DERIVATIVE_UNIFORMITY: &'static str = "derivative_uniformity"; /// Convert from a sentinel word in WGSL into its associated /// [`StandardFilterableTriggeringRule`], if possible. pub fn from_wgsl_ident(s: &str) -> Option { Some(match s { Self::DERIVATIVE_UNIFORMITY => Self::DerivativeUniformity, _ => return None, }) } /// Maps this [`StandardFilterableTriggeringRule`] into the sentinel word associated with it in /// WGSL. pub const fn to_wgsl_ident(self) -> &'static str { match self { Self::DerivativeUniformity => Self::DERIVATIVE_UNIFORMITY, } } } naga-29.0.3/src/common/wgsl/mod.rs000064400000000000000000000003541046102023000150070ustar 00000000000000//! Code shared between the WGSL front and back ends. mod diagnostics; mod to_wgsl; mod types; pub use diagnostics::DisplayFilterableTriggeringRule; pub use to_wgsl::{address_space_str, ToWgsl, TryToWgsl}; pub use types::TypeContext; naga-29.0.3/src/common/wgsl/to_wgsl.rs000064400000000000000000000342471046102023000157160ustar 00000000000000//! Generating WGSL source code for Naga IR types. use alloc::format; use alloc::string::{String, ToString}; /// Types that can return the WGSL source representation of their /// values as a `'static` string. /// /// This trait is specifically for types whose WGSL forms are simple /// enough that they can always be returned as a static string. /// /// - If only some values have a WGSL representation, consider /// implementing [`TryToWgsl`] instead. /// /// - If a type's WGSL form requires dynamic formatting, so that /// returning a `&'static str` isn't feasible, consider implementing /// [`core::fmt::Display`] on some wrapper type instead. pub trait ToWgsl: Sized { /// Return WGSL source code representation of `self`. fn to_wgsl(self) -> &'static str; } /// Types that may be able to return the WGSL source representation /// for their values as a `'static` string. /// /// This trait is specifically for types whose values are either /// simple enough that their WGSL form can be represented a static /// string, or aren't representable in WGSL at all. /// /// - If all values in the type have `&'static str` representations in /// WGSL, consider implementing [`ToWgsl`] instead. /// /// - If a type's WGSL form requires dynamic formatting, so that /// returning a `&'static str` isn't feasible, consider implementing /// [`core::fmt::Display`] on some wrapper type instead. pub trait TryToWgsl: Sized { /// Return the WGSL form of `self` as a `'static` string. /// /// If `self` doesn't have a representation in WGSL (standard or /// as extended by Naga), then return `None`. fn try_to_wgsl(self) -> Option<&'static str>; /// What kind of WGSL thing `Self` represents. const DESCRIPTION: &'static str; /// Return the WGSL form of `self` as appropriate for diagnostics. /// /// If `self` can be expressed in WGSL, return that form as a /// [`String`]. Otherwise, return some representation of `self` /// that is appropriate for use in diagnostic messages. /// /// The default implementation of this function falls back to /// `self`'s [`Debug`] form. /// /// [`Debug`]: core::fmt::Debug fn to_wgsl_for_diagnostics(self) -> String where Self: core::fmt::Debug + Copy, { match self.try_to_wgsl() { Some(static_string) => static_string.to_string(), None => format!("{{non-WGSL {} {self:?}}}", Self::DESCRIPTION), } } } impl TryToWgsl for crate::MathFunction { const DESCRIPTION: &'static str = "math function"; fn try_to_wgsl(self) -> Option<&'static str> { use crate::MathFunction as Mf; Some(match self { Mf::Abs => "abs", Mf::Min => "min", Mf::Max => "max", Mf::Clamp => "clamp", Mf::Saturate => "saturate", Mf::Cos => "cos", Mf::Cosh => "cosh", Mf::Sin => "sin", Mf::Sinh => "sinh", Mf::Tan => "tan", Mf::Tanh => "tanh", Mf::Acos => "acos", Mf::Asin => "asin", Mf::Atan => "atan", Mf::Atan2 => "atan2", Mf::Asinh => "asinh", Mf::Acosh => "acosh", Mf::Atanh => "atanh", Mf::Radians => "radians", Mf::Degrees => "degrees", Mf::Ceil => "ceil", Mf::Floor => "floor", Mf::Round => "round", Mf::Fract => "fract", Mf::Trunc => "trunc", Mf::Modf => "modf", Mf::Frexp => "frexp", Mf::Ldexp => "ldexp", Mf::Exp => "exp", Mf::Exp2 => "exp2", Mf::Log => "log", Mf::Log2 => "log2", Mf::Pow => "pow", Mf::Dot => "dot", Mf::Dot4I8Packed => "dot4I8Packed", Mf::Dot4U8Packed => "dot4U8Packed", Mf::Cross => "cross", Mf::Distance => "distance", Mf::Length => "length", Mf::Normalize => "normalize", Mf::FaceForward => "faceForward", Mf::Reflect => "reflect", Mf::Refract => "refract", Mf::Sign => "sign", Mf::Fma => "fma", Mf::Mix => "mix", Mf::Step => "step", Mf::SmoothStep => "smoothstep", Mf::Sqrt => "sqrt", Mf::InverseSqrt => "inverseSqrt", Mf::Transpose => "transpose", Mf::Determinant => "determinant", Mf::QuantizeToF16 => "quantizeToF16", Mf::CountTrailingZeros => "countTrailingZeros", Mf::CountLeadingZeros => "countLeadingZeros", Mf::CountOneBits => "countOneBits", Mf::ReverseBits => "reverseBits", Mf::ExtractBits => "extractBits", Mf::InsertBits => "insertBits", Mf::FirstTrailingBit => "firstTrailingBit", Mf::FirstLeadingBit => "firstLeadingBit", Mf::Pack4x8snorm => "pack4x8snorm", Mf::Pack4x8unorm => "pack4x8unorm", Mf::Pack2x16snorm => "pack2x16snorm", Mf::Pack2x16unorm => "pack2x16unorm", Mf::Pack2x16float => "pack2x16float", Mf::Pack4xI8 => "pack4xI8", Mf::Pack4xU8 => "pack4xU8", Mf::Pack4xI8Clamp => "pack4xI8Clamp", Mf::Pack4xU8Clamp => "pack4xU8Clamp", Mf::Unpack4x8snorm => "unpack4x8snorm", Mf::Unpack4x8unorm => "unpack4x8unorm", Mf::Unpack2x16snorm => "unpack2x16snorm", Mf::Unpack2x16unorm => "unpack2x16unorm", Mf::Unpack2x16float => "unpack2x16float", Mf::Unpack4xI8 => "unpack4xI8", Mf::Unpack4xU8 => "unpack4xU8", // Non-standard math functions. Mf::Inverse | Mf::Outer => return None, }) } } impl TryToWgsl for crate::BuiltIn { const DESCRIPTION: &'static str = "builtin value"; fn try_to_wgsl(self) -> Option<&'static str> { use crate::BuiltIn as Bi; Some(match self { Bi::Position { .. } => "position", Bi::ViewIndex => "view_index", Bi::InstanceIndex => "instance_index", Bi::VertexIndex => "vertex_index", Bi::ClipDistance => "clip_distances", Bi::FragDepth => "frag_depth", Bi::FrontFacing => "front_facing", Bi::PrimitiveIndex => "primitive_index", Bi::DrawIndex => "draw_index", Bi::Barycentric { perspective: true } => "barycentric", Bi::Barycentric { perspective: false } => "barycentric_no_perspective", Bi::SampleIndex => "sample_index", Bi::SampleMask => "sample_mask", Bi::GlobalInvocationId => "global_invocation_id", Bi::LocalInvocationId => "local_invocation_id", Bi::LocalInvocationIndex => "local_invocation_index", Bi::WorkGroupId => "workgroup_id", Bi::NumWorkGroups => "num_workgroups", Bi::NumSubgroups => "num_subgroups", Bi::SubgroupId => "subgroup_id", Bi::SubgroupSize => "subgroup_size", Bi::SubgroupInvocationId => "subgroup_invocation_id", // Non-standard built-ins. Bi::MeshTaskSize => "mesh_task_size", Bi::TriangleIndices => "triangle_indices", Bi::LineIndices => "line_indices", Bi::PointIndex => "point_index", Bi::Vertices => "vertices", Bi::Primitives => "primitives", Bi::VertexCount => "vertex_count", Bi::PrimitiveCount => "primitive_count", Bi::CullPrimitive => "cull_primitive", Bi::RayInvocationId => "ray_invocation_id", Bi::NumRayInvocations => "num_ray_invocations", Bi::InstanceCustomData => "instance_custom_data", Bi::GeometryIndex => "geometry_index", Bi::WorldRayOrigin => "world_ray_origin", Bi::WorldRayDirection => "world_ray_direction", Bi::ObjectRayOrigin => "object_ray_origin", Bi::ObjectRayDirection => "object_ray_direction", Bi::RayTmin => "ray_t_min", Bi::RayTCurrentMax => "ray_t_current_max", Bi::ObjectToWorld => "object_to_world", Bi::WorldToObject => "world_to_object", Bi::HitKind => "hit_kind", Bi::BaseInstance | Bi::BaseVertex | Bi::CullDistance | Bi::PointSize | Bi::PointCoord | Bi::WorkGroupSize => return None, }) } } impl ToWgsl for crate::Interpolation { fn to_wgsl(self) -> &'static str { match self { crate::Interpolation::Perspective => "perspective", crate::Interpolation::Linear => "linear", crate::Interpolation::Flat => "flat", crate::Interpolation::PerVertex => "per_vertex", } } } impl ToWgsl for crate::Sampling { fn to_wgsl(self) -> &'static str { match self { crate::Sampling::Center => "center", crate::Sampling::Centroid => "centroid", crate::Sampling::Sample => "sample", crate::Sampling::First => "first", crate::Sampling::Either => "either", } } } impl ToWgsl for crate::StorageFormat { fn to_wgsl(self) -> &'static str { use crate::StorageFormat as Sf; match self { Sf::R8Unorm => "r8unorm", Sf::R8Snorm => "r8snorm", Sf::R8Uint => "r8uint", Sf::R8Sint => "r8sint", Sf::R16Uint => "r16uint", Sf::R16Sint => "r16sint", Sf::R16Float => "r16float", Sf::Rg8Unorm => "rg8unorm", Sf::Rg8Snorm => "rg8snorm", Sf::Rg8Uint => "rg8uint", Sf::Rg8Sint => "rg8sint", Sf::R32Uint => "r32uint", Sf::R32Sint => "r32sint", Sf::R32Float => "r32float", Sf::Rg16Uint => "rg16uint", Sf::Rg16Sint => "rg16sint", Sf::Rg16Float => "rg16float", Sf::Rgba8Unorm => "rgba8unorm", Sf::Rgba8Snorm => "rgba8snorm", Sf::Rgba8Uint => "rgba8uint", Sf::Rgba8Sint => "rgba8sint", Sf::Bgra8Unorm => "bgra8unorm", Sf::Rgb10a2Uint => "rgb10a2uint", Sf::Rgb10a2Unorm => "rgb10a2unorm", Sf::Rg11b10Ufloat => "rg11b10ufloat", Sf::R64Uint => "r64uint", Sf::Rg32Uint => "rg32uint", Sf::Rg32Sint => "rg32sint", Sf::Rg32Float => "rg32float", Sf::Rgba16Uint => "rgba16uint", Sf::Rgba16Sint => "rgba16sint", Sf::Rgba16Float => "rgba16float", Sf::Rgba32Uint => "rgba32uint", Sf::Rgba32Sint => "rgba32sint", Sf::Rgba32Float => "rgba32float", Sf::R16Unorm => "r16unorm", Sf::R16Snorm => "r16snorm", Sf::Rg16Unorm => "rg16unorm", Sf::Rg16Snorm => "rg16snorm", Sf::Rgba16Unorm => "rgba16unorm", Sf::Rgba16Snorm => "rgba16snorm", } } } impl TryToWgsl for crate::Scalar { const DESCRIPTION: &'static str = "scalar type"; fn try_to_wgsl(self) -> Option<&'static str> { use crate::Scalar; Some(match self { Scalar::F16 => "f16", Scalar::F32 => "f32", Scalar::F64 => "f64", Scalar::I32 => "i32", Scalar::U32 => "u32", Scalar::I64 => "i64", Scalar::U64 => "u64", Scalar::BOOL => "bool", _ => return None, }) } fn to_wgsl_for_diagnostics(self) -> String { match self.try_to_wgsl() { Some(static_string) => static_string.to_string(), None => match self.kind { crate::ScalarKind::Sint | crate::ScalarKind::Uint | crate::ScalarKind::Float | crate::ScalarKind::Bool => format!("{{non-WGSL scalar {self:?}}}"), crate::ScalarKind::AbstractInt => "{AbstractInt}".to_string(), crate::ScalarKind::AbstractFloat => "{AbstractFloat}".to_string(), }, } } } impl ToWgsl for crate::CooperativeRole { fn to_wgsl(self) -> &'static str { match self { Self::A => "A", Self::B => "B", Self::C => "C", } } } impl ToWgsl for crate::ImageDimension { fn to_wgsl(self) -> &'static str { match self { Self::D1 => "1d", Self::D2 => "2d", Self::D3 => "3d", Self::Cube => "cube", } } } /// Return the WGSL address space and access mode strings for `space`. /// /// Why don't we implement [`ToWgsl`] for [`AddressSpace`]? /// /// In WGSL, the full form of a pointer type is `ptr`, where: /// - `AS` is the address space, /// - `T` is the store type, and /// - `AM` is the access mode. /// /// Since the type `T` intervenes between the address space and the /// access mode, there isn't really any individual WGSL grammar /// production that corresponds to an [`AddressSpace`], so [`ToWgsl`] /// is too simple-minded for this case. /// /// Furthermore, we want to write `var` for most address /// spaces, but we want to just write `var foo: T` for handle types. /// /// [`AddressSpace`]: crate::AddressSpace pub const fn address_space_str( space: crate::AddressSpace, ) -> (Option<&'static str>, Option<&'static str>) { use crate::AddressSpace as As; ( Some(match space { As::Private => "private", As::Uniform => "uniform", As::Storage { access } => { if access.contains(crate::StorageAccess::ATOMIC) { return (Some("storage"), Some("atomic")); } else if access.contains(crate::StorageAccess::STORE) { return (Some("storage"), Some("read_write")); } else { "storage" } } As::Immediate => "immediate", As::WorkGroup => "workgroup", As::Handle => return (None, None), As::Function => "function", As::TaskPayload => "task_payload", As::IncomingRayPayload => "incoming_ray_payload", As::RayPayload => "ray_payload", }), None, ) } naga-29.0.3/src/common/wgsl/types.rs000064400000000000000000000423021046102023000153730ustar 00000000000000//! Code for formatting Naga IR types as WGSL source code. use super::{address_space_str, ToWgsl, TryToWgsl}; use crate::common; use crate::proc::TypeResolution; use crate::{Handle, Scalar, TypeInner}; use alloc::string::String; use core::fmt::Write; /// A context for printing Naga IR types as WGSL. /// /// This trait's default methods [`write_type`] and /// [`write_type_inner`] do the work of formatting types as WGSL. /// Implementors must provide the remaining methods, to customize /// behavior for the context at hand. /// /// For example, the WGSL backend would provide an implementation of /// [`type_name`] that handles hygienic renaming, whereas the WGSL /// front end would simply show the name that was given in the source. /// /// [`write_type`]: TypeContext::write_type /// [`write_type_inner`]: TypeContext::write_type_inner /// [`type_name`]: TypeContext::type_name pub trait TypeContext { /// Return the [`Type`] referred to by `handle`. /// /// [`Type`]: crate::Type fn lookup_type(&self, handle: Handle) -> &crate::Type; /// Return the name to be used for the type referred to by /// `handle`. fn type_name(&self, handle: Handle) -> &str; /// Write the WGSL form of `override` to `out`. fn write_override( &self, r#override: Handle, out: &mut W, ) -> core::fmt::Result; /// Write a [`TypeInner::Struct`] for which we are unable to find a name. /// /// The names of struct types are only available if we have `Handle`, /// not from [`TypeInner`]. For logging and debugging, it's fine to just /// write something helpful to the developer, but for generating WGSL, /// this should be unreachable. fn write_unnamed_struct(&self, inner: &TypeInner, out: &mut W) -> core::fmt::Result; /// Write a [`TypeInner`] that has no representation as WGSL source, /// even including Naga extensions. /// /// A backend might implement this with a call to the [`unreachable!`] /// macro, since backends are allowed to assume that the module has passed /// validation. /// /// The default implementation is appropriate for generating type names to /// appear in error messages. It punts to `TypeInner`'s [`core::fmt::Debug`] /// implementation, since it's probably best to show the user something they /// can act on. fn write_non_wgsl_inner(&self, inner: &TypeInner, out: &mut W) -> core::fmt::Result { write!(out, "{{non-WGSL Naga type {inner:?}}}") } /// Write a [`Scalar`] that has no representation as WGSL source, /// even including Naga extensions. /// /// A backend might implement this with a call to the [`unreachable!`] /// macro, since backends are allowed to assume that the module has passed /// validation. /// /// The default implementation is appropriate for generating type names to /// appear in error messages. It punts to `Scalar`'s [`core::fmt::Debug`] /// implementation, since it's probably best to show the user something they /// can act on. fn write_non_wgsl_scalar(&self, scalar: Scalar, out: &mut W) -> core::fmt::Result { match scalar.kind { crate::ScalarKind::Sint | crate::ScalarKind::Uint | crate::ScalarKind::Float | crate::ScalarKind::Bool => write!(out, "{{non-WGSL Naga scalar {scalar:?}}}"), // The abstract types are kind of an odd quasi-WGSL category: // they are definitely part of the spec, but they are not expressible // in WGSL itself. So we want to call them out by name in error messages, // but the WGSL backend should never generate these. crate::ScalarKind::AbstractInt => out.write_str("{AbstractInt}"), crate::ScalarKind::AbstractFloat => out.write_str("{AbstractFloat}"), } } /// Write the type `ty` as it would appear in a value's declaration. /// /// Write the type referred to by `ty` in `module` as it would appear in /// a `var`, `let`, etc. declaration, or in a function's argument list. fn write_type(&self, handle: Handle, out: &mut W) -> core::fmt::Result { let ty = self.lookup_type(handle); match ty.inner { TypeInner::Struct { .. } => out.write_str(self.type_name(handle))?, ref other => self.write_type_inner(other, out)?, } Ok(()) } /// Write the [`TypeInner`] `inner` as it would appear in a value's declaration. /// /// Write `inner` as it would appear in a `var`, `let`, etc. /// declaration, or in a function's argument list. /// /// Note that this cannot handle writing [`Struct`] types: those /// must be referred to by name, but the name isn't available in /// [`TypeInner`]. /// /// [`Struct`]: TypeInner::Struct fn write_type_inner(&self, inner: &TypeInner, out: &mut W) -> core::fmt::Result { match try_write_type_inner(self, inner, out) { Ok(()) => Ok(()), Err(WriteTypeError::Format(err)) => Err(err), Err(WriteTypeError::NonWgsl) => self.write_non_wgsl_inner(inner, out), } } /// Write the [`Scalar`] `scalar` as a WGSL type. fn write_scalar(&self, scalar: Scalar, out: &mut W) -> core::fmt::Result { match scalar.try_to_wgsl() { Some(string) => out.write_str(string), None => self.write_non_wgsl_scalar(scalar, out), } } /// Write the [`TypeResolution`] `resolution` as a WGSL type. fn write_type_resolution( &self, resolution: &TypeResolution, out: &mut W, ) -> core::fmt::Result { match *resolution { TypeResolution::Handle(handle) => self.write_type(handle, out), TypeResolution::Value(ref inner) => self.write_type_inner(inner, out), } } fn write_type_conclusion( &self, conclusion: &crate::proc::Conclusion, out: &mut W, ) -> core::fmt::Result { use crate::proc::Conclusion as Co; match *conclusion { Co::Value(ref inner) => self.write_type_inner(inner, out), Co::Predeclared(ref predeclared) => out.write_str(&predeclared.struct_name()), } } fn write_type_rule( &self, name: &str, rule: &crate::proc::Rule, out: &mut W, ) -> core::fmt::Result { write!(out, "fn {name}(")?; for (i, arg) in rule.arguments.iter().enumerate() { if i > 0 { out.write_str(", ")?; } self.write_type_resolution(arg, out)? } out.write_str(") -> ")?; self.write_type_conclusion(&rule.conclusion, out)?; Ok(()) } fn type_to_string(&self, handle: Handle) -> String { let mut buf = String::new(); self.write_type(handle, &mut buf).unwrap(); buf } fn type_resolution_to_string(&self, resolution: &TypeResolution) -> String { let mut buf = String::new(); self.write_type_resolution(resolution, &mut buf).unwrap(); buf } fn type_rule_to_string(&self, name: &str, rule: &crate::proc::Rule) -> String { let mut buf = String::new(); self.write_type_rule(name, rule, &mut buf).unwrap(); buf } } fn try_write_type_inner(ctx: &C, inner: &TypeInner, out: &mut W) -> Result<(), WriteTypeError> where C: TypeContext + ?Sized, W: Write, { match *inner { TypeInner::Vector { size, scalar } => { write!(out, "vec{}<", common::vector_size_str(size))?; ctx.write_scalar(scalar, out)?; out.write_str(">")?; } TypeInner::Sampler { comparison: false } => { write!(out, "sampler")?; } TypeInner::Sampler { comparison: true } => { write!(out, "sampler_comparison")?; } TypeInner::Image { dim, arrayed, class, } => { // More about texture types: https://gpuweb.github.io/gpuweb/wgsl/#sampled-texture-type use crate::ImageClass as Ic; let dim_str = dim.to_wgsl(); let arrayed_str = if arrayed { "_array" } else { "" }; match class { Ic::Sampled { kind, multi } => { let multisampled_str = if multi { "multisampled_" } else { "" }; write!(out, "texture_{multisampled_str}{dim_str}{arrayed_str}<")?; ctx.write_scalar(Scalar { kind, width: 4 }, out)?; out.write_str(">")?; } Ic::Depth { multi } => { let multisampled_str = if multi { "multisampled_" } else { "" }; write!( out, "texture_depth_{multisampled_str}{dim_str}{arrayed_str}" )?; } Ic::Storage { format, access } => { let format_str = format.to_wgsl(); let access_str = if access.contains(crate::StorageAccess::ATOMIC) { ",atomic" } else if access .contains(crate::StorageAccess::LOAD | crate::StorageAccess::STORE) { ",read_write" } else if access.contains(crate::StorageAccess::LOAD) { ",read" } else { ",write" }; write!( out, "texture_storage_{dim_str}{arrayed_str}<{format_str}{access_str}>" )?; } Ic::External => { write!(out, "texture_external")?; } } } TypeInner::Scalar(scalar) => { ctx.write_scalar(scalar, out)?; } TypeInner::Atomic(scalar) => { out.write_str("atomic<")?; ctx.write_scalar(scalar, out)?; out.write_str(">")?; } TypeInner::Array { base, size, stride: _, } => { // More info https://gpuweb.github.io/gpuweb/wgsl/#array-types // array -- Constant array // array -- Dynamic array write!(out, "array<")?; match size { crate::ArraySize::Constant(len) => { ctx.write_type(base, out)?; write!(out, ", {len}")?; } crate::ArraySize::Pending(r#override) => { ctx.write_override(r#override, out)?; } crate::ArraySize::Dynamic => { ctx.write_type(base, out)?; } } write!(out, ">")?; } TypeInner::BindingArray { base, size } => { // More info https://github.com/gpuweb/gpuweb/issues/2105 write!(out, "binding_array<")?; match size { crate::ArraySize::Constant(len) => { ctx.write_type(base, out)?; write!(out, ", {len}")?; } crate::ArraySize::Pending(r#override) => { ctx.write_override(r#override, out)?; } crate::ArraySize::Dynamic => { ctx.write_type(base, out)?; } } write!(out, ">")?; } TypeInner::Matrix { columns, rows, scalar, } => { write!( out, "mat{}x{}<", common::vector_size_str(columns), common::vector_size_str(rows), )?; ctx.write_scalar(scalar, out)?; out.write_str(">")?; } TypeInner::CooperativeMatrix { columns, rows, scalar, role, } => { write!( out, "coop_mat{}x{}<{},{}>", columns as u32, rows as u32, scalar.try_to_wgsl().unwrap_or_default(), role.to_wgsl(), )?; } TypeInner::Pointer { base, space } => { let (address, maybe_access) = address_space_str(space); // Everything but `AddressSpace::Handle` gives us a `address` name, but // Naga IR never produces pointers to handles, so it doesn't matter much // how we write such a type. Just write it as the base type alone. if let Some(space) = address { write!(out, "ptr<{space}, ")?; } ctx.write_type(base, out)?; if address.is_some() { if let Some(access) = maybe_access { write!(out, ", {access}")?; } write!(out, ">")?; } } TypeInner::ValuePointer { size: None, scalar, space, } => { let (address, maybe_access) = address_space_str(space); if let Some(space) = address { write!(out, "ptr<{space}, ")?; ctx.write_scalar(scalar, out)?; if let Some(access) = maybe_access { write!(out, ", {access}")?; } write!(out, ">")?; } else { return Err(WriteTypeError::NonWgsl); } } TypeInner::ValuePointer { size: Some(size), scalar, space, } => { let (address, maybe_access) = address_space_str(space); if let Some(space) = address { write!(out, "ptr<{}, vec{}<", space, common::vector_size_str(size),)?; ctx.write_scalar(scalar, out)?; out.write_str(">")?; if let Some(access) = maybe_access { write!(out, ", {access}")?; } write!(out, ">")?; } else { return Err(WriteTypeError::NonWgsl); } write!(out, ">")?; } TypeInner::AccelerationStructure { vertex_return } => { let caps = if vertex_return { "" } else { "" }; write!(out, "acceleration_structure{caps}")? } TypeInner::Struct { .. } => { ctx.write_unnamed_struct(inner, out)?; } TypeInner::RayQuery { vertex_return } => { let caps = if vertex_return { "" } else { "" }; write!(out, "ray_query{caps}")? } } Ok(()) } /// Error type returned by `try_write_type_inner`. /// /// This type is private to the module. enum WriteTypeError { Format(core::fmt::Error), NonWgsl, } impl From for WriteTypeError { fn from(err: core::fmt::Error) -> Self { Self::Format(err) } } /// Format types as WGSL based on a [`GlobalCtx`]. /// /// This is probably good enough for diagnostic output, but it has some /// limitations: /// /// - It does not apply [`Namer`] renamings, to avoid collisions. /// /// - It generates invalid WGSL for anonymous struct types. /// /// - It doesn't write the lengths of override-expression-sized arrays /// correctly, unless the expression is just the override identifier. /// /// [`GlobalCtx`]: crate::proc::GlobalCtx /// [`Namer`]: crate::proc::Namer impl TypeContext for crate::proc::GlobalCtx<'_> { fn lookup_type(&self, handle: Handle) -> &crate::Type { &self.types[handle] } fn type_name(&self, handle: Handle) -> &str { self.types[handle] .name .as_deref() .unwrap_or("{anonymous type}") } fn write_unnamed_struct(&self, _: &TypeInner, out: &mut W) -> core::fmt::Result { write!(out, "{{unnamed struct}}") } fn write_override( &self, handle: Handle, out: &mut W, ) -> core::fmt::Result { match self.overrides[handle].name { Some(ref name) => out.write_str(name), None => write!(out, "{{anonymous override {handle:?}}}"), } } } /// Format types as WGSL based on a `UniqueArena`. /// /// This is probably only good enough for logging: /// /// - It does not apply any kind of [`Namer`] renamings. /// /// - It generates invalid WGSL for anonymous struct types. /// /// - It doesn't write override-sized arrays properly. /// /// [`Namer`]: crate::proc::Namer impl TypeContext for crate::UniqueArena { fn lookup_type(&self, handle: Handle) -> &crate::Type { &self[handle] } fn type_name(&self, handle: Handle) -> &str { self[handle].name.as_deref().unwrap_or("{anonymous type}") } fn write_unnamed_struct(&self, inner: &TypeInner, out: &mut W) -> core::fmt::Result { write!(out, "{{unnamed struct {inner:?}}}") } fn write_override( &self, handle: Handle, out: &mut W, ) -> core::fmt::Result { write!(out, "{{override {handle:?}}}") } } naga-29.0.3/src/compact/expressions.rs000064400000000000000000000415761046102023000160070ustar 00000000000000use super::{HandleMap, HandleSet, ModuleMap}; use crate::arena::{Arena, Handle}; pub struct ExpressionTracer<'tracer> { pub constants: &'tracer Arena, pub overrides: &'tracer Arena, /// The arena in which we are currently tracing expressions. pub expressions: &'tracer Arena, /// The used map for `types`. pub types_used: &'tracer mut HandleSet, /// The used map for global variables. pub global_variables_used: &'tracer mut HandleSet, /// The used map for `constants`. pub constants_used: &'tracer mut HandleSet, /// The used map for `overrides`. pub overrides_used: &'tracer mut HandleSet, /// The used set for `arena`. /// /// This points to whatever arena holds the expressions we are /// currently tracing: either a function's expression arena, or /// the module's constant expression arena. pub expressions_used: &'tracer mut HandleSet, /// The used set for the module's `global_expressions` arena. /// /// If `None`, we are already tracing the constant expressions, /// and `expressions_used` already refers to their handle set. pub global_expressions_used: Option<&'tracer mut HandleSet>, } impl ExpressionTracer<'_> { /// Propagate usage through `self.expressions`, starting with `self.expressions_used`. /// /// Treat `self.expressions_used` as the initial set of "known /// live" expressions, and follow through to identify all /// transitively used expressions. /// /// Mark types, constants, and constant expressions used directly /// by `self.expressions` as used. Items used indirectly are not /// marked. /// /// [fe]: crate::Function::expressions /// [ce]: crate::Module::global_expressions pub fn trace_expressions(&mut self) { log::trace!( "entering trace_expression of {}", if self.global_expressions_used.is_some() { "function expressions" } else { "const expressions" } ); // We don't need recursion or a work list. Because an // expression may only refer to other expressions that precede // it in the arena, it suffices to make a single pass over the // arena from back to front, marking the referents of used // expressions as used themselves. for (handle, expr) in self.expressions.iter().rev() { // If this expression isn't used, it doesn't matter what it uses. if !self.expressions_used.contains(handle) { continue; } log::trace!("tracing new expression {expr:?}"); self.trace_expression(expr); } } pub fn trace_expression(&mut self, expr: &crate::Expression) { use crate::Expression as Ex; match *expr { // Expressions that do not contain handles that need to be traced. Ex::Literal(_) | Ex::FunctionArgument(_) | Ex::LocalVariable(_) | Ex::SubgroupBallotResult | Ex::RayQueryProceedResult => {} // Expressions can refer to constants and overrides, which can refer // in turn to expressions, which complicates our nice one-pass // algorithm. But since constants and overrides don't refer to each // other directly, only via expressions, we can get around this by // looking *through* each constant/override and marking its // initializer expression as used immediately. Since `expr` refers // to the constant/override, which then refers to the initializer, // the initializer must precede `expr` in the arena, so we know we // have yet to visit the initializer, so it's not too late to mark // it. Ex::Constant(handle) => { self.constants_used.insert(handle); let constant = &self.constants[handle]; self.types_used.insert(constant.ty); match self.global_expressions_used { Some(ref mut used) => used.insert(constant.init), None => self.expressions_used.insert(constant.init), }; } Ex::Override(handle) => { self.overrides_used.insert(handle); let r#override = &self.overrides[handle]; self.types_used.insert(r#override.ty); if let Some(init) = r#override.init { match self.global_expressions_used { Some(ref mut used) => used.insert(init), None => self.expressions_used.insert(init), }; } } Ex::ZeroValue(ty) => { self.types_used.insert(ty); } Ex::Compose { ty, ref components } => { self.types_used.insert(ty); self.expressions_used .insert_iter(components.iter().cloned()); } Ex::Access { base, index } => self.expressions_used.insert_iter([base, index]), Ex::AccessIndex { base, index: _ } => { self.expressions_used.insert(base); } Ex::Splat { size: _, value } => { self.expressions_used.insert(value); } Ex::Swizzle { size: _, vector, pattern: _, } => { self.expressions_used.insert(vector); } Ex::GlobalVariable(handle) => { self.global_variables_used.insert(handle); } Ex::Load { pointer } => { self.expressions_used.insert(pointer); } Ex::ImageSample { image, sampler, gather: _, coordinate, array_index, offset, ref level, depth_ref, clamp_to_edge: _, } => { self.expressions_used .insert_iter([image, sampler, coordinate]); self.expressions_used.insert_iter(array_index); self.expressions_used.insert_iter(offset); use crate::SampleLevel as Sl; match *level { Sl::Auto | Sl::Zero => {} Sl::Exact(expr) | Sl::Bias(expr) => { self.expressions_used.insert(expr); } Sl::Gradient { x, y } => self.expressions_used.insert_iter([x, y]), } self.expressions_used.insert_iter(depth_ref); } Ex::ImageLoad { image, coordinate, array_index, sample, level, } => { self.expressions_used.insert(image); self.expressions_used.insert(coordinate); self.expressions_used.insert_iter(array_index); self.expressions_used.insert_iter(sample); self.expressions_used.insert_iter(level); } Ex::ImageQuery { image, ref query } => { self.expressions_used.insert(image); use crate::ImageQuery as Iq; match *query { Iq::Size { level } => self.expressions_used.insert_iter(level), Iq::NumLevels | Iq::NumLayers | Iq::NumSamples => {} } } Ex::RayQueryVertexPositions { query, committed: _, } => { self.expressions_used.insert(query); } Ex::Unary { op: _, expr } => { self.expressions_used.insert(expr); } Ex::Binary { op: _, left, right } => { self.expressions_used.insert_iter([left, right]); } Ex::Select { condition, accept, reject, } => self .expressions_used .insert_iter([condition, accept, reject]), Ex::Derivative { axis: _, ctrl: _, expr, } => { self.expressions_used.insert(expr); } Ex::Relational { fun: _, argument } => { self.expressions_used.insert(argument); } Ex::Math { fun: _, arg, arg1, arg2, arg3, } => { self.expressions_used.insert(arg); self.expressions_used.insert_iter(arg1); self.expressions_used.insert_iter(arg2); self.expressions_used.insert_iter(arg3); } Ex::As { expr, kind: _, convert: _, } => { self.expressions_used.insert(expr); } Ex::ArrayLength(expr) => { self.expressions_used.insert(expr); } // `CallResult` expressions do contain a function handle, but any used // `CallResult` expression should have an associated `ir::Statement::Call` // that we will trace. Ex::CallResult(_) => {} Ex::AtomicResult { ty, comparison: _ } | Ex::WorkGroupUniformLoadResult { ty } | Ex::SubgroupOperationResult { ty } => { self.types_used.insert(ty); } Ex::RayQueryGetIntersection { query, committed: _, } => { self.expressions_used.insert(query); } Ex::CooperativeLoad { ref data, .. } => { self.expressions_used.insert(data.pointer); self.expressions_used.insert(data.stride); } Ex::CooperativeMultiplyAdd { a, b, c } => { self.expressions_used.insert(a); self.expressions_used.insert(b); self.expressions_used.insert(c); } } } } impl ModuleMap { /// Fix up all handles in `expr`. /// /// Use the expression handle remappings in `operand_map`, and all /// other mappings from `self`. pub fn adjust_expression( &self, expr: &mut crate::Expression, operand_map: &HandleMap, ) { let adjust = |expr: &mut Handle| { operand_map.adjust(expr); }; use crate::Expression as Ex; match *expr { // Expressions that do not contain handles that need to be adjusted. Ex::Literal(_) | Ex::FunctionArgument(_) | Ex::LocalVariable(_) | Ex::SubgroupBallotResult | Ex::RayQueryProceedResult => {} // Expressions that contain handles that need to be adjusted. Ex::Constant(ref mut constant) => self.constants.adjust(constant), Ex::Override(ref mut r#override) => self.overrides.adjust(r#override), Ex::ZeroValue(ref mut ty) => self.types.adjust(ty), Ex::Compose { ref mut ty, ref mut components, } => { self.types.adjust(ty); for component in components { adjust(component); } } Ex::Access { ref mut base, ref mut index, } => { adjust(base); adjust(index); } Ex::AccessIndex { ref mut base, index: _, } => adjust(base), Ex::Splat { size: _, ref mut value, } => adjust(value), Ex::Swizzle { size: _, ref mut vector, pattern: _, } => adjust(vector), Ex::GlobalVariable(ref mut handle) => self.globals.adjust(handle), Ex::Load { ref mut pointer } => adjust(pointer), Ex::ImageSample { ref mut image, ref mut sampler, gather: _, ref mut coordinate, ref mut array_index, ref mut offset, ref mut level, ref mut depth_ref, clamp_to_edge: _, } => { adjust(image); adjust(sampler); adjust(coordinate); operand_map.adjust_option(array_index); operand_map.adjust_option(offset); self.adjust_sample_level(level, operand_map); operand_map.adjust_option(depth_ref); } Ex::ImageLoad { ref mut image, ref mut coordinate, ref mut array_index, ref mut sample, ref mut level, } => { adjust(image); adjust(coordinate); operand_map.adjust_option(array_index); operand_map.adjust_option(sample); operand_map.adjust_option(level); } Ex::ImageQuery { ref mut image, ref mut query, } => { adjust(image); self.adjust_image_query(query, operand_map); } Ex::Unary { op: _, ref mut expr, } => adjust(expr), Ex::Binary { op: _, ref mut left, ref mut right, } => { adjust(left); adjust(right); } Ex::Select { ref mut condition, ref mut accept, ref mut reject, } => { adjust(condition); adjust(accept); adjust(reject); } Ex::Derivative { axis: _, ctrl: _, ref mut expr, } => adjust(expr), Ex::Relational { fun: _, ref mut argument, } => adjust(argument), Ex::Math { fun: _, ref mut arg, ref mut arg1, ref mut arg2, ref mut arg3, } => { adjust(arg); operand_map.adjust_option(arg1); operand_map.adjust_option(arg2); operand_map.adjust_option(arg3); } Ex::As { ref mut expr, kind: _, convert: _, } => adjust(expr), Ex::CallResult(ref mut function) => { self.functions.adjust(function); } Ex::AtomicResult { ref mut ty, comparison: _, } => self.types.adjust(ty), Ex::WorkGroupUniformLoadResult { ref mut ty } => self.types.adjust(ty), Ex::SubgroupOperationResult { ref mut ty } => self.types.adjust(ty), Ex::ArrayLength(ref mut expr) => adjust(expr), Ex::RayQueryGetIntersection { ref mut query, committed: _, } => adjust(query), Ex::RayQueryVertexPositions { ref mut query, committed: _, } => adjust(query), Ex::CooperativeLoad { ref mut data, .. } => { adjust(&mut data.pointer); adjust(&mut data.stride); } Ex::CooperativeMultiplyAdd { ref mut a, ref mut b, ref mut c, } => { adjust(a); adjust(b); adjust(c); } } } fn adjust_sample_level( &self, level: &mut crate::SampleLevel, operand_map: &HandleMap, ) { let adjust = |expr: &mut Handle| operand_map.adjust(expr); use crate::SampleLevel as Sl; match *level { Sl::Auto | Sl::Zero => {} Sl::Exact(ref mut expr) => adjust(expr), Sl::Bias(ref mut expr) => adjust(expr), Sl::Gradient { ref mut x, ref mut y, } => { adjust(x); adjust(y); } } } fn adjust_image_query( &self, query: &mut crate::ImageQuery, operand_map: &HandleMap, ) { use crate::ImageQuery as Iq; match *query { Iq::Size { ref mut level } => operand_map.adjust_option(level), Iq::NumLevels | Iq::NumLayers | Iq::NumSamples => {} } } } naga-29.0.3/src/compact/functions.rs000064400000000000000000000102221046102023000154150ustar 00000000000000use super::arena::HandleSet; use super::{FunctionMap, ModuleMap}; pub struct FunctionTracer<'a> { pub function: &'a crate::Function, pub constants: &'a crate::Arena, pub overrides: &'a crate::Arena, pub functions_pending: &'a mut HandleSet, pub functions_used: &'a mut HandleSet, pub types_used: &'a mut HandleSet, pub global_variables_used: &'a mut HandleSet, pub constants_used: &'a mut HandleSet, pub overrides_used: &'a mut HandleSet, pub global_expressions_used: &'a mut HandleSet, /// Function-local expressions used. pub expressions_used: HandleSet, } impl FunctionTracer<'_> { pub fn trace_call(&mut self, function: crate::Handle) { if !self.functions_used.contains(function) { self.functions_used.insert(function); self.functions_pending.insert(function); } } pub fn trace(&mut self) { for argument in self.function.arguments.iter() { self.types_used.insert(argument.ty); } if let Some(ref result) = self.function.result { self.types_used.insert(result.ty); } for (_, local) in self.function.local_variables.iter() { self.types_used.insert(local.ty); if let Some(init) = local.init { self.expressions_used.insert(init); } } // Treat named expressions as alive, for the sake of our test suite, // which uses `let blah = expr;` to exercise lots of things. for (&value, _name) in &self.function.named_expressions { self.expressions_used.insert(value); } self.trace_block(&self.function.body); // Given that `trace_block` has marked the expressions used // directly by statements, walk the arena to find all // expressions used, directly or indirectly. self.as_expression().trace_expressions(); } const fn as_expression(&mut self) -> super::expressions::ExpressionTracer<'_> { super::expressions::ExpressionTracer { constants: self.constants, overrides: self.overrides, expressions: &self.function.expressions, types_used: self.types_used, global_variables_used: self.global_variables_used, constants_used: self.constants_used, overrides_used: self.overrides_used, expressions_used: &mut self.expressions_used, global_expressions_used: Some(&mut self.global_expressions_used), } } } impl FunctionMap { pub fn compact( &self, function: &mut crate::Function, module_map: &ModuleMap, reuse: &mut crate::NamedExpressions, ) { assert!(reuse.is_empty()); for argument in function.arguments.iter_mut() { module_map.types.adjust(&mut argument.ty); } if let Some(ref mut result) = function.result { module_map.types.adjust(&mut result.ty); } for (_, local) in function.local_variables.iter_mut() { log::trace!("adjusting local variable {:?}", local.name); module_map.types.adjust(&mut local.ty); if let Some(ref mut init) = local.init { self.expressions.adjust(init); } } // Drop unused expressions, reusing existing storage. function.expressions.retain_mut(|handle, expr| { if self.expressions.used(handle) { module_map.adjust_expression(expr, &self.expressions); true } else { false } }); // Adjust named expressions. for (mut handle, name) in function.named_expressions.drain(..) { self.expressions.adjust(&mut handle); reuse.insert(handle, name); } core::mem::swap(&mut function.named_expressions, reuse); assert!(reuse.is_empty()); // Adjust statements. self.adjust_body(function, &module_map.functions); } } naga-29.0.3/src/compact/handle_set_map.rs000064400000000000000000000114651046102023000163620ustar 00000000000000use alloc::vec::Vec; use crate::arena::{Arena, Handle, HandleSet, Range}; type Index = crate::non_max_u32::NonMaxU32; /// A map keyed by handles. /// /// In most cases, this is used to map from old handle indices to new, /// compressed handle indices. #[derive(Debug)] pub struct HandleMap { /// The indices assigned to handles in the compacted module. /// /// If `new_index[i]` is `Some(n)`, then `n` is the `Index` of the /// compacted `Handle` corresponding to the pre-compacted `Handle` /// whose index is `i`. new_index: Vec>, /// This type is indexed by values of type `T`. as_keys: core::marker::PhantomData, } impl HandleMap { pub fn with_capacity(capacity: usize) -> Self { Self { new_index: Vec::with_capacity(capacity), as_keys: core::marker::PhantomData, } } pub fn get(&self, handle: Handle) -> Option<&U> { self.new_index.get(handle.index()).unwrap_or(&None).as_ref() } pub fn insert(&mut self, handle: Handle, value: U) -> Option { if self.new_index.len() <= handle.index() { self.new_index.resize_with(handle.index() + 1, || None); } self.new_index[handle.index()].replace(value) } } impl HandleMap { pub fn from_set(set: HandleSet) -> Self { let mut next_index = Index::new(0).unwrap(); Self { new_index: set .all_possible() .map(|handle| { if set.contains(handle) { // This handle will be retained in the compacted version, // so assign it a new index. let this = next_index; next_index = next_index.checked_add(1).unwrap(); Some(this) } else { // This handle will be omitted in the compacted version. None } }) .collect(), as_keys: core::marker::PhantomData, } } /// Return true if `old` is used in the compacted module. pub fn used(&self, old: Handle) -> bool { self.new_index[old.index()].is_some() } /// Return the counterpart to `old` in the compacted module. /// /// If we thought `old` wouldn't be used in the compacted module, return /// `None`. pub fn try_adjust(&self, old: Handle) -> Option> { log::trace!( "adjusting {} handle [{}] -> [{:?}]", core::any::type_name::(), old.index(), self.new_index[old.index()] ); self.new_index[old.index()].map(Handle::new) } /// Return the counterpart to `old` in the compacted module. /// /// If we thought `old` wouldn't be used in the compacted module, panic. pub fn adjust(&self, handle: &mut Handle) { *handle = self.try_adjust(*handle).unwrap(); } /// Like `adjust`, but for optional handles. pub fn adjust_option(&self, handle: &mut Option>) { if let Some(ref mut handle) = *handle { self.adjust(handle); } } /// Shrink `range` to include only used handles. /// /// Fortunately, compaction doesn't arbitrarily scramble the expressions /// in the arena, but instead preserves the order of the elements while /// squeezing out unused ones. That means that a contiguous range in the /// pre-compacted arena always maps to a contiguous range in the /// post-compacted arena. So we just need to adjust the endpoints. /// /// Compaction may have eliminated the endpoints themselves. /// /// Use `compacted_arena` to bounds-check the result. pub fn adjust_range(&self, range: &mut Range, compacted_arena: &Arena) { let mut index_range = range.index_range(); let compacted; if let Some(first) = index_range.find_map(|i| self.new_index[i as usize]) { // The first call to `find_map` mutated `index_range` to hold the // remainder of original range, which is exactly the range we need // to search for the new last handle. if let Some(last) = index_range.rev().find_map(|i| self.new_index[i as usize]) { // Build an end-exclusive range, given the two included indices // `first` and `last`. compacted = first.get()..last.get() + 1; } else { // The range contains only a single live handle, which // we identified with the first `find_map` call. compacted = first.get()..first.get() + 1; } } else { compacted = 0..0; }; *range = Range::from_index_range(compacted, compacted_arena); } } naga-29.0.3/src/compact/mod.rs000064400000000000000000001330671046102023000142010ustar 00000000000000mod expressions; mod functions; mod handle_set_map; mod statements; mod types; use alloc::vec::Vec; use crate::{ arena::{self, HandleSet}, compact::functions::FunctionTracer, ir, }; use handle_set_map::HandleMap; #[cfg(test)] use alloc::{format, string::ToString}; /// Configuration option for [`compact`]. See [`compact`] for details. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum KeepUnused { No, Yes, } impl From for bool { fn from(keep_unused: KeepUnused) -> Self { match keep_unused { KeepUnused::No => false, KeepUnused::Yes => true, } } } /// Remove most unused objects from `module`, which must be valid. /// /// Always removes the following unused objects: /// - anonymous types, overrides, and constants /// - abstract-typed constants /// - expressions /// /// If `keep_unused` is `Yes`, the following are never considered unused, /// otherwise, they will also be removed if unused: /// - functions /// - global variables /// - named types and overrides /// /// The following are never removed: /// - named constants with a concrete type /// - special types /// - entry points /// - within an entry point or a used function: /// - arguments /// - local variables /// - named expressions /// /// After removing items according to the rules above, all handles in the /// remaining objects are adjusted as necessary. When `KeepUnused` is `Yes`, the /// resulting module should have all the named objects (except abstract-typed /// constants) present in the original, and those objects should be functionally /// identical. When `KeepUnused` is `No`, the resulting module should have the /// entry points present in the original, and those entry points should be /// functionally identical. /// /// # Panics /// /// If `module` would not pass validation, this may panic. pub fn compact(module: &mut crate::Module, keep_unused: KeepUnused) { // The trickiest part of compaction is determining what is used and what is // not. Once we have computed that correctly, it's easy enough to call // `retain_mut` on each arena, drop unused elements, and fix up the handles // in what's left. // // For every compactable arena in a `Module`, whether global to the `Module` // or local to a function or entry point, the `ModuleTracer` type holds a // bitmap indicating which elements of that arena are used. Our task is to // populate those bitmaps correctly. // // First, we mark everything that is considered used by definition, as // described in this function's documentation. // // Since functions and entry points are considered used by definition, we // traverse their statement trees, and mark the referents of all handles // appearing in those statements as used. // // Once we've marked which elements of an arena are referred to directly by // handles elsewhere (for example, which of a function's expressions are // referred to by handles in its body statements), we can mark all the other // arena elements that are used indirectly in a single pass, traversing the // arena from back to front. Since Naga allows arena elements to refer only // to prior elements, we know that by the time we reach an element, all // other elements that could possibly refer to it have already been visited. // Thus, if the present element has not been marked as used, then it is // definitely unused, and compaction can remove it. Otherwise, the element // is used and must be retained, so we must mark everything it refers to. // // The final step is to mark the global expressions and types, which must be // traversed simultaneously; see `ModuleTracer::type_expression_tandem`'s // documentation for details. // // # A definition and a rule of thumb // // In this module, to "trace" something is to mark everything else it refers // to as used, on the assumption that the thing itself is used. For example, // to trace an `Expression` is to mark its subexpressions as used, as well // as any types, constants, overrides, etc. that it refers to. This is what // `ExpressionTracer::trace_expression` does. // // Given that we we want to visit each thing only once (to keep compaction // linear in the size of the module), this definition of "trace" implies // that things that are not "used by definition" must be marked as used // *before* we trace them. // // Thus, whenever you are marking something as used, it's a good idea to ask // yourself how you know that thing will be traced in the future. If you're // not sure, then you could be marking it too late to be noticed. The thing // itself will be retained by compaction, but since it will not be traced, // anything it refers to could be compacted away. let mut module_tracer = ModuleTracer::new(module); // Observe what each entry point actually uses. log::trace!("tracing entry points"); let entry_point_maps = module .entry_points .iter() .map(|e| { log::trace!("tracing entry point {:?}", e.function.name); if let Some(sizes) = e.workgroup_size_overrides { for size in sizes.iter().filter_map(|x| *x) { module_tracer.global_expressions_used.insert(size); } } if let Some(task_payload) = e.task_payload { module_tracer.global_variables_used.insert(task_payload); } if let Some(ref mesh_info) = e.mesh_info { module_tracer .global_variables_used .insert(mesh_info.output_variable); module_tracer .types_used .insert(mesh_info.vertex_output_type); module_tracer .types_used .insert(mesh_info.primitive_output_type); if let Some(max_vertices_override) = mesh_info.max_vertices_override { module_tracer .global_expressions_used .insert(max_vertices_override); } if let Some(max_primitives_override) = mesh_info.max_primitives_override { module_tracer .global_expressions_used .insert(max_primitives_override); } } if e.stage == crate::ShaderStage::Task || e.stage == crate::ShaderStage::Mesh { // Mesh shaders always need a u32 type, as it is e.g. the type of some // expressions. We tolerate its absence here because compaction is // infallible, but the module will fail validation. if let Some(u32_type) = module.types.iter().find_map(|tuple| { (tuple.1.inner == crate::TypeInner::Scalar(crate::Scalar::U32)) .then_some(tuple.0) }) { module_tracer.types_used.insert(u32_type); } } let mut used = module_tracer.as_function(&e.function); used.trace(); FunctionMap::from(used) }) .collect::>(); // Observe which types, constant expressions, constants, and expressions // each function uses, and produce maps for each function from // pre-compaction to post-compaction expression handles. // // The function tracing logic here works in conjunction with // `FunctionTracer::trace_call`, which, when tracing a `Statement::Call` // to a function not already identified as used, adds the called function // to both `functions_used` and `functions_pending`. // // Called functions are required to appear before their callers in the // functions arena (recursion is disallowed). We have already traced the // entry point(s) and added any functions called directly by the entry // point(s) to `functions_pending`. We proceed by repeatedly tracing the // last function in `functions_pending`. By an inductive argument, any // functions after the last function in `functions_pending` must be unused. // // When `KeepUnused` is active, we simply mark all functions as pending, // and then trace all of them. log::trace!("tracing functions"); let mut function_maps = HandleMap::with_capacity(module.functions.len()); if keep_unused.into() { module_tracer.functions_used.add_all(); module_tracer.functions_pending.add_all(); } while let Some(handle) = module_tracer.functions_pending.pop() { let function = &module.functions[handle]; log::trace!("tracing function {function:?}"); let mut function_tracer = module_tracer.as_function(function); function_tracer.trace(); function_maps.insert(handle, FunctionMap::from(function_tracer)); } // We treat all special types as used by definition. log::trace!("tracing special types"); module_tracer.trace_special_types(&module.special_types); log::trace!("tracing global variables"); if keep_unused.into() { module_tracer.global_variables_used.add_all(); } for global in module_tracer.global_variables_used.iter() { log::trace!("tracing global {:?}", module.global_variables[global].name); module_tracer .types_used .insert(module.global_variables[global].ty); if let Some(init) = module.global_variables[global].init { module_tracer.global_expressions_used.insert(init); } } // We treat all named constants as used by definition, unless they have an // abstract type as we do not want those reaching the validator. log::trace!("tracing named constants"); for (handle, constant) in module.constants.iter() { if constant.name.is_none() || module.types[constant.ty].inner.is_abstract(&module.types) { continue; } log::trace!("tracing constant {:?}", constant.name.as_ref().unwrap()); module_tracer.constants_used.insert(handle); module_tracer.types_used.insert(constant.ty); module_tracer.global_expressions_used.insert(constant.init); } if keep_unused.into() { // Treat all named overrides as used. for (handle, r#override) in module.overrides.iter() { if r#override.name.is_some() && module_tracer.overrides_used.insert(handle) { module_tracer.types_used.insert(r#override.ty); if let Some(init) = r#override.init { module_tracer.global_expressions_used.insert(init); } } } // Treat all named types as used. for (handle, ty) in module.types.iter() { if ty.name.is_some() { module_tracer.types_used.insert(handle); } } } module_tracer.type_expression_tandem(); // Now that we know what is used and what is never touched, // produce maps from the `Handle`s that appear in `module` now to // the corresponding `Handle`s that will refer to the same items // in the compacted module. let module_map = ModuleMap::from(module_tracer); // Drop unused types from the type arena. // // `FastIndexSet`s don't have an underlying Vec that we can // steal, compact in place, and then rebuild the `FastIndexSet` // from. So we have to rebuild the type arena from scratch. log::trace!("compacting types"); let mut new_types = arena::UniqueArena::new(); for (old_handle, mut ty, span) in module.types.drain_all() { if let Some(expected_new_handle) = module_map.types.try_adjust(old_handle) { module_map.adjust_type(&mut ty); let actual_new_handle = new_types.insert(ty, span); assert_eq!(actual_new_handle, expected_new_handle); } } module.types = new_types; log::trace!("adjusting special types"); module_map.adjust_special_types(&mut module.special_types); // Drop unused constant expressions, reusing existing storage. log::trace!("adjusting constant expressions"); module.global_expressions.retain_mut(|handle, expr| { if module_map.global_expressions.used(handle) { module_map.adjust_expression(expr, &module_map.global_expressions); true } else { false } }); // Drop unused constants in place, reusing existing storage. log::trace!("adjusting constants"); module.constants.retain_mut(|handle, constant| { if module_map.constants.used(handle) { module_map.types.adjust(&mut constant.ty); module_map.global_expressions.adjust(&mut constant.init); true } else { false } }); // Drop unused overrides in place, reusing existing storage. log::trace!("adjusting overrides"); module.overrides.retain_mut(|handle, r#override| { if module_map.overrides.used(handle) { module_map.types.adjust(&mut r#override.ty); if let Some(ref mut init) = r#override.init { module_map.global_expressions.adjust(init); } true } else { false } }); // Adjust workgroup_size_overrides log::trace!("adjusting workgroup_size_overrides"); for e in module.entry_points.iter_mut() { if let Some(sizes) = e.workgroup_size_overrides.as_mut() { for size in sizes.iter_mut() { if let Some(expr) = size.as_mut() { module_map.global_expressions.adjust(expr); } } } } // Drop unused global variables, reusing existing storage. // Adjust used global variables' types and initializers. log::trace!("adjusting global variables"); module.global_variables.retain_mut(|handle, global| { if module_map.globals.used(handle) { log::trace!("retaining global variable {:?}", global.name); module_map.types.adjust(&mut global.ty); if let Some(ref mut init) = global.init { module_map.global_expressions.adjust(init); } true } else { log::trace!("dropping global variable {:?}", global.name); false } }); // Adjust doc comments if let Some(ref mut doc_comments) = module.doc_comments { module_map.adjust_doc_comments(doc_comments.as_mut()); } // Temporary storage to help us reuse allocations of existing // named expression tables. let mut reused_named_expressions = crate::NamedExpressions::default(); // Drop unused functions. Compact and adjust used functions. module.functions.retain_mut(|handle, function| { if let Some(map) = function_maps.get(handle) { log::trace!("retaining and compacting function {:?}", function.name); map.compact(function, &module_map, &mut reused_named_expressions); true } else { log::trace!("dropping function {:?}", function.name); false } }); // Compact each entry point. for (entry, map) in module.entry_points.iter_mut().zip(entry_point_maps.iter()) { log::trace!("compacting entry point {:?}", entry.function.name); map.compact( &mut entry.function, &module_map, &mut reused_named_expressions, ); if let Some(ref mut task_payload) = entry.task_payload { module_map.globals.adjust(task_payload); } if let Some(ref mut mesh_info) = entry.mesh_info { module_map.globals.adjust(&mut mesh_info.output_variable); module_map.types.adjust(&mut mesh_info.vertex_output_type); module_map .types .adjust(&mut mesh_info.primitive_output_type); if let Some(ref mut max_vertices_override) = mesh_info.max_vertices_override { module_map.global_expressions.adjust(max_vertices_override); } if let Some(ref mut max_primitives_override) = mesh_info.max_primitives_override { module_map .global_expressions .adjust(max_primitives_override); } } } } struct ModuleTracer<'module> { module: &'module crate::Module, /// The subset of functions in `functions_used` that have not yet been /// traced. functions_pending: HandleSet, functions_used: HandleSet, types_used: HandleSet, global_variables_used: HandleSet, constants_used: HandleSet, overrides_used: HandleSet, global_expressions_used: HandleSet, } impl<'module> ModuleTracer<'module> { fn new(module: &'module crate::Module) -> Self { Self { module, functions_pending: HandleSet::for_arena(&module.functions), functions_used: HandleSet::for_arena(&module.functions), types_used: HandleSet::for_arena(&module.types), global_variables_used: HandleSet::for_arena(&module.global_variables), constants_used: HandleSet::for_arena(&module.constants), overrides_used: HandleSet::for_arena(&module.overrides), global_expressions_used: HandleSet::for_arena(&module.global_expressions), } } fn trace_special_types(&mut self, special_types: &crate::SpecialTypes) { let crate::SpecialTypes { ref ray_desc, ref ray_intersection, ref ray_vertex_return, ref predeclared_types, ref external_texture_params, ref external_texture_transfer_function, } = *special_types; if let Some(ray_desc) = *ray_desc { self.types_used.insert(ray_desc); } if let Some(ray_intersection) = *ray_intersection { self.types_used.insert(ray_intersection); } if let Some(ray_vertex_return) = *ray_vertex_return { self.types_used.insert(ray_vertex_return); } // The `external_texture_params` type is generated purely as a // convenience to the backends. While it will never actually be used in // the IR, it must be marked as used so that it survives compaction. if let Some(external_texture_params) = *external_texture_params { self.types_used.insert(external_texture_params); } if let Some(external_texture_transfer_function) = *external_texture_transfer_function { self.types_used.insert(external_texture_transfer_function); } for (_, &handle) in predeclared_types { self.types_used.insert(handle); } } /// Traverse types and global expressions in tandem to determine which are used. /// /// Assuming that all types and global expressions used by other parts of /// the module have been added to [`types_used`] and /// [`global_expressions_used`], expand those sets to include all types and /// global expressions reachable from those. /// /// [`types_used`]: ModuleTracer::types_used /// [`global_expressions_used`]: ModuleTracer::global_expressions_used fn type_expression_tandem(&mut self) { // For each type T, compute the latest global expression E that T and // its predecessors refer to. Given the ordering rules on types and // global expressions in valid modules, we can do this with a single // forward scan of the type arena. The rules further imply that T can // only be referred to by expressions after E. let mut max_dep = Vec::with_capacity(self.module.types.len()); let mut previous = None; for (_handle, ty) in self.module.types.iter() { previous = core::cmp::max( previous, match ty.inner { crate::TypeInner::Array { size, .. } | crate::TypeInner::BindingArray { size, .. } => match size { crate::ArraySize::Constant(_) | crate::ArraySize::Dynamic => None, crate::ArraySize::Pending(handle) => self.module.overrides[handle].init, }, _ => None, }, ); max_dep.push(previous); } // Visit types and global expressions from youngest to oldest. // // The outer loop visits types. Before visiting each type, the inner // loop ensures that all global expressions that could possibly refer to // it have been visited. And since the inner loop stop at the latest // expression that the type could possibly refer to, we know that we // have previously visited any types that might refer to each expression // we visit. // // This lets us assume that any type or expression that is *not* marked // as used by the time we visit it is genuinely unused, and can be // ignored. let mut exprs = self.module.global_expressions.iter().rev().peekable(); for ((ty_handle, ty), dep) in self.module.types.iter().zip(max_dep).rev() { while let Some((expr_handle, expr)) = exprs.next_if(|&(h, _)| Some(h) > dep) { if self.global_expressions_used.contains(expr_handle) { self.as_const_expression().trace_expression(expr); } } if self.types_used.contains(ty_handle) { self.as_type().trace_type(ty); } } // Visit any remaining expressions. for (expr_handle, expr) in exprs { if self.global_expressions_used.contains(expr_handle) { self.as_const_expression().trace_expression(expr); } } } const fn as_type(&mut self) -> types::TypeTracer<'_> { types::TypeTracer { overrides: &self.module.overrides, types_used: &mut self.types_used, expressions_used: &mut self.global_expressions_used, overrides_used: &mut self.overrides_used, } } const fn as_const_expression(&mut self) -> expressions::ExpressionTracer<'_> { expressions::ExpressionTracer { constants: &self.module.constants, overrides: &self.module.overrides, expressions: &self.module.global_expressions, types_used: &mut self.types_used, global_variables_used: &mut self.global_variables_used, constants_used: &mut self.constants_used, expressions_used: &mut self.global_expressions_used, overrides_used: &mut self.overrides_used, global_expressions_used: None, } } pub fn as_function<'tracer>( &'tracer mut self, function: &'tracer crate::Function, ) -> FunctionTracer<'tracer> { FunctionTracer { function, constants: &self.module.constants, overrides: &self.module.overrides, functions_pending: &mut self.functions_pending, functions_used: &mut self.functions_used, types_used: &mut self.types_used, global_variables_used: &mut self.global_variables_used, constants_used: &mut self.constants_used, overrides_used: &mut self.overrides_used, global_expressions_used: &mut self.global_expressions_used, expressions_used: HandleSet::for_arena(&function.expressions), } } } struct ModuleMap { functions: HandleMap, types: HandleMap, globals: HandleMap, constants: HandleMap, overrides: HandleMap, global_expressions: HandleMap, } impl From> for ModuleMap { fn from(used: ModuleTracer) -> Self { ModuleMap { functions: HandleMap::from_set(used.functions_used), types: HandleMap::from_set(used.types_used), globals: HandleMap::from_set(used.global_variables_used), constants: HandleMap::from_set(used.constants_used), overrides: HandleMap::from_set(used.overrides_used), global_expressions: HandleMap::from_set(used.global_expressions_used), } } } impl ModuleMap { fn adjust_special_types(&self, special: &mut crate::SpecialTypes) { let crate::SpecialTypes { ref mut ray_desc, ref mut ray_intersection, ref mut ray_vertex_return, ref mut predeclared_types, ref mut external_texture_params, ref mut external_texture_transfer_function, } = *special; if let Some(ref mut ray_desc) = *ray_desc { self.types.adjust(ray_desc); } if let Some(ref mut ray_intersection) = *ray_intersection { self.types.adjust(ray_intersection); } if let Some(ref mut ray_vertex_return) = *ray_vertex_return { self.types.adjust(ray_vertex_return); } if let Some(ref mut external_texture_params) = *external_texture_params { self.types.adjust(external_texture_params); } if let Some(ref mut external_texture_transfer_function) = *external_texture_transfer_function { self.types.adjust(external_texture_transfer_function); } for handle in predeclared_types.values_mut() { self.types.adjust(handle); } } fn adjust_doc_comments(&self, doc_comments: &mut ir::DocComments) { let crate::DocComments { module: _, types: ref mut doc_types, struct_members: ref mut doc_struct_members, entry_points: _, functions: ref mut doc_functions, constants: ref mut doc_constants, global_variables: ref mut doc_globals, } = *doc_comments; log::trace!("adjusting doc comments for types"); for (mut ty, doc_comment) in core::mem::take(doc_types) { if !self.types.used(ty) { continue; } self.types.adjust(&mut ty); doc_types.insert(ty, doc_comment); } log::trace!("adjusting doc comments for struct members"); for ((mut ty, index), doc_comment) in core::mem::take(doc_struct_members) { if !self.types.used(ty) { continue; } self.types.adjust(&mut ty); doc_struct_members.insert((ty, index), doc_comment); } log::trace!("adjusting doc comments for functions"); for (mut handle, doc_comment) in core::mem::take(doc_functions) { if !self.functions.used(handle) { continue; } self.functions.adjust(&mut handle); doc_functions.insert(handle, doc_comment); } log::trace!("adjusting doc comments for constants"); for (mut constant, doc_comment) in core::mem::take(doc_constants) { if !self.constants.used(constant) { continue; } self.constants.adjust(&mut constant); doc_constants.insert(constant, doc_comment); } log::trace!("adjusting doc comments for globals"); for (mut handle, doc_comment) in core::mem::take(doc_globals) { if !self.globals.used(handle) { continue; } self.globals.adjust(&mut handle); doc_globals.insert(handle, doc_comment); } } } struct FunctionMap { expressions: HandleMap, } impl From> for FunctionMap { fn from(used: FunctionTracer) -> Self { FunctionMap { expressions: HandleMap::from_set(used.expressions_used), } } } #[test] fn type_expression_interdependence() { let mut module: crate::Module = Default::default(); let u32 = module.types.insert( crate::Type { name: None, inner: crate::TypeInner::Scalar(crate::Scalar { kind: crate::ScalarKind::Uint, width: 4, }), }, crate::Span::default(), ); let expr = module.global_expressions.append( crate::Expression::Literal(crate::Literal::U32(0)), crate::Span::default(), ); let type_needs_expression = |module: &mut crate::Module, handle| { let override_handle = module.overrides.append( crate::Override { name: None, id: None, ty: u32, init: Some(handle), }, crate::Span::default(), ); module.types.insert( crate::Type { name: None, inner: crate::TypeInner::Array { base: u32, size: crate::ArraySize::Pending(override_handle), stride: 4, }, }, crate::Span::default(), ) }; let expression_needs_type = |module: &mut crate::Module, handle| { module .global_expressions .append(crate::Expression::ZeroValue(handle), crate::Span::default()) }; let expression_needs_expression = |module: &mut crate::Module, handle| { module.global_expressions.append( crate::Expression::Load { pointer: handle }, crate::Span::default(), ) }; let type_needs_type = |module: &mut crate::Module, handle| { module.types.insert( crate::Type { name: None, inner: crate::TypeInner::Array { base: handle, size: crate::ArraySize::Dynamic, stride: 0, }, }, crate::Span::default(), ) }; let mut type_name_counter = 0; let mut type_needed = |module: &mut crate::Module, handle| { let name = Some(format!("type{type_name_counter}")); type_name_counter += 1; module.types.insert( crate::Type { name, inner: crate::TypeInner::Array { base: handle, size: crate::ArraySize::Dynamic, stride: 0, }, }, crate::Span::default(), ) }; let mut override_name_counter = 0; let mut expression_needed = |module: &mut crate::Module, handle| { let name = Some(format!("override{override_name_counter}")); override_name_counter += 1; module.overrides.append( crate::Override { name, id: None, ty: u32, init: Some(handle), }, crate::Span::default(), ) }; let cmp_modules = |mod0: &crate::Module, mod1: &crate::Module| { (mod0.types.iter().collect::>() == mod1.types.iter().collect::>()) && (mod0.global_expressions.iter().collect::>() == mod1.global_expressions.iter().collect::>()) }; // borrow checker breaks without the tmp variables as of Rust 1.83.0 let expr_end = type_needs_expression(&mut module, expr); let ty_trace = type_needs_type(&mut module, expr_end); let expr_init = expression_needs_type(&mut module, ty_trace); expression_needed(&mut module, expr_init); let ty_end = expression_needs_type(&mut module, u32); let expr_trace = expression_needs_expression(&mut module, ty_end); let ty_init = type_needs_expression(&mut module, expr_trace); type_needed(&mut module, ty_init); let untouched = module.clone(); compact(&mut module, KeepUnused::Yes); assert!(cmp_modules(&module, &untouched)); let unused_expr = module.global_expressions.append( crate::Expression::Literal(crate::Literal::U32(1)), crate::Span::default(), ); type_needs_expression(&mut module, unused_expr); assert!(!cmp_modules(&module, &untouched)); compact(&mut module, KeepUnused::Yes); assert!(cmp_modules(&module, &untouched)); } #[test] fn array_length_override() { let mut module: crate::Module = Default::default(); let ty_bool = module.types.insert( crate::Type { name: None, inner: crate::TypeInner::Scalar(crate::Scalar::BOOL), }, crate::Span::default(), ); let ty_u32 = module.types.insert( crate::Type { name: None, inner: crate::TypeInner::Scalar(crate::Scalar::U32), }, crate::Span::default(), ); let one = module.global_expressions.append( crate::Expression::Literal(crate::Literal::U32(1)), crate::Span::default(), ); let _unused_override = module.overrides.append( crate::Override { name: None, id: Some(40), ty: ty_u32, init: None, }, crate::Span::default(), ); let o = module.overrides.append( crate::Override { name: None, id: Some(42), ty: ty_u32, init: Some(one), }, crate::Span::default(), ); let _ty_array = module.types.insert( crate::Type { name: Some("array".to_string()), inner: crate::TypeInner::Array { base: ty_bool, size: crate::ArraySize::Pending(o), stride: 4, }, }, crate::Span::default(), ); let mut validator = super::valid::Validator::new( super::valid::ValidationFlags::all(), super::valid::Capabilities::all(), ); assert!(validator.validate(&module).is_ok()); compact(&mut module, KeepUnused::Yes); assert!(validator.validate(&module).is_ok()); } /// Test mutual references between types and expressions via override /// lengths. #[test] fn array_length_override_mutual() { use crate::Expression as Ex; use crate::Scalar as Sc; use crate::TypeInner as Ti; let nowhere = crate::Span::default(); let mut module = crate::Module::default(); let ty_u32 = module.types.insert( crate::Type { name: None, inner: Ti::Scalar(Sc::U32), }, nowhere, ); // This type is only referred to by the override's init // expression, so if we visit that too early, this type will be // removed incorrectly. let ty_i32 = module.types.insert( crate::Type { name: None, inner: Ti::Scalar(Sc::I32), }, nowhere, ); // An override that the other override's init can refer to. let first_override = module.overrides.append( crate::Override { name: None, // so it is not considered used by definition id: Some(41), ty: ty_i32, init: None, }, nowhere, ); // Initializer expression for the override: // // (first_override + 0) as u32 // // The `first_override` makes it an override expression; the `0` // gets a use of `ty_i32` in there; and the `as` makes it match // the type of `second_override` without actually making // `second_override` point at `ty_i32` directly. let first_override_expr = module .global_expressions .append(Ex::Override(first_override), nowhere); let zero = module .global_expressions .append(Ex::ZeroValue(ty_i32), nowhere); let sum = module.global_expressions.append( Ex::Binary { op: crate::BinaryOperator::Add, left: first_override_expr, right: zero, }, nowhere, ); let init = module.global_expressions.append( Ex::As { expr: sum, kind: crate::ScalarKind::Uint, convert: None, }, nowhere, ); // Override that serves as the array's length. let second_override = module.overrides.append( crate::Override { name: None, // so it is not considered used by definition id: Some(42), ty: ty_u32, init: Some(init), }, nowhere, ); // Array type that uses the overload as its length. // Since this is named, it is considered used by definition. let _ty_array = module.types.insert( crate::Type { name: Some("delicious_array".to_string()), inner: Ti::Array { base: ty_u32, size: crate::ArraySize::Pending(second_override), stride: 4, }, }, nowhere, ); let mut validator = super::valid::Validator::new( super::valid::ValidationFlags::all(), super::valid::Capabilities::all(), ); assert!(validator.validate(&module).is_ok()); compact(&mut module, KeepUnused::Yes); assert!(validator.validate(&module).is_ok()); } #[test] fn array_length_expression() { let mut module: crate::Module = Default::default(); let ty_u32 = module.types.insert( crate::Type { name: None, inner: crate::TypeInner::Scalar(crate::Scalar::U32), }, crate::Span::default(), ); let _unused_zero = module.global_expressions.append( crate::Expression::Literal(crate::Literal::U32(0)), crate::Span::default(), ); let one = module.global_expressions.append( crate::Expression::Literal(crate::Literal::U32(1)), crate::Span::default(), ); let override_one = module.overrides.append( crate::Override { name: None, id: None, ty: ty_u32, init: Some(one), }, crate::Span::default(), ); let _ty_array = module.types.insert( crate::Type { name: Some("array".to_string()), inner: crate::TypeInner::Array { base: ty_u32, size: crate::ArraySize::Pending(override_one), stride: 4, }, }, crate::Span::default(), ); let mut validator = super::valid::Validator::new( super::valid::ValidationFlags::all(), super::valid::Capabilities::all(), ); assert!(validator.validate(&module).is_ok()); compact(&mut module, KeepUnused::Yes); assert!(validator.validate(&module).is_ok()); } #[test] fn global_expression_override() { let mut module: crate::Module = Default::default(); let ty_u32 = module.types.insert( crate::Type { name: None, inner: crate::TypeInner::Scalar(crate::Scalar::U32), }, crate::Span::default(), ); // This will only be retained if we trace the initializers // of overrides referred to by `Expression::Override` // in global expressions. let expr1 = module.global_expressions.append( crate::Expression::Literal(crate::Literal::U32(1)), crate::Span::default(), ); // This will only be traced via a global `Expression::Override`. let o = module.overrides.append( crate::Override { name: None, id: Some(42), ty: ty_u32, init: Some(expr1), }, crate::Span::default(), ); // This is retained by _p. let expr2 = module .global_expressions .append(crate::Expression::Override(o), crate::Span::default()); // Since this is named, it will be retained. let _p = module.overrides.append( crate::Override { name: Some("p".to_string()), id: None, ty: ty_u32, init: Some(expr2), }, crate::Span::default(), ); let mut validator = super::valid::Validator::new( super::valid::ValidationFlags::all(), super::valid::Capabilities::all(), ); assert!(validator.validate(&module).is_ok()); compact(&mut module, KeepUnused::Yes); assert!(validator.validate(&module).is_ok()); } #[test] fn local_expression_override() { let mut module: crate::Module = Default::default(); let ty_u32 = module.types.insert( crate::Type { name: None, inner: crate::TypeInner::Scalar(crate::Scalar::U32), }, crate::Span::default(), ); // This will only be retained if we trace the initializers // of overrides referred to by `Expression::Override` in a function. let expr1 = module.global_expressions.append( crate::Expression::Literal(crate::Literal::U32(1)), crate::Span::default(), ); // This will be removed by compaction. let _unused_override = module.overrides.append( crate::Override { name: None, id: Some(41), ty: ty_u32, init: None, }, crate::Span::default(), ); // This will only be traced via an `Expression::Override` in a function. let o = module.overrides.append( crate::Override { name: None, id: Some(42), ty: ty_u32, init: Some(expr1), }, crate::Span::default(), ); let mut fun = crate::Function { result: Some(crate::FunctionResult { ty: ty_u32, binding: None, }), ..crate::Function::default() }; // This is used by the `Return` statement. let o_expr = fun .expressions .append(crate::Expression::Override(o), crate::Span::default()); fun.body.push( crate::Statement::Return { value: Some(o_expr), }, crate::Span::default(), ); module.functions.append(fun, crate::Span::default()); let mut validator = super::valid::Validator::new( super::valid::ValidationFlags::all(), super::valid::Capabilities::all(), ); assert!(validator.validate(&module).is_ok()); compact(&mut module, KeepUnused::Yes); assert!(validator.validate(&module).is_ok()); } #[test] fn unnamed_constant_type() { let mut module = crate::Module::default(); let nowhere = crate::Span::default(); // This type is used only by the unnamed constant. let ty_u32 = module.types.insert( crate::Type { name: None, inner: crate::TypeInner::Scalar(crate::Scalar::U32), }, nowhere, ); // This type is used by the named constant. let ty_vec_u32 = module.types.insert( crate::Type { name: None, inner: crate::TypeInner::Vector { size: crate::VectorSize::Bi, scalar: crate::Scalar::U32, }, }, nowhere, ); let unnamed_init = module .global_expressions .append(crate::Expression::Literal(crate::Literal::U32(0)), nowhere); let unnamed_constant = module.constants.append( crate::Constant { name: None, ty: ty_u32, init: unnamed_init, }, nowhere, ); // The named constant is initialized using a Splat expression, to // give the named constant a type distinct from the unnamed // constant's. let unnamed_constant_expr = module .global_expressions .append(crate::Expression::Constant(unnamed_constant), nowhere); let named_init = module.global_expressions.append( crate::Expression::Splat { size: crate::VectorSize::Bi, value: unnamed_constant_expr, }, nowhere, ); let _named_constant = module.constants.append( crate::Constant { name: Some("totally_named".to_string()), ty: ty_vec_u32, init: named_init, }, nowhere, ); let mut validator = super::valid::Validator::new( super::valid::ValidationFlags::all(), super::valid::Capabilities::all(), ); assert!(validator.validate(&module).is_ok()); compact(&mut module, KeepUnused::Yes); assert!(validator.validate(&module).is_ok()); } #[test] fn unnamed_override_type() { let mut module = crate::Module::default(); let nowhere = crate::Span::default(); // This type is used only by the unnamed override. let ty_u32 = module.types.insert( crate::Type { name: None, inner: crate::TypeInner::Scalar(crate::Scalar::U32), }, nowhere, ); // This type is used by the named override. let ty_i32 = module.types.insert( crate::Type { name: None, inner: crate::TypeInner::Scalar(crate::Scalar::I32), }, nowhere, ); let unnamed_init = module .global_expressions .append(crate::Expression::Literal(crate::Literal::U32(0)), nowhere); let unnamed_override = module.overrides.append( crate::Override { name: None, id: Some(42), ty: ty_u32, init: Some(unnamed_init), }, nowhere, ); // The named override is initialized using a Splat expression, to // give the named override a type distinct from the unnamed // override's. let unnamed_override_expr = module .global_expressions .append(crate::Expression::Override(unnamed_override), nowhere); let named_init = module.global_expressions.append( crate::Expression::As { expr: unnamed_override_expr, kind: crate::ScalarKind::Sint, convert: None, }, nowhere, ); let _named_override = module.overrides.append( crate::Override { name: Some("totally_named".to_string()), id: None, ty: ty_i32, init: Some(named_init), }, nowhere, ); let mut validator = super::valid::Validator::new( super::valid::ValidationFlags::all(), super::valid::Capabilities::all(), ); assert!(validator.validate(&module).is_ok()); compact(&mut module, KeepUnused::Yes); assert!(validator.validate(&module).is_ok()); } naga-29.0.3/src/compact/statements.rs000064400000000000000000000435621046102023000156110ustar 00000000000000use alloc::{vec, vec::Vec}; use super::functions::FunctionTracer; use super::FunctionMap; use crate::arena::Handle; use crate::compact::handle_set_map::HandleMap; impl FunctionTracer<'_> { pub fn trace_block(&mut self, block: &[crate::Statement]) { let mut worklist: Vec<&[crate::Statement]> = vec![block]; while let Some(last) = worklist.pop() { for stmt in last { use crate::Statement as St; match *stmt { St::Emit(ref _range) => { // If we come across a statement that actually uses an // expression in this range, it'll get traced from // there. But since evaluating expressions has no // effect, we don't need to assume that everything // emitted is live. } St::Block(ref block) => worklist.push(block), St::If { condition, ref accept, ref reject, } => { self.expressions_used.insert(condition); worklist.push(accept); worklist.push(reject); } St::Switch { selector, ref cases, } => { self.expressions_used.insert(selector); for case in cases { worklist.push(&case.body); } } St::Loop { ref body, ref continuing, break_if, } => { if let Some(break_if) = break_if { self.expressions_used.insert(break_if); } worklist.push(body); worklist.push(continuing); } St::Return { value: Some(value) } => { self.expressions_used.insert(value); } St::Store { pointer, value } => { self.expressions_used.insert(pointer); self.expressions_used.insert(value); } St::ImageStore { image, coordinate, array_index, value, } => { self.expressions_used.insert(image); self.expressions_used.insert(coordinate); if let Some(array_index) = array_index { self.expressions_used.insert(array_index); } self.expressions_used.insert(value); } St::Atomic { pointer, ref fun, value, result, } => { self.expressions_used.insert(pointer); self.trace_atomic_function(fun); self.expressions_used.insert(value); if let Some(result) = result { self.expressions_used.insert(result); } } St::ImageAtomic { image, coordinate, array_index, fun: _, value, } => { self.expressions_used.insert(image); self.expressions_used.insert(coordinate); if let Some(array_index) = array_index { self.expressions_used.insert(array_index); } self.expressions_used.insert(value); } St::WorkGroupUniformLoad { pointer, result } => { self.expressions_used.insert(pointer); self.expressions_used.insert(result); } St::Call { function, ref arguments, result, } => { self.trace_call(function); for expr in arguments { self.expressions_used.insert(*expr); } if let Some(result) = result { self.expressions_used.insert(result); } } St::RayQuery { query, ref fun } => { self.expressions_used.insert(query); self.trace_ray_query_function(fun); } St::SubgroupBallot { result, predicate } => { if let Some(predicate) = predicate { self.expressions_used.insert(predicate); } self.expressions_used.insert(result); } St::SubgroupCollectiveOperation { op: _, collective_op: _, argument, result, } => { self.expressions_used.insert(argument); self.expressions_used.insert(result); } St::SubgroupGather { mode, argument, result, } => { match mode { crate::GatherMode::BroadcastFirst => {} crate::GatherMode::Broadcast(index) | crate::GatherMode::Shuffle(index) | crate::GatherMode::ShuffleDown(index) | crate::GatherMode::ShuffleUp(index) | crate::GatherMode::ShuffleXor(index) | crate::GatherMode::QuadBroadcast(index) => { self.expressions_used.insert(index); } crate::GatherMode::QuadSwap(_) => {} } self.expressions_used.insert(argument); self.expressions_used.insert(result); } St::CooperativeStore { target, ref data } => { self.expressions_used.insert(target); self.expressions_used.insert(data.pointer); self.expressions_used.insert(data.stride); } St::RayPipelineFunction(func) => match func { crate::RayPipelineFunction::TraceRay { acceleration_structure, descriptor, payload, } => { self.expressions_used.insert(acceleration_structure); self.expressions_used.insert(descriptor); self.expressions_used.insert(payload); } }, // Trivial statements. St::Break | St::Continue | St::Kill | St::ControlBarrier(_) | St::MemoryBarrier(_) | St::Return { value: None } => {} } } } } fn trace_atomic_function(&mut self, fun: &crate::AtomicFunction) { use crate::AtomicFunction as Af; match *fun { Af::Exchange { compare: Some(expr), } => { self.expressions_used.insert(expr); } Af::Exchange { compare: None } | Af::Add | Af::Subtract | Af::And | Af::ExclusiveOr | Af::InclusiveOr | Af::Min | Af::Max => {} } } fn trace_ray_query_function(&mut self, fun: &crate::RayQueryFunction) { use crate::RayQueryFunction as Qf; match *fun { Qf::Initialize { acceleration_structure, descriptor, } => { self.expressions_used.insert(acceleration_structure); self.expressions_used.insert(descriptor); } Qf::Proceed { result } => { self.expressions_used.insert(result); } Qf::GenerateIntersection { hit_t } => { self.expressions_used.insert(hit_t); } Qf::ConfirmIntersection => {} Qf::Terminate => {} } } } impl FunctionMap { /// Adjust statements in the body of `function`. /// /// Adjusts expressions using `self.expressions`, and adjusts calls to other /// functions using `function_map`. pub fn adjust_body( &self, function: &mut crate::Function, function_map: &HandleMap, ) { let block = &mut function.body; let mut worklist: Vec<&mut [crate::Statement]> = vec![block]; let adjust = |handle: &mut Handle| { self.expressions.adjust(handle); }; while let Some(last) = worklist.pop() { for stmt in last { use crate::Statement as St; match *stmt { St::Emit(ref mut range) => { self.expressions.adjust_range(range, &function.expressions); } St::Block(ref mut block) => worklist.push(block), St::If { ref mut condition, ref mut accept, ref mut reject, } => { adjust(condition); worklist.push(accept); worklist.push(reject); } St::Switch { ref mut selector, ref mut cases, } => { adjust(selector); for case in cases { worklist.push(&mut case.body); } } St::Loop { ref mut body, ref mut continuing, ref mut break_if, } => { if let Some(ref mut break_if) = *break_if { adjust(break_if); } worklist.push(body); worklist.push(continuing); } St::Return { value: Some(ref mut value), } => adjust(value), St::Store { ref mut pointer, ref mut value, } => { adjust(pointer); adjust(value); } St::ImageStore { ref mut image, ref mut coordinate, ref mut array_index, ref mut value, } => { adjust(image); adjust(coordinate); if let Some(ref mut array_index) = *array_index { adjust(array_index); } adjust(value); } St::Atomic { ref mut pointer, ref mut fun, ref mut value, ref mut result, } => { adjust(pointer); self.adjust_atomic_function(fun); adjust(value); if let Some(ref mut result) = *result { adjust(result); } } St::ImageAtomic { ref mut image, ref mut coordinate, ref mut array_index, fun: _, ref mut value, } => { adjust(image); adjust(coordinate); if let Some(ref mut array_index) = *array_index { adjust(array_index); } adjust(value); } St::WorkGroupUniformLoad { ref mut pointer, ref mut result, } => { adjust(pointer); adjust(result); } St::Call { ref mut function, ref mut arguments, ref mut result, } => { function_map.adjust(function); for expr in arguments { adjust(expr); } if let Some(ref mut result) = *result { adjust(result); } } St::RayQuery { ref mut query, ref mut fun, } => { adjust(query); self.adjust_ray_query_function(fun); } St::SubgroupBallot { ref mut result, ref mut predicate, } => { if let Some(ref mut predicate) = *predicate { adjust(predicate); } adjust(result); } St::SubgroupCollectiveOperation { op: _, collective_op: _, ref mut argument, ref mut result, } => { adjust(argument); adjust(result); } St::SubgroupGather { ref mut mode, ref mut argument, ref mut result, } => { match *mode { crate::GatherMode::BroadcastFirst => {} crate::GatherMode::Broadcast(ref mut index) | crate::GatherMode::Shuffle(ref mut index) | crate::GatherMode::ShuffleDown(ref mut index) | crate::GatherMode::ShuffleUp(ref mut index) | crate::GatherMode::ShuffleXor(ref mut index) | crate::GatherMode::QuadBroadcast(ref mut index) => adjust(index), crate::GatherMode::QuadSwap(_) => {} } adjust(argument); adjust(result); } St::CooperativeStore { ref mut target, ref mut data, } => { adjust(target); adjust(&mut data.pointer); adjust(&mut data.stride); } St::RayPipelineFunction(ref mut func) => match *func { crate::RayPipelineFunction::TraceRay { ref mut acceleration_structure, ref mut descriptor, ref mut payload, } => { adjust(acceleration_structure); adjust(descriptor); adjust(payload); } }, // Trivial statements. St::Break | St::Continue | St::Kill | St::ControlBarrier(_) | St::MemoryBarrier(_) | St::Return { value: None } => {} } } } } fn adjust_atomic_function(&self, fun: &mut crate::AtomicFunction) { use crate::AtomicFunction as Af; match *fun { Af::Exchange { compare: Some(ref mut expr), } => { self.expressions.adjust(expr); } Af::Exchange { compare: None } | Af::Add | Af::Subtract | Af::And | Af::ExclusiveOr | Af::InclusiveOr | Af::Min | Af::Max => {} } } fn adjust_ray_query_function(&self, fun: &mut crate::RayQueryFunction) { use crate::RayQueryFunction as Qf; match *fun { Qf::Initialize { ref mut acceleration_structure, ref mut descriptor, } => { self.expressions.adjust(acceleration_structure); self.expressions.adjust(descriptor); } Qf::Proceed { ref mut result } => { self.expressions.adjust(result); } Qf::GenerateIntersection { ref mut hit_t } => { self.expressions.adjust(hit_t); } Qf::ConfirmIntersection => {} Qf::Terminate => {} } } } naga-29.0.3/src/compact/types.rs000064400000000000000000000067671046102023000145740ustar 00000000000000use super::{HandleSet, ModuleMap}; use crate::Handle; pub struct TypeTracer<'a> { pub overrides: &'a crate::Arena, pub types_used: &'a mut HandleSet, pub expressions_used: &'a mut HandleSet, pub overrides_used: &'a mut HandleSet, } impl TypeTracer<'_> { pub fn trace_type(&mut self, ty: &crate::Type) { use crate::TypeInner as Ti; match ty.inner { // Types that do not contain handles. Ti::Scalar { .. } | Ti::Vector { .. } | Ti::Matrix { .. } | Ti::CooperativeMatrix { .. } | Ti::Atomic { .. } | Ti::ValuePointer { .. } | Ti::Image { .. } | Ti::Sampler { .. } | Ti::AccelerationStructure { .. } | Ti::RayQuery { .. } => {} // Types that do contain handles. Ti::Array { base, size, stride: _, } | Ti::BindingArray { base, size } => { self.types_used.insert(base); match size { crate::ArraySize::Pending(handle) => { self.overrides_used.insert(handle); let r#override = &self.overrides[handle]; self.types_used.insert(r#override.ty); if let Some(expr) = r#override.init { self.expressions_used.insert(expr); } } crate::ArraySize::Constant(_) | crate::ArraySize::Dynamic => {} } } Ti::Pointer { base, space: _ } => { self.types_used.insert(base); } Ti::Struct { ref members, span: _, } => { self.types_used.insert_iter(members.iter().map(|m| m.ty)); } } } } impl ModuleMap { pub fn adjust_type(&self, ty: &mut crate::Type) { let adjust = |ty: &mut Handle| self.types.adjust(ty); use crate::TypeInner as Ti; match ty.inner { // Types that do not contain handles. Ti::Scalar(_) | Ti::Vector { .. } | Ti::Matrix { .. } | Ti::CooperativeMatrix { .. } | Ti::Atomic(_) | Ti::ValuePointer { .. } | Ti::Image { .. } | Ti::Sampler { .. } | Ti::AccelerationStructure { .. } | Ti::RayQuery { .. } => {} // Types that do contain handles. Ti::Pointer { ref mut base, space: _, } => adjust(base), Ti::Array { ref mut base, ref mut size, stride: _, } | Ti::BindingArray { ref mut base, ref mut size, } => { adjust(base); match *size { crate::ArraySize::Pending(ref mut r#override) => { self.overrides.adjust(r#override); } crate::ArraySize::Constant(_) | crate::ArraySize::Dynamic => {} } } Ti::Struct { ref mut members, span: _, } => { for member in members { self.types.adjust(&mut member.ty); } } }; } } naga-29.0.3/src/diagnostic_filter.rs000064400000000000000000000224551046102023000154630ustar 00000000000000//! [`DiagnosticFilter`]s and supporting functionality. use alloc::boxed::Box; use crate::{Arena, Handle}; #[cfg(feature = "wgsl-in")] use crate::FastIndexMap; #[cfg(feature = "wgsl-in")] use crate::Span; #[cfg(feature = "arbitrary")] use arbitrary::Arbitrary; #[cfg(feature = "deserialize")] use serde::Deserialize; #[cfg(feature = "serialize")] use serde::Serialize; /// A severity set on a [`DiagnosticFilter`]. /// /// #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum Severity { Off, Info, Warning, Error, } impl Severity { /// Checks whether this severity is [`Self::Error`]. /// /// Naga does not yet support diagnostic items at lesser severities than /// [`Severity::Error`]. When this is implemented, this method should be deleted, and the /// severity should be used directly for reporting diagnostics. pub(crate) fn report_diag( self, err: E, log_handler: impl FnOnce(E, log::Level), ) -> Result<(), E> { let log_level = match self { Severity::Off => return Ok(()), // NOTE: These severities are not yet reported. Severity::Info => log::Level::Info, Severity::Warning => log::Level::Warn, Severity::Error => return Err(err), }; log_handler(err, log_level); Ok(()) } } /// A filterable triggering rule in a [`DiagnosticFilter`]. /// /// #[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum FilterableTriggeringRule { Standard(StandardFilterableTriggeringRule), Unknown(Box), User(Box<[Box; 2]>), } /// A filterable triggering rule in a [`DiagnosticFilter`]. /// /// #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum StandardFilterableTriggeringRule { DerivativeUniformity, } impl StandardFilterableTriggeringRule { /// The default severity associated with this triggering rule. /// /// See for a table of default /// severities. pub(crate) const fn default_severity(self) -> Severity { match self { Self::DerivativeUniformity => Severity::Error, } } } /// A filtering rule that modifies how diagnostics are emitted for shaders. /// /// #[derive(Clone, Debug)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub struct DiagnosticFilter { pub new_severity: Severity, pub triggering_rule: FilterableTriggeringRule, } /// Determines whether [`DiagnosticFilterMap::add`] should consider full duplicates a conflict. /// /// In WGSL, directive position does not consider this case a conflict, while attribute position /// does. #[cfg(feature = "wgsl-in")] pub(crate) enum ShouldConflictOnFullDuplicate { /// Use this for attributes in WGSL. Yes, /// Use this for directives in WGSL. No, } /// A map from diagnostic filters to their severity and span. /// /// Front ends can use this to collect the set of filters applied to a /// particular language construct, and detect duplicate/conflicting filters. /// /// For example, WGSL has global diagnostic filters that apply to the entire /// module, and diagnostic range filter attributes that apply to a specific /// function, statement, or other smaller construct. The set of filters applied /// to any given construct must not conflict, but they can be overridden by /// filters on other constructs nested within it. A front end can use a /// `DiagnosticFilterMap` to collect the filters applied to a single construct, /// using the [`add`] method's error checking to forbid conflicts. /// /// For each filter it contains, a `DiagnosticFilterMap` records the requested /// severity, and the source span of the filter itself. /// /// [`add`]: DiagnosticFilterMap::add #[derive(Clone, Debug, Default)] #[cfg(feature = "wgsl-in")] pub(crate) struct DiagnosticFilterMap(FastIndexMap); #[cfg(feature = "wgsl-in")] impl DiagnosticFilterMap { pub(crate) fn new() -> Self { Self::default() } /// Add the given `diagnostic_filter` parsed at the given `span` to this map. pub(crate) fn add( &mut self, diagnostic_filter: DiagnosticFilter, span: Span, should_conflict_on_full_duplicate: ShouldConflictOnFullDuplicate, ) -> Result<(), ConflictingDiagnosticRuleError> { use indexmap::map::Entry; let &mut Self(ref mut diagnostic_filters) = self; let DiagnosticFilter { new_severity, triggering_rule, } = diagnostic_filter; match diagnostic_filters.entry(triggering_rule.clone()) { Entry::Vacant(entry) => { entry.insert((new_severity, span)); } Entry::Occupied(entry) => { let &(first_severity, first_span) = entry.get(); let should_conflict_on_full_duplicate = match should_conflict_on_full_duplicate { ShouldConflictOnFullDuplicate::Yes => true, ShouldConflictOnFullDuplicate::No => false, }; if first_severity != new_severity || should_conflict_on_full_duplicate { return Err(ConflictingDiagnosticRuleError { triggering_rule_spans: [first_span, span], }); } } } Ok(()) } /// Were any rules specified? pub(crate) fn is_empty(&self) -> bool { let &Self(ref map) = self; map.is_empty() } /// Returns the spans of all contained rules. pub(crate) fn spans(&self) -> impl Iterator + '_ { let &Self(ref map) = self; map.iter().map(|(_, &(_, span))| span) } } #[cfg(feature = "wgsl-in")] impl IntoIterator for DiagnosticFilterMap { type Item = (FilterableTriggeringRule, (Severity, Span)); type IntoIter = indexmap::map::IntoIter; fn into_iter(self) -> Self::IntoIter { let Self(this) = self; this.into_iter() } } /// An error returned by [`DiagnosticFilterMap::add`] when it encounters conflicting rules. #[cfg(feature = "wgsl-in")] #[derive(Clone, Debug)] pub(crate) struct ConflictingDiagnosticRuleError { pub triggering_rule_spans: [Span; 2], } /// Represents a single parent-linking node in a tree of [`DiagnosticFilter`]s backed by a /// [`crate::Arena`]. /// /// A single element of a _tree_ of diagnostic filter rules stored in /// [`crate::Module::diagnostic_filters`]. When nodes are built by a front-end, module-applicable /// filter rules are chained together in runs based on parse site. For instance, given the /// following: /// /// - Module-applicable rules `a` and `b`. /// - Rules `c` and `d`, applicable to an entry point called `c_and_d_func`. /// - Rule `e`, applicable to an entry point called `e_func`. /// /// The tree would be represented as follows: /// /// ```text /// a <- b /// ^ /// |- c <- d /// | /// \- e /// ``` /// /// ...where: /// /// - `d` is the first leaf consulted by validation in `c_and_d_func`. /// - `e` is the first leaf consulted by validation in `e_func`. #[derive(Clone, Debug)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub struct DiagnosticFilterNode { pub inner: DiagnosticFilter, pub parent: Option>, } impl DiagnosticFilterNode { /// Finds the most specific filter rule applicable to `triggering_rule` from the chain of /// diagnostic filter rules in `arena`, starting with `node`, and returns its severity. If none /// is found, return the value of [`StandardFilterableTriggeringRule::default_severity`]. /// /// When `triggering_rule` is not applicable to this node, its parent is consulted recursively. pub(crate) fn search( node: Option>, arena: &Arena, triggering_rule: StandardFilterableTriggeringRule, ) -> Severity { let mut next = node; while let Some(handle) = next { let node = &arena[handle]; let &Self { ref inner, parent } = node; let &DiagnosticFilter { triggering_rule: ref rule, new_severity, } = inner; if rule == &FilterableTriggeringRule::Standard(triggering_rule) { return new_severity; } next = parent; } triggering_rule.default_severity() } } naga-29.0.3/src/error.rs000064400000000000000000000145201046102023000131150ustar 00000000000000use alloc::{borrow::Cow, boxed::Box, string::String}; use core::{error::Error, fmt}; #[derive(Clone, Debug)] pub struct ShaderError { /// The source code of the shader. pub source: String, pub label: Option, pub inner: Box, } #[cfg(feature = "wgsl-in")] impl fmt::Display for ShaderError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let label = self.label.as_deref().unwrap_or_default(); let string = self.inner.emit_to_string(&self.source); write!(f, "\nShader '{label}' parsing {string}") } } #[cfg(feature = "glsl-in")] impl fmt::Display for ShaderError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let label = self.label.as_deref().unwrap_or_default(); let string = self.inner.emit_to_string(&self.source); write!(f, "\nShader '{label}' parsing {string}") } } #[cfg(feature = "spv-in")] impl fmt::Display for ShaderError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let label = self.label.as_deref().unwrap_or_default(); let string = self.inner.emit_to_string(&self.source); write!(f, "\nShader '{label}' parsing {string}") } } impl fmt::Display for ShaderError> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { use codespan_reporting::{files::SimpleFile, term}; let label = self.label.as_deref().unwrap_or_default(); let files = SimpleFile::new(label, replace_control_chars(&self.source)); let config = term::Config::default(); let mut writer = DiagnosticBuffer::new(); writer .emit_to_self(&config, &files, &self.inner.diagnostic()) .expect("cannot write error"); let writer = writer.into_string(); write!(f, "\nShader validation {writer}") } } cfg_if::cfg_if! { if #[cfg(feature = "termcolor")] { type DiagnosticBufferInner = codespan_reporting::term::termcolor::NoColor>; } else if #[cfg(feature = "stderr")] { type DiagnosticBufferInner = alloc::vec::Vec; } else { type DiagnosticBufferInner = String; } } cfg_if::cfg_if! { if #[cfg(all(feature = "stderr", feature = "termcolor"))] { pub(crate) use codespan_reporting::term::termcolor::WriteColor as _ErrorWrite; } else if #[cfg(feature = "stderr")] { pub(crate) use std::io::Write as _ErrorWrite; } } #[cfg(feature = "stderr")] pub(crate) use _ErrorWrite as ErrorWrite; #[cfg(feature = "stderr")] #[cfg_attr( not(any(feature = "spv-in", feature = "glsl-in")), expect( dead_code, reason = "only need `emit_to_writer` with an appropriate front-end." ) )] pub(crate) fn emit_to_writer<'files, F: codespan_reporting::files::Files<'files> + ?Sized>( writer: &mut impl ErrorWrite, config: &codespan_reporting::term::Config, files: &'files F, diagnostic: &codespan_reporting::diagnostic::Diagnostic, ) -> Result<(), codespan_reporting::files::Error> { cfg_if::cfg_if! { if #[cfg(feature = "termcolor")] { codespan_reporting::term::emit_to_write_style(writer, config, files, diagnostic) } else { codespan_reporting::term::emit_to_io_write(writer, config, files, diagnostic) } } } pub(crate) struct DiagnosticBuffer { inner: DiagnosticBufferInner, } impl DiagnosticBuffer { #[cfg_attr( not(feature = "termcolor"), expect( clippy::missing_const_for_fn, reason = "`NoColor::new` isn't `const`, but other `inner`s are." ) )] pub fn new() -> Self { cfg_if::cfg_if! { if #[cfg(feature = "termcolor")] { let inner = codespan_reporting::term::termcolor::NoColor::new(alloc::vec::Vec::new()); } else if #[cfg(feature = "stderr")] { let inner = alloc::vec::Vec::new(); } else { let inner = String::new(); } }; Self { inner } } pub fn emit_to_self<'files, F: codespan_reporting::files::Files<'files> + ?Sized>( &mut self, config: &codespan_reporting::term::Config, files: &'files F, diagnostic: &codespan_reporting::diagnostic::Diagnostic, ) -> Result<(), codespan_reporting::files::Error> { cfg_if::cfg_if! { if #[cfg(feature = "termcolor")] { codespan_reporting::term::emit_to_write_style(&mut self.inner, config, files, diagnostic) } else if #[cfg(feature = "stderr")] { codespan_reporting::term::emit_to_io_write(&mut self.inner, config, files, diagnostic) } else { codespan_reporting::term::emit_to_string(&mut self.inner, config, files, diagnostic) } } } pub fn into_string(self) -> String { let Self { inner } = self; cfg_if::cfg_if! { if #[cfg(feature = "termcolor")] { String::from_utf8(inner.into_inner()).unwrap() } else if #[cfg(feature = "stderr")] { String::from_utf8(inner).unwrap() } else { inner } } } } impl Error for ShaderError where ShaderError: fmt::Display, E: Error + 'static, { fn source(&self) -> Option<&(dyn Error + 'static)> { self.inner.source() } } pub(crate) fn replace_control_chars(s: &str) -> Cow<'_, str> { const REPLACEMENT_CHAR: &str = "\u{FFFD}"; debug_assert_eq!( REPLACEMENT_CHAR.chars().next().unwrap(), char::REPLACEMENT_CHARACTER ); let mut res = Cow::Borrowed(s); let mut offset = 0; while let Some(found_pos) = res[offset..].find(|c: char| c.is_control() && !c.is_whitespace()) { offset += found_pos; let found_len = res[offset..].chars().next().unwrap().len_utf8(); res.to_mut() .replace_range(offset..offset + found_len, REPLACEMENT_CHAR); offset += REPLACEMENT_CHAR.len(); } res } #[test] fn test_replace_control_chars() { // The UTF-8 encoding of \u{0080} is multiple bytes. let input = "Foo\u{0080}Bar\u{0001}Baz\n"; let expected = "Foo\u{FFFD}Bar\u{FFFD}Baz\n"; assert_eq!(replace_control_chars(input), expected); } naga-29.0.3/src/front/atomic_upgrade.rs000064400000000000000000000230561046102023000161030ustar 00000000000000//! Upgrade the types of scalars observed to be accessed as atomics to [`Atomic`] types. //! //! In SPIR-V, atomic operations can be applied to any scalar value, but in Naga //! IR atomic operations can only be applied to values of type [`Atomic`]. Naga //! IR's restriction matches Metal Shading Language and WGSL, so we don't want //! to relax that. Instead, when the SPIR-V front end observes a value being //! accessed using atomic instructions, it promotes the value's type from //! [`Scalar`] to [`Atomic`]. This module implements `Module::upgrade_atomics`, //! the function that makes that change. //! //! Atomics can only appear in global variables in the [`Storage`] and //! [`Workgroup`] address spaces. These variables can either have `Atomic` types //! themselves, or be [`Array`]s of such, or be [`Struct`]s containing such. //! So we only need to change the types of globals and struct fields. //! //! Naga IR [`Load`] expressions and [`Store`] statements can operate directly //! on [`Atomic`] values, retrieving and depositing ordinary [`Scalar`] values, //! so changing the types doesn't have much effect on the code that operates on //! those values. //! //! Future work: //! //! - The GLSL front end could use this transformation as well. //! //! [`Atomic`]: TypeInner::Atomic //! [`Scalar`]: TypeInner::Scalar //! [`Storage`]: crate::AddressSpace::Storage //! [`WorkGroup`]: crate::AddressSpace::WorkGroup //! [`Array`]: TypeInner::Array //! [`Struct`]: TypeInner::Struct //! [`Load`]: crate::Expression::Load //! [`Store`]: crate::Statement::Store use alloc::{format, sync::Arc}; use core::sync::atomic::AtomicUsize; use crate::{GlobalVariable, Handle, Module, Type, TypeInner}; #[derive(Clone, Debug, thiserror::Error)] pub enum Error { #[error("encountered an unsupported expression")] Unsupported, #[error("unexpected end of struct field access indices")] UnexpectedEndOfIndices, #[error("encountered unsupported global initializer in an atomic variable")] GlobalInitUnsupported, #[error("expected to find a global variable")] GlobalVariableMissing, #[error("atomic compare exchange requires a scalar base type")] CompareExchangeNonScalarBaseType, } #[derive(Clone, Default)] struct Padding(Arc); impl core::fmt::Display for Padding { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { for _ in 0..self.0.load(core::sync::atomic::Ordering::Relaxed) { f.write_str(" ")?; } Ok(()) } } impl Drop for Padding { fn drop(&mut self) { let _ = self.0.fetch_sub(1, core::sync::atomic::Ordering::Relaxed); } } impl Padding { fn trace(&self, msg: impl core::fmt::Display, t: impl core::fmt::Debug) { format!("{msg} {t:#?}") .split('\n') .for_each(|ln| log::trace!("{self}{ln}")); } fn debug(&self, msg: impl core::fmt::Display, t: impl core::fmt::Debug) { format!("{msg} {t:#?}") .split('\n') .for_each(|ln| log::debug!("{self}{ln}")); } fn inc_padding(&self) -> Padding { let _ = self.0.fetch_add(1, core::sync::atomic::Ordering::Relaxed); self.clone() } } #[derive(Debug, Default)] pub struct Upgrades { /// Global variables that we've accessed using atomic operations. /// /// This includes globals with composite types (arrays, structs) where we've /// only accessed some components (elements, fields) atomically. globals: crate::arena::HandleSet, /// Struct fields that we've accessed using atomic operations. /// /// Each key refers to some [`Struct`] type, and each value is a set of /// the indices of the fields in that struct that have been accessed /// atomically. /// /// This includes fields with composite types (arrays, structs) /// of which we've only accessed some components (elements, fields) /// atomically. /// /// [`Struct`]: crate::TypeInner::Struct fields: crate::FastHashMap, bit_set::BitSet>, } impl Upgrades { pub fn insert_global(&mut self, global: Handle) { self.globals.insert(global); } pub fn insert_field(&mut self, struct_type: Handle, field: usize) { self.fields.entry(struct_type).or_default().insert(field); } pub fn is_empty(&self) -> bool { self.globals.is_empty() } } struct UpgradeState<'a> { padding: Padding, module: &'a mut Module, /// A map from old types to their upgraded versions. /// /// This ensures we never try to rebuild a type more than once. upgraded_types: crate::FastHashMap, Handle>, } impl UpgradeState<'_> { fn inc_padding(&self) -> Padding { self.padding.inc_padding() } /// Get a type equivalent to `ty`, but with [`Scalar`] leaves upgraded to [`Atomic`] scalars. /// /// If such a type already exists in `self.module.types`, return its handle. /// Otherwise, construct a new one and return that handle. /// /// If `ty` is a [`Pointer`], [`Array`], [`BindingArray`], recurse into the /// type and upgrade its leaf types. /// /// If `ty` is a [`Struct`], recurse into it and upgrade only those fields /// whose indices appear in `field_indices`. /// /// The existing type is not affected. /// /// [`Scalar`]: crate::TypeInner::Scalar /// [`Atomic`]: crate::TypeInner::Atomic /// [`Pointer`]: crate::TypeInner::Pointer /// [`Array`]: crate::TypeInner::Array /// [`Struct`]: crate::TypeInner::Struct /// [`BindingArray`]: crate::TypeInner::BindingArray fn upgrade_type( &mut self, ty: Handle, upgrades: &Upgrades, ) -> Result, Error> { let padding = self.inc_padding(); padding.trace("visiting type: ", ty); // If we've already upgraded this type, return the handle we produced at // the time. if let Some(&new) = self.upgraded_types.get(&ty) { return Ok(new); } let inner = match self.module.types[ty].inner { TypeInner::Scalar(scalar) => { log::trace!("{padding}hit the scalar leaf, replacing with an atomic"); TypeInner::Atomic(scalar) } TypeInner::Pointer { base, space } => TypeInner::Pointer { base: self.upgrade_type(base, upgrades)?, space, }, TypeInner::Array { base, size, stride } => TypeInner::Array { base: self.upgrade_type(base, upgrades)?, size, stride, }, TypeInner::Struct { ref members, span } => { // If no field or subfield of this struct was ever accessed // atomically, no change is needed. We should never have arrived here. let Some(fields) = upgrades.fields.get(&ty) else { unreachable!("global or field incorrectly flagged as atomically accessed"); }; let mut new_members = members.clone(); for field in fields { new_members[field].ty = self.upgrade_type(new_members[field].ty, upgrades)?; } TypeInner::Struct { members: new_members, span, } } TypeInner::BindingArray { base, size } => TypeInner::BindingArray { base: self.upgrade_type(base, upgrades)?, size, }, _ => return Ok(ty), }; // At this point, we have a `TypeInner` that is the upgraded version of // `ty`. Find a suitable `Type` for this, creating a new one if // necessary, and return its handle. let r#type = &self.module.types[ty]; let span = self.module.types.get_span(ty); let new_type = Type { name: r#type.name.clone(), inner, }; padding.debug("ty: ", ty); padding.debug("from: ", r#type); padding.debug("to: ", &new_type); let new_handle = self.module.types.insert(new_type, span); self.upgraded_types.insert(ty, new_handle); Ok(new_handle) } fn upgrade_all(&mut self, upgrades: &Upgrades) -> Result<(), Error> { for handle in upgrades.globals.iter() { let padding = self.inc_padding(); let global = &self.module.global_variables[handle]; padding.trace("visiting global variable: ", handle); padding.trace("var: ", global); if global.init.is_some() { return Err(Error::GlobalInitUnsupported); } let var_ty = global.ty; let new_ty = self.upgrade_type(var_ty, upgrades)?; if new_ty != var_ty { padding.debug("upgrading global variable: ", handle); padding.debug("from ty: ", var_ty); padding.debug("to ty: ", new_ty); self.module.global_variables[handle].ty = new_ty; } } Ok(()) } } impl Module { /// Upgrade `global_var_handles` to have [`Atomic`] leaf types. /// /// [`Atomic`]: TypeInner::Atomic pub(crate) fn upgrade_atomics(&mut self, upgrades: &Upgrades) -> Result<(), Error> { let mut state = UpgradeState { padding: Default::default(), module: self, upgraded_types: crate::FastHashMap::with_capacity_and_hasher( upgrades.fields.len(), Default::default(), ), }; state.upgrade_all(upgrades)?; Ok(()) } } naga-29.0.3/src/front/glsl/ast.rs000064400000000000000000000275151046102023000146540ustar 00000000000000use alloc::{borrow::Cow, string::String, vec::Vec}; use core::fmt; use super::{builtins::MacroCall, Span}; use crate::{ AddressSpace, BinaryOperator, Binding, Constant, Expression, Function, GlobalVariable, Handle, Interpolation, Literal, Override, Sampling, StorageAccess, Type, UnaryOperator, }; #[derive(Debug, Clone, Copy)] pub enum GlobalLookupKind { Variable(Handle), Constant(Handle, Handle), Override(Handle, Handle), BlockSelect(Handle, u32), } #[derive(Debug, Clone, Copy)] pub struct GlobalLookup { pub kind: GlobalLookupKind, pub entry_arg: Option, pub mutable: bool, } #[derive(Debug, Clone)] pub struct ParameterInfo { pub qualifier: ParameterQualifier, /// Whether the parameter should be treated as a depth image instead of a /// sampled image. pub depth: bool, } /// How the function is implemented #[derive(Clone, Copy)] pub enum FunctionKind { /// The function is user defined Call(Handle), /// The function is a builtin Macro(MacroCall), } impl fmt::Debug for FunctionKind { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match *self { Self::Call(_) => write!(f, "Call"), Self::Macro(_) => write!(f, "Macro"), } } } #[derive(Debug)] pub struct Overload { /// Normalized function parameters, modifiers are not applied pub parameters: Vec>, pub parameters_info: Vec, /// How the function is implemented pub kind: FunctionKind, /// Whether this function was already defined or is just a prototype pub defined: bool, /// Whether this overload is the one provided by the language or has /// been redeclared by the user (builtins only) pub internal: bool, /// Whether or not this function returns void (nothing) pub void: bool, } bitflags::bitflags! { /// Tracks the variations of the builtin already generated, this is needed because some /// builtins overloads can't be generated unless explicitly used, since they might cause /// unneeded capabilities to be requested #[derive(Default)] #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub struct BuiltinVariations: u32 { /// Request the standard overloads const STANDARD = 1 << 0; /// Request overloads that use the double type const DOUBLE = 1 << 1; /// Request overloads that use `samplerCubeArray(Shadow)` const CUBE_TEXTURES_ARRAY = 1 << 2; /// Request overloads that use `sampler2DMSArray` const D2_MULTI_TEXTURES_ARRAY = 1 << 3; } } #[derive(Debug, Default)] pub struct FunctionDeclaration { pub overloads: Vec, /// Tracks the builtin overload variations that were already generated pub variations: BuiltinVariations, } #[derive(Debug)] pub struct EntryArg { pub name: Option, pub binding: Binding, pub handle: Handle, pub storage: StorageQualifier, } #[derive(Debug, Clone)] pub struct VariableReference { pub expr: Handle, /// Whether the variable is of a pointer type (and needs loading) or not pub load: bool, /// Whether the value of the variable can be changed or not pub mutable: bool, pub constant: Option<(Handle, Handle)>, pub entry_arg: Option, } #[derive(Debug, Clone)] pub struct HirExpr { pub kind: HirExprKind, pub meta: Span, } #[derive(Debug, Clone)] pub enum HirExprKind { /// Represents a sequence of expressions. It returns the type and value of the last (i.e. right-most) expression. Sequence { exprs: Vec>, }, Access { base: Handle, index: Handle, }, Select { base: Handle, field: String, }, Literal(Literal), Binary { left: Handle, op: BinaryOperator, right: Handle, }, Unary { op: UnaryOperator, expr: Handle, }, Variable(VariableReference), Call(FunctionCall), /// Represents the ternary operator in glsl (`:?`) Conditional { /// The expression that will decide which branch to take, must evaluate to a boolean condition: Handle, /// The expression that will be evaluated if [`condition`] returns `true` /// /// [`condition`]: Self::Conditional::condition accept: Handle, /// The expression that will be evaluated if [`condition`] returns `false` /// /// [`condition`]: Self::Conditional::condition reject: Handle, }, Assign { tgt: Handle, value: Handle, }, /// A prefix/postfix operator like `++` PrePostfix { /// The operation to be performed op: BinaryOperator, /// Whether this is a postfix or a prefix postfix: bool, /// The target expression expr: Handle, }, /// A method call like `what.something(a, b, c)` Method { /// expression the method call applies to (`what` in the example) expr: Handle, /// the method name (`something` in the example) name: String, /// the arguments to the method (`a`, `b`, and `c` in the example) args: Vec>, }, } #[derive(Debug, Hash, PartialEq, Eq)] pub enum QualifierKey<'a> { String(Cow<'a, str>), /// Used for `std140` and `std430` layout qualifiers Layout, /// Used for image formats Format, /// Used for `index` layout qualifiers Index, } #[derive(Debug)] pub enum QualifierValue { None, Uint(u32), Layout(StructLayout), Format(crate::StorageFormat), } #[derive(Debug, Default)] pub struct TypeQualifiers<'a> { pub span: Span, pub storage: (StorageQualifier, Span), pub invariant: Option, pub interpolation: Option<(Interpolation, Span)>, pub precision: Option<(Precision, Span)>, pub sampling: Option<(Sampling, Span)>, /// Memory qualifiers used in the declaration to set the storage access to be used /// in declarations that support it (storage images and buffers) pub storage_access: Option<(StorageAccess, Span)>, pub layout_qualifiers: crate::FastHashMap, (QualifierValue, Span)>, } impl<'a> TypeQualifiers<'a> { /// Appends `errors` with errors for all unused qualifiers pub fn unused_errors(&self, errors: &mut Vec) { if let Some(meta) = self.invariant { errors.push(super::Error { kind: super::ErrorKind::SemanticError( "Invariant qualifier can only be used in in/out variables".into(), ), meta, }); } if let Some((_, meta)) = self.interpolation { errors.push(super::Error { kind: super::ErrorKind::SemanticError( "Interpolation qualifiers can only be used in in/out variables".into(), ), meta, }); } if let Some((_, meta)) = self.sampling { errors.push(super::Error { kind: super::ErrorKind::SemanticError( "Sampling qualifiers can only be used in in/out variables".into(), ), meta, }); } if let Some((_, meta)) = self.storage_access { errors.push(super::Error { kind: super::ErrorKind::SemanticError( "Memory qualifiers can only be used in storage variables".into(), ), meta, }); } for &(_, meta) in self.layout_qualifiers.values() { errors.push(super::Error { kind: super::ErrorKind::SemanticError("Unexpected qualifier".into()), meta, }); } } /// Removes the layout qualifier with `name`, if it exists and adds an error if it isn't /// a [`QualifierValue::Uint`] pub fn uint_layout_qualifier( &mut self, name: &'a str, errors: &mut Vec, ) -> Option { match self .layout_qualifiers .remove(&QualifierKey::String(name.into())) { Some((QualifierValue::Uint(v), _)) => Some(v), Some((_, meta)) => { errors.push(super::Error { kind: super::ErrorKind::SemanticError("Qualifier expects a uint value".into()), meta, }); // Return a dummy value instead of `None` to differentiate from // the qualifier not existing, since some parts might require the // qualifier to exist and throwing another error that it doesn't // exist would be unhelpful Some(0) } _ => None, } } /// Removes the layout qualifier with `name`, if it exists and adds an error if it isn't /// a [`QualifierValue::None`] pub fn none_layout_qualifier(&mut self, name: &'a str, errors: &mut Vec) -> bool { match self .layout_qualifiers .remove(&QualifierKey::String(name.into())) { Some((QualifierValue::None, _)) => true, Some((_, meta)) => { errors.push(super::Error { kind: super::ErrorKind::SemanticError( "Qualifier doesn't expect a value".into(), ), meta, }); // Return a `true` to since the qualifier is defined and adding // another error for it not being defined would be unhelpful true } _ => false, } } } #[derive(Debug, Clone)] pub enum FunctionCallKind { TypeConstructor(Handle), Function(String), } #[derive(Debug, Clone)] pub struct FunctionCall { pub kind: FunctionCallKind, pub args: Vec>, } #[derive(Debug, Clone, Copy, PartialEq)] pub enum StorageQualifier { AddressSpace(AddressSpace), Input, Output, Const, } impl Default for StorageQualifier { fn default() -> Self { StorageQualifier::AddressSpace(AddressSpace::Function) } } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum StructLayout { Std140, Std430, } // TODO: Encode precision hints in the IR /// A precision hint used in GLSL declarations. /// /// Precision hints can be used to either speed up shader execution or control /// the precision of arithmetic operations. /// /// To use a precision hint simply add it before the type in the declaration. /// ```glsl /// mediump float a; /// ``` /// /// The default when no precision is declared is `highp` which means that all /// operations operate with the type defined width. /// /// For `mediump` and `lowp` operations follow the spir-v /// [`RelaxedPrecision`][RelaxedPrecision] decoration semantics. /// /// [RelaxedPrecision]: https://www.khronos.org/registry/SPIR-V/specs/unified1/SPIRV.html#_a_id_relaxedprecisionsection_a_relaxed_precision #[derive(Debug, Clone, PartialEq, Copy)] pub enum Precision { /// `lowp` precision Low, /// `mediump` precision Medium, /// `highp` precision High, } #[derive(Debug, Clone, PartialEq, Copy)] pub enum ParameterQualifier { In, Out, InOut, Const, } impl ParameterQualifier { /// Returns true if the argument should be passed as a lhs expression pub const fn is_lhs(&self) -> bool { match *self { ParameterQualifier::Out | ParameterQualifier::InOut => true, _ => false, } } } /// The GLSL profile used by a shader. #[derive(Debug, Clone, Copy, PartialEq)] pub enum Profile { /// The `core` profile, default when no profile is specified. Core, } naga-29.0.3/src/front/glsl/builtins.rs000064400000000000000000002603121046102023000157100ustar 00000000000000use alloc::{vec, vec::Vec}; use super::{ ast::{ BuiltinVariations, FunctionDeclaration, FunctionKind, Overload, ParameterInfo, ParameterQualifier, }, context::Context, Error, ErrorKind, Frontend, Result, }; use crate::{ BinaryOperator, DerivativeAxis as Axis, DerivativeControl as Ctrl, Expression, Handle, ImageClass, ImageDimension as Dim, ImageQuery, MathFunction, Module, RelationalFunction, SampleLevel, Scalar, ScalarKind as Sk, Span, Type, TypeInner, UnaryOperator, VectorSize, }; impl crate::ScalarKind { const fn dummy_storage_format(&self) -> crate::StorageFormat { match *self { Sk::Sint => crate::StorageFormat::R16Sint, Sk::Uint => crate::StorageFormat::R16Uint, _ => crate::StorageFormat::R16Float, } } } impl Module { /// Helper function, to create a function prototype for a builtin fn add_builtin(&mut self, args: Vec, builtin: MacroCall) -> Overload { let mut parameters = Vec::with_capacity(args.len()); let mut parameters_info = Vec::with_capacity(args.len()); for arg in args { parameters.push(self.types.insert( Type { name: None, inner: arg, }, Span::default(), )); parameters_info.push(ParameterInfo { qualifier: ParameterQualifier::In, depth: false, }); } Overload { parameters, parameters_info, kind: FunctionKind::Macro(builtin), defined: false, internal: true, void: false, } } } const fn make_coords_arg(number_of_components: usize, kind: Sk) -> TypeInner { let scalar = Scalar { kind, width: 4 }; match number_of_components { 1 => TypeInner::Scalar(scalar), _ => TypeInner::Vector { size: match number_of_components { 2 => VectorSize::Bi, 3 => VectorSize::Tri, _ => VectorSize::Quad, }, scalar, }, } } /// Inject builtins into the declaration /// /// This is done to not add a large startup cost and not increase memory /// usage if it isn't needed. pub fn inject_builtin( declaration: &mut FunctionDeclaration, module: &mut Module, name: &str, mut variations: BuiltinVariations, ) { log::trace!( "{} variations: {:?} {:?}", name, variations, declaration.variations ); // Don't regeneate variations variations.remove(declaration.variations); declaration.variations |= variations; if variations.contains(BuiltinVariations::STANDARD) { inject_standard_builtins(declaration, module, name) } if variations.contains(BuiltinVariations::DOUBLE) { inject_double_builtin(declaration, module, name) } match name { "texture" | "textureGrad" | "textureGradOffset" | "textureLod" | "textureLodOffset" | "textureOffset" | "textureProj" | "textureProjGrad" | "textureProjGradOffset" | "textureProjLod" | "textureProjLodOffset" | "textureProjOffset" => { let f = |kind, dim, arrayed, multi, shadow| { for bits in 0..=0b11 { let variant = bits & 0b1 != 0; let bias = bits & 0b10 != 0; let (proj, offset, level_type) = match name { // texture(gsampler, gvec P, [float bias]); "texture" => (false, false, TextureLevelType::None), // textureGrad(gsampler, gvec P, gvec dPdx, gvec dPdy); "textureGrad" => (false, false, TextureLevelType::Grad), // textureGradOffset(gsampler, gvec P, gvec dPdx, gvec dPdy, ivec offset); "textureGradOffset" => (false, true, TextureLevelType::Grad), // textureLod(gsampler, gvec P, float lod); "textureLod" => (false, false, TextureLevelType::Lod), // textureLodOffset(gsampler, gvec P, float lod, ivec offset); "textureLodOffset" => (false, true, TextureLevelType::Lod), // textureOffset(gsampler, gvec+1 P, ivec offset, [float bias]); "textureOffset" => (false, true, TextureLevelType::None), // textureProj(gsampler, gvec+1 P, [float bias]); "textureProj" => (true, false, TextureLevelType::None), // textureProjGrad(gsampler, gvec+1 P, gvec dPdx, gvec dPdy); "textureProjGrad" => (true, false, TextureLevelType::Grad), // textureProjGradOffset(gsampler, gvec+1 P, gvec dPdx, gvec dPdy, ivec offset); "textureProjGradOffset" => (true, true, TextureLevelType::Grad), // textureProjLod(gsampler, gvec+1 P, float lod); "textureProjLod" => (true, false, TextureLevelType::Lod), // textureProjLodOffset(gsampler, gvec+1 P, gvec dPdx, gvec dPdy, ivec offset); "textureProjLodOffset" => (true, true, TextureLevelType::Lod), // textureProjOffset(gsampler, gvec+1 P, ivec offset, [float bias]); "textureProjOffset" => (true, true, TextureLevelType::None), _ => unreachable!(), }; let builtin = MacroCall::Texture { proj, offset, shadow, level_type, }; // Parse out the variant settings. let grad = level_type == TextureLevelType::Grad; let lod = level_type == TextureLevelType::Lod; let supports_variant = proj && !shadow; if variant && !supports_variant { continue; } if bias && !matches!(level_type, TextureLevelType::None) { continue; } // Proj doesn't work with arrayed or Cube if proj && (arrayed || dim == Dim::Cube) { continue; } // texture operations with offset are not supported for cube maps if dim == Dim::Cube && offset { continue; } // sampler2DArrayShadow can't be used in textureLod or in texture with bias if (lod || bias) && arrayed && shadow && dim == Dim::D2 { continue; } // TODO: glsl supports using bias with depth samplers but naga doesn't if bias && shadow { continue; } let class = match shadow { true => ImageClass::Depth { multi }, false => ImageClass::Sampled { kind, multi }, }; let image = TypeInner::Image { dim, arrayed, class, }; let num_coords_from_dim = image_dims_to_coords_size(dim).min(3); let mut num_coords = num_coords_from_dim; if shadow && proj { num_coords = 4; } else if dim == Dim::D1 && shadow { num_coords = 3; } else if shadow { num_coords += 1; } else if proj { if variant && num_coords == 4 { // Normal form already has 4 components, no need to have a variant form. continue; } else if variant { num_coords = 4; } else { num_coords += 1; } } if !(dim == Dim::D1 && shadow) { num_coords += arrayed as usize; } // Special case: texture(gsamplerCubeArrayShadow) kicks the shadow compare ref to a separate argument, // since it would otherwise take five arguments. It also can't take a bias, nor can it be proj/grad/lod/offset // (presumably because nobody asked for it, and implementation complexity?) if num_coords >= 5 { if lod || grad || offset || proj || bias { continue; } debug_assert!(dim == Dim::Cube && shadow && arrayed); } debug_assert!(num_coords <= 5); let vector = make_coords_arg(num_coords, Sk::Float); let mut args = vec![image, vector]; if num_coords == 5 { args.push(TypeInner::Scalar(Scalar::F32)); } match level_type { TextureLevelType::Lod => { args.push(TypeInner::Scalar(Scalar::F32)); } TextureLevelType::Grad => { args.push(make_coords_arg(num_coords_from_dim, Sk::Float)); args.push(make_coords_arg(num_coords_from_dim, Sk::Float)); } TextureLevelType::None => {} }; if offset { args.push(make_coords_arg(num_coords_from_dim, Sk::Sint)); } if bias { args.push(TypeInner::Scalar(Scalar::F32)); } declaration .overloads .push(module.add_builtin(args, builtin)); } }; texture_args_generator(TextureArgsOptions::SHADOW | variations.into(), f) } "textureSize" => { let f = |kind, dim, arrayed, multi, shadow| { let class = match shadow { true => ImageClass::Depth { multi }, false => ImageClass::Sampled { kind, multi }, }; let image = TypeInner::Image { dim, arrayed, class, }; let mut args = vec![image]; if !multi { args.push(TypeInner::Scalar(Scalar::I32)) } declaration .overloads .push(module.add_builtin(args, MacroCall::TextureSize { arrayed })) }; texture_args_generator( TextureArgsOptions::SHADOW | TextureArgsOptions::MULTI | variations.into(), f, ) } "textureQueryLevels" => { let f = |kind, dim, arrayed, multi, shadow| { let class = match shadow { true => ImageClass::Depth { multi }, false => ImageClass::Sampled { kind, multi }, }; let image = TypeInner::Image { dim, arrayed, class, }; declaration .overloads .push(module.add_builtin(vec![image], MacroCall::TextureQueryLevels)) }; texture_args_generator(TextureArgsOptions::SHADOW | variations.into(), f) } "texelFetch" | "texelFetchOffset" => { let offset = "texelFetchOffset" == name; let f = |kind, dim, arrayed, multi, _shadow| { // Cube images aren't supported if let Dim::Cube = dim { return; } let image = TypeInner::Image { dim, arrayed, class: ImageClass::Sampled { kind, multi }, }; let dim_value = image_dims_to_coords_size(dim); let coordinates = make_coords_arg(dim_value + arrayed as usize, Sk::Sint); let mut args = vec![image, coordinates, TypeInner::Scalar(Scalar::I32)]; if offset { args.push(make_coords_arg(dim_value, Sk::Sint)); } declaration .overloads .push(module.add_builtin(args, MacroCall::ImageLoad { multi })) }; // Don't generate shadow images since they aren't supported texture_args_generator(TextureArgsOptions::MULTI | variations.into(), f) } "imageSize" => { let f = |kind: Sk, dim, arrayed, _, _| { // Naga doesn't support cube images and it's usefulness // is questionable, so they won't be supported for now if dim == Dim::Cube { return; } let image = TypeInner::Image { dim, arrayed, class: ImageClass::Storage { format: kind.dummy_storage_format(), access: crate::StorageAccess::empty(), }, }; declaration .overloads .push(module.add_builtin(vec![image], MacroCall::TextureSize { arrayed })) }; texture_args_generator(variations.into(), f) } "imageLoad" => { let f = |kind: Sk, dim, arrayed, _, _| { // Naga doesn't support cube images and it's usefulness // is questionable, so they won't be supported for now if dim == Dim::Cube { return; } let image = TypeInner::Image { dim, arrayed, class: ImageClass::Storage { format: kind.dummy_storage_format(), access: crate::StorageAccess::LOAD, }, }; let dim_value = image_dims_to_coords_size(dim); let mut coord_size = dim_value + arrayed as usize; // > Every OpenGL API call that operates on cubemap array // > textures takes layer-faces, not array layers // // So this means that imageCubeArray only takes a three component // vector coordinate and the third component is a layer index. if Dim::Cube == dim && arrayed { coord_size = 3 } let coordinates = make_coords_arg(coord_size, Sk::Sint); let args = vec![image, coordinates]; declaration .overloads .push(module.add_builtin(args, MacroCall::ImageLoad { multi: false })) }; // Don't generate shadow nor multisampled images since they aren't supported texture_args_generator(variations.into(), f) } "imageStore" => { let f = |kind: Sk, dim, arrayed, _, _| { // Naga doesn't support cube images and it's usefulness // is questionable, so they won't be supported for now if dim == Dim::Cube { return; } let image = TypeInner::Image { dim, arrayed, class: ImageClass::Storage { format: kind.dummy_storage_format(), access: crate::StorageAccess::STORE, }, }; let dim_value = image_dims_to_coords_size(dim); let mut coord_size = dim_value + arrayed as usize; // > Every OpenGL API call that operates on cubemap array // > textures takes layer-faces, not array layers // // So this means that imageCubeArray only takes a three component // vector coordinate and the third component is a layer index. if Dim::Cube == dim && arrayed { coord_size = 3 } let coordinates = make_coords_arg(coord_size, Sk::Sint); let args = vec![ image, coordinates, TypeInner::Vector { size: VectorSize::Quad, scalar: Scalar { kind, width: 4 }, }, ]; let mut overload = module.add_builtin(args, MacroCall::ImageStore); overload.void = true; declaration.overloads.push(overload) }; // Don't generate shadow nor multisampled images since they aren't supported texture_args_generator(variations.into(), f) } _ => {} } } /// Injects the builtins into declaration that don't need any special variations fn inject_standard_builtins( declaration: &mut FunctionDeclaration, module: &mut Module, name: &str, ) { // Some samplers (sampler1D, etc...) can be float, int, or uint let anykind_sampler = if name.starts_with("sampler") { Some((name, Sk::Float)) } else if name.starts_with("usampler") { Some((&name[1..], Sk::Uint)) } else if name.starts_with("isampler") { Some((&name[1..], Sk::Sint)) } else { None }; if let Some((sampler, kind)) = anykind_sampler { match sampler { "sampler1D" | "sampler1DArray" | "sampler2D" | "sampler2DArray" | "sampler2DMS" | "sampler2DMSArray" | "sampler3D" | "samplerCube" | "samplerCubeArray" => { declaration.overloads.push(module.add_builtin( vec![ TypeInner::Image { dim: match sampler { "sampler1D" | "sampler1DArray" => Dim::D1, "sampler2D" | "sampler2DArray" | "sampler2DMS" | "sampler2DMSArray" => Dim::D2, "sampler3D" => Dim::D3, _ => Dim::Cube, }, arrayed: matches!( sampler, "sampler1DArray" | "sampler2DArray" | "sampler2DMSArray" | "samplerCubeArray" ), class: ImageClass::Sampled { kind, multi: matches!(sampler, "sampler2DMS" | "sampler2DMSArray"), }, }, TypeInner::Sampler { comparison: false }, ], MacroCall::Sampler, )); return; } _ => (), } } match name { // Shadow sampler can only be of kind `Sk::Float` "sampler1DShadow" | "sampler1DArrayShadow" | "sampler2DShadow" | "sampler2DArrayShadow" | "samplerCubeShadow" | "samplerCubeArrayShadow" => { let dim = match name { "sampler1DShadow" | "sampler1DArrayShadow" => Dim::D1, "sampler2DShadow" | "sampler2DArrayShadow" => Dim::D2, _ => Dim::Cube, }; let arrayed = matches!( name, "sampler1DArrayShadow" | "sampler2DArrayShadow" | "samplerCubeArrayShadow" ); for i in 0..2 { let ty = TypeInner::Image { dim, arrayed, class: match i { 0 => ImageClass::Sampled { kind: Sk::Float, multi: false, }, _ => ImageClass::Depth { multi: false }, }, }; declaration.overloads.push(module.add_builtin( vec![ty, TypeInner::Sampler { comparison: true }], MacroCall::SamplerShadow, )) } } "sin" | "exp" | "exp2" | "sinh" | "cos" | "cosh" | "tan" | "tanh" | "acos" | "asin" | "log" | "log2" | "radians" | "degrees" | "asinh" | "acosh" | "atanh" | "floatBitsToInt" | "floatBitsToUint" | "dFdx" | "dFdxFine" | "dFdxCoarse" | "dFdy" | "dFdyFine" | "dFdyCoarse" | "fwidth" | "fwidthFine" | "fwidthCoarse" => { // bits layout // bit 0 through 1 - dims for bits in 0..0b100 { let size = match bits { 0b00 => None, 0b01 => Some(VectorSize::Bi), 0b10 => Some(VectorSize::Tri), _ => Some(VectorSize::Quad), }; let scalar = Scalar::F32; declaration.overloads.push(module.add_builtin( vec![match size { Some(size) => TypeInner::Vector { size, scalar }, None => TypeInner::Scalar(scalar), }], match name { "sin" => MacroCall::MathFunction(MathFunction::Sin), "exp" => MacroCall::MathFunction(MathFunction::Exp), "exp2" => MacroCall::MathFunction(MathFunction::Exp2), "sinh" => MacroCall::MathFunction(MathFunction::Sinh), "cos" => MacroCall::MathFunction(MathFunction::Cos), "cosh" => MacroCall::MathFunction(MathFunction::Cosh), "tan" => MacroCall::MathFunction(MathFunction::Tan), "tanh" => MacroCall::MathFunction(MathFunction::Tanh), "acos" => MacroCall::MathFunction(MathFunction::Acos), "asin" => MacroCall::MathFunction(MathFunction::Asin), "log" => MacroCall::MathFunction(MathFunction::Log), "log2" => MacroCall::MathFunction(MathFunction::Log2), "asinh" => MacroCall::MathFunction(MathFunction::Asinh), "acosh" => MacroCall::MathFunction(MathFunction::Acosh), "atanh" => MacroCall::MathFunction(MathFunction::Atanh), "radians" => MacroCall::MathFunction(MathFunction::Radians), "degrees" => MacroCall::MathFunction(MathFunction::Degrees), "floatBitsToInt" => MacroCall::BitCast(Sk::Sint), "floatBitsToUint" => MacroCall::BitCast(Sk::Uint), "dFdxCoarse" => MacroCall::Derivate(Axis::X, Ctrl::Coarse), "dFdyCoarse" => MacroCall::Derivate(Axis::Y, Ctrl::Coarse), "fwidthCoarse" => MacroCall::Derivate(Axis::Width, Ctrl::Coarse), "dFdxFine" => MacroCall::Derivate(Axis::X, Ctrl::Fine), "dFdyFine" => MacroCall::Derivate(Axis::Y, Ctrl::Fine), "fwidthFine" => MacroCall::Derivate(Axis::Width, Ctrl::Fine), "dFdx" => MacroCall::Derivate(Axis::X, Ctrl::None), "dFdy" => MacroCall::Derivate(Axis::Y, Ctrl::None), "fwidth" => MacroCall::Derivate(Axis::Width, Ctrl::None), _ => unreachable!(), }, )) } } "intBitsToFloat" | "uintBitsToFloat" => { // bits layout // bit 0 through 1 - dims for bits in 0..0b100 { let size = match bits { 0b00 => None, 0b01 => Some(VectorSize::Bi), 0b10 => Some(VectorSize::Tri), _ => Some(VectorSize::Quad), }; let scalar = match name { "intBitsToFloat" => Scalar::I32, _ => Scalar::U32, }; declaration.overloads.push(module.add_builtin( vec![match size { Some(size) => TypeInner::Vector { size, scalar }, None => TypeInner::Scalar(scalar), }], MacroCall::BitCast(Sk::Float), )) } } "pow" => { // bits layout // bit 0 through 1 - dims for bits in 0..0b100 { let size = match bits { 0b00 => None, 0b01 => Some(VectorSize::Bi), 0b10 => Some(VectorSize::Tri), _ => Some(VectorSize::Quad), }; let scalar = Scalar::F32; let ty = || match size { Some(size) => TypeInner::Vector { size, scalar }, None => TypeInner::Scalar(scalar), }; declaration.overloads.push( module .add_builtin(vec![ty(), ty()], MacroCall::MathFunction(MathFunction::Pow)), ) } } "abs" | "sign" => { // bits layout // bit 0 through 1 - dims // bit 2 - float/sint for bits in 0..0b1000 { let size = match bits & 0b11 { 0b00 => None, 0b01 => Some(VectorSize::Bi), 0b10 => Some(VectorSize::Tri), _ => Some(VectorSize::Quad), }; let scalar = match bits >> 2 { 0b0 => Scalar::F32, _ => Scalar::I32, }; let args = vec![match size { Some(size) => TypeInner::Vector { size, scalar }, None => TypeInner::Scalar(scalar), }]; declaration.overloads.push(module.add_builtin( args, MacroCall::MathFunction(match name { "abs" => MathFunction::Abs, "sign" => MathFunction::Sign, _ => unreachable!(), }), )) } } "bitCount" | "bitfieldReverse" | "bitfieldExtract" | "bitfieldInsert" | "findLSB" | "findMSB" => { let fun = match name { "bitCount" => MathFunction::CountOneBits, "bitfieldReverse" => MathFunction::ReverseBits, "bitfieldExtract" => MathFunction::ExtractBits, "bitfieldInsert" => MathFunction::InsertBits, "findLSB" => MathFunction::FirstTrailingBit, "findMSB" => MathFunction::FirstLeadingBit, _ => unreachable!(), }; let mc = match fun { MathFunction::ExtractBits => MacroCall::BitfieldExtract, MathFunction::InsertBits => MacroCall::BitfieldInsert, _ => MacroCall::MathFunction(fun), }; // bits layout // bit 0 - int/uint // bit 1 through 2 - dims for bits in 0..0b1000 { let scalar = match bits & 0b1 { 0b0 => Scalar::I32, _ => Scalar::U32, }; let size = match bits >> 1 { 0b00 => None, 0b01 => Some(VectorSize::Bi), 0b10 => Some(VectorSize::Tri), _ => Some(VectorSize::Quad), }; let ty = || match size { Some(size) => TypeInner::Vector { size, scalar }, None => TypeInner::Scalar(scalar), }; let mut args = vec![ty()]; match fun { MathFunction::ExtractBits => { args.push(TypeInner::Scalar(Scalar::I32)); args.push(TypeInner::Scalar(Scalar::I32)); } MathFunction::InsertBits => { args.push(ty()); args.push(TypeInner::Scalar(Scalar::I32)); args.push(TypeInner::Scalar(Scalar::I32)); } _ => {} } // we need to cast the return type of findLsb / findMsb let mc = if scalar.kind == Sk::Uint { match mc { MacroCall::MathFunction(MathFunction::FirstTrailingBit) => { MacroCall::FindLsbUint } MacroCall::MathFunction(MathFunction::FirstLeadingBit) => { MacroCall::FindMsbUint } mc => mc, } } else { mc }; declaration.overloads.push(module.add_builtin(args, mc)) } } "packSnorm4x8" | "packUnorm4x8" | "packSnorm2x16" | "packUnorm2x16" | "packHalf2x16" => { let fun = match name { "packSnorm4x8" => MathFunction::Pack4x8snorm, "packUnorm4x8" => MathFunction::Pack4x8unorm, "packSnorm2x16" => MathFunction::Pack2x16unorm, "packUnorm2x16" => MathFunction::Pack2x16snorm, "packHalf2x16" => MathFunction::Pack2x16float, _ => unreachable!(), }; let ty = match fun { MathFunction::Pack4x8snorm | MathFunction::Pack4x8unorm => TypeInner::Vector { size: VectorSize::Quad, scalar: Scalar::F32, }, MathFunction::Pack2x16unorm | MathFunction::Pack2x16snorm | MathFunction::Pack2x16float => TypeInner::Vector { size: VectorSize::Bi, scalar: Scalar::F32, }, _ => unreachable!(), }; let args = vec![ty]; declaration .overloads .push(module.add_builtin(args, MacroCall::MathFunction(fun))); } "unpackSnorm4x8" | "unpackUnorm4x8" | "unpackSnorm2x16" | "unpackUnorm2x16" | "unpackHalf2x16" => { let fun = match name { "unpackSnorm4x8" => MathFunction::Unpack4x8snorm, "unpackUnorm4x8" => MathFunction::Unpack4x8unorm, "unpackSnorm2x16" => MathFunction::Unpack2x16snorm, "unpackUnorm2x16" => MathFunction::Unpack2x16unorm, "unpackHalf2x16" => MathFunction::Unpack2x16float, _ => unreachable!(), }; let args = vec![TypeInner::Scalar(Scalar::U32)]; declaration .overloads .push(module.add_builtin(args, MacroCall::MathFunction(fun))); } "atan" => { // bits layout // bit 0 - atan/atan2 // bit 1 through 2 - dims for bits in 0..0b1000 { let fun = match bits & 0b1 { 0b0 => MathFunction::Atan, _ => MathFunction::Atan2, }; let size = match bits >> 1 { 0b00 => None, 0b01 => Some(VectorSize::Bi), 0b10 => Some(VectorSize::Tri), _ => Some(VectorSize::Quad), }; let scalar = Scalar::F32; let ty = || match size { Some(size) => TypeInner::Vector { size, scalar }, None => TypeInner::Scalar(scalar), }; let mut args = vec![ty()]; if fun == MathFunction::Atan2 { args.push(ty()) } declaration .overloads .push(module.add_builtin(args, MacroCall::MathFunction(fun))) } } "all" | "any" | "not" => { // bits layout // bit 0 through 1 - dims for bits in 0..0b11 { let size = match bits { 0b00 => VectorSize::Bi, 0b01 => VectorSize::Tri, _ => VectorSize::Quad, }; let args = vec![TypeInner::Vector { size, scalar: Scalar::BOOL, }]; let fun = match name { "all" => MacroCall::Relational(RelationalFunction::All), "any" => MacroCall::Relational(RelationalFunction::Any), "not" => MacroCall::Unary(UnaryOperator::LogicalNot), _ => unreachable!(), }; declaration.overloads.push(module.add_builtin(args, fun)) } } "lessThan" | "greaterThan" | "lessThanEqual" | "greaterThanEqual" => { for bits in 0..0b1001 { let (size, scalar) = match bits { 0b0000 => (VectorSize::Bi, Scalar::F32), 0b0001 => (VectorSize::Tri, Scalar::F32), 0b0010 => (VectorSize::Quad, Scalar::F32), 0b0011 => (VectorSize::Bi, Scalar::I32), 0b0100 => (VectorSize::Tri, Scalar::I32), 0b0101 => (VectorSize::Quad, Scalar::I32), 0b0110 => (VectorSize::Bi, Scalar::U32), 0b0111 => (VectorSize::Tri, Scalar::U32), _ => (VectorSize::Quad, Scalar::U32), }; let ty = || TypeInner::Vector { size, scalar }; let args = vec![ty(), ty()]; let fun = MacroCall::Binary(match name { "lessThan" => BinaryOperator::Less, "greaterThan" => BinaryOperator::Greater, "lessThanEqual" => BinaryOperator::LessEqual, "greaterThanEqual" => BinaryOperator::GreaterEqual, _ => unreachable!(), }); declaration.overloads.push(module.add_builtin(args, fun)) } } "equal" | "notEqual" => { for bits in 0..0b1100 { let (size, scalar) = match bits { 0b0000 => (VectorSize::Bi, Scalar::F32), 0b0001 => (VectorSize::Tri, Scalar::F32), 0b0010 => (VectorSize::Quad, Scalar::F32), 0b0011 => (VectorSize::Bi, Scalar::I32), 0b0100 => (VectorSize::Tri, Scalar::I32), 0b0101 => (VectorSize::Quad, Scalar::I32), 0b0110 => (VectorSize::Bi, Scalar::U32), 0b0111 => (VectorSize::Tri, Scalar::U32), 0b1000 => (VectorSize::Quad, Scalar::U32), 0b1001 => (VectorSize::Bi, Scalar::BOOL), 0b1010 => (VectorSize::Tri, Scalar::BOOL), _ => (VectorSize::Quad, Scalar::BOOL), }; let ty = || TypeInner::Vector { size, scalar }; let args = vec![ty(), ty()]; let fun = MacroCall::Binary(match name { "equal" => BinaryOperator::Equal, "notEqual" => BinaryOperator::NotEqual, _ => unreachable!(), }); declaration.overloads.push(module.add_builtin(args, fun)) } } "min" | "max" => { // bits layout // bit 0 through 1 - scalar kind // bit 2 through 4 - dims for bits in 0..0b11100 { let scalar = match bits & 0b11 { 0b00 => Scalar::F32, 0b01 => Scalar::I32, 0b10 => Scalar::U32, _ => continue, }; let (size, second_size) = match bits >> 2 { 0b000 => (None, None), 0b001 => (Some(VectorSize::Bi), None), 0b010 => (Some(VectorSize::Tri), None), 0b011 => (Some(VectorSize::Quad), None), 0b100 => (Some(VectorSize::Bi), Some(VectorSize::Bi)), 0b101 => (Some(VectorSize::Tri), Some(VectorSize::Tri)), _ => (Some(VectorSize::Quad), Some(VectorSize::Quad)), }; let args = vec![ match size { Some(size) => TypeInner::Vector { size, scalar }, None => TypeInner::Scalar(scalar), }, match second_size { Some(size) => TypeInner::Vector { size, scalar }, None => TypeInner::Scalar(scalar), }, ]; let fun = match name { "max" => MacroCall::Splatted(MathFunction::Max, size, 1), "min" => MacroCall::Splatted(MathFunction::Min, size, 1), _ => unreachable!(), }; declaration.overloads.push(module.add_builtin(args, fun)) } } "mix" => { // bits layout // bit 0 through 1 - dims // bit 2 through 4 - types // // 0b10011 is the last element since splatted single elements // were already added for bits in 0..0b10011 { let size = match bits & 0b11 { 0b00 => Some(VectorSize::Bi), 0b01 => Some(VectorSize::Tri), 0b10 => Some(VectorSize::Quad), _ => None, }; let (scalar, splatted, boolean) = match bits >> 2 { 0b000 => (Scalar::I32, false, true), 0b001 => (Scalar::U32, false, true), 0b010 => (Scalar::F32, false, true), 0b011 => (Scalar::F32, false, false), _ => (Scalar::F32, true, false), }; let ty = |scalar| match size { Some(size) => TypeInner::Vector { size, scalar }, None => TypeInner::Scalar(scalar), }; let args = vec![ ty(scalar), ty(scalar), match (boolean, splatted) { (true, _) => ty(Scalar::BOOL), (_, false) => TypeInner::Scalar(scalar), _ => ty(scalar), }, ]; declaration.overloads.push(module.add_builtin( args, match boolean { true => MacroCall::MixBoolean, false => MacroCall::Splatted(MathFunction::Mix, size, 2), }, )) } } "clamp" => { // bits layout // bit 0 through 1 - float/int/uint // bit 2 through 3 - dims // bit 4 - splatted // // 0b11010 is the last element since splatted single elements // were already added for bits in 0..0b11011 { let scalar = match bits & 0b11 { 0b00 => Scalar::F32, 0b01 => Scalar::I32, 0b10 => Scalar::U32, _ => continue, }; let size = match (bits >> 2) & 0b11 { 0b00 => Some(VectorSize::Bi), 0b01 => Some(VectorSize::Tri), 0b10 => Some(VectorSize::Quad), _ => None, }; let splatted = bits & 0b10000 == 0b10000; let base_ty = || match size { Some(size) => TypeInner::Vector { size, scalar }, None => TypeInner::Scalar(scalar), }; let limit_ty = || match splatted { true => TypeInner::Scalar(scalar), false => base_ty(), }; let args = vec![base_ty(), limit_ty(), limit_ty()]; declaration .overloads .push(module.add_builtin(args, MacroCall::Clamp(size))) } } "barrier" => declaration .overloads .push(module.add_builtin(Vec::new(), MacroCall::Barrier)), // Add common builtins with floats _ => inject_common_builtin(declaration, module, name, 4), } } /// Injects the builtins into declaration that need doubles fn inject_double_builtin(declaration: &mut FunctionDeclaration, module: &mut Module, name: &str) { match name { "abs" | "sign" => { // bits layout // bit 0 through 1 - dims for bits in 0..0b100 { let size = match bits { 0b00 => None, 0b01 => Some(VectorSize::Bi), 0b10 => Some(VectorSize::Tri), _ => Some(VectorSize::Quad), }; let scalar = Scalar::F64; let args = vec![match size { Some(size) => TypeInner::Vector { size, scalar }, None => TypeInner::Scalar(scalar), }]; declaration.overloads.push(module.add_builtin( args, MacroCall::MathFunction(match name { "abs" => MathFunction::Abs, "sign" => MathFunction::Sign, _ => unreachable!(), }), )) } } "min" | "max" => { // bits layout // bit 0 through 2 - dims for bits in 0..0b111 { let (size, second_size) = match bits { 0b000 => (None, None), 0b001 => (Some(VectorSize::Bi), None), 0b010 => (Some(VectorSize::Tri), None), 0b011 => (Some(VectorSize::Quad), None), 0b100 => (Some(VectorSize::Bi), Some(VectorSize::Bi)), 0b101 => (Some(VectorSize::Tri), Some(VectorSize::Tri)), _ => (Some(VectorSize::Quad), Some(VectorSize::Quad)), }; let scalar = Scalar::F64; let args = vec![ match size { Some(size) => TypeInner::Vector { size, scalar }, None => TypeInner::Scalar(scalar), }, match second_size { Some(size) => TypeInner::Vector { size, scalar }, None => TypeInner::Scalar(scalar), }, ]; let fun = match name { "max" => MacroCall::Splatted(MathFunction::Max, size, 1), "min" => MacroCall::Splatted(MathFunction::Min, size, 1), _ => unreachable!(), }; declaration.overloads.push(module.add_builtin(args, fun)) } } "mix" => { // bits layout // bit 0 through 1 - dims // bit 2 through 3 - splatted/boolean // // 0b1010 is the last element since splatted with single elements // is equal to normal single elements for bits in 0..0b1011 { let size = match bits & 0b11 { 0b00 => Some(VectorSize::Quad), 0b01 => Some(VectorSize::Bi), 0b10 => Some(VectorSize::Tri), _ => None, }; let scalar = Scalar::F64; let (splatted, boolean) = match bits >> 2 { 0b00 => (false, false), 0b01 => (false, true), _ => (true, false), }; let ty = |scalar| match size { Some(size) => TypeInner::Vector { size, scalar }, None => TypeInner::Scalar(scalar), }; let args = vec![ ty(scalar), ty(scalar), match (boolean, splatted) { (true, _) => ty(Scalar::BOOL), (_, false) => TypeInner::Scalar(scalar), _ => ty(scalar), }, ]; declaration.overloads.push(module.add_builtin( args, match boolean { true => MacroCall::MixBoolean, false => MacroCall::Splatted(MathFunction::Mix, size, 2), }, )) } } "clamp" => { // bits layout // bit 0 through 1 - dims // bit 2 - splatted // // 0b110 is the last element since splatted with single elements // is equal to normal single elements for bits in 0..0b111 { let scalar = Scalar::F64; let size = match bits & 0b11 { 0b00 => Some(VectorSize::Bi), 0b01 => Some(VectorSize::Tri), 0b10 => Some(VectorSize::Quad), _ => None, }; let splatted = bits & 0b100 == 0b100; let base_ty = || match size { Some(size) => TypeInner::Vector { size, scalar }, None => TypeInner::Scalar(scalar), }; let limit_ty = || match splatted { true => TypeInner::Scalar(scalar), false => base_ty(), }; let args = vec![base_ty(), limit_ty(), limit_ty()]; declaration .overloads .push(module.add_builtin(args, MacroCall::Clamp(size))) } } "lessThan" | "greaterThan" | "lessThanEqual" | "greaterThanEqual" | "equal" | "notEqual" => { let scalar = Scalar::F64; for bits in 0..0b11 { let size = match bits { 0b00 => VectorSize::Bi, 0b01 => VectorSize::Tri, _ => VectorSize::Quad, }; let ty = || TypeInner::Vector { size, scalar }; let args = vec![ty(), ty()]; let fun = MacroCall::Binary(match name { "lessThan" => BinaryOperator::Less, "greaterThan" => BinaryOperator::Greater, "lessThanEqual" => BinaryOperator::LessEqual, "greaterThanEqual" => BinaryOperator::GreaterEqual, "equal" => BinaryOperator::Equal, "notEqual" => BinaryOperator::NotEqual, _ => unreachable!(), }); declaration.overloads.push(module.add_builtin(args, fun)) } } // Add common builtins with doubles _ => inject_common_builtin(declaration, module, name, 8), } } /// Injects the builtins into declaration that can used either float or doubles fn inject_common_builtin( declaration: &mut FunctionDeclaration, module: &mut Module, name: &str, float_width: crate::Bytes, ) { let float_scalar = Scalar { kind: Sk::Float, width: float_width, }; match name { "ceil" | "round" | "roundEven" | "floor" | "fract" | "trunc" | "sqrt" | "inversesqrt" | "normalize" | "length" | "isinf" | "isnan" => { // bits layout // bit 0 through 1 - dims for bits in 0..0b100 { let size = match bits { 0b00 => None, 0b01 => Some(VectorSize::Bi), 0b10 => Some(VectorSize::Tri), _ => Some(VectorSize::Quad), }; let args = vec![match size { Some(size) => TypeInner::Vector { size, scalar: float_scalar, }, None => TypeInner::Scalar(float_scalar), }]; let fun = match name { "ceil" => MacroCall::MathFunction(MathFunction::Ceil), "round" | "roundEven" => MacroCall::MathFunction(MathFunction::Round), "floor" => MacroCall::MathFunction(MathFunction::Floor), "fract" => MacroCall::MathFunction(MathFunction::Fract), "trunc" => MacroCall::MathFunction(MathFunction::Trunc), "sqrt" => MacroCall::MathFunction(MathFunction::Sqrt), "inversesqrt" => MacroCall::MathFunction(MathFunction::InverseSqrt), "normalize" => MacroCall::MathFunction(MathFunction::Normalize), "length" => MacroCall::MathFunction(MathFunction::Length), "isinf" => MacroCall::Relational(RelationalFunction::IsInf), "isnan" => MacroCall::Relational(RelationalFunction::IsNan), _ => unreachable!(), }; declaration.overloads.push(module.add_builtin(args, fun)) } } "dot" | "reflect" | "distance" | "ldexp" => { // bits layout // bit 0 through 1 - dims for bits in 0..0b100 { let size = match bits { 0b00 => None, 0b01 => Some(VectorSize::Bi), 0b10 => Some(VectorSize::Tri), _ => Some(VectorSize::Quad), }; let ty = |scalar| match size { Some(size) => TypeInner::Vector { size, scalar }, None => TypeInner::Scalar(scalar), }; let fun = match name { "dot" => MacroCall::MathFunction(MathFunction::Dot), "reflect" => MacroCall::MathFunction(MathFunction::Reflect), "distance" => MacroCall::MathFunction(MathFunction::Distance), "ldexp" => MacroCall::MathFunction(MathFunction::Ldexp), _ => unreachable!(), }; let second_scalar = match fun { MacroCall::MathFunction(MathFunction::Ldexp) => Scalar::I32, _ => float_scalar, }; declaration .overloads .push(module.add_builtin(vec![ty(float_scalar), ty(second_scalar)], fun)) } } "transpose" => { // bits layout // bit 0 through 3 - dims for bits in 0..0b1001 { let (rows, columns) = match bits { 0b0000 => (VectorSize::Bi, VectorSize::Bi), 0b0001 => (VectorSize::Bi, VectorSize::Tri), 0b0010 => (VectorSize::Bi, VectorSize::Quad), 0b0011 => (VectorSize::Tri, VectorSize::Bi), 0b0100 => (VectorSize::Tri, VectorSize::Tri), 0b0101 => (VectorSize::Tri, VectorSize::Quad), 0b0110 => (VectorSize::Quad, VectorSize::Bi), 0b0111 => (VectorSize::Quad, VectorSize::Tri), _ => (VectorSize::Quad, VectorSize::Quad), }; declaration.overloads.push(module.add_builtin( vec![TypeInner::Matrix { columns, rows, scalar: float_scalar, }], MacroCall::MathFunction(MathFunction::Transpose), )) } } "inverse" | "determinant" => { // bits layout // bit 0 through 1 - dims for bits in 0..0b11 { let (rows, columns) = match bits { 0b00 => (VectorSize::Bi, VectorSize::Bi), 0b01 => (VectorSize::Tri, VectorSize::Tri), _ => (VectorSize::Quad, VectorSize::Quad), }; let args = vec![TypeInner::Matrix { columns, rows, scalar: float_scalar, }]; declaration.overloads.push(module.add_builtin( args, MacroCall::MathFunction(match name { "inverse" => MathFunction::Inverse, "determinant" => MathFunction::Determinant, _ => unreachable!(), }), )) } } "mod" | "step" => { // bits layout // bit 0 through 2 - dims for bits in 0..0b111 { let (size, second_size) = match bits { 0b000 => (None, None), 0b001 => (Some(VectorSize::Bi), None), 0b010 => (Some(VectorSize::Tri), None), 0b011 => (Some(VectorSize::Quad), None), 0b100 => (Some(VectorSize::Bi), Some(VectorSize::Bi)), 0b101 => (Some(VectorSize::Tri), Some(VectorSize::Tri)), _ => (Some(VectorSize::Quad), Some(VectorSize::Quad)), }; let mut args = Vec::with_capacity(2); let step = name == "step"; for i in 0..2 { let maybe_size = match i == step as u32 { true => size, false => second_size, }; args.push(match maybe_size { Some(size) => TypeInner::Vector { size, scalar: float_scalar, }, None => TypeInner::Scalar(float_scalar), }) } let fun = match name { "mod" => MacroCall::Mod(size), "step" => MacroCall::Splatted(MathFunction::Step, size, 0), _ => unreachable!(), }; declaration.overloads.push(module.add_builtin(args, fun)) } } // TODO: https://github.com/gfx-rs/naga/issues/2526 // "modf" | "frexp" => { ... } "cross" => { let args = vec![ TypeInner::Vector { size: VectorSize::Tri, scalar: float_scalar, }, TypeInner::Vector { size: VectorSize::Tri, scalar: float_scalar, }, ]; declaration .overloads .push(module.add_builtin(args, MacroCall::MathFunction(MathFunction::Cross))) } "outerProduct" => { // bits layout // bit 0 through 3 - dims for bits in 0..0b1001 { let (size1, size2) = match bits { 0b0000 => (VectorSize::Bi, VectorSize::Bi), 0b0001 => (VectorSize::Bi, VectorSize::Tri), 0b0010 => (VectorSize::Bi, VectorSize::Quad), 0b0011 => (VectorSize::Tri, VectorSize::Bi), 0b0100 => (VectorSize::Tri, VectorSize::Tri), 0b0101 => (VectorSize::Tri, VectorSize::Quad), 0b0110 => (VectorSize::Quad, VectorSize::Bi), 0b0111 => (VectorSize::Quad, VectorSize::Tri), _ => (VectorSize::Quad, VectorSize::Quad), }; let args = vec![ TypeInner::Vector { size: size1, scalar: float_scalar, }, TypeInner::Vector { size: size2, scalar: float_scalar, }, ]; declaration .overloads .push(module.add_builtin(args, MacroCall::MathFunction(MathFunction::Outer))) } } "faceforward" | "fma" => { // bits layout // bit 0 through 1 - dims for bits in 0..0b100 { let size = match bits { 0b00 => None, 0b01 => Some(VectorSize::Bi), 0b10 => Some(VectorSize::Tri), _ => Some(VectorSize::Quad), }; let ty = || match size { Some(size) => TypeInner::Vector { size, scalar: float_scalar, }, None => TypeInner::Scalar(float_scalar), }; let args = vec![ty(), ty(), ty()]; let fun = match name { "faceforward" => MacroCall::MathFunction(MathFunction::FaceForward), "fma" => MacroCall::MathFunction(MathFunction::Fma), _ => unreachable!(), }; declaration.overloads.push(module.add_builtin(args, fun)) } } "refract" => { // bits layout // bit 0 through 1 - dims for bits in 0..0b100 { let size = match bits { 0b00 => None, 0b01 => Some(VectorSize::Bi), 0b10 => Some(VectorSize::Tri), _ => Some(VectorSize::Quad), }; let ty = || match size { Some(size) => TypeInner::Vector { size, scalar: float_scalar, }, None => TypeInner::Scalar(float_scalar), }; let args = vec![ty(), ty(), TypeInner::Scalar(Scalar::F32)]; declaration .overloads .push(module.add_builtin(args, MacroCall::MathFunction(MathFunction::Refract))) } } "smoothstep" => { // bit 0 - splatted // bit 1 through 2 - dims for bits in 0..0b1000 { let splatted = bits & 0b1 == 0b1; let size = match bits >> 1 { 0b00 => None, 0b01 => Some(VectorSize::Bi), 0b10 => Some(VectorSize::Tri), _ => Some(VectorSize::Quad), }; if splatted && size.is_none() { continue; } let base_ty = || match size { Some(size) => TypeInner::Vector { size, scalar: float_scalar, }, None => TypeInner::Scalar(float_scalar), }; let ty = || match splatted { true => TypeInner::Scalar(float_scalar), false => base_ty(), }; declaration.overloads.push(module.add_builtin( vec![ty(), ty(), base_ty()], MacroCall::SmoothStep { splatted: size }, )) } } // The function isn't a builtin or we don't yet support it _ => {} } } #[derive(Clone, Copy, PartialEq, Debug)] pub enum TextureLevelType { None, Lod, Grad, } /// A compiler defined builtin function #[derive(Clone, Copy, PartialEq, Debug)] pub enum MacroCall { Sampler, SamplerShadow, Texture { proj: bool, offset: bool, shadow: bool, level_type: TextureLevelType, }, TextureSize { arrayed: bool, }, TextureQueryLevels, ImageLoad { multi: bool, }, ImageStore, MathFunction(MathFunction), FindLsbUint, FindMsbUint, BitfieldExtract, BitfieldInsert, Relational(RelationalFunction), Unary(UnaryOperator), Binary(BinaryOperator), Mod(Option), Splatted(MathFunction, Option, usize), MixBoolean, Clamp(Option), BitCast(Sk), Derivate(Axis, Ctrl), Barrier, /// SmoothStep needs a separate variant because it might need it's inputs /// to be splatted depending on the overload SmoothStep { /// The size of the splat operation if some splatted: Option, }, } impl MacroCall { /// Adds the necessary expressions and statements to the passed body and /// finally returns the final expression with the correct result pub fn call( &self, frontend: &mut Frontend, ctx: &mut Context, args: &mut [Handle], meta: Span, ) -> Result>> { Ok(Some(match *self { MacroCall::Sampler => { ctx.samplers.insert(args[0], args[1]); args[0] } MacroCall::SamplerShadow => { sampled_to_depth(ctx, args[0], meta, &mut frontend.errors); ctx.invalidate_expression(args[0], meta)?; ctx.samplers.insert(args[0], args[1]); args[0] } MacroCall::Texture { proj, offset, shadow, level_type, } => { let mut coords = args[1]; if proj { let size = match *ctx.resolve_type(coords, meta)? { TypeInner::Vector { size, .. } => size, _ => unreachable!(), }; let mut right = ctx.add_expression( Expression::AccessIndex { base: coords, index: size as u32 - 1, }, Span::default(), )?; let left = if let VectorSize::Bi = size { ctx.add_expression( Expression::AccessIndex { base: coords, index: 0, }, Span::default(), )? } else { let size = match size { VectorSize::Tri => VectorSize::Bi, _ => VectorSize::Tri, }; right = ctx.add_expression( Expression::Splat { size, value: right }, Span::default(), )?; ctx.vector_resize(size, coords, Span::default())? }; coords = ctx.add_expression( Expression::Binary { op: BinaryOperator::Divide, left, right, }, Span::default(), )?; } let extra = args.get(2).copied(); let comps = frontend.coordinate_components(ctx, args[0], coords, extra, meta)?; let mut num_args = 2; if comps.used_extra { num_args += 1; }; // Parse out explicit texture level. let mut level = match level_type { TextureLevelType::None => SampleLevel::Auto, TextureLevelType::Lod => { num_args += 1; if shadow { log::debug!("Assuming LOD {:?} is zero", args[2],); SampleLevel::Zero } else { SampleLevel::Exact(args[2]) } } TextureLevelType::Grad => { num_args += 2; if shadow { log::debug!( "Assuming gradients {:?} and {:?} are not greater than 1", args[2], args[3], ); SampleLevel::Zero } else { SampleLevel::Gradient { x: args[2], y: args[3], } } } }; let texture_offset = match offset { true => { let offset_arg = args[num_args]; num_args += 1; Some(offset_arg) } false => None, }; // Now go back and look for optional bias arg (if available) if let TextureLevelType::None = level_type { level = args .get(num_args) .copied() .map_or(SampleLevel::Auto, SampleLevel::Bias); } texture_call(ctx, args[0], level, comps, texture_offset, meta)? } MacroCall::TextureSize { arrayed } => { let mut expr = ctx.add_expression( Expression::ImageQuery { image: args[0], query: ImageQuery::Size { level: args.get(1).copied(), }, }, Span::default(), )?; if arrayed { let mut components = Vec::with_capacity(4); let size = match *ctx.resolve_type(expr, meta)? { TypeInner::Vector { size: ori_size, .. } => { for index in 0..(ori_size as u32) { components.push(ctx.add_expression( Expression::AccessIndex { base: expr, index }, Span::default(), )?) } match ori_size { VectorSize::Bi => VectorSize::Tri, _ => VectorSize::Quad, } } _ => { components.push(expr); VectorSize::Bi } }; components.push(ctx.add_expression( Expression::ImageQuery { image: args[0], query: ImageQuery::NumLayers, }, Span::default(), )?); let ty = ctx.module.types.insert( Type { name: None, inner: TypeInner::Vector { size, scalar: Scalar::U32, }, }, Span::default(), ); expr = ctx.add_expression(Expression::Compose { components, ty }, meta)? } ctx.add_expression( Expression::As { expr, kind: Sk::Sint, convert: Some(4), }, Span::default(), )? } MacroCall::TextureQueryLevels => { let expr = ctx.add_expression( Expression::ImageQuery { image: args[0], query: ImageQuery::NumLevels, }, Span::default(), )?; ctx.add_expression( Expression::As { expr, kind: Sk::Sint, convert: Some(4), }, Span::default(), )? } MacroCall::ImageLoad { multi } => { let comps = frontend.coordinate_components(ctx, args[0], args[1], None, meta)?; let (sample, level) = match (multi, args.get(2)) { (_, None) => (None, None), (true, Some(&arg)) => (Some(arg), None), (false, Some(&arg)) => (None, Some(arg)), }; ctx.add_expression( Expression::ImageLoad { image: args[0], coordinate: comps.coordinate, array_index: comps.array_index, sample, level, }, Span::default(), )? } MacroCall::ImageStore => { let comps = frontend.coordinate_components(ctx, args[0], args[1], None, meta)?; ctx.emit_restart(); ctx.body.push( crate::Statement::ImageStore { image: args[0], coordinate: comps.coordinate, array_index: comps.array_index, value: args[2], }, meta, ); return Ok(None); } MacroCall::MathFunction(fun) => ctx.add_expression( Expression::Math { fun, arg: args[0], arg1: args.get(1).copied(), arg2: args.get(2).copied(), arg3: args.get(3).copied(), }, Span::default(), )?, mc @ (MacroCall::FindLsbUint | MacroCall::FindMsbUint) => { let fun = match mc { MacroCall::FindLsbUint => MathFunction::FirstTrailingBit, MacroCall::FindMsbUint => MathFunction::FirstLeadingBit, _ => unreachable!(), }; let res = ctx.add_expression( Expression::Math { fun, arg: args[0], arg1: None, arg2: None, arg3: None, }, Span::default(), )?; ctx.add_expression( Expression::As { expr: res, kind: Sk::Sint, convert: Some(4), }, Span::default(), )? } MacroCall::BitfieldInsert => { let conv_arg_2 = ctx.add_expression( Expression::As { expr: args[2], kind: Sk::Uint, convert: Some(4), }, Span::default(), )?; let conv_arg_3 = ctx.add_expression( Expression::As { expr: args[3], kind: Sk::Uint, convert: Some(4), }, Span::default(), )?; ctx.add_expression( Expression::Math { fun: MathFunction::InsertBits, arg: args[0], arg1: Some(args[1]), arg2: Some(conv_arg_2), arg3: Some(conv_arg_3), }, Span::default(), )? } MacroCall::BitfieldExtract => { let conv_arg_1 = ctx.add_expression( Expression::As { expr: args[1], kind: Sk::Uint, convert: Some(4), }, Span::default(), )?; let conv_arg_2 = ctx.add_expression( Expression::As { expr: args[2], kind: Sk::Uint, convert: Some(4), }, Span::default(), )?; ctx.add_expression( Expression::Math { fun: MathFunction::ExtractBits, arg: args[0], arg1: Some(conv_arg_1), arg2: Some(conv_arg_2), arg3: None, }, Span::default(), )? } MacroCall::Relational(fun) => ctx.add_expression( Expression::Relational { fun, argument: args[0], }, Span::default(), )?, MacroCall::Unary(op) => { ctx.add_expression(Expression::Unary { op, expr: args[0] }, Span::default())? } MacroCall::Binary(op) => ctx.add_expression( Expression::Binary { op, left: args[0], right: args[1], }, Span::default(), )?, MacroCall::Mod(size) => { ctx.implicit_splat(&mut args[1], meta, size)?; // x - y * floor(x / y) let div = ctx.add_expression( Expression::Binary { op: BinaryOperator::Divide, left: args[0], right: args[1], }, Span::default(), )?; let floor = ctx.add_expression( Expression::Math { fun: MathFunction::Floor, arg: div, arg1: None, arg2: None, arg3: None, }, Span::default(), )?; let mult = ctx.add_expression( Expression::Binary { op: BinaryOperator::Multiply, left: floor, right: args[1], }, Span::default(), )?; ctx.add_expression( Expression::Binary { op: BinaryOperator::Subtract, left: args[0], right: mult, }, Span::default(), )? } MacroCall::Splatted(fun, size, i) => { ctx.implicit_splat(&mut args[i], meta, size)?; ctx.add_expression( Expression::Math { fun, arg: args[0], arg1: args.get(1).copied(), arg2: args.get(2).copied(), arg3: args.get(3).copied(), }, Span::default(), )? } MacroCall::MixBoolean => ctx.add_expression( Expression::Select { condition: args[2], accept: args[1], reject: args[0], }, Span::default(), )?, MacroCall::Clamp(size) => { ctx.implicit_splat(&mut args[1], meta, size)?; ctx.implicit_splat(&mut args[2], meta, size)?; ctx.add_expression( Expression::Math { fun: MathFunction::Clamp, arg: args[0], arg1: args.get(1).copied(), arg2: args.get(2).copied(), arg3: args.get(3).copied(), }, Span::default(), )? } MacroCall::BitCast(kind) => ctx.add_expression( Expression::As { expr: args[0], kind, convert: None, }, Span::default(), )?, MacroCall::Derivate(axis, ctrl) => ctx.add_expression( Expression::Derivative { axis, ctrl, expr: args[0], }, Span::default(), )?, MacroCall::Barrier => { ctx.emit_restart(); ctx.body.push( crate::Statement::ControlBarrier(crate::Barrier::all()), meta, ); return Ok(None); } MacroCall::SmoothStep { splatted } => { ctx.implicit_splat(&mut args[0], meta, splatted)?; ctx.implicit_splat(&mut args[1], meta, splatted)?; ctx.add_expression( Expression::Math { fun: MathFunction::SmoothStep, arg: args[0], arg1: args.get(1).copied(), arg2: args.get(2).copied(), arg3: None, }, Span::default(), )? } })) } } fn texture_call( ctx: &mut Context, image: Handle, level: SampleLevel, comps: CoordComponents, offset: Option>, meta: Span, ) -> Result> { if let Some(sampler) = ctx.samplers.get(&image).copied() { let mut array_index = comps.array_index; if let Some(ref mut array_index_expr) = array_index { ctx.conversion(array_index_expr, meta, Scalar::I32)?; } Ok(ctx.add_expression( Expression::ImageSample { image, sampler, gather: None, //TODO coordinate: comps.coordinate, array_index, offset, level, depth_ref: comps.depth_ref, clamp_to_edge: false, }, meta, )?) } else { Err(Error { kind: ErrorKind::SemanticError("Bad call".into()), meta, }) } } /// Helper struct for texture calls with the separate components from the vector argument /// /// Obtained by calling [`coordinate_components`](Frontend::coordinate_components) #[derive(Debug)] struct CoordComponents { coordinate: Handle, depth_ref: Option>, array_index: Option>, used_extra: bool, } impl Frontend { /// Helper function for texture calls, splits the vector argument into it's components fn coordinate_components( &mut self, ctx: &mut Context, image: Handle, coord: Handle, extra: Option>, meta: Span, ) -> Result { if let TypeInner::Image { dim, arrayed, class, } = *ctx.resolve_type(image, meta)? { let image_size = match dim { Dim::D1 => None, Dim::D2 => Some(VectorSize::Bi), Dim::D3 => Some(VectorSize::Tri), Dim::Cube => Some(VectorSize::Tri), }; let coord_size = match *ctx.resolve_type(coord, meta)? { TypeInner::Vector { size, .. } => Some(size), _ => None, }; let (shadow, storage) = match class { ImageClass::Depth { .. } => (true, false), ImageClass::Storage { .. } => (false, true), ImageClass::Sampled { .. } => (false, false), ImageClass::External => unreachable!(), }; let coordinate = match (image_size, coord_size) { (Some(size), Some(coord_s)) if size != coord_s => { ctx.vector_resize(size, coord, Span::default())? } (None, Some(_)) => ctx.add_expression( Expression::AccessIndex { base: coord, index: 0, }, Span::default(), )?, _ => coord, }; let mut coord_index = image_size.map_or(1, |s| s as u32); let array_index = if arrayed && !(storage && dim == Dim::Cube) { let index = coord_index; coord_index += 1; Some(ctx.add_expression( Expression::AccessIndex { base: coord, index }, Span::default(), )?) } else { None }; let mut used_extra = false; let depth_ref = match shadow { true => { let index = coord_index; if index == 4 { used_extra = true; extra } else { Some(ctx.add_expression( Expression::AccessIndex { base: coord, index }, Span::default(), )?) } } false => None, }; Ok(CoordComponents { coordinate, depth_ref, array_index, used_extra, }) } else { self.errors.push(Error { kind: ErrorKind::SemanticError("Type is not an image".into()), meta, }); Ok(CoordComponents { coordinate: coord, depth_ref: None, array_index: None, used_extra: false, }) } } } /// Helper function to cast a expression holding a sampled image to a /// depth image. pub fn sampled_to_depth( ctx: &mut Context, image: Handle, meta: Span, errors: &mut Vec, ) { // Get the a mutable type handle of the underlying image storage let ty = match ctx[image] { Expression::GlobalVariable(handle) => &mut ctx.module.global_variables.get_mut(handle).ty, Expression::FunctionArgument(i) => { // Mark the function argument as carrying a depth texture ctx.parameters_info[i as usize].depth = true; // NOTE: We need to later also change the parameter type &mut ctx.arguments[i as usize].ty } _ => { // Only globals and function arguments are allowed to carry an image return errors.push(Error { kind: ErrorKind::SemanticError("Not a valid texture expression".into()), meta, }); } }; match ctx.module.types[*ty].inner { // Update the image class to depth in case it already isn't TypeInner::Image { class, dim, arrayed, } => match class { ImageClass::Sampled { multi, .. } => { *ty = ctx.module.types.insert( Type { name: None, inner: TypeInner::Image { dim, arrayed, class: ImageClass::Depth { multi }, }, }, Span::default(), ) } ImageClass::Depth { .. } => {} // Other image classes aren't allowed to be transformed to depth ImageClass::Storage { .. } => errors.push(Error { kind: ErrorKind::SemanticError("Not a texture".into()), meta, }), ImageClass::External => unreachable!(), }, _ => errors.push(Error { kind: ErrorKind::SemanticError("Not a texture".into()), meta, }), }; // Copy the handle to allow borrowing the `ctx` again let ty = *ty; // If the image was passed through a function argument we also need to change // the corresponding parameter if let Expression::FunctionArgument(i) = ctx[image] { ctx.parameters[i as usize] = ty; } } bitflags::bitflags! { /// Influences the operation [`texture_args_generator`] struct TextureArgsOptions: u32 { /// Generates multisampled variants of images const MULTI = 1 << 0; /// Generates shadow variants of images const SHADOW = 1 << 1; /// Generates standard images const STANDARD = 1 << 2; /// Generates cube arrayed images const CUBE_ARRAY = 1 << 3; /// Generates cube arrayed images const D2_MULTI_ARRAY = 1 << 4; } } impl From for TextureArgsOptions { fn from(variations: BuiltinVariations) -> Self { let mut options = TextureArgsOptions::empty(); if variations.contains(BuiltinVariations::STANDARD) { options |= TextureArgsOptions::STANDARD } if variations.contains(BuiltinVariations::CUBE_TEXTURES_ARRAY) { options |= TextureArgsOptions::CUBE_ARRAY } if variations.contains(BuiltinVariations::D2_MULTI_TEXTURES_ARRAY) { options |= TextureArgsOptions::D2_MULTI_ARRAY } options } } /// Helper function to generate the image components for texture/image builtins /// /// Calls the passed function `f` with: /// ```text /// f(ScalarKind, ImageDimension, arrayed, multi, shadow) /// ``` /// /// `options` controls extra image variants generation like multisampling and depth, /// see the struct documentation fn texture_args_generator( options: TextureArgsOptions, mut f: impl FnMut(crate::ScalarKind, Dim, bool, bool, bool), ) { for kind in [Sk::Float, Sk::Uint, Sk::Sint].iter().copied() { for dim in [Dim::D1, Dim::D2, Dim::D3, Dim::Cube].iter().copied() { for arrayed in [false, true].iter().copied() { if dim == Dim::Cube && arrayed { if !options.contains(TextureArgsOptions::CUBE_ARRAY) { continue; } } else if Dim::D2 == dim && options.contains(TextureArgsOptions::MULTI) && arrayed && options.contains(TextureArgsOptions::D2_MULTI_ARRAY) { // multisampling for sampler2DMSArray f(kind, dim, arrayed, true, false); } else if !options.contains(TextureArgsOptions::STANDARD) { continue; } f(kind, dim, arrayed, false, false); // 3D images can't be neither arrayed nor shadow // so we break out early, this way arrayed will always // be false and we won't hit the shadow branch if let Dim::D3 = dim { break; } if Dim::D2 == dim && options.contains(TextureArgsOptions::MULTI) && !arrayed { // multisampling f(kind, dim, arrayed, true, false); } if Sk::Float == kind && options.contains(TextureArgsOptions::SHADOW) { // shadow f(kind, dim, arrayed, false, true); } } } } } /// Helper functions used to convert from a image dimension into a integer representing the /// number of components needed for the coordinates vector (1 means scalar instead of vector) const fn image_dims_to_coords_size(dim: Dim) -> usize { match dim { Dim::D1 => 1, Dim::D2 => 2, _ => 3, } } naga-29.0.3/src/front/glsl/context.rs000064400000000000000000001672121046102023000155500ustar 00000000000000use alloc::{format, string::String, vec::Vec}; use core::ops::Index; use super::{ ast::{ GlobalLookup, GlobalLookupKind, HirExpr, HirExprKind, ParameterInfo, ParameterQualifier, VariableReference, }, error::{Error, ErrorKind}, types::{scalar_components, type_power}, Frontend, Result, }; use crate::{ front::Typifier, proc::Emitter, proc::Layouter, AddressSpace, Arena, BinaryOperator, Block, Expression, FastHashMap, FunctionArgument, Handle, Literal, LocalVariable, RelationalFunction, Scalar, Span, Statement, Type, TypeInner, VectorSize, }; /// The position at which an expression is, used while lowering #[derive(Clone, Copy, PartialEq, Eq, Debug)] pub enum ExprPos { /// The expression is in the left hand side of an assignment Lhs, /// The expression is in the right hand side of an assignment Rhs, /// The expression is an array being indexed, needed to allow constant /// arrays to be dynamically indexed AccessBase { /// The index is a constant constant_index: bool, }, } impl ExprPos { /// Returns an lhs position if the current position is lhs otherwise AccessBase const fn maybe_access_base(&self, constant_index: bool) -> Self { match *self { ExprPos::Lhs | ExprPos::AccessBase { constant_index: false, } => *self, _ => ExprPos::AccessBase { constant_index }, } } } #[derive(Debug)] pub(crate) struct Context<'a> { pub expressions: Arena, pub locals: Arena, /// The [`FunctionArgument`]s for the final [`crate::Function`]. /// /// Parameters with the `out` and `inout` qualifiers have [`Pointer`] types /// here. For example, an `inout vec2 a` argument would be a [`Pointer`] to /// a [`Vector`]. /// /// [`Pointer`]: crate::TypeInner::Pointer /// [`Vector`]: crate::TypeInner::Vector pub arguments: Vec, /// The parameter types given in the source code. /// /// The `out` and `inout` qualifiers don't affect the types that appear /// here. For example, an `inout vec2 a` argument would simply be a /// [`Vector`], not a pointer to one. /// /// [`Vector`]: crate::TypeInner::Vector pub parameters: Vec>, pub parameters_info: Vec, pub symbol_table: crate::front::SymbolTable, pub samplers: FastHashMap, Handle>, pub const_typifier: Typifier, pub typifier: Typifier, layouter: Layouter, emitter: Emitter, stmt_ctx: Option, pub body: Block, pub module: &'a mut crate::Module, pub is_const: bool, /// Tracks the expression kind of `Expression`s residing in `self.expressions` pub local_expression_kind_tracker: crate::proc::ExpressionKindTracker, /// Tracks the expression kind of `Expression`s residing in `self.module.global_expressions` pub global_expression_kind_tracker: &'a mut crate::proc::ExpressionKindTracker, } impl<'a> Context<'a> { pub fn new( frontend: &Frontend, module: &'a mut crate::Module, is_const: bool, global_expression_kind_tracker: &'a mut crate::proc::ExpressionKindTracker, ) -> Result { let mut this = Context { expressions: Arena::new(), locals: Arena::new(), arguments: Vec::new(), parameters: Vec::new(), parameters_info: Vec::new(), symbol_table: crate::front::SymbolTable::default(), samplers: FastHashMap::default(), const_typifier: Typifier::new(), typifier: Typifier::new(), layouter: Layouter::default(), emitter: Emitter::default(), stmt_ctx: Some(StmtContext::new()), body: Block::new(), module, is_const: false, local_expression_kind_tracker: crate::proc::ExpressionKindTracker::new(), global_expression_kind_tracker, }; this.emit_start(); for &(ref name, lookup) in frontend.global_variables.iter() { this.add_global(name, lookup)? } this.is_const = is_const; Ok(this) } pub fn new_body(&mut self, cb: F) -> Result where F: FnOnce(&mut Self) -> Result<()>, { self.new_body_with_ret(cb).map(|(b, _)| b) } pub fn new_body_with_ret(&mut self, cb: F) -> Result<(Block, R)> where F: FnOnce(&mut Self) -> Result, { self.emit_restart(); let old_body = core::mem::replace(&mut self.body, Block::new()); let res = cb(self); self.emit_restart(); let new_body = core::mem::replace(&mut self.body, old_body); res.map(|r| (new_body, r)) } pub fn with_body(&mut self, body: Block, cb: F) -> Result where F: FnOnce(&mut Self) -> Result<()>, { self.emit_restart(); let old_body = core::mem::replace(&mut self.body, body); let res = cb(self); self.emit_restart(); let body = core::mem::replace(&mut self.body, old_body); res.map(|_| body) } pub fn add_global( &mut self, name: &str, GlobalLookup { kind, entry_arg, mutable, }: GlobalLookup, ) -> Result<()> { let (expr, load, constant) = match kind { GlobalLookupKind::Variable(v) => { let span = self.module.global_variables.get_span(v); ( self.add_expression(Expression::GlobalVariable(v), span)?, self.module.global_variables[v].space != AddressSpace::Handle, None, ) } GlobalLookupKind::BlockSelect(handle, index) => { let span = self.module.global_variables.get_span(handle); let base = self.add_expression(Expression::GlobalVariable(handle), span)?; let expr = self.add_expression(Expression::AccessIndex { base, index }, span)?; ( expr, { let ty = self.module.global_variables[handle].ty; match self.module.types[ty].inner { TypeInner::Struct { ref members, .. } => { if let TypeInner::Array { size: crate::ArraySize::Dynamic, .. } = self.module.types[members[index as usize].ty].inner { false } else { true } } _ => true, } }, None, ) } GlobalLookupKind::Constant(v, ty) => { let span = self.module.constants.get_span(v); ( self.add_expression(Expression::Constant(v), span)?, false, Some((v, ty)), ) } GlobalLookupKind::Override(v, _ty) => { let span = self.module.overrides.get_span(v); ( self.add_expression(Expression::Override(v), span)?, false, None, ) } }; let var = VariableReference { expr, load, mutable, constant, entry_arg, }; self.symbol_table.add(name.into(), var); Ok(()) } /// Starts the expression emitter /// /// # Panics /// /// - If called twice in a row without calling [`emit_end`][Self::emit_end]. #[inline] pub fn emit_start(&mut self) { self.emitter.start(&self.expressions) } /// Emits all the expressions captured by the emitter to the current body /// /// # Panics /// /// - If called before calling [`emit_start`]. /// - If called twice in a row without calling [`emit_start`]. /// /// [`emit_start`]: Self::emit_start pub fn emit_end(&mut self) { self.body.extend(self.emitter.finish(&self.expressions)) } /// Emits all the expressions captured by the emitter to the current body /// and starts the emitter again /// /// # Panics /// /// - If called before calling [`emit_start`][Self::emit_start]. pub fn emit_restart(&mut self) { self.emit_end(); self.emit_start() } pub fn add_expression(&mut self, expr: Expression, meta: Span) -> Result> { let mut eval = if self.is_const { crate::proc::ConstantEvaluator::for_glsl_module( self.module, self.global_expression_kind_tracker, &mut self.layouter, ) } else { crate::proc::ConstantEvaluator::for_glsl_function( self.module, &mut self.expressions, &mut self.local_expression_kind_tracker, &mut self.layouter, &mut self.emitter, &mut self.body, ) }; eval.try_eval_and_append(expr, meta).map_err(|e| Error { kind: e.into(), meta, }) } /// Add variable to current scope /// /// Returns a variable if a variable with the same name was already defined, /// otherwise returns `None` pub fn add_local_var( &mut self, name: String, expr: Handle, mutable: bool, ) -> Option { let var = VariableReference { expr, load: true, mutable, constant: None, entry_arg: None, }; self.symbol_table.add(name, var) } /// Add function argument to current scope pub fn add_function_arg( &mut self, name_meta: Option<(String, Span)>, ty: Handle, qualifier: ParameterQualifier, ) -> Result<()> { let index = self.arguments.len(); let mut arg = FunctionArgument { name: name_meta.as_ref().map(|&(ref name, _)| name.clone()), ty, binding: None, }; self.parameters.push(ty); let opaque = match self.module.types[ty].inner { TypeInner::Image { .. } | TypeInner::Sampler { .. } => true, _ => false, }; if qualifier.is_lhs() { let span = self.module.types.get_span(arg.ty); arg.ty = self.module.types.insert( Type { name: None, inner: TypeInner::Pointer { base: arg.ty, space: AddressSpace::Function, }, }, span, ) } self.arguments.push(arg); self.parameters_info.push(ParameterInfo { qualifier, depth: false, }); if let Some((name, meta)) = name_meta { let expr = self.add_expression(Expression::FunctionArgument(index as u32), meta)?; let mutable = qualifier != ParameterQualifier::Const && !opaque; let load = qualifier.is_lhs(); let var = if mutable && !load { let handle = self.locals.append( LocalVariable { name: Some(name.clone()), ty, init: None, }, meta, ); let local_expr = self.add_expression(Expression::LocalVariable(handle), meta)?; self.emit_restart(); self.body.push( Statement::Store { pointer: local_expr, value: expr, }, meta, ); VariableReference { expr: local_expr, load: true, mutable, constant: None, entry_arg: None, } } else { VariableReference { expr, load, mutable, constant: None, entry_arg: None, } }; self.symbol_table.add(name, var); } Ok(()) } /// Returns a [`StmtContext`] to be used in parsing and lowering /// /// # Panics /// /// - If more than one [`StmtContext`] are active at the same time or if the /// previous call didn't use it in lowering. #[must_use] pub const fn stmt_ctx(&mut self) -> StmtContext { self.stmt_ctx.take().unwrap() } /// Lowers a [`HirExpr`] which might produce a [`Expression`]. /// /// consumes a [`StmtContext`] returning it to the context so that it can be /// used again later. pub fn lower( &mut self, mut stmt: StmtContext, frontend: &mut Frontend, expr: Handle, pos: ExprPos, ) -> Result<(Option>, Span)> { let res = self.lower_inner(&stmt, frontend, expr, pos); stmt.hir_exprs.clear(); self.stmt_ctx = Some(stmt); res } /// Similar to [`lower`](Self::lower) but returns an error if the expression /// returns void (ie. doesn't produce a [`Expression`]). /// /// consumes a [`StmtContext`] returning it to the context so that it can be /// used again later. pub fn lower_expect( &mut self, mut stmt: StmtContext, frontend: &mut Frontend, expr: Handle, pos: ExprPos, ) -> Result<(Handle, Span)> { let res = self.lower_expect_inner(&stmt, frontend, expr, pos); stmt.hir_exprs.clear(); self.stmt_ctx = Some(stmt); res } /// internal implementation of [`lower_expect`](Self::lower_expect) /// /// this method is only public because it's used in /// [`function_call`](Frontend::function_call), unless you know what /// you're doing use [`lower_expect`](Self::lower_expect) pub fn lower_expect_inner( &mut self, stmt: &StmtContext, frontend: &mut Frontend, expr: Handle, pos: ExprPos, ) -> Result<(Handle, Span)> { let (maybe_expr, meta) = self.lower_inner(stmt, frontend, expr, pos)?; let expr = match maybe_expr { Some(e) => e, None => { return Err(Error { kind: ErrorKind::SemanticError("Expression returns void".into()), meta, }) } }; Ok((expr, meta)) } fn lower_store( &mut self, pointer: Handle, value: Handle, meta: Span, ) -> Result<()> { if let Expression::Swizzle { size, mut vector, pattern, } = self.expressions[pointer] { // Stores to swizzled values are not directly supported, // lower them as series of per-component stores. let size = match size { VectorSize::Bi => 2, VectorSize::Tri => 3, VectorSize::Quad => 4, }; if let Expression::Load { pointer } = self.expressions[vector] { vector = pointer; } #[allow(clippy::needless_range_loop)] for index in 0..size { let dst = self.add_expression( Expression::AccessIndex { base: vector, index: pattern[index].index(), }, meta, )?; let src = self.add_expression( Expression::AccessIndex { base: value, index: index as u32, }, meta, )?; self.emit_restart(); self.body.push( Statement::Store { pointer: dst, value: src, }, meta, ); } } else { self.emit_restart(); self.body.push(Statement::Store { pointer, value }, meta); } Ok(()) } /// Internal implementation of [`lower`](Self::lower) fn lower_inner( &mut self, stmt: &StmtContext, frontend: &mut Frontend, expr: Handle, pos: ExprPos, ) -> Result<(Option>, Span)> { let HirExpr { ref kind, meta } = stmt.hir_exprs[expr]; log::debug!("Lowering {expr:?} (kind {kind:?}, pos {pos:?})"); let handle = match *kind { HirExprKind::Access { base, index } => { let (index, _) = self.lower_expect_inner(stmt, frontend, index, ExprPos::Rhs)?; let maybe_constant_index = match pos { // Don't try to generate `AccessIndex` if in a LHS position, since it // wouldn't produce a pointer. ExprPos::Lhs => None, _ => self .module .to_ctx() .get_const_val_from(index, &self.expressions) .ok(), }; let base = self .lower_expect_inner( stmt, frontend, base, pos.maybe_access_base(maybe_constant_index.is_some()), )? .0; let pointer = maybe_constant_index .map(|index| self.add_expression(Expression::AccessIndex { base, index }, meta)) .unwrap_or_else(|| { self.add_expression(Expression::Access { base, index }, meta) })?; if ExprPos::Rhs == pos { let resolved = self.resolve_type(pointer, meta)?; if resolved.pointer_space().is_some() { return Ok(( Some(self.add_expression(Expression::Load { pointer }, meta)?), meta, )); } } pointer } HirExprKind::Select { base, ref field } => { let base = self.lower_expect_inner(stmt, frontend, base, pos)?.0; frontend.field_selection(self, pos, base, field, meta)? } HirExprKind::Literal(literal) if pos != ExprPos::Lhs => { self.add_expression(Expression::Literal(literal), meta)? } HirExprKind::Binary { left, op, right } if pos != ExprPos::Lhs => { let (mut left, left_meta) = self.lower_expect_inner(stmt, frontend, left, ExprPos::Rhs)?; let (mut right, right_meta) = self.lower_expect_inner(stmt, frontend, right, ExprPos::Rhs)?; match op { BinaryOperator::ShiftLeft | BinaryOperator::ShiftRight => { self.implicit_conversion(&mut right, right_meta, Scalar::U32)? } _ => self .binary_implicit_conversion(&mut left, left_meta, &mut right, right_meta)?, } self.typifier_grow(left, left_meta)?; self.typifier_grow(right, right_meta)?; let left_inner = self.get_type(left); let right_inner = self.get_type(right); match (left_inner, right_inner) { ( &TypeInner::Matrix { columns: left_columns, rows: left_rows, scalar: left_scalar, }, &TypeInner::Matrix { columns: right_columns, rows: right_rows, scalar: right_scalar, }, ) => { let dimensions_ok = if op == BinaryOperator::Multiply { left_columns == right_rows } else { left_columns == right_columns && left_rows == right_rows }; // Check that the two arguments have the same dimensions if !dimensions_ok || left_scalar != right_scalar { frontend.errors.push(Error { kind: ErrorKind::SemanticError( format!( "Cannot apply operation to {left_inner:?} and {right_inner:?}" ) .into(), ), meta, }) } match op { BinaryOperator::Divide => { // Naga IR doesn't support matrix division so we need to // divide the columns individually and reassemble the matrix let mut components = Vec::with_capacity(left_columns as usize); for index in 0..left_columns as u32 { // Get the column vectors let left_vector = self.add_expression( Expression::AccessIndex { base: left, index }, meta, )?; let right_vector = self.add_expression( Expression::AccessIndex { base: right, index }, meta, )?; // Divide the vectors let column = self.add_expression( Expression::Binary { op, left: left_vector, right: right_vector, }, meta, )?; components.push(column) } let ty = self.module.types.insert( Type { name: None, inner: TypeInner::Matrix { columns: left_columns, rows: left_rows, scalar: left_scalar, }, }, Span::default(), ); // Rebuild the matrix from the divided vectors self.add_expression(Expression::Compose { ty, components }, meta)? } BinaryOperator::Equal | BinaryOperator::NotEqual => { // Naga IR doesn't support matrix comparisons so we need to // compare the columns individually and then fold them together // // The folding is done using a logical and for equality and // a logical or for inequality let equals = op == BinaryOperator::Equal; let (op, combine, fun) = match equals { true => ( BinaryOperator::Equal, BinaryOperator::LogicalAnd, RelationalFunction::All, ), false => ( BinaryOperator::NotEqual, BinaryOperator::LogicalOr, RelationalFunction::Any, ), }; let mut root = None; for index in 0..left_columns as u32 { // Get the column vectors let left_vector = self.add_expression( Expression::AccessIndex { base: left, index }, meta, )?; let right_vector = self.add_expression( Expression::AccessIndex { base: right, index }, meta, )?; let argument = self.add_expression( Expression::Binary { op, left: left_vector, right: right_vector, }, meta, )?; // The result of comparing two vectors is a boolean vector // so use a relational function like all to get a single // boolean value let compare = self.add_expression( Expression::Relational { fun, argument }, meta, )?; // Fold the result root = Some(match root { Some(right) => self.add_expression( Expression::Binary { op: combine, left: compare, right, }, meta, )?, None => compare, }); } root.unwrap() } _ => { self.add_expression(Expression::Binary { left, op, right }, meta)? } } } (&TypeInner::Vector { .. }, &TypeInner::Vector { .. }) => match op { BinaryOperator::Equal | BinaryOperator::NotEqual => { let equals = op == BinaryOperator::Equal; let (op, fun) = match equals { true => (BinaryOperator::Equal, RelationalFunction::All), false => (BinaryOperator::NotEqual, RelationalFunction::Any), }; let argument = self.add_expression(Expression::Binary { op, left, right }, meta)?; self.add_expression(Expression::Relational { fun, argument }, meta)? } _ => self.add_expression(Expression::Binary { left, op, right }, meta)?, }, (&TypeInner::Vector { size, .. }, &TypeInner::Scalar { .. }) => match op { BinaryOperator::Add | BinaryOperator::Subtract | BinaryOperator::Divide | BinaryOperator::And | BinaryOperator::ExclusiveOr | BinaryOperator::InclusiveOr | BinaryOperator::ShiftLeft | BinaryOperator::ShiftRight => { let scalar_vector = self .add_expression(Expression::Splat { size, value: right }, meta)?; self.add_expression( Expression::Binary { op, left, right: scalar_vector, }, meta, )? } _ => self.add_expression(Expression::Binary { left, op, right }, meta)?, }, (&TypeInner::Scalar { .. }, &TypeInner::Vector { size, .. }) => match op { BinaryOperator::Add | BinaryOperator::Subtract | BinaryOperator::Divide | BinaryOperator::And | BinaryOperator::ExclusiveOr | BinaryOperator::InclusiveOr => { let scalar_vector = self.add_expression(Expression::Splat { size, value: left }, meta)?; self.add_expression( Expression::Binary { op, left: scalar_vector, right, }, meta, )? } _ => self.add_expression(Expression::Binary { left, op, right }, meta)?, }, ( &TypeInner::Scalar(left_scalar), &TypeInner::Matrix { rows, columns, scalar: right_scalar, }, ) => { // Check that the two arguments have the same scalar type if left_scalar != right_scalar { frontend.errors.push(Error { kind: ErrorKind::SemanticError( format!( "Cannot apply operation to {left_inner:?} and {right_inner:?}" ) .into(), ), meta, }) } match op { BinaryOperator::Divide | BinaryOperator::Add | BinaryOperator::Subtract => { // Naga IR doesn't support all matrix by scalar operations so // we need for some to turn the scalar into a vector by // splatting it and then for each column vector apply the // operation and finally reconstruct the matrix let scalar_vector = self.add_expression( Expression::Splat { size: rows, value: left, }, meta, )?; let mut components = Vec::with_capacity(columns as usize); for index in 0..columns as u32 { // Get the column vector let matrix_column = self.add_expression( Expression::AccessIndex { base: right, index }, meta, )?; // Apply the operation to the splatted vector and // the column vector let column = self.add_expression( Expression::Binary { op, left: scalar_vector, right: matrix_column, }, meta, )?; components.push(column) } let ty = self.module.types.insert( Type { name: None, inner: TypeInner::Matrix { columns, rows, scalar: left_scalar, }, }, Span::default(), ); // Rebuild the matrix from the operation result vectors self.add_expression(Expression::Compose { ty, components }, meta)? } _ => { self.add_expression(Expression::Binary { left, op, right }, meta)? } } } ( &TypeInner::Matrix { rows, columns, scalar: left_scalar, }, &TypeInner::Scalar(right_scalar), ) => { // Check that the two arguments have the same scalar type if left_scalar != right_scalar { frontend.errors.push(Error { kind: ErrorKind::SemanticError( format!( "Cannot apply operation to {left_inner:?} and {right_inner:?}" ) .into(), ), meta, }) } match op { BinaryOperator::Divide | BinaryOperator::Add | BinaryOperator::Subtract => { // Naga IR doesn't support all matrix by scalar operations so // we need for some to turn the scalar into a vector by // splatting it and then for each column vector apply the // operation and finally reconstruct the matrix let scalar_vector = self.add_expression( Expression::Splat { size: rows, value: right, }, meta, )?; let mut components = Vec::with_capacity(columns as usize); for index in 0..columns as u32 { // Get the column vector let matrix_column = self.add_expression( Expression::AccessIndex { base: left, index }, meta, )?; // Apply the operation to the splatted vector and // the column vector let column = self.add_expression( Expression::Binary { op, left: matrix_column, right: scalar_vector, }, meta, )?; components.push(column) } let ty = self.module.types.insert( Type { name: None, inner: TypeInner::Matrix { columns, rows, scalar: left_scalar, }, }, Span::default(), ); // Rebuild the matrix from the operation result vectors self.add_expression(Expression::Compose { ty, components }, meta)? } _ => { self.add_expression(Expression::Binary { left, op, right }, meta)? } } } _ => self.add_expression(Expression::Binary { left, op, right }, meta)?, } } HirExprKind::Unary { op, expr } if pos != ExprPos::Lhs => { let expr = self .lower_expect_inner(stmt, frontend, expr, ExprPos::Rhs)? .0; if let TypeInner::Matrix { scalar, .. } = *self.resolve_type(expr, meta)? { // Naga IR doesn't support matrix negation, so we need to turn it into // multiplication by scalar -1. let minus_one = Literal::minus_one(scalar).ok_or_else(|| Error { kind: ErrorKind::SemanticError( format!("Cannot apply operator {op:?} to type {scalar:?}").into(), ), meta, })?; let lhs = self.add_expression(Expression::Literal(minus_one), meta)?; self.add_expression( Expression::Binary { op: BinaryOperator::Multiply, left: lhs, right: expr, }, meta, )? } else { self.add_expression(Expression::Unary { op, expr }, meta)? } } HirExprKind::Variable(ref var) => match pos { ExprPos::Lhs => { if !var.mutable { frontend.errors.push(Error { kind: ErrorKind::SemanticError( "Variable cannot be used in LHS position".into(), ), meta, }) } var.expr } ExprPos::AccessBase { constant_index } => { // If the index isn't constant all accesses backed by a constant base need // to be done through a proxy local variable, since constants have a non // pointer type which is required for dynamic indexing if !constant_index { if let Some((constant, ty)) = var.constant { let init = self .add_expression(Expression::Constant(constant), Span::default())?; let local = self.locals.append( LocalVariable { name: None, ty, init: Some(init), }, Span::default(), ); self.add_expression(Expression::LocalVariable(local), Span::default())? } else { var.expr } } else { var.expr } } _ if var.load => { self.add_expression(Expression::Load { pointer: var.expr }, meta)? } ExprPos::Rhs => { if let Some((constant, _)) = self.is_const.then_some(var.constant).flatten() { self.add_expression(Expression::Constant(constant), meta)? } else { // Check if this is an Override expression in const context if self.is_const { if let Expression::Override(o) = self.expressions[var.expr] { // Need to add the Override expression to the global arena self.add_expression(Expression::Override(o), meta)? } else { var.expr } } else { var.expr } } } }, HirExprKind::Call(ref call) if pos != ExprPos::Lhs => { let maybe_expr = frontend.function_or_constructor_call( self, stmt, call.kind.clone(), &call.args, meta, )?; return Ok((maybe_expr, meta)); } // `HirExprKind::Conditional` represents the ternary operator in glsl (`:?`) // // The ternary operator is defined to only evaluate one of the two possible // expressions which means that it's behavior is that of an `if` statement, // and it's merely syntactic sugar for it. HirExprKind::Conditional { condition, accept, reject, } if ExprPos::Lhs != pos => { // Given an expression `a ? b : c`, we need to produce a Naga // statement roughly like: // // var temp; // if a { // temp = convert(b); // } else { // temp = convert(c); // } // // where `convert` stands for type conversions to bring `b` and `c` to // the same type, and then use `temp` to represent the value of the whole // conditional expression in subsequent code. // Lower the condition first to the current bodyy let condition = self .lower_expect_inner(stmt, frontend, condition, ExprPos::Rhs)? .0; let (mut accept_body, (mut accept, accept_meta)) = self.new_body_with_ret(|ctx| { // Lower the `true` branch ctx.lower_expect_inner(stmt, frontend, accept, pos) })?; let (mut reject_body, (mut reject, reject_meta)) = self.new_body_with_ret(|ctx| { // Lower the `false` branch ctx.lower_expect_inner(stmt, frontend, reject, pos) })?; // We need to do some custom implicit conversions since the two target expressions // are in different bodies if let (Some((accept_power, accept_scalar)), Some((reject_power, reject_scalar))) = ( // Get the components of both branches and calculate the type power self.expr_scalar_components(accept, accept_meta)? .and_then(|scalar| Some((type_power(scalar)?, scalar))), self.expr_scalar_components(reject, reject_meta)? .and_then(|scalar| Some((type_power(scalar)?, scalar))), ) { match accept_power.cmp(&reject_power) { core::cmp::Ordering::Less => { accept_body = self.with_body(accept_body, |ctx| { ctx.conversion(&mut accept, accept_meta, reject_scalar)?; Ok(()) })?; } core::cmp::Ordering::Equal => {} core::cmp::Ordering::Greater => { reject_body = self.with_body(reject_body, |ctx| { ctx.conversion(&mut reject, reject_meta, accept_scalar)?; Ok(()) })?; } } } // We need to get the type of the resulting expression to create the local, // this must be done after implicit conversions to ensure both branches have // the same type. let ty = self.resolve_type_handle(accept, accept_meta)?; // Add the local that will hold the result of our conditional let local = self.locals.append( LocalVariable { name: None, ty, init: None, }, meta, ); let local_expr = self.add_expression(Expression::LocalVariable(local), meta)?; // Add to each the store to the result variable accept_body.push( Statement::Store { pointer: local_expr, value: accept, }, accept_meta, ); reject_body.push( Statement::Store { pointer: local_expr, value: reject, }, reject_meta, ); // Finally add the `If` to the main body with the `condition` we lowered // earlier and the branches we prepared. self.body.push( Statement::If { condition, accept: accept_body, reject: reject_body, }, meta, ); // Note: `Expression::Load` must be emitted before it's used so make // sure the emitter is active here. self.add_expression( Expression::Load { pointer: local_expr, }, meta, )? } HirExprKind::Assign { tgt, value } if ExprPos::Lhs != pos => { let (pointer, ptr_meta) = self.lower_expect_inner(stmt, frontend, tgt, ExprPos::Lhs)?; let (mut value, value_meta) = self.lower_expect_inner(stmt, frontend, value, ExprPos::Rhs)?; let ty = match *self.resolve_type(pointer, ptr_meta)? { TypeInner::Pointer { base, .. } => &self.module.types[base].inner, ref ty => ty, }; if let Some(scalar) = scalar_components(ty) { self.implicit_conversion(&mut value, value_meta, scalar)?; } self.lower_store(pointer, value, meta)?; value } HirExprKind::PrePostfix { op, postfix, expr } if ExprPos::Lhs != pos => { let (pointer, _) = self.lower_expect_inner(stmt, frontend, expr, ExprPos::Lhs)?; let left = if let Expression::Swizzle { .. } = self.expressions[pointer] { pointer } else { self.add_expression(Expression::Load { pointer }, meta)? }; let res = match *self.resolve_type(left, meta)? { TypeInner::Scalar(scalar) => { let ty = TypeInner::Scalar(scalar); Literal::one(scalar).map(|i| (ty, i, None, None)) } TypeInner::Vector { size, scalar } => { let ty = TypeInner::Vector { size, scalar }; Literal::one(scalar).map(|i| (ty, i, Some(size), None)) } TypeInner::Matrix { columns, rows, scalar, } => { let ty = TypeInner::Matrix { columns, rows, scalar, }; Literal::one(scalar).map(|i| (ty, i, Some(rows), Some(columns))) } _ => None, }; let (ty_inner, literal, rows, columns) = match res { Some(res) => res, None => { frontend.errors.push(Error { kind: ErrorKind::SemanticError( "Increment/decrement only works on scalar/vector/matrix".into(), ), meta, }); return Ok((Some(left), meta)); } }; let mut right = self.add_expression(Expression::Literal(literal), meta)?; // Glsl allows pre/postfixes operations on vectors and matrices, so if the // target is either of them change the right side of the addition to be splatted // to the same size as the target, furthermore if the target is a matrix // use a composed matrix using the splatted value. if let Some(size) = rows { right = self.add_expression(Expression::Splat { size, value: right }, meta)?; if let Some(cols) = columns { let ty = self.module.types.insert( Type { name: None, inner: ty_inner, }, meta, ); right = self.add_expression( Expression::Compose { ty, components: core::iter::repeat_n(right, cols as usize).collect(), }, meta, )?; } } let value = self.add_expression(Expression::Binary { op, left, right }, meta)?; self.lower_store(pointer, value, meta)?; if postfix { left } else { value } } HirExprKind::Method { expr: object, ref name, ref args, } if ExprPos::Lhs != pos => { let args = args .iter() .map(|e| self.lower_expect_inner(stmt, frontend, *e, ExprPos::Rhs)) .collect::>>()?; match name.as_ref() { "length" => { if !args.is_empty() { frontend.errors.push(Error { kind: ErrorKind::SemanticError( ".length() doesn't take any arguments".into(), ), meta, }); } let lowered_array = self.lower_expect_inner(stmt, frontend, object, pos)?.0; let array_type = self.resolve_type(lowered_array, meta)?; match *array_type { TypeInner::Array { size: crate::ArraySize::Constant(size), .. } => { let mut array_length = self.add_expression( Expression::Literal(Literal::U32(size.get())), meta, )?; self.forced_conversion(&mut array_length, meta, Scalar::I32)?; array_length } // let the error be handled in type checking if it's not a dynamic array _ => { let mut array_length = self .add_expression(Expression::ArrayLength(lowered_array), meta)?; self.conversion(&mut array_length, meta, Scalar::I32)?; array_length } } } _ => { return Err(Error { kind: ErrorKind::SemanticError( format!("unknown method '{name}'").into(), ), meta, }); } } } HirExprKind::Sequence { ref exprs } if pos != ExprPos::Lhs => { let mut last_handle = None; for expr in exprs.iter() { let (handle, _) = self.lower_expect_inner(stmt, frontend, *expr, ExprPos::Rhs)?; last_handle = Some(handle); } match last_handle { Some(handle) => handle, None => unreachable!(), } } _ => { return Err(Error { kind: ErrorKind::SemanticError( format!("{:?} cannot be in the left hand side", stmt.hir_exprs[expr]) .into(), ), meta, }) } }; log::trace!("Lowered {expr:?}\n\tKind = {kind:?}\n\tPos = {pos:?}\n\tResult = {handle:?}"); Ok((Some(handle), meta)) } pub fn expr_scalar_components( &mut self, expr: Handle, meta: Span, ) -> Result> { let ty = self.resolve_type(expr, meta)?; Ok(scalar_components(ty)) } pub fn expr_power(&mut self, expr: Handle, meta: Span) -> Result> { Ok(self .expr_scalar_components(expr, meta)? .and_then(type_power)) } pub fn conversion( &mut self, expr: &mut Handle, meta: Span, scalar: Scalar, ) -> Result<()> { *expr = self.add_expression( Expression::As { expr: *expr, kind: scalar.kind, convert: Some(scalar.width), }, meta, )?; Ok(()) } pub fn implicit_conversion( &mut self, expr: &mut Handle, meta: Span, scalar: Scalar, ) -> Result<()> { if let (Some(tgt_power), Some(expr_power)) = (type_power(scalar), self.expr_power(*expr, meta)?) { if tgt_power > expr_power { self.conversion(expr, meta, scalar)?; } } Ok(()) } pub fn forced_conversion( &mut self, expr: &mut Handle, meta: Span, scalar: Scalar, ) -> Result<()> { if let Some(expr_scalar) = self.expr_scalar_components(*expr, meta)? { if expr_scalar != scalar { self.conversion(expr, meta, scalar)?; } } Ok(()) } pub fn binary_implicit_conversion( &mut self, left: &mut Handle, left_meta: Span, right: &mut Handle, right_meta: Span, ) -> Result<()> { let left_components = self.expr_scalar_components(*left, left_meta)?; let right_components = self.expr_scalar_components(*right, right_meta)?; if let (Some((left_power, left_scalar)), Some((right_power, right_scalar))) = ( left_components.and_then(|scalar| Some((type_power(scalar)?, scalar))), right_components.and_then(|scalar| Some((type_power(scalar)?, scalar))), ) { match left_power.cmp(&right_power) { core::cmp::Ordering::Less => { self.conversion(left, left_meta, right_scalar)?; } core::cmp::Ordering::Equal => {} core::cmp::Ordering::Greater => { self.conversion(right, right_meta, left_scalar)?; } } } Ok(()) } pub fn implicit_splat( &mut self, expr: &mut Handle, meta: Span, vector_size: Option, ) -> Result<()> { let expr_type = self.resolve_type(*expr, meta)?; if let (&TypeInner::Scalar { .. }, Some(size)) = (expr_type, vector_size) { *expr = self.add_expression(Expression::Splat { size, value: *expr }, meta)? } Ok(()) } pub fn vector_resize( &mut self, size: VectorSize, vector: Handle, meta: Span, ) -> Result> { self.add_expression( Expression::Swizzle { size, vector, pattern: crate::SwizzleComponent::XYZW, }, meta, ) } } impl Index> for Context<'_> { type Output = Expression; fn index(&self, index: Handle) -> &Self::Output { if self.is_const { &self.module.global_expressions[index] } else { &self.expressions[index] } } } /// Helper struct passed when parsing expressions /// /// This struct should only be obtained through [`stmt_ctx`](Context::stmt_ctx) /// and only one of these may be active at any time per context. #[derive(Debug)] pub struct StmtContext { /// A arena of high level expressions which can be lowered through a /// [`Context`] to Naga's [`Expression`]s pub hir_exprs: Arena, } impl StmtContext { const fn new() -> Self { StmtContext { hir_exprs: Arena::new(), } } } naga-29.0.3/src/front/glsl/error.rs000064400000000000000000000216631046102023000152140ustar 00000000000000use alloc::{ borrow::Cow, string::{String, ToString}, vec, vec::Vec, }; use codespan_reporting::diagnostic::{Diagnostic, Label}; use codespan_reporting::files::SimpleFile; use codespan_reporting::term; use pp_rs::token::PreprocessorError; use thiserror::Error; use super::token::TokenValue; #[cfg(feature = "stderr")] use crate::error::ErrorWrite; use crate::{error::replace_control_chars, SourceLocation}; use crate::{proc::ConstantEvaluatorError, Span}; fn join_with_comma(list: &[ExpectedToken]) -> String { let mut string = "".to_string(); for (i, val) in list.iter().enumerate() { string.push_str(&val.to_string()); match i { i if i == list.len() - 1 => {} i if i == list.len() - 2 => string.push_str(" or "), _ => string.push_str(", "), } } string } /// One of the expected tokens returned in [`InvalidToken`](ErrorKind::InvalidToken). #[derive(Clone, Debug, PartialEq)] pub enum ExpectedToken { /// A specific token was expected. Token(TokenValue), /// A type was expected. TypeName, /// An identifier was expected. Identifier, /// An integer literal was expected. IntLiteral, /// A float literal was expected. FloatLiteral, /// A boolean literal was expected. BoolLiteral, /// The end of file was expected. Eof, } impl From for ExpectedToken { fn from(token: TokenValue) -> Self { ExpectedToken::Token(token) } } impl core::fmt::Display for ExpectedToken { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { match *self { ExpectedToken::Token(ref token) => write!(f, "{token:?}"), ExpectedToken::TypeName => write!(f, "a type"), ExpectedToken::Identifier => write!(f, "identifier"), ExpectedToken::IntLiteral => write!(f, "integer literal"), ExpectedToken::FloatLiteral => write!(f, "float literal"), ExpectedToken::BoolLiteral => write!(f, "bool literal"), ExpectedToken::Eof => write!(f, "end of file"), } } } /// Information about the cause of an error. #[derive(Clone, Debug, Error)] #[cfg_attr(test, derive(PartialEq))] pub enum ErrorKind { /// Whilst parsing as encountered an unexpected EOF. #[error("Unexpected end of file")] EndOfFile, /// The shader specified an unsupported or invalid profile. #[error("Invalid profile: {0}")] InvalidProfile(String), /// The shader requested an unsupported or invalid version. #[error("Invalid version: {0}")] InvalidVersion(u64), /// Whilst parsing an unexpected token was encountered. /// /// A list of expected tokens is also returned. #[error("Expected {expected_tokens}, found {found_token:?}", found_token = .0, expected_tokens = join_with_comma(.1))] InvalidToken(TokenValue, Vec), /// A specific feature is not yet implemented. /// /// To help prioritize work please open an issue in the github issue tracker /// if none exist already or react to the already existing one. #[error("Not implemented: {0}")] NotImplemented(&'static str), /// A reference to a variable that wasn't declared was used. #[error("Unknown variable: {0}")] UnknownVariable(String), /// A reference to a type that wasn't declared was used. #[error("Unknown type: {0}")] UnknownType(String), /// A reference to a non existent member of a type was made. #[error("Unknown field: {0}")] UnknownField(String), /// An unknown layout qualifier was used. /// /// If the qualifier does exist please open an issue in the github issue tracker /// if none exist already or react to the already existing one to help /// prioritize work. #[error("Unknown layout qualifier: {0}")] UnknownLayoutQualifier(String), /// Unsupported matrix of the form matCx2 /// /// Our IR expects matrices of the form matCx2 to have a stride of 8 however /// matrices in the std140 layout have a stride of at least 16. #[error("unsupported matrix of the form matCx2 (in this case mat{columns}x2) in std140 block layout. See https://github.com/gfx-rs/wgpu/issues/4375")] UnsupportedMatrixWithTwoRowsInStd140 { columns: u8 }, /// Unsupported matrix of the form f16matCxR /// /// Our IR expects matrices of the form f16matCxR to have a stride of 4/8/8 depending on row-count, /// however matrices in the std140 layout have a stride of at least 16. #[error("unsupported matrix of the form f16matCxR (in this case f16mat{columns}x{rows}) in std140 block layout. See https://github.com/gfx-rs/wgpu/issues/4375")] UnsupportedF16MatrixInStd140 { columns: u8, rows: u8 }, /// A variable with the same name already exists in the current scope. #[error("Variable already declared: {0}")] VariableAlreadyDeclared(String), /// A semantic error was detected in the shader. #[error("{0}")] SemanticError(Cow<'static, str>), /// An error was returned by the preprocessor. #[error("{0:?}")] PreprocessorError(PreprocessorError), /// The parser entered an illegal state and exited /// /// This obviously is a bug and as such should be reported in the github issue tracker #[error("Internal error: {0}")] InternalError(&'static str), } impl From for ErrorKind { fn from(err: ConstantEvaluatorError) -> Self { ErrorKind::SemanticError(err.to_string().into()) } } /// Error returned during shader parsing. #[derive(Clone, Debug, Error)] #[error("{kind}")] #[cfg_attr(test, derive(PartialEq))] pub struct Error { /// Holds the information about the error itself. pub kind: ErrorKind, /// Holds information about the range of the source code where the error happened. pub meta: Span, } impl Error { /// Returns a [`SourceLocation`] for the error message. pub fn location(&self, source: &str) -> Option { Some(self.meta.location(source)) } } /// A collection of errors returned during shader parsing. #[derive(Clone, Debug)] #[cfg_attr(test, derive(PartialEq))] pub struct ParseErrors { pub errors: Vec, } impl ParseErrors { #[cfg(feature = "stderr")] pub fn emit_to_writer(&self, writer: &mut impl ErrorWrite, source: &str) { self.emit_to_writer_with_path(writer, source, "glsl"); } #[cfg(feature = "stderr")] pub fn emit_to_writer_with_path(&self, writer: &mut impl ErrorWrite, source: &str, path: &str) { let path = path.to_string(); let files = SimpleFile::new(path, replace_control_chars(source)); let config = term::Config::default(); for err in &self.errors { let diagnostic = Self::make_diagnostic(err); crate::error::emit_to_writer(writer, &config, &files, &diagnostic) .expect("cannot write error"); } } /// Emits a summary of the errors to standard error stream. #[cfg(feature = "stderr")] pub fn emit_to_stderr(&self, source: &str) { self.emit_to_stderr_with_path(source, "glsl") } /// Emits a summary of the errors to standard error stream. #[cfg(feature = "stderr")] pub fn emit_to_stderr_with_path(&self, source: &str, path: &str) { cfg_if::cfg_if! { if #[cfg(feature = "termcolor")] { let writer = term::termcolor::StandardStream::stderr(term::termcolor::ColorChoice::Auto); self.emit_to_writer_with_path(&mut writer.lock(), source, path); } else { let writer = std::io::stderr(); self.emit_to_writer_with_path(&mut writer.lock(), source, path); } } } pub fn emit_to_string(&self, source: &str) -> String { self.emit_to_string_with_path(source, "glsl") } pub fn emit_to_string_with_path(&self, source: &str, path: &str) -> String { let path = path.to_string(); let files = SimpleFile::new(path, replace_control_chars(source)); let config = term::Config::default(); let mut writer = crate::error::DiagnosticBuffer::new(); for err in &self.errors { let diagnostic = Self::make_diagnostic(err); writer .emit_to_self(&config, &files, &diagnostic) .expect("cannot write error"); } writer.into_string() } fn make_diagnostic(err: &Error) -> Diagnostic<()> { let mut diagnostic = Diagnostic::error().with_message(err.kind.to_string()); if let Some(range) = err.meta.to_range() { diagnostic = diagnostic.with_labels(vec![Label::primary((), range)]); } diagnostic } } impl core::fmt::Display for ParseErrors { fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { self.errors.iter().try_for_each(|e| write!(f, "{e:?}")) } } impl core::error::Error for ParseErrors {} impl From> for ParseErrors { fn from(errors: Vec) -> Self { Self { errors } } } naga-29.0.3/src/front/glsl/functions.rs000064400000000000000000001700611046102023000160700ustar 00000000000000use alloc::{ format, string::{String, ToString}, vec, vec::Vec, }; use core::iter; use super::{ ast::*, builtins::{inject_builtin, sampled_to_depth}, context::{Context, ExprPos, StmtContext}, error::{Error, ErrorKind}, types::scalar_components, Frontend, Result, }; use crate::{ front::glsl::types::type_power, proc::ensure_block_returns, AddressSpace, Block, EntryPoint, Expression, Function, FunctionArgument, FunctionResult, Handle, Literal, LocalVariable, Scalar, ScalarKind, Span, Statement, StructMember, Type, TypeInner, }; /// Struct detailing a store operation that must happen after a function call struct ProxyWrite { /// The store target target: Handle, /// A pointer to read the value of the store value: Handle, /// An optional conversion to be applied convert: Option, } impl Frontend { pub(crate) fn function_or_constructor_call( &mut self, ctx: &mut Context, stmt: &StmtContext, fc: FunctionCallKind, raw_args: &[Handle], meta: Span, ) -> Result>> { let args: Vec<_> = raw_args .iter() .map(|e| ctx.lower_expect_inner(stmt, self, *e, ExprPos::Rhs)) .collect::>()?; match fc { FunctionCallKind::TypeConstructor(ty) => { if args.len() == 1 { self.constructor_single(ctx, ty, args[0], meta).map(Some) } else { self.constructor_many(ctx, ty, args, meta).map(Some) } } FunctionCallKind::Function(name) => { self.function_call(ctx, stmt, name, args, raw_args, meta) } } } fn constructor_single( &mut self, ctx: &mut Context, ty: Handle, (mut value, expr_meta): (Handle, Span), meta: Span, ) -> Result> { let expr_type = ctx.resolve_type(value, expr_meta)?; let vector_size = match *expr_type { TypeInner::Vector { size, .. } => Some(size), _ => None, }; let expr_is_bool = expr_type.scalar_kind() == Some(ScalarKind::Bool); // Special case: if casting from a bool, we need to use Select and not As. match ctx.module.types[ty].inner.scalar() { Some(result_scalar) if expr_is_bool && result_scalar.kind != ScalarKind::Bool => { let result_scalar = Scalar { width: 4, ..result_scalar }; let l0 = Literal::zero(result_scalar).unwrap(); let l1 = Literal::one(result_scalar).unwrap(); let mut reject = ctx.add_expression(Expression::Literal(l0), expr_meta)?; let mut accept = ctx.add_expression(Expression::Literal(l1), expr_meta)?; ctx.implicit_splat(&mut reject, meta, vector_size)?; ctx.implicit_splat(&mut accept, meta, vector_size)?; let h = ctx.add_expression( Expression::Select { accept, reject, condition: value, }, expr_meta, )?; return Ok(h); } _ => {} } Ok(match ctx.module.types[ty].inner { TypeInner::Vector { size, scalar } if vector_size.is_none() => { ctx.forced_conversion(&mut value, expr_meta, scalar)?; if let TypeInner::Scalar { .. } = *ctx.resolve_type(value, expr_meta)? { ctx.add_expression(Expression::Splat { size, value }, meta)? } else { self.vector_constructor(ctx, ty, size, scalar, &[(value, expr_meta)], meta)? } } TypeInner::Scalar(scalar) => { let mut expr = value; if let TypeInner::Vector { .. } | TypeInner::Matrix { .. } = *ctx.resolve_type(value, expr_meta)? { expr = ctx.add_expression( Expression::AccessIndex { base: expr, index: 0, }, meta, )?; } if let TypeInner::Matrix { .. } = *ctx.resolve_type(value, expr_meta)? { expr = ctx.add_expression( Expression::AccessIndex { base: expr, index: 0, }, meta, )?; } ctx.add_expression( Expression::As { kind: scalar.kind, expr, convert: Some(scalar.width), }, meta, )? } TypeInner::Vector { size, scalar } => { if vector_size != Some(size) { value = ctx.vector_resize(size, value, expr_meta)?; } ctx.add_expression( Expression::As { kind: scalar.kind, expr: value, convert: Some(scalar.width), }, meta, )? } TypeInner::Matrix { columns, rows, scalar, } => self.matrix_one_arg(ctx, ty, columns, rows, scalar, (value, expr_meta), meta)?, TypeInner::Struct { ref members, .. } => { let scalar_components = members .first() .and_then(|member| scalar_components(&ctx.module.types[member.ty].inner)); if let Some(scalar) = scalar_components { ctx.implicit_conversion(&mut value, expr_meta, scalar)?; } ctx.add_expression( Expression::Compose { ty, components: vec![value], }, meta, )? } TypeInner::Array { base, .. } => { let scalar_components = scalar_components(&ctx.module.types[base].inner); if let Some(scalar) = scalar_components { ctx.implicit_conversion(&mut value, expr_meta, scalar)?; } ctx.add_expression( Expression::Compose { ty, components: vec![value], }, meta, )? } _ => { self.errors.push(Error { kind: ErrorKind::SemanticError("Bad type constructor".into()), meta, }); value } }) } #[allow(clippy::too_many_arguments)] fn matrix_one_arg( &mut self, ctx: &mut Context, ty: Handle, columns: crate::VectorSize, rows: crate::VectorSize, element_scalar: Scalar, (mut value, expr_meta): (Handle, Span), meta: Span, ) -> Result> { let mut components = Vec::with_capacity(columns as usize); // TODO: casts // `Expression::As` doesn't support matrix width // casts so we need to do some extra work for casts ctx.forced_conversion(&mut value, expr_meta, element_scalar)?; match *ctx.resolve_type(value, expr_meta)? { TypeInner::Scalar(_) => { // If a matrix is constructed with a single scalar value, then that // value is used to initialize all the values along the diagonal of // the matrix; the rest are given zeros. let vector_ty = ctx.module.types.insert( Type { name: None, inner: TypeInner::Vector { size: rows, scalar: element_scalar, }, }, meta, ); let zero_literal = Literal::zero(element_scalar).unwrap(); let zero = ctx.add_expression(Expression::Literal(zero_literal), meta)?; for i in 0..columns as u32 { components.push( ctx.add_expression( Expression::Compose { ty: vector_ty, components: (0..rows as u32) .map(|r| match r == i { true => value, false => zero, }) .collect(), }, meta, )?, ) } } TypeInner::Matrix { rows: ori_rows, columns: ori_cols, .. } => { // If a matrix is constructed from a matrix, then each component // (column i, row j) in the result that has a corresponding component // (column i, row j) in the argument will be initialized from there. All // other components will be initialized to the identity matrix. let zero_literal = Literal::zero(element_scalar).unwrap(); let one_literal = Literal::one(element_scalar).unwrap(); let zero = ctx.add_expression(Expression::Literal(zero_literal), meta)?; let one = ctx.add_expression(Expression::Literal(one_literal), meta)?; let vector_ty = ctx.module.types.insert( Type { name: None, inner: TypeInner::Vector { size: rows, scalar: element_scalar, }, }, meta, ); for i in 0..columns as u32 { if i < ori_cols as u32 { use core::cmp::Ordering; let vector = ctx.add_expression( Expression::AccessIndex { base: value, index: i, }, meta, )?; components.push(match ori_rows.cmp(&rows) { Ordering::Less => { let components = (0..rows as u32) .map(|r| { if r < ori_rows as u32 { ctx.add_expression( Expression::AccessIndex { base: vector, index: r, }, meta, ) } else if r == i { Ok(one) } else { Ok(zero) } }) .collect::>()?; ctx.add_expression( Expression::Compose { ty: vector_ty, components, }, meta, )? } Ordering::Equal => vector, Ordering::Greater => ctx.vector_resize(rows, vector, meta)?, }) } else { let compose_expr = Expression::Compose { ty: vector_ty, components: (0..rows as u32) .map(|r| match r == i { true => one, false => zero, }) .collect(), }; let vec = ctx.add_expression(compose_expr, meta)?; components.push(vec) } } } _ => { components = iter::repeat_n(value, columns as usize).collect(); } } ctx.add_expression(Expression::Compose { ty, components }, meta) } fn vector_constructor( &mut self, ctx: &mut Context, ty: Handle, size: crate::VectorSize, scalar: Scalar, args: &[(Handle, Span)], meta: Span, ) -> Result> { let mut components = Vec::with_capacity(size as usize); for (mut arg, expr_meta) in args.iter().copied() { ctx.forced_conversion(&mut arg, expr_meta, scalar)?; if components.len() >= size as usize { break; } match *ctx.resolve_type(arg, expr_meta)? { TypeInner::Scalar { .. } => components.push(arg), TypeInner::Matrix { rows, columns, .. } => { components.reserve(rows as usize * columns as usize); for c in 0..(columns as u32) { let base = ctx.add_expression( Expression::AccessIndex { base: arg, index: c, }, expr_meta, )?; for r in 0..(rows as u32) { components.push(ctx.add_expression( Expression::AccessIndex { base, index: r }, expr_meta, )?) } } } TypeInner::Vector { size: ori_size, .. } => { components.reserve(ori_size as usize); for index in 0..(ori_size as u32) { components.push(ctx.add_expression( Expression::AccessIndex { base: arg, index }, expr_meta, )?) } } _ => components.push(arg), } } components.truncate(size as usize); ctx.add_expression(Expression::Compose { ty, components }, meta) } fn constructor_many( &mut self, ctx: &mut Context, ty: Handle, args: Vec<(Handle, Span)>, meta: Span, ) -> Result> { let mut components = Vec::with_capacity(args.len()); let struct_member_data = match ctx.module.types[ty].inner { TypeInner::Matrix { columns, rows, scalar: element_scalar, } => { let mut flattened = Vec::with_capacity(columns as usize * rows as usize); for (mut arg, meta) in args.iter().copied() { ctx.forced_conversion(&mut arg, meta, element_scalar)?; match *ctx.resolve_type(arg, meta)? { TypeInner::Vector { size, .. } => { for i in 0..(size as u32) { flattened.push(ctx.add_expression( Expression::AccessIndex { base: arg, index: i, }, meta, )?) } } _ => flattened.push(arg), } } let ty = ctx.module.types.insert( Type { name: None, inner: TypeInner::Vector { size: rows, scalar: element_scalar, }, }, meta, ); for chunk in flattened.chunks(rows as usize) { components.push(ctx.add_expression( Expression::Compose { ty, components: Vec::from(chunk), }, meta, )?) } None } TypeInner::Vector { size, scalar } => { return self.vector_constructor(ctx, ty, size, scalar, &args, meta) } TypeInner::Array { base, .. } => { for (mut arg, meta) in args.iter().copied() { let scalar_components = scalar_components(&ctx.module.types[base].inner); if let Some(scalar) = scalar_components { ctx.implicit_conversion(&mut arg, meta, scalar)?; } components.push(arg) } None } TypeInner::Struct { ref members, .. } => Some( members .iter() .map(|member| scalar_components(&ctx.module.types[member.ty].inner)) .collect::>(), ), _ => { return Err(Error { kind: ErrorKind::SemanticError("Constructor: Too many arguments".into()), meta, }) } }; if let Some(struct_member_data) = struct_member_data { for ((mut arg, meta), scalar_components) in args.iter().copied().zip(struct_member_data.iter().copied()) { if let Some(scalar) = scalar_components { ctx.implicit_conversion(&mut arg, meta, scalar)?; } components.push(arg) } } ctx.add_expression(Expression::Compose { ty, components }, meta) } fn function_call( &mut self, ctx: &mut Context, stmt: &StmtContext, name: String, args: Vec<(Handle, Span)>, raw_args: &[Handle], meta: Span, ) -> Result>> { // Grow the typifier to be able to index it later without needing // to hold the context mutably for &(expr, span) in args.iter() { ctx.typifier_grow(expr, span)?; } // Check if the passed arguments require any special variations let mut variations = builtin_required_variations(args.iter().map(|&(expr, _)| ctx.get_type(expr))); // Initiate the declaration if it wasn't previously initialized and inject builtins let declaration = self.lookup_function.entry(name.clone()).or_insert_with(|| { variations |= BuiltinVariations::STANDARD; Default::default() }); inject_builtin(declaration, ctx.module, &name, variations); // Borrow again but without mutability, at this point a declaration is guaranteed let declaration = self.lookup_function.get(&name).unwrap(); // Possibly contains the overload to be used in the call let mut maybe_overload = None; // The conversions needed for the best analyzed overload, this is initialized all to // `NONE` to make sure that conversions always pass the first time without ambiguity let mut old_conversions = vec![Conversion::None; args.len()]; // Tracks whether the comparison between overloads lead to an ambiguity let mut ambiguous = false; // Iterate over all the available overloads to select either an exact match or a // overload which has suitable implicit conversions 'outer: for (overload_idx, overload) in declaration.overloads.iter().enumerate() { // If the overload and the function call don't have the same number of arguments // continue to the next overload if args.len() != overload.parameters.len() { continue; } log::trace!("Testing overload {overload_idx}"); // Stores whether the current overload matches exactly the function call let mut exact = true; // State of the selection // If None we still don't know what is the best overload // If Some(true) the new overload is better // If Some(false) the old overload is better let mut superior = None; // Store the conversions for the current overload so that later they can replace the // conversions used for querying the best overload let mut new_conversions = vec![Conversion::None; args.len()]; // Loop through the overload parameters and check if the current overload is better // compared to the previous best overload. for (i, overload_parameter) in overload.parameters.iter().enumerate() { let call_argument = &args[i]; let parameter_info = &overload.parameters_info[i]; // If the image is used in the overload as a depth texture convert it // before comparing, otherwise exact matches wouldn't be reported if parameter_info.depth { sampled_to_depth(ctx, call_argument.0, call_argument.1, &mut self.errors); ctx.invalidate_expression(call_argument.0, call_argument.1)? } ctx.typifier_grow(call_argument.0, call_argument.1)?; let overload_param_ty = &ctx.module.types[*overload_parameter].inner; let call_arg_ty = ctx.get_type(call_argument.0); log::trace!( "Testing parameter {i}\n\tOverload = {overload_param_ty:?}\n\tCall = {call_arg_ty:?}" ); // Storage images cannot be directly compared since while the access is part of the // type in naga's IR, in glsl they are a qualifier and don't enter in the match as // long as the access needed is satisfied. if let ( &TypeInner::Image { class: crate::ImageClass::Storage { format: overload_format, access: overload_access, }, dim: overload_dim, arrayed: overload_arrayed, }, &TypeInner::Image { class: crate::ImageClass::Storage { format: call_format, access: call_access, }, dim: call_dim, arrayed: call_arrayed, }, ) = (overload_param_ty, call_arg_ty) { // Images size must match otherwise the overload isn't what we want let good_size = call_dim == overload_dim && call_arrayed == overload_arrayed; // Glsl requires the formats to strictly match unless you are builtin // function overload and have not been replaced, in which case we only // check that the format scalar kind matches let good_format = overload_format == call_format || (overload.internal && Scalar::from(overload_format) == Scalar::from(call_format)); if !(good_size && good_format) { continue 'outer; } // While storage access mismatch is an error it isn't one that causes // the overload matching to fail so we defer the error and consider // that the images match exactly if !call_access.contains(overload_access) { self.errors.push(Error { kind: ErrorKind::SemanticError( format!( "'{name}': image needs {overload_access:?} access but only {call_access:?} was provided" ) .into(), ), meta, }); } // The images satisfy the conditions to be considered as an exact match new_conversions[i] = Conversion::Exact; continue; } else if overload_param_ty == call_arg_ty { // If the types match there's no need to check for conversions so continue new_conversions[i] = Conversion::Exact; continue; } // Glsl defines that inout follows both the conversions for input parameters and // output parameters, this means that the type must have a conversion from both the // call argument to the function parameter and the function parameter to the call // argument, the only way this is possible is for the conversion to be an identity // (i.e. call argument = function parameter) if let ParameterQualifier::InOut = parameter_info.qualifier { continue 'outer; } // The function call argument and the function definition // parameter are not equal at this point, so we need to try // implicit conversions. // // Now there are two cases, the argument is defined as a normal // parameter (`in` or `const`), in this case an implicit // conversion is made from the calling argument to the // definition argument. If the parameter is `out` the // opposite needs to be done, so the implicit conversion is made // from the definition argument to the calling argument. let maybe_conversion = if parameter_info.qualifier.is_lhs() { conversion(call_arg_ty, overload_param_ty) } else { conversion(overload_param_ty, call_arg_ty) }; let conversion = match maybe_conversion { Some(info) => info, None => continue 'outer, }; // At this point a conversion will be needed so the overload no longer // exactly matches the call arguments exact = false; // Compare the conversions needed for this overload parameter to that of the // last overload analyzed respective parameter, the value is: // - `true` when the new overload argument has a better conversion // - `false` when the old overload argument has a better conversion let best_arg = match (conversion, old_conversions[i]) { // An exact match is always better, we don't need to check this for the // current overload since it was checked earlier (_, Conversion::Exact) => false, // No overload was yet analyzed so this one is the best yet (_, Conversion::None) => true, // A conversion from a float to a double is the best possible conversion (Conversion::FloatToDouble, _) => true, (_, Conversion::FloatToDouble) => false, // A conversion from a float to an integer is preferred than one // from double to an integer (Conversion::IntToFloat, Conversion::IntToDouble) => true, (Conversion::IntToDouble, Conversion::IntToFloat) => false, // This case handles things like no conversion and exact which were already // treated and other cases which no conversion is better than the other _ => continue, }; // Check if the best parameter corresponds to the current selected overload // to pass to the next comparison, if this isn't true mark it as ambiguous match best_arg { true => match superior { Some(false) => ambiguous = true, _ => { superior = Some(true); new_conversions[i] = conversion } }, false => match superior { Some(true) => ambiguous = true, _ => superior = Some(false), }, } } // The overload matches exactly the function call so there's no ambiguity (since // repeated overload aren't allowed) and the current overload is selected, no // further querying is needed. if exact { maybe_overload = Some(overload); ambiguous = false; break; } match superior { // New overload is better keep it Some(true) => { maybe_overload = Some(overload); // Replace the conversions old_conversions = new_conversions; } // Old overload is better do nothing Some(false) => {} // No overload was better than the other this can be caused // when all conversions are ambiguous in which the overloads themselves are // ambiguous. None => { ambiguous = true; // Assign the new overload, this helps ensures that in this case of // ambiguity the parsing won't end immediately and allow for further // collection of errors. maybe_overload = Some(overload); } } } if ambiguous { self.errors.push(Error { kind: ErrorKind::SemanticError( format!("Ambiguous best function for '{name}'").into(), ), meta, }) } let overload = maybe_overload.ok_or_else(|| Error { kind: ErrorKind::SemanticError(format!("Unknown function '{name}'").into()), meta, })?; let parameters_info = overload.parameters_info.clone(); let parameters = overload.parameters.clone(); let is_void = overload.void; let kind = overload.kind; let mut arguments = Vec::with_capacity(args.len()); let mut proxy_writes = Vec::new(); // Iterate through the function call arguments applying transformations as needed for (((parameter_info, call_argument), expr), parameter) in parameters_info .iter() .zip(&args) .zip(raw_args) .zip(¶meters) { if parameter_info.qualifier.is_lhs() { // Reprocess argument in LHS position let (handle, meta) = ctx.lower_expect_inner(stmt, self, *expr, ExprPos::Lhs)?; self.process_lhs_argument( ctx, meta, *parameter, parameter_info, handle, call_argument, &mut proxy_writes, &mut arguments, )?; continue; } let (mut handle, meta) = *call_argument; let scalar_comps = scalar_components(&ctx.module.types[*parameter].inner); // Apply implicit conversions as needed if let Some(scalar) = scalar_comps { ctx.implicit_conversion(&mut handle, meta, scalar)?; } arguments.push(handle) } match kind { FunctionKind::Call(function) => { ctx.emit_end(); let result = if !is_void { Some(ctx.add_expression(Expression::CallResult(function), meta)?) } else { None }; ctx.body.push( Statement::Call { function, arguments, result, }, meta, ); ctx.emit_start(); // Write back all the variables that were scheduled to their original place for proxy_write in proxy_writes { let mut value = ctx.add_expression( Expression::Load { pointer: proxy_write.value, }, meta, )?; if let Some(scalar) = proxy_write.convert { ctx.conversion(&mut value, meta, scalar)?; } ctx.emit_restart(); ctx.body.push( Statement::Store { pointer: proxy_write.target, value, }, meta, ); } Ok(result) } FunctionKind::Macro(builtin) => builtin.call(self, ctx, arguments.as_mut_slice(), meta), } } /// Processes a function call argument that appears in place of an output /// parameter. #[allow(clippy::too_many_arguments)] fn process_lhs_argument( &mut self, ctx: &mut Context, meta: Span, parameter_ty: Handle, parameter_info: &ParameterInfo, original: Handle, call_argument: &(Handle, Span), proxy_writes: &mut Vec, arguments: &mut Vec>, ) -> Result<()> { let original_ty = ctx.resolve_type(original, meta)?; let original_pointer_space = original_ty.pointer_space(); // The type of a possible spill variable needed for a proxy write let mut maybe_ty = match *original_ty { // If the argument is to be passed as a pointer but the type of the // expression returns a vector it must mean that it was for example // swizzled and it must be spilled into a local before calling TypeInner::Vector { size, scalar } => Some(ctx.module.types.insert( Type { name: None, inner: TypeInner::Vector { size, scalar }, }, Span::default(), )), // If the argument is a pointer whose address space isn't `Function`, an // indirection through a local variable is needed to align the address // spaces of the call argument and the overload parameter. TypeInner::Pointer { base, space } if space != AddressSpace::Function => Some(base), TypeInner::ValuePointer { size, scalar, space, } if space != AddressSpace::Function => { let inner = match size { Some(size) => TypeInner::Vector { size, scalar }, None => TypeInner::Scalar(scalar), }; Some( ctx.module .types .insert(Type { name: None, inner }, Span::default()), ) } _ => None, }; // Since the original expression might be a pointer and we want a value // for the proxy writes, we might need to load the pointer. let value = if original_pointer_space.is_some() { ctx.add_expression(Expression::Load { pointer: original }, Span::default())? } else { original }; ctx.typifier_grow(call_argument.0, call_argument.1)?; let overload_param_ty = &ctx.module.types[parameter_ty].inner; let call_arg_ty = ctx.get_type(call_argument.0); let needs_conversion = call_arg_ty != overload_param_ty; let arg_scalar_comps = scalar_components(call_arg_ty); // Since output parameters also allow implicit conversions from the // parameter to the argument, we need to spill the conversion to a // variable and create a proxy write for the original variable. if needs_conversion { maybe_ty = Some(parameter_ty); } if let Some(ty) = maybe_ty { // Create the spill variable let spill_var = ctx.locals.append( LocalVariable { name: None, ty, init: None, }, Span::default(), ); let spill_expr = ctx.add_expression(Expression::LocalVariable(spill_var), Span::default())?; // If the argument is also copied in we must store the value of the // original variable to the spill variable. if let ParameterQualifier::InOut = parameter_info.qualifier { ctx.body.push( Statement::Store { pointer: spill_expr, value, }, Span::default(), ); } // Add the spill variable as an argument to the function call arguments.push(spill_expr); let convert = if needs_conversion { arg_scalar_comps } else { None }; // Register the temporary local to be written back to it's original // place after the function call if let Expression::Swizzle { size, mut vector, pattern, } = ctx.expressions[original] { if let Expression::Load { pointer } = ctx.expressions[vector] { vector = pointer; } for (i, component) in pattern.iter().take(size as usize).enumerate() { let original = ctx.add_expression( Expression::AccessIndex { base: vector, index: *component as u32, }, Span::default(), )?; let spill_component = ctx.add_expression( Expression::AccessIndex { base: spill_expr, index: i as u32, }, Span::default(), )?; proxy_writes.push(ProxyWrite { target: original, value: spill_component, convert, }); } } else { proxy_writes.push(ProxyWrite { target: original, value: spill_expr, convert, }); } } else { arguments.push(original); } Ok(()) } pub(crate) fn add_function( &mut self, mut ctx: Context, name: String, result: Option, meta: Span, ) { ensure_block_returns(&mut ctx.body); let void = result.is_none(); // Check if the passed arguments require any special variations let mut variations = builtin_required_variations( ctx.parameters .iter() .map(|&arg| &ctx.module.types[arg].inner), ); // Initiate the declaration if it wasn't previously initialized and inject builtins let declaration = self.lookup_function.entry(name.clone()).or_insert_with(|| { variations |= BuiltinVariations::STANDARD; Default::default() }); inject_builtin(declaration, ctx.module, &name, variations); let Context { expressions, locals, arguments, parameters, parameters_info, body, module, .. } = ctx; let function = Function { name: Some(name), arguments, result, local_variables: locals, expressions, named_expressions: crate::NamedExpressions::default(), body, diagnostic_filter_leaf: None, }; 'outer: for decl in declaration.overloads.iter_mut() { if parameters.len() != decl.parameters.len() { continue; } for (new_parameter, old_parameter) in parameters.iter().zip(decl.parameters.iter()) { let new_inner = &module.types[*new_parameter].inner; let old_inner = &module.types[*old_parameter].inner; if new_inner != old_inner { continue 'outer; } } if decl.defined { return self.errors.push(Error { kind: ErrorKind::SemanticError("Function already defined".into()), meta, }); } decl.defined = true; decl.parameters_info = parameters_info; match decl.kind { FunctionKind::Call(handle) => *module.functions.get_mut(handle) = function, FunctionKind::Macro(_) => { let handle = module.functions.append(function, meta); decl.kind = FunctionKind::Call(handle) } } return; } let handle = module.functions.append(function, meta); declaration.overloads.push(Overload { parameters, parameters_info, kind: FunctionKind::Call(handle), defined: true, internal: false, void, }); } pub(crate) fn add_prototype( &mut self, ctx: Context, name: String, result: Option, meta: Span, ) { let void = result.is_none(); // Check if the passed arguments require any special variations let mut variations = builtin_required_variations( ctx.parameters .iter() .map(|&arg| &ctx.module.types[arg].inner), ); // Initiate the declaration if it wasn't previously initialized and inject builtins let declaration = self.lookup_function.entry(name.clone()).or_insert_with(|| { variations |= BuiltinVariations::STANDARD; Default::default() }); inject_builtin(declaration, ctx.module, &name, variations); let Context { arguments, parameters, parameters_info, module, .. } = ctx; let function = Function { name: Some(name), arguments, result, ..Default::default() }; 'outer: for decl in declaration.overloads.iter() { if parameters.len() != decl.parameters.len() { continue; } for (new_parameter, old_parameter) in parameters.iter().zip(decl.parameters.iter()) { let new_inner = &module.types[*new_parameter].inner; let old_inner = &module.types[*old_parameter].inner; if new_inner != old_inner { continue 'outer; } } return self.errors.push(Error { kind: ErrorKind::SemanticError("Prototype already defined".into()), meta, }); } let handle = module.functions.append(function, meta); declaration.overloads.push(Overload { parameters, parameters_info, kind: FunctionKind::Call(handle), defined: false, internal: false, void, }); } /// Create a Naga [`EntryPoint`] that calls the GLSL `main` function. /// /// We compile the GLSL `main` function as an ordinary Naga [`Function`]. /// This function synthesizes a Naga [`EntryPoint`] to call that. /// /// Each GLSL input and output variable (including builtins) becomes a Naga /// [`GlobalVariable`]s in the [`Private`] address space, which `main` can /// access in the usual way. /// /// The `EntryPoint` we synthesize here has an argument for each GLSL input /// variable, and returns a struct with a member for each GLSL output /// variable. The entry point contains code to: /// /// - copy its arguments into the Naga globals representing the GLSL input /// variables, /// /// - call the Naga `Function` representing the GLSL `main` function, and then /// /// - build its return value from whatever values the GLSL `main` left in /// the Naga globals representing GLSL `output` variables. /// /// Upon entry, [`ctx.body`] should contain code, accumulated by prior calls /// to [`ParsingContext::parse_external_declaration`][pxd], to initialize /// private global variables as needed. This code gets spliced into the /// entry point before the call to `main`. /// /// [`GlobalVariable`]: crate::GlobalVariable /// [`Private`]: crate::AddressSpace::Private /// [`ctx.body`]: Context::body /// [pxd]: super::ParsingContext::parse_external_declaration pub(crate) fn add_entry_point( &mut self, function: Handle, mut ctx: Context, ) -> Result<()> { let mut arguments = Vec::new(); let body = Block::with_capacity( // global init body ctx.body.len() + // prologue and epilogue self.entry_args.len() * 2 // Call, Emit for composing struct and return + 3, ); let global_init_body = core::mem::replace(&mut ctx.body, body); for arg in self.entry_args.iter() { if arg.storage != StorageQualifier::Input { continue; } let pointer = ctx .expressions .append(Expression::GlobalVariable(arg.handle), Default::default()); ctx.local_expression_kind_tracker .insert(pointer, crate::proc::ExpressionKind::Runtime); let ty = ctx.module.global_variables[arg.handle].ty; ctx.arg_type_walker( arg.name.clone(), arg.binding.clone(), pointer, ty, &mut |ctx, name, pointer, ty, binding| { let idx = arguments.len() as u32; arguments.push(FunctionArgument { name, ty, binding: Some(binding), }); let value = ctx .expressions .append(Expression::FunctionArgument(idx), Default::default()); ctx.local_expression_kind_tracker .insert(value, crate::proc::ExpressionKind::Runtime); ctx.body .push(Statement::Store { pointer, value }, Default::default()); }, )? } ctx.body.extend_block(global_init_body); ctx.body.push( Statement::Call { function, arguments: Vec::new(), result: None, }, Default::default(), ); let mut span = 0; let mut members = Vec::new(); let mut components = Vec::new(); for arg in self.entry_args.iter() { if arg.storage != StorageQualifier::Output { continue; } let pointer = ctx .expressions .append(Expression::GlobalVariable(arg.handle), Default::default()); ctx.local_expression_kind_tracker .insert(pointer, crate::proc::ExpressionKind::Runtime); let ty = ctx.module.global_variables[arg.handle].ty; ctx.arg_type_walker( arg.name.clone(), arg.binding.clone(), pointer, ty, &mut |ctx, name, pointer, ty, binding| { members.push(StructMember { name, ty, binding: Some(binding), offset: span, }); span += ctx.module.types[ty].inner.size(ctx.module.to_ctx()); let len = ctx.expressions.len(); let load = ctx .expressions .append(Expression::Load { pointer }, Default::default()); ctx.local_expression_kind_tracker .insert(load, crate::proc::ExpressionKind::Runtime); ctx.body.push( Statement::Emit(ctx.expressions.range_from(len)), Default::default(), ); components.push(load) }, )? } let (ty, value) = if !components.is_empty() { let ty = ctx.module.types.insert( Type { name: None, inner: TypeInner::Struct { members, span }, }, Default::default(), ); let len = ctx.expressions.len(); let res = ctx .expressions .append(Expression::Compose { ty, components }, Default::default()); ctx.local_expression_kind_tracker .insert(res, crate::proc::ExpressionKind::Runtime); ctx.body.push( Statement::Emit(ctx.expressions.range_from(len)), Default::default(), ); (Some(ty), Some(res)) } else { (None, None) }; ctx.body .push(Statement::Return { value }, Default::default()); let Context { body, expressions, .. } = ctx; ctx.module.entry_points.push(EntryPoint { name: "main".to_string(), stage: self.meta.stage, early_depth_test: Some(crate::EarlyDepthTest::Force) .filter(|_| self.meta.early_fragment_tests), workgroup_size: self.meta.workgroup_size, workgroup_size_overrides: None, function: Function { arguments, expressions, body, result: ty.map(|ty| FunctionResult { ty, binding: None }), ..Default::default() }, mesh_info: None, task_payload: None, incoming_ray_payload: None, }); Ok(()) } } impl Context<'_> { /// Helper function for building the input/output interface of the entry point /// /// Calls `f` with the data of the entry point argument, flattening composite types /// recursively /// /// The passed arguments to the callback are: /// - The ctx /// - The name /// - The pointer expression to the global storage /// - The handle to the type of the entry point argument /// - The binding of the entry point argument fn arg_type_walker( &mut self, name: Option, binding: crate::Binding, pointer: Handle, ty: Handle, f: &mut impl FnMut( &mut Context, Option, Handle, Handle, crate::Binding, ), ) -> Result<()> { match self.module.types[ty].inner { // TODO: Better error reporting // right now we just don't walk the array if the size isn't known at // compile time and let validation catch it TypeInner::Array { base, size: crate::ArraySize::Constant(size), .. } => { let mut location = match binding { crate::Binding::Location { location, .. } => location, crate::Binding::BuiltIn(_) => return Ok(()), }; let interpolation = self.module.types[base] .inner .scalar_kind() .map(|kind| match kind { ScalarKind::Float => crate::Interpolation::Perspective, _ => crate::Interpolation::Flat, }); for index in 0..size.get() { let member_pointer = self.add_expression( Expression::AccessIndex { base: pointer, index, }, Span::default(), )?; let binding = crate::Binding::Location { location, interpolation, sampling: None, blend_src: None, per_primitive: false, }; location += 1; self.arg_type_walker(name.clone(), binding, member_pointer, base, f)? } } TypeInner::Struct { ref members, .. } => { let mut location = match binding { crate::Binding::Location { location, .. } => location, crate::Binding::BuiltIn(_) => return Ok(()), }; for (i, member) in members.clone().into_iter().enumerate() { let member_pointer = self.add_expression( Expression::AccessIndex { base: pointer, index: i as u32, }, Span::default(), )?; let binding = match member.binding { Some(binding) => binding, None => { let interpolation = self.module.types[member.ty] .inner .scalar_kind() .map(|kind| match kind { ScalarKind::Float => crate::Interpolation::Perspective, _ => crate::Interpolation::Flat, }); let binding = crate::Binding::Location { location, interpolation, sampling: None, blend_src: None, per_primitive: false, }; location += 1; binding } }; self.arg_type_walker(member.name, binding, member_pointer, member.ty, f)? } } _ => f(self, name, pointer, ty, binding), } Ok(()) } } /// Helper enum containing the type of conversion need for a call #[derive(PartialEq, Eq, Clone, Copy, Debug)] enum Conversion { /// No conversion needed Exact, /// Float to double conversion needed FloatToDouble, /// Int or uint to float conversion needed IntToFloat, /// Int or uint to double conversion needed IntToDouble, /// Other type of conversion needed Other, /// No conversion was yet registered None, } /// Helper function, returns the type of conversion from `source` to `target`, if a /// conversion is not possible returns None. fn conversion(target: &TypeInner, source: &TypeInner) -> Option { use ScalarKind::*; // Gather the `ScalarKind` and scalar width from both the target and the source let (target_scalar, source_scalar) = match (target, source) { // Conversions between scalars are allowed (&TypeInner::Scalar(tgt_scalar), &TypeInner::Scalar(src_scalar)) => { (tgt_scalar, src_scalar) } // Conversions between vectors of the same size are allowed ( &TypeInner::Vector { size: tgt_size, scalar: tgt_scalar, }, &TypeInner::Vector { size: src_size, scalar: src_scalar, }, ) if tgt_size == src_size => (tgt_scalar, src_scalar), // Conversions between matrices of the same size are allowed ( &TypeInner::Matrix { rows: tgt_rows, columns: tgt_cols, scalar: tgt_scalar, }, &TypeInner::Matrix { rows: src_rows, columns: src_cols, scalar: src_scalar, }, ) if tgt_cols == src_cols && tgt_rows == src_rows => (tgt_scalar, src_scalar), _ => return None, }; // Check if source can be converted into target, if this is the case then the type // power of target must be higher than that of source let target_power = type_power(target_scalar); let source_power = type_power(source_scalar); if target_power < source_power { return None; } Some(match (target_scalar, source_scalar) { // A conversion from a float to a double is special (Scalar::F64, Scalar::F32) => Conversion::FloatToDouble, // A conversion from an integer to a float is special ( Scalar::F32, Scalar { kind: Sint | Uint, width: _, }, ) => Conversion::IntToFloat, // A conversion from an integer to a double is special ( Scalar::F64, Scalar { kind: Sint | Uint, width: _, }, ) => Conversion::IntToDouble, _ => Conversion::Other, }) } /// Helper method returning all the non standard builtin variations needed /// to process the function call with the passed arguments fn builtin_required_variations<'a>(args: impl Iterator) -> BuiltinVariations { let mut variations = BuiltinVariations::empty(); for ty in args { match *ty { TypeInner::ValuePointer { scalar, .. } | TypeInner::Scalar(scalar) | TypeInner::Vector { scalar, .. } | TypeInner::Matrix { scalar, .. } => { if scalar == Scalar::F64 { variations |= BuiltinVariations::DOUBLE } } TypeInner::Image { dim, arrayed, class, } => { if dim == crate::ImageDimension::Cube && arrayed { variations |= BuiltinVariations::CUBE_TEXTURES_ARRAY } if dim == crate::ImageDimension::D2 && arrayed && class.is_multisampled() { variations |= BuiltinVariations::D2_MULTI_TEXTURES_ARRAY } } _ => {} } } variations } naga-29.0.3/src/front/glsl/lex.rs000064400000000000000000000266141046102023000146540ustar 00000000000000use alloc::string::String; use pp_rs::{ pp::Preprocessor, token::{PreprocessorError, Punct, TokenValue as PPTokenValue}, }; use super::{ ast::Precision, token::{Directive, DirectiveKind, Token, TokenValue}, types::parse_type, }; use crate::{FastHashMap, Span, StorageAccess}; #[derive(Debug)] #[cfg_attr(test, derive(PartialEq))] pub struct LexerResult { pub kind: LexerResultKind, pub meta: Span, } #[derive(Debug)] #[cfg_attr(test, derive(PartialEq))] pub enum LexerResultKind { Token(Token), Directive(Directive), Error(PreprocessorError), } pub struct Lexer<'a> { pp: Preprocessor<'a>, } impl<'a> Lexer<'a> { pub fn new(input: &'a str, defines: &'a FastHashMap) -> Self { let mut pp = Preprocessor::new(input); for (define, value) in defines { pp.add_define(define, value).unwrap(); //TODO: handle error } Lexer { pp } } } impl Iterator for Lexer<'_> { type Item = LexerResult; fn next(&mut self) -> Option { let pp_token = match self.pp.next()? { Ok(t) => t, Err((err, loc)) => { return Some(LexerResult { kind: LexerResultKind::Error(err), meta: loc.into(), }); } }; let meta = pp_token.location.into(); let value = match pp_token.value { PPTokenValue::Extension(extension) => { return Some(LexerResult { kind: LexerResultKind::Directive(Directive { kind: DirectiveKind::Extension, tokens: extension.tokens, }), meta, }) } PPTokenValue::Float(float) => TokenValue::FloatConstant(float), PPTokenValue::Ident(ident) => { match ident.as_str() { // Qualifiers "layout" => TokenValue::Layout, "in" => TokenValue::In, "out" => TokenValue::Out, "uniform" => TokenValue::Uniform, "buffer" => TokenValue::Buffer, "shared" => TokenValue::Shared, "invariant" => TokenValue::Invariant, "flat" => TokenValue::Interpolation(crate::Interpolation::Flat), "noperspective" => TokenValue::Interpolation(crate::Interpolation::Linear), "smooth" => TokenValue::Interpolation(crate::Interpolation::Perspective), "centroid" => TokenValue::Sampling(crate::Sampling::Centroid), "sample" => TokenValue::Sampling(crate::Sampling::Sample), "const" => TokenValue::Const, "inout" => TokenValue::InOut, "precision" => TokenValue::Precision, "highp" => TokenValue::PrecisionQualifier(Precision::High), "mediump" => TokenValue::PrecisionQualifier(Precision::Medium), "lowp" => TokenValue::PrecisionQualifier(Precision::Low), "restrict" => TokenValue::Restrict, "readonly" => TokenValue::MemoryQualifier(StorageAccess::LOAD), "writeonly" => TokenValue::MemoryQualifier(StorageAccess::STORE), // values "true" => TokenValue::BoolConstant(true), "false" => TokenValue::BoolConstant(false), // jump statements "continue" => TokenValue::Continue, "break" => TokenValue::Break, "return" => TokenValue::Return, "discard" => TokenValue::Discard, // selection statements "if" => TokenValue::If, "else" => TokenValue::Else, "switch" => TokenValue::Switch, "case" => TokenValue::Case, "default" => TokenValue::Default, // iteration statements "while" => TokenValue::While, "do" => TokenValue::Do, "for" => TokenValue::For, // types "void" => TokenValue::Void, "struct" => TokenValue::Struct, word => match parse_type(word) { Some(t) => TokenValue::TypeName(t), None => TokenValue::Identifier(String::from(word)), }, } } PPTokenValue::Integer(integer) => TokenValue::IntConstant(integer), PPTokenValue::Punct(punct) => match punct { // Compound assignments Punct::AddAssign => TokenValue::AddAssign, Punct::SubAssign => TokenValue::SubAssign, Punct::MulAssign => TokenValue::MulAssign, Punct::DivAssign => TokenValue::DivAssign, Punct::ModAssign => TokenValue::ModAssign, Punct::LeftShiftAssign => TokenValue::LeftShiftAssign, Punct::RightShiftAssign => TokenValue::RightShiftAssign, Punct::AndAssign => TokenValue::AndAssign, Punct::XorAssign => TokenValue::XorAssign, Punct::OrAssign => TokenValue::OrAssign, // Two character punctuation Punct::Increment => TokenValue::Increment, Punct::Decrement => TokenValue::Decrement, Punct::LogicalAnd => TokenValue::LogicalAnd, Punct::LogicalOr => TokenValue::LogicalOr, Punct::LogicalXor => TokenValue::LogicalXor, Punct::LessEqual => TokenValue::LessEqual, Punct::GreaterEqual => TokenValue::GreaterEqual, Punct::EqualEqual => TokenValue::Equal, Punct::NotEqual => TokenValue::NotEqual, Punct::LeftShift => TokenValue::LeftShift, Punct::RightShift => TokenValue::RightShift, // Parenthesis or similar Punct::LeftBrace => TokenValue::LeftBrace, Punct::RightBrace => TokenValue::RightBrace, Punct::LeftParen => TokenValue::LeftParen, Punct::RightParen => TokenValue::RightParen, Punct::LeftBracket => TokenValue::LeftBracket, Punct::RightBracket => TokenValue::RightBracket, // Other one character punctuation Punct::LeftAngle => TokenValue::LeftAngle, Punct::RightAngle => TokenValue::RightAngle, Punct::Semicolon => TokenValue::Semicolon, Punct::Comma => TokenValue::Comma, Punct::Colon => TokenValue::Colon, Punct::Dot => TokenValue::Dot, Punct::Equal => TokenValue::Assign, Punct::Bang => TokenValue::Bang, Punct::Minus => TokenValue::Dash, Punct::Tilde => TokenValue::Tilde, Punct::Plus => TokenValue::Plus, Punct::Star => TokenValue::Star, Punct::Slash => TokenValue::Slash, Punct::Percent => TokenValue::Percent, Punct::Pipe => TokenValue::VerticalBar, Punct::Caret => TokenValue::Caret, Punct::Ampersand => TokenValue::Ampersand, Punct::Question => TokenValue::Question, }, PPTokenValue::Pragma(pragma) => { return Some(LexerResult { kind: LexerResultKind::Directive(Directive { kind: DirectiveKind::Pragma, tokens: pragma.tokens, }), meta, }) } PPTokenValue::Version(version) => { return Some(LexerResult { kind: LexerResultKind::Directive(Directive { kind: DirectiveKind::Version { is_first_directive: version.is_first_directive, }, tokens: version.tokens, }), meta, }) } }; Some(LexerResult { kind: LexerResultKind::Token(Token { value, meta }), meta, }) } } #[cfg(test)] mod tests { use alloc::vec; use pp_rs::token::{Integer, Location, Token as PPToken, TokenValue as PPTokenValue}; use super::{ super::token::{Directive, DirectiveKind, Token, TokenValue}, Lexer, LexerResult, LexerResultKind, }; use crate::Span; #[test] fn lex_tokens() { let defines = crate::FastHashMap::default(); // line comments let mut lex = Lexer::new("#version 450\nvoid main () {}", &defines); let mut location = Location::default(); location.start = 9; location.end = 12; assert_eq!( lex.next().unwrap(), LexerResult { kind: LexerResultKind::Directive(Directive { kind: DirectiveKind::Version { is_first_directive: true }, tokens: vec![PPToken { value: PPTokenValue::Integer(Integer { signed: true, value: 450, width: 32 }), location }] }), meta: Span::new(1, 8) } ); assert_eq!( lex.next().unwrap(), LexerResult { kind: LexerResultKind::Token(Token { value: TokenValue::Void, meta: Span::new(13, 17) }), meta: Span::new(13, 17) } ); assert_eq!( lex.next().unwrap(), LexerResult { kind: LexerResultKind::Token(Token { value: TokenValue::Identifier("main".into()), meta: Span::new(18, 22) }), meta: Span::new(18, 22) } ); assert_eq!( lex.next().unwrap(), LexerResult { kind: LexerResultKind::Token(Token { value: TokenValue::LeftParen, meta: Span::new(23, 24) }), meta: Span::new(23, 24) } ); assert_eq!( lex.next().unwrap(), LexerResult { kind: LexerResultKind::Token(Token { value: TokenValue::RightParen, meta: Span::new(24, 25) }), meta: Span::new(24, 25) } ); assert_eq!( lex.next().unwrap(), LexerResult { kind: LexerResultKind::Token(Token { value: TokenValue::LeftBrace, meta: Span::new(26, 27) }), meta: Span::new(26, 27) } ); assert_eq!( lex.next().unwrap(), LexerResult { kind: LexerResultKind::Token(Token { value: TokenValue::RightBrace, meta: Span::new(27, 28) }), meta: Span::new(27, 28) } ); assert_eq!(lex.next(), None); } } naga-29.0.3/src/front/glsl/mod.rs000064400000000000000000000155661046102023000146470ustar 00000000000000/*! Frontend for [GLSL][glsl] (OpenGL Shading Language). To begin, take a look at the documentation for the [`Frontend`]. # Supported versions ## Vulkan - 440 (partial) - 450 - 460 [glsl]: https://www.khronos.org/registry/OpenGL/index_gl.php */ pub use ast::{Precision, Profile}; pub use error::{Error, ErrorKind, ExpectedToken, ParseErrors}; pub use token::TokenValue; use alloc::{string::String, vec::Vec}; use crate::{proc::Layouter, FastHashMap, FastHashSet, Handle, Module, ShaderStage, Span, Type}; use ast::{EntryArg, FunctionDeclaration, GlobalLookup}; use parser::ParsingContext; mod ast; mod builtins; mod context; mod error; mod functions; mod lex; mod offset; mod parser; #[cfg(test)] mod parser_tests; mod token; mod types; mod variables; type Result = core::result::Result; /// Per-shader options passed to [`parse`](Frontend::parse). /// /// The [`From`] trait is implemented for [`ShaderStage`] to provide a quick way /// to create an `Options` instance. /// /// ```rust /// # use naga::ShaderStage; /// # use naga::front::glsl::Options; /// Options::from(ShaderStage::Vertex); /// ``` #[derive(Debug)] pub struct Options { /// The shader stage in the pipeline. pub stage: ShaderStage, /// Preprocessor definitions to be used, akin to having /// ```glsl /// #define key value /// ``` /// for each key value pair in the map. pub defines: FastHashMap, } impl From for Options { fn from(stage: ShaderStage) -> Self { Options { stage, defines: FastHashMap::default(), } } } /// Additional information about the GLSL shader. /// /// Stores additional information about the GLSL shader which might not be /// stored in the shader [`Module`]. #[derive(Debug)] pub struct ShaderMetadata { /// The GLSL version specified in the shader through the use of the /// `#version` preprocessor directive. pub version: u16, /// The GLSL profile specified in the shader through the use of the /// `#version` preprocessor directive. pub profile: Profile, /// The shader stage in the pipeline, passed to the [`parse`](Frontend::parse) /// method via the [`Options`] struct. pub stage: ShaderStage, /// The workgroup size for compute shaders, defaults to `[1; 3]` for /// compute shaders and `[0; 3]` for non compute shaders. pub workgroup_size: [u32; 3], /// Whether or not early fragment tests where requested by the shader. /// Defaults to `false`. pub early_fragment_tests: bool, /// The shader can request extensions via the /// `#extension` preprocessor directive, in the directive a behavior /// parameter is used to control whether the extension should be disabled, /// warn on usage, enabled if possible or required. /// /// This field only stores extensions which were required or requested to /// be enabled if possible and they are supported. pub extensions: FastHashSet, } impl ShaderMetadata { fn reset(&mut self, stage: ShaderStage) { self.version = 0; self.profile = Profile::Core; self.stage = stage; self.workgroup_size = [u32::from(stage.compute_like()); 3]; self.early_fragment_tests = false; self.extensions.clear(); } } impl Default for ShaderMetadata { fn default() -> Self { ShaderMetadata { version: 0, profile: Profile::Core, stage: ShaderStage::Vertex, workgroup_size: [0; 3], early_fragment_tests: false, extensions: FastHashSet::default(), } } } /// The `Frontend` is the central structure of the GLSL frontend. /// /// To instantiate a new `Frontend` the [`Default`] trait is used, so a /// call to the associated function [`Frontend::default`](Frontend::default) will /// return a new `Frontend` instance. /// /// To parse a shader simply call the [`parse`](Frontend::parse) method with a /// [`Options`] struct and a [`&str`](str) holding the glsl code. /// /// The `Frontend` also provides the [`metadata`](Frontend::metadata) to get some /// further information about the previously parsed shader, like version and /// extensions used (see the documentation for /// [`ShaderMetadata`] to see all the returned information) /// /// # Example usage /// ```rust /// use naga::ShaderStage; /// use naga::front::glsl::{Frontend, Options}; /// /// let glsl = r#" /// #version 450 core /// /// void main() {} /// "#; /// /// let mut frontend = Frontend::default(); /// let options = Options::from(ShaderStage::Vertex); /// frontend.parse(&options, glsl); /// ``` /// /// # Reusability /// /// If there's a need to parse more than one shader reusing the same `Frontend` /// instance may be beneficial since internal allocations will be reused. /// /// Calling the [`parse`](Frontend::parse) method multiple times will reset the /// `Frontend` so no extra care is needed when reusing. #[derive(Debug, Default)] pub struct Frontend { meta: ShaderMetadata, lookup_function: FastHashMap, lookup_type: FastHashMap>, global_variables: Vec<(String, GlobalLookup)>, entry_args: Vec, layouter: Layouter, errors: Vec, } impl Frontend { fn reset(&mut self, stage: ShaderStage) { self.meta.reset(stage); self.lookup_function.clear(); self.lookup_type.clear(); self.global_variables.clear(); self.entry_args.clear(); self.layouter.clear(); } /// Parses a shader either outputting a shader [`Module`] or a list of /// [`Error`]s. /// /// Multiple calls using the same `Frontend` and different shaders are supported. pub fn parse( &mut self, options: &Options, source: &str, ) -> core::result::Result { self.reset(options.stage); let lexer = lex::Lexer::new(source, &options.defines); let mut ctx = ParsingContext::new(lexer); match ctx.parse(self) { Ok(module) => { if self.errors.is_empty() { Ok(module) } else { Err(core::mem::take(&mut self.errors).into()) } } Err(e) => { self.errors.push(e); Err(core::mem::take(&mut self.errors).into()) } } } /// Returns additional information about the parsed shader which might not /// be stored in the [`Module`], see the documentation for /// [`ShaderMetadata`] for more information about the returned data. /// /// # Notes /// /// Following an unsuccessful parsing the state of the returned information /// is undefined, it might contain only partial information about the /// current shader, the previous shader or both. pub const fn metadata(&self) -> &ShaderMetadata { &self.meta } } naga-29.0.3/src/front/glsl/offset.rs000064400000000000000000000155341046102023000153510ustar 00000000000000/*! Module responsible for calculating the offset and span for types. There exists two types of layouts std140 and std430 (there's technically two more layouts, shared and packed. Shared is not supported by spirv. Packed is implementation dependent and for now it's just implemented as an alias to std140). The OpenGl spec (the layout rules are defined by the OpenGl spec in section 7.6.2.2 as opposed to the GLSL spec) uses the term basic machine units which are equivalent to bytes. */ use alloc::vec::Vec; use super::{ ast::StructLayout, error::{Error, ErrorKind}, Span, }; use crate::{proc::Alignment, Handle, Scalar, Type, TypeInner, UniqueArena}; /// Struct with information needed for defining a struct member. /// /// Returned by [`calculate_offset`]. #[derive(Debug)] pub struct TypeAlignSpan { /// The handle to the type, this might be the same handle passed to /// [`calculate_offset`] or a new such a new array type with a different /// stride set. pub ty: Handle, /// The alignment required by the type. pub align: Alignment, /// The size of the type. pub span: u32, } /// Returns the type, alignment and span of a struct member according to a [`StructLayout`]. /// /// The functions returns a [`TypeAlignSpan`] which has a `ty` member this /// should be used as the struct member type because for example arrays may have /// to change the stride and as such need to have a different type. pub fn calculate_offset( mut ty: Handle, meta: Span, layout: StructLayout, types: &mut UniqueArena, errors: &mut Vec, ) -> TypeAlignSpan { // When using the std430 storage layout, shader storage blocks will be laid out in buffer storage // identically to uniform and shader storage blocks using the std140 layout, except // that the base alignment and stride of arrays of scalars and vectors in rule 4 and of // structures in rule 9 are not rounded up a multiple of the base alignment of a vec4. let (align, span) = match types[ty].inner { // 1. If the member is a scalar consuming N basic machine units, // the base alignment is N. TypeInner::Scalar(Scalar { width, .. }) => (Alignment::from_width(width), width as u32), // 2. If the member is a two- or four-component vector with components // consuming N basic machine units, the base alignment is 2N or 4N, respectively. // 3. If the member is a three-component vector with components consuming N // basic machine units, the base alignment is 4N. TypeInner::Vector { size, scalar: Scalar { width, .. }, } => ( Alignment::from(size) * Alignment::from_width(width), size as u32 * width as u32, ), // 4. If the member is an array of scalars or vectors, the base alignment and array // stride are set to match the base alignment of a single array element, according // to rules (1), (2), and (3), and rounded up to the base alignment of a vec4. // TODO: Matrices array TypeInner::Array { base, size, .. } => { let info = calculate_offset(base, meta, layout, types, errors); let name = types[ty].name.clone(); // See comment at the beginning of the function let (align, stride) = if StructLayout::Std430 == layout { (info.align, info.align.round_up(info.span)) } else { let align = info.align.max(Alignment::MIN_UNIFORM); (align, align.round_up(info.span)) }; let span = match size { crate::ArraySize::Constant(size) => size.get() * stride, crate::ArraySize::Pending(_) => unreachable!(), crate::ArraySize::Dynamic => stride, }; let ty_span = types.get_span(ty); ty = types.insert( Type { name, inner: TypeInner::Array { base: info.ty, size, stride, }, }, ty_span, ); (align, span) } // 5. If the member is a column-major matrix with C columns and R rows, the // matrix is stored identically to an array of C column vectors with R // components each, according to rule (4) // TODO: Row major matrices TypeInner::Matrix { columns, rows, scalar, } => { let mut align = Alignment::from(rows) * Alignment::from_width(scalar.width); // See comment at the beginning of the function if StructLayout::Std430 != layout { align = align.max(Alignment::MIN_UNIFORM); } // See comment on the error kind if StructLayout::Std140 == layout { // Do the f16 test first, as it's more specific if scalar == Scalar::F16 { errors.push(Error { kind: ErrorKind::UnsupportedF16MatrixInStd140 { columns: columns as u8, rows: rows as u8, }, meta, }); } if rows == crate::VectorSize::Bi { errors.push(Error { kind: ErrorKind::UnsupportedMatrixWithTwoRowsInStd140 { columns: columns as u8, }, meta, }); } } (align, align * columns as u32) } TypeInner::Struct { ref members, .. } => { let mut span = 0; let mut align = Alignment::ONE; let mut members = members.clone(); let name = types[ty].name.clone(); for member in members.iter_mut() { let info = calculate_offset(member.ty, meta, layout, types, errors); let member_alignment = info.align; span = member_alignment.round_up(span); align = member_alignment.max(align); member.ty = info.ty; member.offset = span; span += info.span; } span = align.round_up(span); let ty_span = types.get_span(ty); ty = types.insert( Type { name, inner: TypeInner::Struct { members, span }, }, ty_span, ); (align, span) } _ => { errors.push(Error { kind: ErrorKind::SemanticError("Invalid struct member type".into()), meta, }); (Alignment::ONE, 0) } }; TypeAlignSpan { ty, align, span } } naga-29.0.3/src/front/glsl/parser/declarations.rs000064400000000000000000000635631046102023000200340ustar 00000000000000use alloc::{string::String, vec, vec::Vec}; use super::{DeclarationContext, ParsingContext, Result}; use crate::{ front::glsl::{ ast::{ GlobalLookup, GlobalLookupKind, Precision, QualifierKey, QualifierValue, StorageQualifier, StructLayout, TypeQualifiers, }, context::{Context, ExprPos}, error::ExpectedToken, offset, token::{Token, TokenValue}, types::scalar_components, variables::{GlobalOrConstant, VarDeclaration}, Error, ErrorKind, Frontend, Span, }, proc::Alignment, AddressSpace, Expression, FunctionResult, Handle, Scalar, ScalarKind, Statement, StructMember, Type, TypeInner, }; /// Helper method used to retrieve the child type of `ty` at /// index `i`. /// /// # Note /// /// Does not check if the index is valid and returns the same type /// when indexing out-of-bounds a struct or indexing a non indexable /// type. fn element_or_member_type( ty: Handle, i: usize, types: &mut crate::UniqueArena, ) -> Handle { match types[ty].inner { // The child type of a vector is a scalar of the same kind and width TypeInner::Vector { scalar, .. } => types.insert( Type { name: None, inner: TypeInner::Scalar(scalar), }, Default::default(), ), // The child type of a matrix is a vector of floats with the same // width and the size of the matrix rows. TypeInner::Matrix { rows, scalar, .. } => types.insert( Type { name: None, inner: TypeInner::Vector { size: rows, scalar }, }, Default::default(), ), // The child type of an array is the base type of the array TypeInner::Array { base, .. } => base, // The child type of a struct at index `i` is the type of it's // member at that same index. // // In case the index is out of bounds the same type is returned TypeInner::Struct { ref members, .. } => { members.get(i).map(|member| member.ty).unwrap_or(ty) } // The type isn't indexable, the same type is returned _ => ty, } } impl ParsingContext<'_> { pub fn parse_external_declaration( &mut self, frontend: &mut Frontend, global_ctx: &mut Context, ) -> Result<()> { if self .parse_declaration(frontend, global_ctx, true, false)? .is_none() { let token = self.bump(frontend)?; match token.value { TokenValue::Semicolon if frontend.meta.version == 460 => Ok(()), _ => { let expected = match frontend.meta.version { 460 => vec![TokenValue::Semicolon.into(), ExpectedToken::Eof], _ => vec![ExpectedToken::Eof], }; Err(Error { kind: ErrorKind::InvalidToken(token.value, expected), meta: token.meta, }) } } } else { Ok(()) } } pub fn parse_initializer( &mut self, frontend: &mut Frontend, ty: Handle, ctx: &mut Context, ) -> Result<(Handle, Span)> { // initializer: // assignment_expression // LEFT_BRACE initializer_list RIGHT_BRACE // LEFT_BRACE initializer_list COMMA RIGHT_BRACE // // initializer_list: // initializer // initializer_list COMMA initializer if let Some(Token { mut meta, .. }) = self.bump_if(frontend, TokenValue::LeftBrace) { // initializer_list let mut components = Vec::new(); loop { // The type expected to be parsed inside the initializer list let new_ty = element_or_member_type(ty, components.len(), &mut ctx.module.types); components.push(self.parse_initializer(frontend, new_ty, ctx)?.0); let token = self.bump(frontend)?; match token.value { TokenValue::Comma => { if let Some(Token { meta: end_meta, .. }) = self.bump_if(frontend, TokenValue::RightBrace) { meta.subsume(end_meta); break; } } TokenValue::RightBrace => { meta.subsume(token.meta); break; } _ => { return Err(Error { kind: ErrorKind::InvalidToken( token.value, vec![TokenValue::Comma.into(), TokenValue::RightBrace.into()], ), meta: token.meta, }) } } } Ok(( ctx.add_expression(Expression::Compose { ty, components }, meta)?, meta, )) } else { let mut stmt = ctx.stmt_ctx(); let expr = self.parse_assignment(frontend, ctx, &mut stmt)?; let (mut init, init_meta) = ctx.lower_expect(stmt, frontend, expr, ExprPos::Rhs)?; let scalar_components = scalar_components(&ctx.module.types[ty].inner); if let Some(scalar) = scalar_components { ctx.implicit_conversion(&mut init, init_meta, scalar)?; } Ok((init, init_meta)) } } // Note: caller preparsed the type and qualifiers // Note: caller skips this if the fallthrough token is not expected to be consumed here so this // produced Error::InvalidToken if it isn't consumed pub fn parse_init_declarator_list( &mut self, frontend: &mut Frontend, mut ty: Handle, ctx: &mut DeclarationContext, ) -> Result<()> { // init_declarator_list: // single_declaration // init_declarator_list COMMA IDENTIFIER // init_declarator_list COMMA IDENTIFIER array_specifier // init_declarator_list COMMA IDENTIFIER array_specifier EQUAL initializer // init_declarator_list COMMA IDENTIFIER EQUAL initializer // // single_declaration: // fully_specified_type // fully_specified_type IDENTIFIER // fully_specified_type IDENTIFIER array_specifier // fully_specified_type IDENTIFIER array_specifier EQUAL initializer // fully_specified_type IDENTIFIER EQUAL initializer // Consume any leading comma, e.g. this is valid: `float, a=1;` if self .peek(frontend) .is_some_and(|t| t.value == TokenValue::Comma) { self.next(frontend); } loop { let token = self.bump(frontend)?; let name = match token.value { TokenValue::Semicolon => break, TokenValue::Identifier(name) => name, _ => { return Err(Error { kind: ErrorKind::InvalidToken( token.value, vec![ExpectedToken::Identifier, TokenValue::Semicolon.into()], ), meta: token.meta, }) } }; let mut meta = token.meta; // array_specifier // array_specifier EQUAL initializer // EQUAL initializer // parse an array specifier if it exists // NOTE: unlike other parse methods this one doesn't expect an array specifier and // returns Ok(None) rather than an error if there is not one self.parse_array_specifier(frontend, ctx.ctx, &mut meta, &mut ty)?; let is_global_const = ctx.qualifiers.storage.0 == StorageQualifier::Const && ctx.external; let init = self .bump_if(frontend, TokenValue::Assign) .map::, _>(|_| { let prev_const = ctx.ctx.is_const; ctx.ctx.is_const = is_global_const; let (mut expr, init_meta) = self.parse_initializer(frontend, ty, ctx.ctx)?; let scalar_components = scalar_components(&ctx.ctx.module.types[ty].inner); if let Some(scalar) = scalar_components { ctx.ctx.implicit_conversion(&mut expr, init_meta, scalar)?; } ctx.ctx.is_const = prev_const; meta.subsume(init_meta); Ok(expr) }) .transpose()?; let decl_initializer; let late_initializer; if is_global_const { decl_initializer = init; late_initializer = None; } else if ctx.external { decl_initializer = init.and_then(|expr| ctx.ctx.lift_up_const_expression(expr).ok()); late_initializer = None; } else if let Some(init) = init { if ctx.is_inside_loop || !ctx.ctx.local_expression_kind_tracker.is_const(init) { decl_initializer = None; late_initializer = Some(init); } else { decl_initializer = Some(init); late_initializer = None; } } else { decl_initializer = None; late_initializer = None; }; let pointer = ctx.add_var(frontend, ty, name, decl_initializer, meta)?; if let Some(value) = late_initializer { ctx.ctx.emit_restart(); ctx.ctx.body.push(Statement::Store { pointer, value }, meta); } let token = self.bump(frontend)?; match token.value { TokenValue::Semicolon => break, TokenValue::Comma => {} _ => { return Err(Error { kind: ErrorKind::InvalidToken( token.value, vec![TokenValue::Comma.into(), TokenValue::Semicolon.into()], ), meta: token.meta, }) } } } Ok(()) } /// `external` whether or not we are in a global or local context pub fn parse_declaration( &mut self, frontend: &mut Frontend, ctx: &mut Context, external: bool, is_inside_loop: bool, ) -> Result> { //declaration: // function_prototype SEMICOLON // // init_declarator_list SEMICOLON // PRECISION precision_qualifier type_specifier SEMICOLON // // type_qualifier IDENTIFIER LEFT_BRACE struct_declaration_list RIGHT_BRACE SEMICOLON // type_qualifier IDENTIFIER LEFT_BRACE struct_declaration_list RIGHT_BRACE IDENTIFIER SEMICOLON // type_qualifier IDENTIFIER LEFT_BRACE struct_declaration_list RIGHT_BRACE IDENTIFIER array_specifier SEMICOLON // type_qualifier SEMICOLON type_qualifier IDENTIFIER SEMICOLON // type_qualifier IDENTIFIER identifier_list SEMICOLON if self.peek_type_qualifier(frontend) || self.peek_type_name(frontend) { let mut qualifiers = self.parse_type_qualifiers(frontend, ctx)?; if self.peek_type_name(frontend) { // This branch handles variables and function prototypes and if // external is true also function definitions let (ty, mut meta) = self.parse_type(frontend, ctx)?; let token = self.bump(frontend)?; let token_fallthrough = match token.value { TokenValue::Identifier(name) => match self.expect_peek(frontend)?.value { TokenValue::LeftParen => { // This branch handles function definition and prototypes self.bump(frontend)?; let result = ty.map(|ty| FunctionResult { ty, binding: None }); let mut context = Context::new( frontend, ctx.module, false, ctx.global_expression_kind_tracker, )?; self.parse_function_args(frontend, &mut context)?; let end_meta = self.expect(frontend, TokenValue::RightParen)?.meta; meta.subsume(end_meta); let token = self.bump(frontend)?; return match token.value { TokenValue::Semicolon => { // This branch handles function prototypes frontend.add_prototype(context, name, result, meta); Ok(Some(meta)) } TokenValue::LeftBrace if external => { // This branch handles function definitions // as you can see by the guard this branch // only happens if external is also true // parse the body self.parse_compound_statement( token.meta, frontend, &mut context, &mut None, false, )?; frontend.add_function(context, name, result, meta); Ok(Some(meta)) } _ if external => Err(Error { kind: ErrorKind::InvalidToken( token.value, vec![ TokenValue::LeftBrace.into(), TokenValue::Semicolon.into(), ], ), meta: token.meta, }), _ => Err(Error { kind: ErrorKind::InvalidToken( token.value, vec![TokenValue::Semicolon.into()], ), meta: token.meta, }), }; } // Pass the token to the init_declarator_list parser _ => Token { value: TokenValue::Identifier(name), meta: token.meta, }, }, // Pass the token to the init_declarator_list parser _ => token, }; // If program execution has reached here then this will be a // init_declarator_list // token_fallthrough will have a token that was already bumped if let Some(ty) = ty { let mut ctx = DeclarationContext { qualifiers, external, is_inside_loop, ctx, }; self.backtrack(token_fallthrough)?; self.parse_init_declarator_list(frontend, ty, &mut ctx)?; } else { frontend.errors.push(Error { kind: ErrorKind::SemanticError("Declaration cannot have void type".into()), meta, }) } Ok(Some(meta)) } else { // This branch handles struct definitions and modifiers like // ```glsl // layout(early_fragment_tests); // ``` let token = self.bump(frontend)?; match token.value { TokenValue::Identifier(ty_name) => { if self.bump_if(frontend, TokenValue::LeftBrace).is_some() { self.parse_block_declaration( frontend, ctx, &mut qualifiers, ty_name, token.meta, ) .map(Some) } else { if qualifiers.invariant.take().is_some() { frontend.make_variable_invariant(ctx, &ty_name, token.meta)?; qualifiers.unused_errors(&mut frontend.errors); self.expect(frontend, TokenValue::Semicolon)?; return Ok(Some(qualifiers.span)); } //TODO: declaration // type_qualifier IDENTIFIER SEMICOLON // type_qualifier IDENTIFIER identifier_list SEMICOLON Err(Error { kind: ErrorKind::NotImplemented("variable qualifier"), meta: token.meta, }) } } TokenValue::Semicolon => { if let Some(value) = qualifiers.uint_layout_qualifier("local_size_x", &mut frontend.errors) { frontend.meta.workgroup_size[0] = value; } if let Some(value) = qualifiers.uint_layout_qualifier("local_size_y", &mut frontend.errors) { frontend.meta.workgroup_size[1] = value; } if let Some(value) = qualifiers.uint_layout_qualifier("local_size_z", &mut frontend.errors) { frontend.meta.workgroup_size[2] = value; } frontend.meta.early_fragment_tests |= qualifiers .none_layout_qualifier("early_fragment_tests", &mut frontend.errors); qualifiers.unused_errors(&mut frontend.errors); Ok(Some(qualifiers.span)) } _ => Err(Error { kind: ErrorKind::InvalidToken( token.value, vec![ExpectedToken::Identifier, TokenValue::Semicolon.into()], ), meta: token.meta, }), } } } else { match self.peek(frontend).map(|t| &t.value) { Some(&TokenValue::Precision) => { // PRECISION precision_qualifier type_specifier SEMICOLON self.bump(frontend)?; let token = self.bump(frontend)?; let _ = match token.value { TokenValue::PrecisionQualifier(p) => p, _ => { return Err(Error { kind: ErrorKind::InvalidToken( token.value, vec![ TokenValue::PrecisionQualifier(Precision::High).into(), TokenValue::PrecisionQualifier(Precision::Medium).into(), TokenValue::PrecisionQualifier(Precision::Low).into(), ], ), meta: token.meta, }) } }; let (ty, meta) = self.parse_type_non_void(frontend, ctx)?; match ctx.module.types[ty].inner { TypeInner::Scalar(Scalar { kind: ScalarKind::Float | ScalarKind::Sint, .. }) => {} _ => frontend.errors.push(Error { kind: ErrorKind::SemanticError( "Precision statement can only work on floats and ints".into(), ), meta, }), } self.expect(frontend, TokenValue::Semicolon)?; Ok(Some(meta)) } _ => Ok(None), } } } pub fn parse_block_declaration( &mut self, frontend: &mut Frontend, ctx: &mut Context, qualifiers: &mut TypeQualifiers, ty_name: String, mut meta: Span, ) -> Result { let layout = match qualifiers.layout_qualifiers.remove(&QualifierKey::Layout) { Some((QualifierValue::Layout(l), _)) => l, None => { if let StorageQualifier::AddressSpace(AddressSpace::Storage { .. }) = qualifiers.storage.0 { StructLayout::Std430 } else { StructLayout::Std140 } } _ => unreachable!(), }; let mut members = Vec::new(); let span = self.parse_struct_declaration_list(frontend, ctx, &mut members, layout)?; self.expect(frontend, TokenValue::RightBrace)?; let mut ty = ctx.module.types.insert( Type { name: Some(ty_name), inner: TypeInner::Struct { members: members.clone(), span, }, }, Default::default(), ); let token = self.bump(frontend)?; let name = match token.value { TokenValue::Semicolon => None, TokenValue::Identifier(name) => { self.parse_array_specifier(frontend, ctx, &mut meta, &mut ty)?; self.expect(frontend, TokenValue::Semicolon)?; Some(name) } _ => { return Err(Error { kind: ErrorKind::InvalidToken( token.value, vec![ExpectedToken::Identifier, TokenValue::Semicolon.into()], ), meta: token.meta, }) } }; let global = frontend.add_global_var( ctx, VarDeclaration { qualifiers, ty, name, init: None, meta, }, )?; for (i, k, ty) in members.into_iter().enumerate().filter_map(|(i, m)| { let ty = m.ty; m.name.map(|s| (i as u32, s, ty)) }) { let lookup = GlobalLookup { kind: match global { GlobalOrConstant::Global(handle) => GlobalLookupKind::BlockSelect(handle, i), GlobalOrConstant::Constant(handle) => GlobalLookupKind::Constant(handle, ty), GlobalOrConstant::Override(handle) => GlobalLookupKind::Override(handle, ty), }, entry_arg: None, mutable: true, }; ctx.add_global(&k, lookup)?; frontend.global_variables.push((k, lookup)); } Ok(meta) } // TODO: Accept layout arguments pub fn parse_struct_declaration_list( &mut self, frontend: &mut Frontend, ctx: &mut Context, members: &mut Vec, layout: StructLayout, ) -> Result { let mut span = 0; let mut align = Alignment::ONE; loop { // TODO: type_qualifier let (base_ty, mut meta) = self.parse_type_non_void(frontend, ctx)?; loop { let (name, name_meta) = self.expect_ident(frontend)?; let mut ty = base_ty; self.parse_array_specifier(frontend, ctx, &mut meta, &mut ty)?; meta.subsume(name_meta); let info = offset::calculate_offset( ty, meta, layout, &mut ctx.module.types, &mut frontend.errors, ); let member_alignment = info.align; span = member_alignment.round_up(span); align = member_alignment.max(align); members.push(StructMember { name: Some(name), ty: info.ty, binding: None, offset: span, }); span += info.span; if self.bump_if(frontend, TokenValue::Comma).is_none() { break; } } self.expect(frontend, TokenValue::Semicolon)?; if let TokenValue::RightBrace = self.expect_peek(frontend)?.value { break; } } span = align.round_up(span); Ok(span) } } naga-29.0.3/src/front/glsl/parser/expressions.rs000064400000000000000000000510431046102023000177340ustar 00000000000000use alloc::{vec, vec::Vec}; use core::num::NonZeroU32; use crate::{ front::glsl::{ ast::{FunctionCall, FunctionCallKind, HirExpr, HirExprKind}, context::{Context, StmtContext}, error::{ErrorKind, ExpectedToken}, parser::ParsingContext, token::{Token, TokenValue}, Error, Frontend, Result, Span, }, ArraySize, BinaryOperator, Handle, Literal, Type, TypeInner, UnaryOperator, }; impl ParsingContext<'_> { pub fn parse_primary( &mut self, frontend: &mut Frontend, ctx: &mut Context, stmt: &mut StmtContext, ) -> Result> { let mut token = self.bump(frontend)?; let literal = match token.value { TokenValue::IntConstant(int) => { if int.width != 32 { frontend.errors.push(Error { kind: ErrorKind::SemanticError("Unsupported non-32bit integer".into()), meta: token.meta, }); } if int.signed { Literal::I32(int.value as i32) } else { Literal::U32(int.value as u32) } } TokenValue::FloatConstant(float) => { if float.width != 32 { frontend.errors.push(Error { kind: ErrorKind::SemanticError( concat!( "Unsupported floating-point value ", "(expected single-precision floating-point number)" ) .into(), ), meta: token.meta, }); } Literal::F32(float.value) } TokenValue::BoolConstant(value) => Literal::Bool(value), TokenValue::LeftParen => { let expr = self.parse_expression(frontend, ctx, stmt)?; let meta = self.expect(frontend, TokenValue::RightParen)?.meta; token.meta.subsume(meta); return Ok(expr); } _ => { return Err(Error { kind: ErrorKind::InvalidToken( token.value, vec![ TokenValue::LeftParen.into(), ExpectedToken::IntLiteral, ExpectedToken::FloatLiteral, ExpectedToken::BoolLiteral, ], ), meta: token.meta, }); } }; Ok(stmt.hir_exprs.append( HirExpr { kind: HirExprKind::Literal(literal), meta: token.meta, }, Default::default(), )) } pub fn parse_function_call_args( &mut self, frontend: &mut Frontend, ctx: &mut Context, stmt: &mut StmtContext, meta: &mut Span, ) -> Result>> { let mut args = Vec::new(); if let Some(token) = self.bump_if(frontend, TokenValue::RightParen) { meta.subsume(token.meta); } else { loop { args.push(self.parse_assignment(frontend, ctx, stmt)?); let token = self.bump(frontend)?; match token.value { TokenValue::Comma => {} TokenValue::RightParen => { meta.subsume(token.meta); break; } _ => { return Err(Error { kind: ErrorKind::InvalidToken( token.value, vec![TokenValue::Comma.into(), TokenValue::RightParen.into()], ), meta: token.meta, }); } } } } Ok(args) } pub fn parse_postfix( &mut self, frontend: &mut Frontend, ctx: &mut Context, stmt: &mut StmtContext, ) -> Result> { let mut base = if self.peek_type_name(frontend) { let (mut handle, mut meta) = self.parse_type_non_void(frontend, ctx)?; self.expect(frontend, TokenValue::LeftParen)?; let args = self.parse_function_call_args(frontend, ctx, stmt, &mut meta)?; if let TypeInner::Array { size: ArraySize::Dynamic, stride, base, } = ctx.module.types[handle].inner { let span = ctx.module.types.get_span(handle); let size = u32::try_from(args.len()) .ok() .and_then(NonZeroU32::new) .ok_or(Error { kind: ErrorKind::SemanticError( "There must be at least one argument".into(), ), meta, })?; handle = ctx.module.types.insert( Type { name: None, inner: TypeInner::Array { stride, base, size: ArraySize::Constant(size), }, }, span, ) } stmt.hir_exprs.append( HirExpr { kind: HirExprKind::Call(FunctionCall { kind: FunctionCallKind::TypeConstructor(handle), args, }), meta, }, Default::default(), ) } else if let TokenValue::Identifier(_) = self.expect_peek(frontend)?.value { let (name, mut meta) = self.expect_ident(frontend)?; let expr = if self.bump_if(frontend, TokenValue::LeftParen).is_some() { let args = self.parse_function_call_args(frontend, ctx, stmt, &mut meta)?; let kind = match frontend.lookup_type.get(&name) { Some(ty) => FunctionCallKind::TypeConstructor(*ty), None => FunctionCallKind::Function(name), }; HirExpr { kind: HirExprKind::Call(FunctionCall { kind, args }), meta, } } else { let var = match frontend.lookup_variable(ctx, &name, meta)? { Some(var) => var, None => { return Err(Error { kind: ErrorKind::UnknownVariable(name), meta, }) } }; HirExpr { kind: HirExprKind::Variable(var), meta, } }; stmt.hir_exprs.append(expr, Default::default()) } else { self.parse_primary(frontend, ctx, stmt)? }; while let TokenValue::LeftBracket | TokenValue::Dot | TokenValue::Increment | TokenValue::Decrement = self.expect_peek(frontend)?.value { let Token { value, mut meta } = self.bump(frontend)?; match value { TokenValue::LeftBracket => { let index = self.parse_expression(frontend, ctx, stmt)?; let end_meta = self.expect(frontend, TokenValue::RightBracket)?.meta; meta.subsume(end_meta); base = stmt.hir_exprs.append( HirExpr { kind: HirExprKind::Access { base, index }, meta, }, Default::default(), ) } TokenValue::Dot => { let (field, end_meta) = self.expect_ident(frontend)?; if self.bump_if(frontend, TokenValue::LeftParen).is_some() { let args = self.parse_function_call_args(frontend, ctx, stmt, &mut meta)?; base = stmt.hir_exprs.append( HirExpr { kind: HirExprKind::Method { expr: base, name: field, args, }, meta, }, Default::default(), ); continue; } meta.subsume(end_meta); base = stmt.hir_exprs.append( HirExpr { kind: HirExprKind::Select { base, field }, meta, }, Default::default(), ) } TokenValue::Increment | TokenValue::Decrement => { base = stmt.hir_exprs.append( HirExpr { kind: HirExprKind::PrePostfix { op: match value { TokenValue::Increment => BinaryOperator::Add, _ => BinaryOperator::Subtract, }, postfix: true, expr: base, }, meta, }, Default::default(), ) } _ => unreachable!(), } } Ok(base) } pub fn parse_unary( &mut self, frontend: &mut Frontend, ctx: &mut Context, stmt: &mut StmtContext, ) -> Result> { Ok(match self.expect_peek(frontend)?.value { TokenValue::Plus | TokenValue::Dash | TokenValue::Bang | TokenValue::Tilde => { let Token { value, mut meta } = self.bump(frontend)?; let expr = self.parse_unary(frontend, ctx, stmt)?; let end_meta = stmt.hir_exprs[expr].meta; let kind = match value { TokenValue::Dash => HirExprKind::Unary { op: UnaryOperator::Negate, expr, }, TokenValue::Bang => HirExprKind::Unary { op: UnaryOperator::LogicalNot, expr, }, TokenValue::Tilde => HirExprKind::Unary { op: UnaryOperator::BitwiseNot, expr, }, _ => return Ok(expr), }; meta.subsume(end_meta); stmt.hir_exprs .append(HirExpr { kind, meta }, Default::default()) } TokenValue::Increment | TokenValue::Decrement => { let Token { value, meta } = self.bump(frontend)?; let expr = self.parse_unary(frontend, ctx, stmt)?; stmt.hir_exprs.append( HirExpr { kind: HirExprKind::PrePostfix { op: match value { TokenValue::Increment => BinaryOperator::Add, _ => BinaryOperator::Subtract, }, postfix: false, expr, }, meta, }, Default::default(), ) } _ => self.parse_postfix(frontend, ctx, stmt)?, }) } pub fn parse_binary( &mut self, frontend: &mut Frontend, ctx: &mut Context, stmt: &mut StmtContext, passthrough: Option>, min_bp: u8, ) -> Result> { let mut left = passthrough .ok_or(ErrorKind::EndOfFile /* Dummy error */) .or_else(|_| self.parse_unary(frontend, ctx, stmt))?; let mut meta = stmt.hir_exprs[left].meta; while let Some((l_bp, r_bp)) = binding_power(&self.expect_peek(frontend)?.value) { if l_bp < min_bp { break; } let Token { value, .. } = self.bump(frontend)?; let right = self.parse_binary(frontend, ctx, stmt, None, r_bp)?; let end_meta = stmt.hir_exprs[right].meta; meta.subsume(end_meta); left = stmt.hir_exprs.append( HirExpr { kind: HirExprKind::Binary { left, op: match value { TokenValue::LogicalOr => BinaryOperator::LogicalOr, TokenValue::LogicalXor => BinaryOperator::NotEqual, TokenValue::LogicalAnd => BinaryOperator::LogicalAnd, TokenValue::VerticalBar => BinaryOperator::InclusiveOr, TokenValue::Caret => BinaryOperator::ExclusiveOr, TokenValue::Ampersand => BinaryOperator::And, TokenValue::Equal => BinaryOperator::Equal, TokenValue::NotEqual => BinaryOperator::NotEqual, TokenValue::GreaterEqual => BinaryOperator::GreaterEqual, TokenValue::LessEqual => BinaryOperator::LessEqual, TokenValue::LeftAngle => BinaryOperator::Less, TokenValue::RightAngle => BinaryOperator::Greater, TokenValue::LeftShift => BinaryOperator::ShiftLeft, TokenValue::RightShift => BinaryOperator::ShiftRight, TokenValue::Plus => BinaryOperator::Add, TokenValue::Dash => BinaryOperator::Subtract, TokenValue::Star => BinaryOperator::Multiply, TokenValue::Slash => BinaryOperator::Divide, TokenValue::Percent => BinaryOperator::Modulo, _ => unreachable!(), }, right, }, meta, }, Default::default(), ) } Ok(left) } pub fn parse_conditional( &mut self, frontend: &mut Frontend, ctx: &mut Context, stmt: &mut StmtContext, passthrough: Option>, ) -> Result> { let mut condition = self.parse_binary(frontend, ctx, stmt, passthrough, 0)?; let mut meta = stmt.hir_exprs[condition].meta; if self.bump_if(frontend, TokenValue::Question).is_some() { let accept = self.parse_expression(frontend, ctx, stmt)?; self.expect(frontend, TokenValue::Colon)?; let reject = self.parse_assignment(frontend, ctx, stmt)?; let end_meta = stmt.hir_exprs[reject].meta; meta.subsume(end_meta); condition = stmt.hir_exprs.append( HirExpr { kind: HirExprKind::Conditional { condition, accept, reject, }, meta, }, Default::default(), ) } Ok(condition) } pub fn parse_assignment( &mut self, frontend: &mut Frontend, ctx: &mut Context, stmt: &mut StmtContext, ) -> Result> { let tgt = self.parse_unary(frontend, ctx, stmt)?; let mut meta = stmt.hir_exprs[tgt].meta; Ok(match self.expect_peek(frontend)?.value { TokenValue::Assign => { self.bump(frontend)?; let value = self.parse_assignment(frontend, ctx, stmt)?; let end_meta = stmt.hir_exprs[value].meta; meta.subsume(end_meta); stmt.hir_exprs.append( HirExpr { kind: HirExprKind::Assign { tgt, value }, meta, }, Default::default(), ) } TokenValue::OrAssign | TokenValue::AndAssign | TokenValue::AddAssign | TokenValue::DivAssign | TokenValue::ModAssign | TokenValue::SubAssign | TokenValue::MulAssign | TokenValue::LeftShiftAssign | TokenValue::RightShiftAssign | TokenValue::XorAssign => { let token = self.bump(frontend)?; let right = self.parse_assignment(frontend, ctx, stmt)?; let end_meta = stmt.hir_exprs[right].meta; meta.subsume(end_meta); let value = stmt.hir_exprs.append( HirExpr { meta, kind: HirExprKind::Binary { left: tgt, op: match token.value { TokenValue::OrAssign => BinaryOperator::InclusiveOr, TokenValue::AndAssign => BinaryOperator::And, TokenValue::AddAssign => BinaryOperator::Add, TokenValue::DivAssign => BinaryOperator::Divide, TokenValue::ModAssign => BinaryOperator::Modulo, TokenValue::SubAssign => BinaryOperator::Subtract, TokenValue::MulAssign => BinaryOperator::Multiply, TokenValue::LeftShiftAssign => BinaryOperator::ShiftLeft, TokenValue::RightShiftAssign => BinaryOperator::ShiftRight, TokenValue::XorAssign => BinaryOperator::ExclusiveOr, _ => unreachable!(), }, right, }, }, Default::default(), ); stmt.hir_exprs.append( HirExpr { kind: HirExprKind::Assign { tgt, value }, meta, }, Default::default(), ) } _ => self.parse_conditional(frontend, ctx, stmt, Some(tgt))?, }) } pub fn parse_expression( &mut self, frontend: &mut Frontend, ctx: &mut Context, stmt: &mut StmtContext, ) -> Result> { let mut exprs = Vec::new(); let mut expr = self.parse_assignment(frontend, ctx, stmt)?; exprs.push(expr); while let TokenValue::Comma = self.expect_peek(frontend)?.value { self.bump(frontend)?; expr = self.parse_assignment(frontend, ctx, stmt)?; exprs.push(expr); } if exprs.len() == 1 { Ok(expr) } else { let mut meta = stmt.hir_exprs[exprs[0]].meta; for &e in &exprs[1..] { meta.subsume(stmt.hir_exprs[e].meta); } expr = stmt.hir_exprs.append( HirExpr { kind: HirExprKind::Sequence { exprs }, meta, }, Default::default(), ); Ok(expr) } } } const fn binding_power(value: &TokenValue) -> Option<(u8, u8)> { Some(match *value { TokenValue::LogicalOr => (1, 2), TokenValue::LogicalXor => (3, 4), TokenValue::LogicalAnd => (5, 6), TokenValue::VerticalBar => (7, 8), TokenValue::Caret => (9, 10), TokenValue::Ampersand => (11, 12), TokenValue::Equal | TokenValue::NotEqual => (13, 14), TokenValue::GreaterEqual | TokenValue::LessEqual | TokenValue::LeftAngle | TokenValue::RightAngle => (15, 16), TokenValue::LeftShift | TokenValue::RightShift => (17, 18), TokenValue::Plus | TokenValue::Dash => (19, 20), TokenValue::Star | TokenValue::Slash | TokenValue::Percent => (21, 22), _ => return None, }) } naga-29.0.3/src/front/glsl/parser/functions.rs000064400000000000000000000621561046102023000173710ustar 00000000000000use alloc::{vec, vec::Vec}; use crate::front::glsl::context::ExprPos; use crate::front::glsl::Span; use crate::Literal; use crate::{ front::glsl::{ ast::ParameterQualifier, context::Context, parser::ParsingContext, token::{Token, TokenValue}, variables::VarDeclaration, Error, ErrorKind, Frontend, Result, }, Block, Expression, Statement, SwitchCase, UnaryOperator, }; impl ParsingContext<'_> { pub fn peek_parameter_qualifier(&mut self, frontend: &mut Frontend) -> bool { self.peek(frontend).is_some_and(|t| match t.value { TokenValue::In | TokenValue::Out | TokenValue::InOut | TokenValue::Const => true, _ => false, }) } /// Returns the parsed `ParameterQualifier` or `ParameterQualifier::In` pub fn parse_parameter_qualifier(&mut self, frontend: &mut Frontend) -> ParameterQualifier { if self.peek_parameter_qualifier(frontend) { match self.bump(frontend).unwrap().value { TokenValue::In => ParameterQualifier::In, TokenValue::Out => ParameterQualifier::Out, TokenValue::InOut => ParameterQualifier::InOut, TokenValue::Const => ParameterQualifier::Const, _ => unreachable!(), } } else { ParameterQualifier::In } } pub fn parse_statement( &mut self, frontend: &mut Frontend, ctx: &mut Context, terminator: &mut Option, is_inside_loop: bool, ) -> Result> { // Type qualifiers always identify a declaration statement if self.peek_type_qualifier(frontend) { return self.parse_declaration(frontend, ctx, false, is_inside_loop); } // Type names can identify either declaration statements or type constructors // depending on whether the token following the type name is a `(` (LeftParen) if self.peek_type_name(frontend) { // Start by consuming the type name so that we can peek the token after it let token = self.bump(frontend)?; // Peek the next token and check if it's a `(` (LeftParen) if so the statement // is a constructor, otherwise it's a declaration. We need to do the check // beforehand and not in the if since we will backtrack before the if let declaration = TokenValue::LeftParen != self.expect_peek(frontend)?.value; self.backtrack(token)?; if declaration { return self.parse_declaration(frontend, ctx, false, is_inside_loop); } } let new_break = || { let mut block = Block::new(); block.push(Statement::Break, Span::default()); block }; let &Token { ref value, mut meta, } = self.expect_peek(frontend)?; let meta_rest = match *value { TokenValue::Continue => { let meta = self.bump(frontend)?.meta; ctx.body.push(Statement::Continue, meta); terminator.get_or_insert(ctx.body.len()); self.expect(frontend, TokenValue::Semicolon)?.meta } TokenValue::Break => { let meta = self.bump(frontend)?.meta; ctx.body.push(Statement::Break, meta); terminator.get_or_insert(ctx.body.len()); self.expect(frontend, TokenValue::Semicolon)?.meta } TokenValue::Return => { self.bump(frontend)?; let (value, meta) = match self.expect_peek(frontend)?.value { TokenValue::Semicolon => (None, self.bump(frontend)?.meta), _ => { // TODO: Implicit conversions let mut stmt = ctx.stmt_ctx(); let expr = self.parse_expression(frontend, ctx, &mut stmt)?; self.expect(frontend, TokenValue::Semicolon)?; let (handle, meta) = ctx.lower_expect(stmt, frontend, expr, ExprPos::Rhs)?; (Some(handle), meta) } }; ctx.emit_restart(); ctx.body.push(Statement::Return { value }, meta); terminator.get_or_insert(ctx.body.len()); meta } TokenValue::Discard => { let meta = self.bump(frontend)?.meta; ctx.body.push(Statement::Kill, meta); terminator.get_or_insert(ctx.body.len()); self.expect(frontend, TokenValue::Semicolon)?.meta } TokenValue::If => { let mut meta = self.bump(frontend)?.meta; self.expect(frontend, TokenValue::LeftParen)?; let condition = { let mut stmt = ctx.stmt_ctx(); let expr = self.parse_expression(frontend, ctx, &mut stmt)?; let (handle, more_meta) = ctx.lower_expect(stmt, frontend, expr, ExprPos::Rhs)?; meta.subsume(more_meta); handle }; self.expect(frontend, TokenValue::RightParen)?; let accept = ctx.new_body(|ctx| { if let Some(more_meta) = self.parse_statement(frontend, ctx, &mut None, is_inside_loop)? { meta.subsume(more_meta); } Ok(()) })?; let reject = ctx.new_body(|ctx| { if self.bump_if(frontend, TokenValue::Else).is_some() { if let Some(more_meta) = self.parse_statement(frontend, ctx, &mut None, is_inside_loop)? { meta.subsume(more_meta); } } Ok(()) })?; ctx.body.push( Statement::If { condition, accept, reject, }, meta, ); meta } TokenValue::Switch => { let mut meta = self.bump(frontend)?.meta; let end_meta; self.expect(frontend, TokenValue::LeftParen)?; let (selector, uint) = { let mut stmt = ctx.stmt_ctx(); let expr = self.parse_expression(frontend, ctx, &mut stmt)?; let (root, meta) = ctx.lower_expect(stmt, frontend, expr, ExprPos::Rhs)?; let uint = ctx.resolve_type(root, meta)?.scalar_kind() == Some(crate::ScalarKind::Uint); (root, uint) }; self.expect(frontend, TokenValue::RightParen)?; ctx.emit_restart(); let mut cases = Vec::new(); // Track if any default case is present in the switch statement. let mut default_present = false; self.expect(frontend, TokenValue::LeftBrace)?; loop { let value = match self.expect_peek(frontend)?.value { TokenValue::Case => { self.bump(frontend)?; let (const_expr, meta) = self.parse_constant_expression( frontend, ctx.module, ctx.global_expression_kind_tracker, )?; match ctx.module.global_expressions[const_expr] { Expression::Literal(Literal::I32(value)) => match uint { // This unchecked cast isn't good, but since // we only reach this code when the selector // is unsigned but the case label is signed, // verification will reject the module // anyway (which also matches GLSL's rules). true => crate::SwitchValue::U32(value as u32), false => crate::SwitchValue::I32(value), }, Expression::Literal(Literal::U32(value)) => { crate::SwitchValue::U32(value) } _ => { frontend.errors.push(Error { kind: ErrorKind::SemanticError( "Case values can only be integers".into(), ), meta, }); crate::SwitchValue::I32(0) } } } TokenValue::Default => { self.bump(frontend)?; default_present = true; crate::SwitchValue::Default } TokenValue::RightBrace => { end_meta = self.bump(frontend)?.meta; break; } _ => { let Token { value, meta } = self.bump(frontend)?; return Err(Error { kind: ErrorKind::InvalidToken( value, vec![ TokenValue::Case.into(), TokenValue::Default.into(), TokenValue::RightBrace.into(), ], ), meta, }); } }; self.expect(frontend, TokenValue::Colon)?; let mut fall_through = true; let body = ctx.new_body(|ctx| { let mut case_terminator = None; loop { match self.expect_peek(frontend)?.value { TokenValue::Case | TokenValue::Default | TokenValue::RightBrace => { break } _ => { self.parse_statement( frontend, ctx, &mut case_terminator, is_inside_loop, )?; } } } if let Some(mut idx) = case_terminator { if let Statement::Break = ctx.body[idx - 1] { fall_through = false; idx -= 1; } ctx.body.cull(idx..) } Ok(()) })?; cases.push(SwitchCase { value, body, fall_through, }) } meta.subsume(end_meta); // NOTE: do not unwrap here since a switch statement isn't required // to have any cases. if let Some(case) = cases.last_mut() { // GLSL requires that the last case not be empty, so we check // that here and produce an error otherwise (fall_through must // also be checked because `break`s count as statements but // they aren't added to the body) if case.body.is_empty() && case.fall_through { frontend.errors.push(Error { kind: ErrorKind::SemanticError( "last case/default label must be followed by statements".into(), ), meta, }) } // GLSL allows the last case to not have any `break` statement, // this would mark it as fall through but naga's IR requires that // the last case must not be fall through, so we mark need to mark // the last case as not fall through always. case.fall_through = false; } // Add an empty default case in case non was present, this is needed because // naga's IR requires that all switch statements must have a default case but // GLSL doesn't require that, so we might need to add an empty default case. if !default_present { cases.push(SwitchCase { value: crate::SwitchValue::Default, body: Block::new(), fall_through: false, }) } ctx.body.push(Statement::Switch { selector, cases }, meta); meta } TokenValue::While => { let mut meta = self.bump(frontend)?.meta; let loop_body = ctx.new_body(|ctx| { let mut stmt = ctx.stmt_ctx(); self.expect(frontend, TokenValue::LeftParen)?; let root = self.parse_expression(frontend, ctx, &mut stmt)?; meta.subsume(self.expect(frontend, TokenValue::RightParen)?.meta); let (expr, expr_meta) = ctx.lower_expect(stmt, frontend, root, ExprPos::Rhs)?; let condition = ctx.add_expression( Expression::Unary { op: UnaryOperator::LogicalNot, expr, }, expr_meta, )?; ctx.emit_restart(); ctx.body.push( Statement::If { condition, accept: new_break(), reject: Block::new(), }, Span::default(), ); meta.subsume(expr_meta); if let Some(body_meta) = self.parse_statement(frontend, ctx, &mut None, true)? { meta.subsume(body_meta); } Ok(()) })?; ctx.body.push( Statement::Loop { body: loop_body, continuing: Block::new(), break_if: None, }, meta, ); meta } TokenValue::Do => { let mut meta = self.bump(frontend)?.meta; let loop_body = ctx.new_body(|ctx| { let mut terminator = None; self.parse_statement(frontend, ctx, &mut terminator, true)?; let mut stmt = ctx.stmt_ctx(); self.expect(frontend, TokenValue::While)?; self.expect(frontend, TokenValue::LeftParen)?; let root = self.parse_expression(frontend, ctx, &mut stmt)?; let end_meta = self.expect(frontend, TokenValue::RightParen)?.meta; meta.subsume(end_meta); let (expr, expr_meta) = ctx.lower_expect(stmt, frontend, root, ExprPos::Rhs)?; let condition = ctx.add_expression( Expression::Unary { op: UnaryOperator::LogicalNot, expr, }, expr_meta, )?; ctx.emit_restart(); ctx.body.push( Statement::If { condition, accept: new_break(), reject: Block::new(), }, Span::default(), ); if let Some(idx) = terminator { ctx.body.cull(idx..) } Ok(()) })?; ctx.body.push( Statement::Loop { body: loop_body, continuing: Block::new(), break_if: None, }, meta, ); meta } TokenValue::For => { let mut meta = self.bump(frontend)?.meta; ctx.symbol_table.push_scope(); self.expect(frontend, TokenValue::LeftParen)?; if self.bump_if(frontend, TokenValue::Semicolon).is_none() { if self.peek_type_name(frontend) || self.peek_type_qualifier(frontend) { self.parse_declaration(frontend, ctx, false, is_inside_loop)?; } else { let mut stmt = ctx.stmt_ctx(); let expr = self.parse_expression(frontend, ctx, &mut stmt)?; ctx.lower(stmt, frontend, expr, ExprPos::Rhs)?; self.expect(frontend, TokenValue::Semicolon)?; } } let loop_body = ctx.new_body(|ctx| { if self.bump_if(frontend, TokenValue::Semicolon).is_none() { let (expr, expr_meta) = if self.peek_type_name(frontend) || self.peek_type_qualifier(frontend) { let mut qualifiers = self.parse_type_qualifiers(frontend, ctx)?; let (ty, mut meta) = self.parse_type_non_void(frontend, ctx)?; let name = self.expect_ident(frontend)?.0; self.expect(frontend, TokenValue::Assign)?; let (value, end_meta) = self.parse_initializer(frontend, ty, ctx)?; meta.subsume(end_meta); let decl = VarDeclaration { qualifiers: &mut qualifiers, ty, name: Some(name), init: None, meta, }; let pointer = frontend.add_local_var(ctx, decl)?; ctx.emit_restart(); ctx.body.push(Statement::Store { pointer, value }, meta); (value, end_meta) } else { let mut stmt = ctx.stmt_ctx(); let root = self.parse_expression(frontend, ctx, &mut stmt)?; ctx.lower_expect(stmt, frontend, root, ExprPos::Rhs)? }; let condition = ctx.add_expression( Expression::Unary { op: UnaryOperator::LogicalNot, expr, }, expr_meta, )?; ctx.emit_restart(); ctx.body.push( Statement::If { condition, accept: new_break(), reject: Block::new(), }, Span::default(), ); self.expect(frontend, TokenValue::Semicolon)?; } Ok(()) })?; let continuing = ctx.new_body(|ctx| { match self.expect_peek(frontend)?.value { TokenValue::RightParen => {} _ => { let mut stmt = ctx.stmt_ctx(); let rest = self.parse_expression(frontend, ctx, &mut stmt)?; ctx.lower(stmt, frontend, rest, ExprPos::Rhs)?; } } Ok(()) })?; meta.subsume(self.expect(frontend, TokenValue::RightParen)?.meta); let loop_body = ctx.with_body(loop_body, |ctx| { if let Some(stmt_meta) = self.parse_statement(frontend, ctx, &mut None, true)? { meta.subsume(stmt_meta); } Ok(()) })?; ctx.body.push( Statement::Loop { body: loop_body, continuing, break_if: None, }, meta, ); ctx.symbol_table.pop_scope(); meta } TokenValue::LeftBrace => { let mut meta = self.bump(frontend)?.meta; let mut block_terminator = None; let block = ctx.new_body(|ctx| { let block_meta = self.parse_compound_statement( meta, frontend, ctx, &mut block_terminator, is_inside_loop, )?; meta.subsume(block_meta); Ok(()) })?; ctx.body.push(Statement::Block(block), meta); if block_terminator.is_some() { terminator.get_or_insert(ctx.body.len()); } meta } TokenValue::Semicolon => self.bump(frontend)?.meta, _ => { // Attempt to force expression parsing for remainder of the // tokens. Unknown or invalid tokens will be caught there and // turned into an error. let mut stmt = ctx.stmt_ctx(); let expr = self.parse_expression(frontend, ctx, &mut stmt)?; ctx.lower(stmt, frontend, expr, ExprPos::Rhs)?; self.expect(frontend, TokenValue::Semicolon)?.meta } }; meta.subsume(meta_rest); Ok(Some(meta)) } pub fn parse_compound_statement( &mut self, mut meta: Span, frontend: &mut Frontend, ctx: &mut Context, terminator: &mut Option, is_inside_loop: bool, ) -> Result { ctx.symbol_table.push_scope(); loop { if let Some(Token { meta: brace_meta, .. }) = self.bump_if(frontend, TokenValue::RightBrace) { meta.subsume(brace_meta); break; } let stmt = self.parse_statement(frontend, ctx, terminator, is_inside_loop)?; if let Some(stmt_meta) = stmt { meta.subsume(stmt_meta); } } if let Some(idx) = *terminator { ctx.body.cull(idx..) } ctx.symbol_table.pop_scope(); Ok(meta) } pub fn parse_function_args( &mut self, frontend: &mut Frontend, ctx: &mut Context, ) -> Result<()> { if self.bump_if(frontend, TokenValue::Void).is_some() { return Ok(()); } loop { if self.peek_type_name(frontend) || self.peek_parameter_qualifier(frontend) { let qualifier = self.parse_parameter_qualifier(frontend); let mut ty = self.parse_type_non_void(frontend, ctx)?.0; match self.expect_peek(frontend)?.value { TokenValue::Comma => { self.bump(frontend)?; ctx.add_function_arg(None, ty, qualifier)?; continue; } TokenValue::Identifier(_) => { let mut name = self.expect_ident(frontend)?; self.parse_array_specifier(frontend, ctx, &mut name.1, &mut ty)?; ctx.add_function_arg(Some(name), ty, qualifier)?; if self.bump_if(frontend, TokenValue::Comma).is_some() { continue; } break; } _ => break, } } break; } Ok(()) } } naga-29.0.3/src/front/glsl/parser/types.rs000064400000000000000000000411251046102023000165160ustar 00000000000000use alloc::{vec, vec::Vec}; use core::num::NonZeroU32; use crate::{ front::glsl::{ ast::{QualifierKey, QualifierValue, StorageQualifier, StructLayout, TypeQualifiers}, context::Context, error::ExpectedToken, parser::ParsingContext, token::{Token, TokenValue}, Error, ErrorKind, Frontend, Result, }, AddressSpace, ArraySize, Handle, Span, Type, TypeInner, }; impl ParsingContext<'_> { /// Parses an optional array_specifier returning whether or not it's present /// and modifying the type handle if it exists pub fn parse_array_specifier( &mut self, frontend: &mut Frontend, ctx: &mut Context, span: &mut Span, ty: &mut Handle, ) -> Result<()> { while self.parse_array_specifier_single(frontend, ctx, span, ty)? {} Ok(()) } /// Implementation of [`Self::parse_array_specifier`] for a single array_specifier fn parse_array_specifier_single( &mut self, frontend: &mut Frontend, ctx: &mut Context, span: &mut Span, ty: &mut Handle, ) -> Result { if self.bump_if(frontend, TokenValue::LeftBracket).is_some() { let size = if let Some(Token { meta, .. }) = self.bump_if(frontend, TokenValue::RightBracket) { span.subsume(meta); ArraySize::Dynamic } else { let (value, constant_span) = self.parse_uint_constant(frontend, ctx)?; let size = NonZeroU32::new(value).ok_or(Error { kind: ErrorKind::SemanticError("Array size must be greater than zero".into()), meta: constant_span, })?; let end_span = self.expect(frontend, TokenValue::RightBracket)?.meta; span.subsume(end_span); ArraySize::Constant(size) }; frontend.layouter.update(ctx.module.to_ctx()).unwrap(); let stride = frontend.layouter[*ty].to_stride(); *ty = ctx.module.types.insert( Type { name: None, inner: TypeInner::Array { base: *ty, size, stride, }, }, *span, ); Ok(true) } else { Ok(false) } } pub fn parse_type( &mut self, frontend: &mut Frontend, ctx: &mut Context, ) -> Result<(Option>, Span)> { let token = self.bump(frontend)?; let mut handle = match token.value { TokenValue::Void => return Ok((None, token.meta)), TokenValue::TypeName(ty) => ctx.module.types.insert(ty, token.meta), TokenValue::Struct => { let mut meta = token.meta; let ty_name = self.expect_ident(frontend)?.0; self.expect(frontend, TokenValue::LeftBrace)?; let mut members = Vec::new(); let span = self.parse_struct_declaration_list( frontend, ctx, &mut members, StructLayout::Std140, )?; let end_meta = self.expect(frontend, TokenValue::RightBrace)?.meta; meta.subsume(end_meta); let ty = ctx.module.types.insert( Type { name: Some(ty_name.clone()), inner: TypeInner::Struct { members, span }, }, meta, ); frontend.lookup_type.insert(ty_name, ty); ty } TokenValue::Identifier(ident) => match frontend.lookup_type.get(&ident) { Some(ty) => *ty, None => { return Err(Error { kind: ErrorKind::UnknownType(ident), meta: token.meta, }) } }, _ => { return Err(Error { kind: ErrorKind::InvalidToken( token.value, vec![ TokenValue::Void.into(), TokenValue::Struct.into(), ExpectedToken::TypeName, ], ), meta: token.meta, }); } }; let mut span = token.meta; self.parse_array_specifier(frontend, ctx, &mut span, &mut handle)?; Ok((Some(handle), span)) } pub fn parse_type_non_void( &mut self, frontend: &mut Frontend, ctx: &mut Context, ) -> Result<(Handle, Span)> { let (maybe_ty, meta) = self.parse_type(frontend, ctx)?; let ty = maybe_ty.ok_or_else(|| Error { kind: ErrorKind::SemanticError("Type can't be void".into()), meta, })?; Ok((ty, meta)) } pub fn peek_type_qualifier(&mut self, frontend: &mut Frontend) -> bool { self.peek(frontend).is_some_and(|t| match t.value { TokenValue::Invariant | TokenValue::Interpolation(_) | TokenValue::Sampling(_) | TokenValue::PrecisionQualifier(_) | TokenValue::Const | TokenValue::In | TokenValue::Out | TokenValue::Uniform | TokenValue::Shared | TokenValue::Buffer | TokenValue::Restrict | TokenValue::MemoryQualifier(_) | TokenValue::Layout => true, _ => false, }) } pub fn parse_type_qualifiers<'a>( &mut self, frontend: &mut Frontend, ctx: &mut Context, ) -> Result> { let mut qualifiers = TypeQualifiers::default(); while self.peek_type_qualifier(frontend) { let token = self.bump(frontend)?; // Handle layout qualifiers outside the match since this can push multiple values if token.value == TokenValue::Layout { self.parse_layout_qualifier_id_list(frontend, ctx, &mut qualifiers)?; continue; } qualifiers.span.subsume(token.meta); match token.value { TokenValue::Invariant => { if qualifiers.invariant.is_some() { frontend.errors.push(Error { kind: ErrorKind::SemanticError( "Cannot use more than one invariant qualifier per declaration" .into(), ), meta: token.meta, }) } qualifiers.invariant = Some(token.meta); } TokenValue::Interpolation(i) => { if qualifiers.interpolation.is_some() { frontend.errors.push(Error { kind: ErrorKind::SemanticError( "Cannot use more than one interpolation qualifier per declaration" .into(), ), meta: token.meta, }) } qualifiers.interpolation = Some((i, token.meta)); } TokenValue::Const | TokenValue::In | TokenValue::Out | TokenValue::Uniform | TokenValue::Shared | TokenValue::Buffer => { let storage = match token.value { TokenValue::Const => StorageQualifier::Const, TokenValue::In => StorageQualifier::Input, TokenValue::Out => StorageQualifier::Output, TokenValue::Uniform => { StorageQualifier::AddressSpace(AddressSpace::Uniform) } TokenValue::Shared => { StorageQualifier::AddressSpace(AddressSpace::WorkGroup) } TokenValue::Buffer => { StorageQualifier::AddressSpace(AddressSpace::Storage { access: crate::StorageAccess::LOAD | crate::StorageAccess::STORE, }) } _ => unreachable!(), }; if StorageQualifier::AddressSpace(AddressSpace::Function) != qualifiers.storage.0 { frontend.errors.push(Error { kind: ErrorKind::SemanticError( "Cannot use more than one storage qualifier per declaration".into(), ), meta: token.meta, }); } qualifiers.storage = (storage, token.meta); } TokenValue::Sampling(s) => { if qualifiers.sampling.is_some() { frontend.errors.push(Error { kind: ErrorKind::SemanticError( "Cannot use more than one sampling qualifier per declaration" .into(), ), meta: token.meta, }) } qualifiers.sampling = Some((s, token.meta)); } TokenValue::PrecisionQualifier(p) => { if qualifiers.precision.is_some() { frontend.errors.push(Error { kind: ErrorKind::SemanticError( "Cannot use more than one precision qualifier per declaration" .into(), ), meta: token.meta, }) } qualifiers.precision = Some((p, token.meta)); } TokenValue::MemoryQualifier(access) => { let load_store = crate::StorageAccess::LOAD | crate::StorageAccess::STORE; let storage_access = qualifiers .storage_access .get_or_insert((load_store, Span::default())); if !storage_access.0.contains(!access & load_store) { frontend.errors.push(Error { kind: ErrorKind::SemanticError( "The same memory qualifier can only be used once".into(), ), meta: token.meta, }) } storage_access.0 &= access; storage_access.1.subsume(token.meta); } TokenValue::Restrict => continue, _ => unreachable!(), }; } Ok(qualifiers) } pub fn parse_layout_qualifier_id_list( &mut self, frontend: &mut Frontend, ctx: &mut Context, qualifiers: &mut TypeQualifiers, ) -> Result<()> { self.expect(frontend, TokenValue::LeftParen)?; loop { self.parse_layout_qualifier_id(frontend, ctx, &mut qualifiers.layout_qualifiers)?; if self.bump_if(frontend, TokenValue::Comma).is_some() { continue; } break; } let token = self.expect(frontend, TokenValue::RightParen)?; qualifiers.span.subsume(token.meta); Ok(()) } pub fn parse_layout_qualifier_id( &mut self, frontend: &mut Frontend, ctx: &mut Context, qualifiers: &mut crate::FastHashMap, ) -> Result<()> { // layout_qualifier_id: // IDENTIFIER // IDENTIFIER EQUAL constant_expression // SHARED let mut token = self.bump(frontend)?; match token.value { TokenValue::Identifier(name) => { let (key, value) = match name.as_str() { "std140" => ( QualifierKey::Layout, QualifierValue::Layout(StructLayout::Std140), ), "std430" => ( QualifierKey::Layout, QualifierValue::Layout(StructLayout::Std430), ), "index" => { self.expect(frontend, TokenValue::Assign)?; let (value, end_meta) = self.parse_uint_constant(frontend, ctx)?; token.meta.subsume(end_meta); (QualifierKey::Index, QualifierValue::Uint(value)) } word => { if let Some(format) = map_image_format(word) { (QualifierKey::Format, QualifierValue::Format(format)) } else { let key = QualifierKey::String(name.into()); let value = if self.bump_if(frontend, TokenValue::Assign).is_some() { let (value, end_meta) = match self.parse_uint_constant(frontend, ctx) { Ok(v) => v, Err(e) => { frontend.errors.push(e); (0, Span::default()) } }; token.meta.subsume(end_meta); QualifierValue::Uint(value) } else { QualifierValue::None }; (key, value) } } }; qualifiers.insert(key, (value, token.meta)); } _ => frontend.errors.push(Error { kind: ErrorKind::InvalidToken(token.value, vec![ExpectedToken::Identifier]), meta: token.meta, }), } Ok(()) } pub fn peek_type_name(&mut self, frontend: &mut Frontend) -> bool { self.peek(frontend).is_some_and(|t| match t.value { TokenValue::TypeName(_) | TokenValue::Void => true, TokenValue::Struct => true, TokenValue::Identifier(ref ident) => frontend.lookup_type.contains_key(ident), _ => false, }) } } fn map_image_format(word: &str) -> Option { use crate::StorageFormat as Sf; let format = match word { // float-image-format-qualifier: "rgba32f" => Sf::Rgba32Float, "rgba16f" => Sf::Rgba16Float, "rg32f" => Sf::Rg32Float, "rg16f" => Sf::Rg16Float, "r11f_g11f_b10f" => Sf::Rg11b10Ufloat, "r32f" => Sf::R32Float, "r16f" => Sf::R16Float, "rgba16" => Sf::Rgba16Unorm, "rgb10_a2ui" => Sf::Rgb10a2Uint, "rgb10_a2" => Sf::Rgb10a2Unorm, "rgba8" => Sf::Rgba8Unorm, "rg16" => Sf::Rg16Unorm, "rg8" => Sf::Rg8Unorm, "r16" => Sf::R16Unorm, "r8" => Sf::R8Unorm, "rgba16_snorm" => Sf::Rgba16Snorm, "rgba8_snorm" => Sf::Rgba8Snorm, "rg16_snorm" => Sf::Rg16Snorm, "rg8_snorm" => Sf::Rg8Snorm, "r16_snorm" => Sf::R16Snorm, "r8_snorm" => Sf::R8Snorm, // int-image-format-qualifier: "rgba32i" => Sf::Rgba32Sint, "rgba16i" => Sf::Rgba16Sint, "rgba8i" => Sf::Rgba8Sint, "rg32i" => Sf::Rg32Sint, "rg16i" => Sf::Rg16Sint, "rg8i" => Sf::Rg8Sint, "r32i" => Sf::R32Sint, "r16i" => Sf::R16Sint, "r8i" => Sf::R8Sint, // uint-image-format-qualifier: "rgba32ui" => Sf::Rgba32Uint, "rgba16ui" => Sf::Rgba16Uint, "rgba8ui" => Sf::Rgba8Uint, "r64ui" => Sf::R64Uint, "rg32ui" => Sf::Rg32Uint, "rg16ui" => Sf::Rg16Uint, "rg8ui" => Sf::Rg8Uint, "r32ui" => Sf::R32Uint, "r16ui" => Sf::R16Uint, "r8ui" => Sf::R8Uint, // TODO: These next ones seem incorrect to me // "rgb10_a2ui" => Sf::Rgb10a2Unorm, _ => return None, }; Some(format) } naga-29.0.3/src/front/glsl/parser.rs000064400000000000000000000375361046102023000153650ustar 00000000000000use alloc::{string::String, vec}; use core::iter::Peekable; use pp_rs::token::{PreprocessorError, Token as PPToken, TokenValue as PPTokenValue}; use super::{ ast::{FunctionKind, Profile, TypeQualifiers}, context::{Context, ExprPos}, error::ExpectedToken, error::{Error, ErrorKind}, lex::{Lexer, LexerResultKind}, token::{Directive, DirectiveKind}, token::{Token, TokenValue}, variables::{GlobalOrConstant, VarDeclaration}, Frontend, Result, }; use crate::{arena::Handle, proc::ConstValueError, Expression, Module, Span, Type}; mod declarations; mod expressions; mod functions; mod types; pub struct ParsingContext<'source> { lexer: Peekable>, /// Used to store tokens already consumed by the parser but that need to be backtracked backtracked_token: Option, last_meta: Span, } impl<'source> ParsingContext<'source> { pub fn new(lexer: Lexer<'source>) -> Self { ParsingContext { lexer: lexer.peekable(), backtracked_token: None, last_meta: Span::default(), } } /// Helper method for backtracking from a consumed token /// /// This method should always be used instead of assigning to `backtracked_token` since /// it validates that backtracking hasn't occurred more than one time in a row /// /// # Panics /// - If the parser already backtracked without bumping in between pub fn backtrack(&mut self, token: Token) -> Result<()> { // This should never happen if let Some(ref prev_token) = self.backtracked_token { return Err(Error { kind: ErrorKind::InternalError("The parser tried to backtrack twice in a row"), meta: prev_token.meta, }); } self.backtracked_token = Some(token); Ok(()) } pub fn expect_ident(&mut self, frontend: &mut Frontend) -> Result<(String, Span)> { let token = self.bump(frontend)?; match token.value { TokenValue::Identifier(name) => Ok((name, token.meta)), _ => Err(Error { kind: ErrorKind::InvalidToken(token.value, vec![ExpectedToken::Identifier]), meta: token.meta, }), } } pub fn expect(&mut self, frontend: &mut Frontend, value: TokenValue) -> Result { let token = self.bump(frontend)?; if token.value != value { Err(Error { kind: ErrorKind::InvalidToken(token.value, vec![value.into()]), meta: token.meta, }) } else { Ok(token) } } pub fn next(&mut self, frontend: &mut Frontend) -> Option { loop { if let Some(token) = self.backtracked_token.take() { self.last_meta = token.meta; break Some(token); } let res = self.lexer.next()?; match res.kind { LexerResultKind::Token(token) => { self.last_meta = token.meta; break Some(token); } LexerResultKind::Directive(directive) => { frontend.handle_directive(directive, res.meta) } LexerResultKind::Error(error) => frontend.errors.push(Error { kind: ErrorKind::PreprocessorError(error), meta: res.meta, }), } } } pub fn bump(&mut self, frontend: &mut Frontend) -> Result { self.next(frontend).ok_or(Error { kind: ErrorKind::EndOfFile, meta: self.last_meta, }) } /// Returns None on the end of the file rather than an error like other methods pub fn bump_if(&mut self, frontend: &mut Frontend, value: TokenValue) -> Option { if self.peek(frontend).filter(|t| t.value == value).is_some() { self.bump(frontend).ok() } else { None } } pub fn peek(&mut self, frontend: &mut Frontend) -> Option<&Token> { loop { if let Some(ref token) = self.backtracked_token { break Some(token); } match self.lexer.peek()?.kind { LexerResultKind::Token(_) => { let res = self.lexer.peek()?; match res.kind { LexerResultKind::Token(ref token) => break Some(token), _ => unreachable!(), } } LexerResultKind::Error(_) | LexerResultKind::Directive(_) => { let res = self.lexer.next()?; match res.kind { LexerResultKind::Directive(directive) => { frontend.handle_directive(directive, res.meta) } LexerResultKind::Error(error) => frontend.errors.push(Error { kind: ErrorKind::PreprocessorError(error), meta: res.meta, }), LexerResultKind::Token(_) => unreachable!(), } } } } } pub fn expect_peek(&mut self, frontend: &mut Frontend) -> Result<&Token> { let meta = self.last_meta; self.peek(frontend).ok_or(Error { kind: ErrorKind::EndOfFile, meta, }) } pub fn parse(&mut self, frontend: &mut Frontend) -> Result { let mut module = Module::default(); let mut global_expression_kind_tracker = crate::proc::ExpressionKindTracker::new(); // Body and expression arena for global initialization let mut ctx = Context::new( frontend, &mut module, false, &mut global_expression_kind_tracker, )?; while self.peek(frontend).is_some() { self.parse_external_declaration(frontend, &mut ctx)?; } // Add an `EntryPoint` to `parser.module` for `main`, if a // suitable overload exists. Error out if we can't find one. if let Some(declaration) = frontend.lookup_function.get("main") { for decl in declaration.overloads.iter() { if let FunctionKind::Call(handle) = decl.kind { if decl.defined && decl.parameters.is_empty() { frontend.add_entry_point(handle, ctx)?; return Ok(module); } } } } Err(Error { kind: ErrorKind::SemanticError("Missing entry point".into()), meta: Span::default(), }) } fn parse_uint_constant( &mut self, frontend: &mut Frontend, ctx: &mut Context, ) -> Result<(u32, Span)> { let (const_expr, meta) = self.parse_constant_expression( frontend, ctx.module, ctx.global_expression_kind_tracker, )?; let res = ctx.module.to_ctx().get_const_val(const_expr); let int = match res { Ok(value) => Ok(value), Err(ConstValueError::Negative) => Err(Error { kind: ErrorKind::SemanticError("int constant overflows".into()), meta, }), Err(ConstValueError::NonConst | ConstValueError::InvalidType) => Err(Error { kind: ErrorKind::SemanticError("Expected a uint constant".into()), meta, }), }?; Ok((int, meta)) } fn parse_constant_expression( &mut self, frontend: &mut Frontend, module: &mut Module, global_expression_kind_tracker: &mut crate::proc::ExpressionKindTracker, ) -> Result<(Handle, Span)> { let mut ctx = Context::new(frontend, module, true, global_expression_kind_tracker)?; let mut stmt_ctx = ctx.stmt_ctx(); let expr = self.parse_conditional(frontend, &mut ctx, &mut stmt_ctx, None)?; let (root, meta) = ctx.lower_expect(stmt_ctx, frontend, expr, ExprPos::Rhs)?; Ok((root, meta)) } } impl Frontend { fn handle_directive(&mut self, directive: Directive, meta: Span) { let mut tokens = directive.tokens.into_iter(); match directive.kind { DirectiveKind::Version { is_first_directive } => { if !is_first_directive { self.errors.push(Error { kind: ErrorKind::SemanticError( "#version must occur first in shader".into(), ), meta, }) } match tokens.next() { Some(PPToken { value: PPTokenValue::Integer(int), location, }) => match int.value { 440 | 450 | 460 => self.meta.version = int.value as u16, _ => self.errors.push(Error { kind: ErrorKind::InvalidVersion(int.value), meta: location.into(), }), }, Some(PPToken { value, location }) => self.errors.push(Error { kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedToken( value, )), meta: location.into(), }), None => self.errors.push(Error { kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedNewLine), meta, }), }; match tokens.next() { Some(PPToken { value: PPTokenValue::Ident(name), location, }) => match name.as_str() { "core" => self.meta.profile = Profile::Core, _ => self.errors.push(Error { kind: ErrorKind::InvalidProfile(name), meta: location.into(), }), }, Some(PPToken { value, location }) => self.errors.push(Error { kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedToken( value, )), meta: location.into(), }), None => {} }; if let Some(PPToken { value, location }) = tokens.next() { self.errors.push(Error { kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedToken( value, )), meta: location.into(), }) } } DirectiveKind::Extension => { // TODO: Proper extension handling // - Checking for extension support in the compiler // - Handle behaviors such as warn // - Handle the all extension let name = match tokens.next() { Some(PPToken { value: PPTokenValue::Ident(name), .. }) => Some(name), Some(PPToken { value, location }) => { self.errors.push(Error { kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedToken( value, )), meta: location.into(), }); None } None => { self.errors.push(Error { kind: ErrorKind::PreprocessorError( PreprocessorError::UnexpectedNewLine, ), meta, }); None } }; match tokens.next() { Some(PPToken { value: PPTokenValue::Punct(pp_rs::token::Punct::Colon), .. }) => {} Some(PPToken { value, location }) => self.errors.push(Error { kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedToken( value, )), meta: location.into(), }), None => self.errors.push(Error { kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedNewLine), meta, }), }; match tokens.next() { Some(PPToken { value: PPTokenValue::Ident(behavior), location, }) => match behavior.as_str() { "require" | "enable" | "warn" | "disable" => { if let Some(name) = name { self.meta.extensions.insert(name); } } _ => self.errors.push(Error { kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedToken( PPTokenValue::Ident(behavior), )), meta: location.into(), }), }, Some(PPToken { value, location }) => self.errors.push(Error { kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedToken( value, )), meta: location.into(), }), None => self.errors.push(Error { kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedNewLine), meta, }), } if let Some(PPToken { value, location }) = tokens.next() { self.errors.push(Error { kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedToken( value, )), meta: location.into(), }) } } DirectiveKind::Pragma => { // TODO: handle some common pragmas? } } } } pub struct DeclarationContext<'ctx, 'qualifiers, 'a> { qualifiers: TypeQualifiers<'qualifiers>, /// Indicates a global declaration external: bool, is_inside_loop: bool, ctx: &'ctx mut Context<'a>, } impl DeclarationContext<'_, '_, '_> { fn add_var( &mut self, frontend: &mut Frontend, ty: Handle, name: String, init: Option>, meta: Span, ) -> Result> { let decl = VarDeclaration { qualifiers: &mut self.qualifiers, ty, name: Some(name), init, meta, }; match self.external { true => { let global = frontend.add_global_var(self.ctx, decl)?; let expr = match global { GlobalOrConstant::Global(handle) => Expression::GlobalVariable(handle), GlobalOrConstant::Constant(handle) => Expression::Constant(handle), GlobalOrConstant::Override(handle) => Expression::Override(handle), }; Ok(self.ctx.add_expression(expr, meta)?) } false => frontend.add_local_var(self.ctx, decl), } } } naga-29.0.3/src/front/glsl/parser_tests.rs000064400000000000000000000441741046102023000166030ustar 00000000000000use alloc::{borrow::ToOwned, vec}; use pp_rs::token::PreprocessorError; use super::{ ast::Profile, error::ExpectedToken, error::{Error, ErrorKind, ParseErrors}, token::TokenValue, Frontend, Options, Span, }; use crate::ShaderStage; #[cfg(test)] use std::println; #[test] fn version() { let mut frontend = Frontend::default(); // invalid versions assert_eq!( frontend .parse( &Options::from(ShaderStage::Vertex), "#version 99000\n void main(){}", ) .err() .unwrap(), ParseErrors { errors: vec![Error { kind: ErrorKind::InvalidVersion(99000), meta: Span::new(9, 14) }], }, ); assert_eq!( frontend .parse( &Options::from(ShaderStage::Vertex), "#version 449\n void main(){}", ) .err() .unwrap(), ParseErrors { errors: vec![Error { kind: ErrorKind::InvalidVersion(449), meta: Span::new(9, 12) }] }, ); assert_eq!( frontend .parse( &Options::from(ShaderStage::Vertex), "#version 450 smart\n void main(){}", ) .err() .unwrap(), ParseErrors { errors: vec![Error { kind: ErrorKind::InvalidProfile("smart".into()), meta: Span::new(13, 18), }] }, ); assert_eq!( frontend .parse( &Options::from(ShaderStage::Vertex), "#version 450\nvoid main(){} #version 450", ) .err() .unwrap(), ParseErrors { errors: vec![ Error { kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedHash,), meta: Span::new(27, 28), }, Error { kind: ErrorKind::InvalidToken( TokenValue::Identifier("version".into()), vec![ExpectedToken::Eof] ), meta: Span::new(28, 35) } ] }, ); // valid versions frontend .parse( &Options::from(ShaderStage::Vertex), " # version 450\nvoid main() {}", ) .unwrap(); assert_eq!( (frontend.metadata().version, frontend.metadata().profile), (450, Profile::Core) ); frontend .parse( &Options::from(ShaderStage::Vertex), "#version 450\nvoid main() {}", ) .unwrap(); assert_eq!( (frontend.metadata().version, frontend.metadata().profile), (450, Profile::Core) ); frontend .parse( &Options::from(ShaderStage::Vertex), "#version 450 core\nvoid main(void) {}", ) .unwrap(); assert_eq!( (frontend.metadata().version, frontend.metadata().profile), (450, Profile::Core) ); } #[test] fn control_flow() { let mut frontend = Frontend::default(); frontend .parse( &Options::from(ShaderStage::Vertex), r#" # version 450 void main() { if (true) { return 1; } else { return 2; } } "#, ) .unwrap(); frontend .parse( &Options::from(ShaderStage::Vertex), r#" # version 450 void main() { if (true) { return 1; } } "#, ) .unwrap(); frontend .parse( &Options::from(ShaderStage::Vertex), r#" # version 450 void main() { int x; int y = 3; switch (5) { case 2: x = 2; case 5: x = 5; y = 2; break; default: x = 0; } } "#, ) .unwrap(); frontend .parse( &Options::from(ShaderStage::Vertex), r#" # version 450 void main() { int x = 0; while(x < 5) { x = x + 1; } do { x = x - 1; } while(x >= 4) } "#, ) .unwrap(); frontend .parse( &Options::from(ShaderStage::Vertex), r#" # version 450 void main() { int x = 0; for(int i = 0; i < 10;) { x = x + 2; } for(;;); return x; } "#, ) .unwrap(); } #[test] fn declarations() { let mut frontend = Frontend::default(); frontend .parse( &Options::from(ShaderStage::Vertex), r#" #version 450 layout(location = 0) in vec2 v_uv; layout(location = 0) out vec4 o_color; layout(set = 1, binding = 1) uniform texture2D tex; layout(set = 1, binding = 2) uniform sampler tex_sampler; layout(early_fragment_tests) in; void main() {} "#, ) .unwrap(); frontend .parse( &Options::from(ShaderStage::Vertex), r#" #version 450 layout(std140, set = 2, binding = 0) uniform u_locals { vec3 model_offs; float load_time; ivec4 atlas_offs; }; void main() {} "#, ) .unwrap(); frontend .parse( &Options::from(ShaderStage::Vertex), r#" #version 450 layout(push_constant) uniform u_locals { vec3 model_offs; float load_time; ivec4 atlas_offs; }; void main() {} "#, ) .unwrap(); frontend .parse( &Options::from(ShaderStage::Vertex), r#" #version 450 layout(std430, set = 2, binding = 0) uniform u_locals { vec3 model_offs; float load_time; ivec4 atlas_offs; }; void main() {} "#, ) .unwrap(); frontend .parse( &Options::from(ShaderStage::Vertex), r#" #version 450 layout(std140, set = 2, binding = 0) uniform u_locals { vec3 model_offs; float load_time; } block_var; void main() { load_time * model_offs; block_var.load_time * block_var.model_offs; } "#, ) .unwrap(); frontend .parse( &Options::from(ShaderStage::Vertex), r#" #version 450 float vector = vec4(1.0 / 17.0, 9.0 / 17.0, 3.0 / 17.0, 11.0 / 17.0); void main() {} "#, ) .unwrap(); frontend .parse( &Options::from(ShaderStage::Vertex), r#" #version 450 precision highp float; void main() {} "#, ) .unwrap(); } #[test] fn textures() { let mut frontend = Frontend::default(); frontend .parse( &Options::from(ShaderStage::Vertex), r#" #version 450 layout(location = 0) in vec2 v_uv; layout(location = 0) out vec4 o_color; layout(set = 1, binding = 1) uniform texture2D tex; layout(set = 1, binding = 2) uniform sampler tex_sampler; void main() { o_color = texture(sampler2D(tex, tex_sampler), v_uv); o_color.a = texture(sampler2D(tex, tex_sampler), v_uv, 2.0).a; } "#, ) .unwrap(); } #[test] fn functions() { let mut frontend = Frontend::default(); frontend .parse( &Options::from(ShaderStage::Vertex), r#" # version 450 void test1(float); void test1(float) {} void main() {} "#, ) .unwrap(); frontend .parse( &Options::from(ShaderStage::Vertex), r#" # version 450 void test2(float a) {} void test3(float a, float b) {} void test4(float, float) {} void main() {} "#, ) .unwrap(); frontend .parse( &Options::from(ShaderStage::Vertex), r#" # version 450 float test(float a) { return a; } void main() {} "#, ) .unwrap(); frontend .parse( &Options::from(ShaderStage::Vertex), r#" # version 450 float test(vec4 p) { return p.x; } void main() {} "#, ) .unwrap(); // Function overloading frontend .parse( &Options::from(ShaderStage::Vertex), r#" # version 450 float test(vec2 p); float test(vec3 p); float test(vec4 p); float test(vec2 p) { return p.x; } float test(vec3 p) { return p.x; } float test(vec4 p) { return p.x; } void main() {} "#, ) .unwrap(); assert_eq!( frontend .parse( &Options::from(ShaderStage::Vertex), r#" # version 450 int test(vec4 p) { return p.x; } float test(vec4 p) { return p.x; } void main() {} "#, ) .err() .unwrap(), ParseErrors { errors: vec![Error { kind: ErrorKind::SemanticError("Function already defined".into()), meta: Span::new(134, 152), }] }, ); println!(); frontend .parse( &Options::from(ShaderStage::Vertex), r#" # version 450 float callee(uint q) { return float(q); } float caller() { callee(1u); } void main() {} "#, ) .unwrap(); // Nested function call frontend .parse( &Options::from(ShaderStage::Vertex), r#" # version 450 layout(set = 0, binding = 1) uniform texture2D t_noise; layout(set = 0, binding = 2) uniform sampler s_noise; void main() { textureLod(sampler2D(t_noise, s_noise), vec2(1.0), 0); } "#, ) .unwrap(); frontend .parse( &Options::from(ShaderStage::Vertex), r#" # version 450 void fun(vec2 in_parameter, out float out_parameter) { ivec2 _ = ivec2(in_parameter); } void main() { float a; fun(vec2(1.0), a); } "#, ) .unwrap(); } #[test] fn constants() { use crate::{Constant, Expression, Type, TypeInner}; let mut frontend = Frontend::default(); let module = frontend .parse( &Options::from(ShaderStage::Vertex), r#" # version 450 const float a = 1.0; float global = a; const float b = a; void main() {} "#, ) .unwrap(); let mut types = module.types.iter(); let mut constants = module.constants.iter(); let mut global_expressions = module.global_expressions.iter(); let (ty_handle, ty) = types.next().unwrap(); assert_eq!( ty, &Type { name: None, inner: TypeInner::Scalar(crate::Scalar::F32) } ); let (init_handle, init) = global_expressions.next().unwrap(); assert_eq!(init, &Expression::Literal(crate::Literal::F32(1.0))); assert_eq!( constants.next().unwrap().1, &Constant { name: Some("a".to_owned()), ty: ty_handle, init: init_handle } ); assert_eq!( constants.next().unwrap().1, &Constant { name: Some("b".to_owned()), ty: ty_handle, init: init_handle } ); assert!(constants.next().is_none()); } #[test] fn function_overloading() { let mut frontend = Frontend::default(); frontend .parse( &Options::from(ShaderStage::Vertex), r#" # version 450 float saturate(float v) { return clamp(v, 0.0, 1.0); } vec2 saturate(vec2 v) { return clamp(v, vec2(0.0), vec2(1.0)); } vec3 saturate(vec3 v) { return clamp(v, vec3(0.0), vec3(1.0)); } vec4 saturate(vec4 v) { return clamp(v, vec4(0.0), vec4(1.0)); } void main() { float v1 = saturate(1.5); vec2 v2 = saturate(vec2(0.5, 1.5)); vec3 v3 = saturate(vec3(0.5, 1.5, 2.5)); vec3 v4 = saturate(vec4(0.5, 1.5, 2.5, 3.5)); } "#, ) .unwrap(); } #[test] fn implicit_conversions() { let mut frontend = Frontend::default(); frontend .parse( &Options::from(ShaderStage::Vertex), r#" # version 450 void main() { mat4 a = mat4(1); float b = 1u; float c = 1 + 2.0; } "#, ) .unwrap(); assert_eq!( frontend .parse( &Options::from(ShaderStage::Vertex), r#" # version 450 void test(int a) {} void test(uint a) {} void main() { test(1.0); } "#, ) .err() .unwrap(), ParseErrors { errors: vec![Error { kind: ErrorKind::SemanticError("Unknown function \'test\'".into()), meta: Span::new(156, 165), }] }, ); assert_eq!( frontend .parse( &Options::from(ShaderStage::Vertex), r#" # version 450 void test(float a) {} void test(uint a) {} void main() { test(1); } "#, ) .err() .unwrap(), ParseErrors { errors: vec![Error { kind: ErrorKind::SemanticError("Ambiguous best function for \'test\'".into()), meta: Span::new(158, 165), }] } ); } #[test] fn structs() { let mut frontend = Frontend::default(); frontend .parse( &Options::from(ShaderStage::Vertex), r#" # version 450 Test { vec4 pos; } xx; void main() {} "#, ) .unwrap_err(); frontend .parse( &Options::from(ShaderStage::Vertex), r#" # version 450 struct Test { vec4 pos; }; void main() {} "#, ) .unwrap(); frontend .parse( &Options::from(ShaderStage::Vertex), r#" # version 450 const int NUM_VECS = 42; struct Test { vec4 vecs[NUM_VECS]; }; void main() {} "#, ) .unwrap(); frontend .parse( &Options::from(ShaderStage::Vertex), r#" # version 450 struct Hello { vec4 test; } test() { return Hello( vec4(1.0) ); } void main() {} "#, ) .unwrap(); frontend .parse( &Options::from(ShaderStage::Vertex), r#" # version 450 struct Test {}; void main() {} "#, ) .unwrap_err(); frontend .parse( &Options::from(ShaderStage::Vertex), r#" # version 450 inout struct Test { vec4 x; }; void main() {} "#, ) .unwrap_err(); } #[test] fn swizzles() { let mut frontend = Frontend::default(); frontend .parse( &Options::from(ShaderStage::Vertex), r#" # version 450 void main() { vec4 v = vec4(1); v.xyz = vec3(2); v.x = 5.0; v.xyz.zxy.yx.xy = vec2(5.0, 1.0); } "#, ) .unwrap(); frontend .parse( &Options::from(ShaderStage::Vertex), r#" # version 450 void main() { vec4 v = vec4(1); v.xx = vec2(5.0); } "#, ) .unwrap_err(); frontend .parse( &Options::from(ShaderStage::Vertex), r#" # version 450 void main() { vec3 v = vec3(1); v.w = 2.0; } "#, ) .unwrap_err(); } #[test] fn expressions() { let mut frontend = Frontend::default(); // Vector indexing frontend .parse( &Options::from(ShaderStage::Vertex), r#" # version 450 float test(int index) { vec4 v = vec4(1.0, 2.0, 3.0, 4.0); return v[index] + 1.0; } void main() {} "#, ) .unwrap(); // Prefix increment/decrement frontend .parse( &Options::from(ShaderStage::Vertex), r#" # version 450 void main() { uint index = 0; --index; ++index; } "#, ) .unwrap(); // Dynamic indexing of array frontend .parse( &Options::from(ShaderStage::Vertex), r#" # version 450 void main() { const vec4 positions[1] = { vec4(0) }; gl_Position = positions[gl_VertexIndex]; } "#, ) .unwrap(); } naga-29.0.3/src/front/glsl/token.rs000064400000000000000000000045131046102023000151760ustar 00000000000000pub use pp_rs::token::{Float, Integer, Location, Token as PPToken}; use alloc::{string::String, vec::Vec}; use super::ast::Precision; use crate::{Interpolation, Sampling, Span, Type}; impl From for Span { fn from(loc: Location) -> Self { Span::new(loc.start, loc.end) } } #[derive(Debug)] #[cfg_attr(test, derive(PartialEq))] pub struct Token { pub value: TokenValue, pub meta: Span, } /// A token passed from the lexing used in the parsing. /// /// This type is exported since it's returned in the /// [`InvalidToken`](super::ErrorKind::InvalidToken) error. #[derive(Clone, Debug, PartialEq)] pub enum TokenValue { Identifier(String), FloatConstant(Float), IntConstant(Integer), BoolConstant(bool), Layout, In, Out, InOut, Uniform, Buffer, Const, Shared, Restrict, /// A `glsl` memory qualifier such as `writeonly` /// /// The associated [`crate::StorageAccess`] is the access being allowed /// (for example `writeonly` has an associated value of [`crate::StorageAccess::STORE`]) MemoryQualifier(crate::StorageAccess), Invariant, Interpolation(Interpolation), Sampling(Sampling), Precision, PrecisionQualifier(Precision), Continue, Break, Return, Discard, If, Else, Switch, Case, Default, While, Do, For, Void, Struct, TypeName(Type), Assign, AddAssign, SubAssign, MulAssign, DivAssign, ModAssign, LeftShiftAssign, RightShiftAssign, AndAssign, XorAssign, OrAssign, Increment, Decrement, LogicalOr, LogicalAnd, LogicalXor, LessEqual, GreaterEqual, Equal, NotEqual, LeftShift, RightShift, LeftBrace, RightBrace, LeftParen, RightParen, LeftBracket, RightBracket, LeftAngle, RightAngle, Comma, Semicolon, Colon, Dot, Bang, Dash, Tilde, Plus, Star, Slash, Percent, VerticalBar, Caret, Ampersand, Question, } #[derive(Debug)] #[cfg_attr(test, derive(PartialEq))] pub struct Directive { pub kind: DirectiveKind, pub tokens: Vec, } #[derive(Debug)] #[cfg_attr(test, derive(PartialEq))] pub enum DirectiveKind { Version { is_first_directive: bool }, Extension, Pragma, } naga-29.0.3/src/front/glsl/types.rs000064400000000000000000000310551046102023000152230ustar 00000000000000use alloc::format; use super::{context::Context, Error, ErrorKind, Result, Span}; use crate::{ proc::ResolveContext, Expression, Handle, ImageClass, ImageDimension, Scalar, ScalarKind, Type, TypeInner, VectorSize, }; pub fn parse_type(type_name: &str) -> Option { match type_name { "bool" => Some(Type { name: None, inner: TypeInner::Scalar(Scalar::BOOL), }), "float16_t" => Some(Type { name: None, inner: TypeInner::Scalar(Scalar::F16), }), "float" => Some(Type { name: None, inner: TypeInner::Scalar(Scalar::F32), }), "double" => Some(Type { name: None, inner: TypeInner::Scalar(Scalar::F64), }), "int" => Some(Type { name: None, inner: TypeInner::Scalar(Scalar::I32), }), "uint" => Some(Type { name: None, inner: TypeInner::Scalar(Scalar::U32), }), "sampler" | "samplerShadow" => Some(Type { name: None, inner: TypeInner::Sampler { comparison: type_name == "samplerShadow", }, }), word => { fn kind_width_parse(ty: &str) -> Option { Some(match ty { "" => Scalar::F32, "b" => Scalar::BOOL, "i" => Scalar::I32, "u" => Scalar::U32, "d" => Scalar::F64, "f16" => Scalar::F16, _ => return None, }) } fn size_parse(n: &str) -> Option { Some(match n { "2" => VectorSize::Bi, "3" => VectorSize::Tri, "4" => VectorSize::Quad, _ => return None, }) } let vec_parse = |word: &str| { let mut iter = word.split("vec"); let kind = iter.next()?; let size = iter.next()?; let scalar = kind_width_parse(kind)?; let size = size_parse(size)?; Some(Type { name: None, inner: TypeInner::Vector { size, scalar }, }) }; let mat_parse = |word: &str| { let mut iter = word.split("mat"); let kind = iter.next()?; let size = iter.next()?; let scalar = kind_width_parse(kind)?; let (columns, rows) = if let Some(size) = size_parse(size) { (size, size) } else { let mut iter = size.split('x'); match (iter.next()?, iter.next()?, iter.next()) { (col, row, None) => (size_parse(col)?, size_parse(row)?), _ => return None, } }; Some(Type { name: None, inner: TypeInner::Matrix { columns, rows, scalar, }, }) }; let texture_parse = |word: &str| { let mut iter = word.split("texture"); let texture_kind = |ty| { Some(match ty { "" => ScalarKind::Float, "i" => ScalarKind::Sint, "u" => ScalarKind::Uint, _ => return None, }) }; let kind = iter.next()?; let size = iter.next()?; let kind = texture_kind(kind)?; let sampled = |multi| ImageClass::Sampled { kind, multi }; let (dim, arrayed, class) = match size { "1D" => (ImageDimension::D1, false, sampled(false)), "1DArray" => (ImageDimension::D1, true, sampled(false)), "2D" => (ImageDimension::D2, false, sampled(false)), "2DArray" => (ImageDimension::D2, true, sampled(false)), "2DMS" => (ImageDimension::D2, false, sampled(true)), "2DMSArray" => (ImageDimension::D2, true, sampled(true)), "3D" => (ImageDimension::D3, false, sampled(false)), "Cube" => (ImageDimension::Cube, false, sampled(false)), "CubeArray" => (ImageDimension::Cube, true, sampled(false)), _ => return None, }; Some(Type { name: None, inner: TypeInner::Image { dim, arrayed, class, }, }) }; let image_parse = |word: &str| { let mut iter = word.split("image"); let texture_kind = |ty| { Some(match ty { "" => ScalarKind::Float, "i" => ScalarKind::Sint, "u" => ScalarKind::Uint, _ => return None, }) }; let kind = iter.next()?; let size = iter.next()?; // TODO: Check that the texture format and the kind match let _ = texture_kind(kind)?; let class = ImageClass::Storage { format: crate::StorageFormat::R8Uint, access: crate::StorageAccess::LOAD | crate::StorageAccess::STORE, }; // TODO: glsl support multisampled storage images, naga doesn't let (dim, arrayed) = match size { "1D" => (ImageDimension::D1, false), "1DArray" => (ImageDimension::D1, true), "2D" => (ImageDimension::D2, false), "2DArray" => (ImageDimension::D2, true), "3D" => (ImageDimension::D3, false), // Naga doesn't support cube images and it's usefulness // is questionable, so they won't be supported for now // "Cube" => (ImageDimension::Cube, false), // "CubeArray" => (ImageDimension::Cube, true), _ => return None, }; Some(Type { name: None, inner: TypeInner::Image { dim, arrayed, class, }, }) }; vec_parse(word) .or_else(|| mat_parse(word)) .or_else(|| texture_parse(word)) .or_else(|| image_parse(word)) } } } pub const fn scalar_components(ty: &TypeInner) -> Option { match *ty { TypeInner::Scalar(scalar) | TypeInner::Vector { scalar, .. } | TypeInner::ValuePointer { scalar, .. } | TypeInner::Matrix { scalar, .. } => Some(scalar), _ => None, } } pub const fn type_power(scalar: Scalar) -> Option { Some(match scalar.kind { ScalarKind::Sint => 0, ScalarKind::Uint => 1, ScalarKind::Float if scalar.width == 4 => 2, ScalarKind::Float => 3, ScalarKind::Bool | ScalarKind::AbstractInt | ScalarKind::AbstractFloat => return None, }) } impl Context<'_> { /// Resolves the types of the expressions until `expr` (inclusive) /// /// This needs to be done before the [`typifier`] can be queried for /// the types of the expressions in the range between the last grow and `expr`. /// /// # Note /// /// The `resolve_type*` methods (like [`resolve_type`]) automatically /// grow the [`typifier`] so calling this method is not necessary when using /// them. /// /// [`typifier`]: Context::typifier /// [`resolve_type`]: Self::resolve_type pub(crate) fn typifier_grow(&mut self, expr: Handle, meta: Span) -> Result<()> { let resolve_ctx = ResolveContext::with_locals(self.module, &self.locals, &self.arguments); let typifier = if self.is_const { &mut self.const_typifier } else { &mut self.typifier }; let expressions = if self.is_const { &self.module.global_expressions } else { &self.expressions }; typifier .grow(expr, expressions, &resolve_ctx) .map_err(|error| Error { kind: ErrorKind::SemanticError(format!("Can't resolve type: {error:?}").into()), meta, }) } pub(crate) fn get_type(&self, expr: Handle) -> &TypeInner { let typifier = if self.is_const { &self.const_typifier } else { &self.typifier }; typifier.get(expr, &self.module.types) } /// Gets the type for the result of the `expr` expression /// /// Automatically grows the [`typifier`] to `expr` so calling /// [`typifier_grow`] is not necessary /// /// [`typifier`]: Context::typifier /// [`typifier_grow`]: Self::typifier_grow pub(crate) fn resolve_type( &mut self, expr: Handle, meta: Span, ) -> Result<&TypeInner> { self.typifier_grow(expr, meta)?; Ok(self.get_type(expr)) } /// Gets the type handle for the result of the `expr` expression /// /// Automatically grows the [`typifier`] to `expr` so calling /// [`typifier_grow`] is not necessary /// /// # Note /// /// Consider using [`resolve_type`] whenever possible /// since it doesn't require adding each type to the [`types`] arena /// and it doesn't need to mutably borrow the [`Parser`][Self] /// /// [`types`]: crate::Module::types /// [`typifier`]: Context::typifier /// [`typifier_grow`]: Self::typifier_grow /// [`resolve_type`]: Self::resolve_type pub(crate) fn resolve_type_handle( &mut self, expr: Handle, meta: Span, ) -> Result> { self.typifier_grow(expr, meta)?; let typifier = if self.is_const { &mut self.const_typifier } else { &mut self.typifier }; Ok(typifier.register_type(expr, &mut self.module.types)) } /// Invalidates the cached type resolution for `expr` forcing a recomputation pub(crate) fn invalidate_expression( &mut self, expr: Handle, meta: Span, ) -> Result<()> { let resolve_ctx = ResolveContext::with_locals(self.module, &self.locals, &self.arguments); let typifier = if self.is_const { &mut self.const_typifier } else { &mut self.typifier }; typifier .invalidate(expr, &self.expressions, &resolve_ctx) .map_err(|error| Error { kind: ErrorKind::SemanticError(format!("Can't resolve type: {error:?}").into()), meta, }) } pub(crate) fn lift_up_const_expression( &mut self, expr: Handle, ) -> Result> { let meta = self.expressions.get_span(expr); let h = match self.expressions[expr] { ref expr @ (Expression::Literal(_) | Expression::Constant(_) | Expression::ZeroValue(_)) => { self.module.global_expressions.append(expr.clone(), meta) } Expression::Compose { ty, ref components } => { let mut components = components.clone(); for component in &mut components { *component = self.lift_up_const_expression(*component)?; } self.module .global_expressions .append(Expression::Compose { ty, components }, meta) } Expression::Splat { size, value } => { let value = self.lift_up_const_expression(value)?; self.module .global_expressions .append(Expression::Splat { size, value }, meta) } _ => { return Err(Error { kind: ErrorKind::SemanticError("Expression is not const-expression".into()), meta, }) } }; self.global_expression_kind_tracker .insert(h, crate::proc::ExpressionKind::Const); Ok(h) } } naga-29.0.3/src/front/glsl/variables.rs000064400000000000000000000647021046102023000160340ustar 00000000000000use alloc::{format, string::String, vec::Vec}; use super::{ ast::*, context::{Context, ExprPos}, error::{Error, ErrorKind}, Frontend, Result, Span, }; use crate::{ AddressSpace, Binding, BuiltIn, Constant, Expression, GlobalVariable, Handle, Interpolation, LocalVariable, Override, ResourceBinding, Scalar, ScalarKind, ShaderStage, SwizzleComponent, Type, TypeInner, VectorSize, }; pub struct VarDeclaration<'a, 'key> { pub qualifiers: &'a mut TypeQualifiers<'key>, pub ty: Handle, pub name: Option, pub init: Option>, pub meta: Span, } /// Information about a builtin used in [`add_builtin`](Frontend::add_builtin). struct BuiltInData { /// The type of the builtin. inner: TypeInner, /// The associated builtin class. builtin: BuiltIn, /// Whether the builtin can be written to or not. mutable: bool, /// The storage used for the builtin. storage: StorageQualifier, } pub enum GlobalOrConstant { Global(Handle), Constant(Handle), Override(Handle), } impl Frontend { /// Adds a builtin and returns a variable reference to it fn add_builtin( &mut self, ctx: &mut Context, name: &str, data: BuiltInData, meta: Span, ) -> Result> { let ty = ctx.module.types.insert( Type { name: None, inner: data.inner, }, meta, ); let handle = ctx.module.global_variables.append( GlobalVariable { name: Some(name.into()), space: AddressSpace::Private, binding: None, ty, init: None, memory_decorations: crate::MemoryDecorations::empty(), }, meta, ); let idx = self.entry_args.len(); self.entry_args.push(EntryArg { name: Some(name.into()), binding: Binding::BuiltIn(data.builtin), handle, storage: data.storage, }); self.global_variables.push(( name.into(), GlobalLookup { kind: GlobalLookupKind::Variable(handle), entry_arg: Some(idx), mutable: data.mutable, }, )); let expr = ctx.add_expression(Expression::GlobalVariable(handle), meta)?; let var = VariableReference { expr, load: true, mutable: data.mutable, constant: None, entry_arg: Some(idx), }; ctx.symbol_table.add_root(name.into(), var.clone()); Ok(Some(var)) } pub(crate) fn lookup_variable( &mut self, ctx: &mut Context, name: &str, meta: Span, ) -> Result> { if let Some(var) = ctx.symbol_table.lookup(name).cloned() { return Ok(Some(var)); } let data = match name { "gl_Position" => BuiltInData { inner: TypeInner::Vector { size: VectorSize::Quad, scalar: Scalar::F32, }, builtin: BuiltIn::Position { invariant: false }, mutable: true, storage: StorageQualifier::Output, }, "gl_FragCoord" => BuiltInData { inner: TypeInner::Vector { size: VectorSize::Quad, scalar: Scalar::F32, }, builtin: BuiltIn::Position { invariant: false }, mutable: false, storage: StorageQualifier::Input, }, "gl_PointCoord" => BuiltInData { inner: TypeInner::Vector { size: VectorSize::Bi, scalar: Scalar::F32, }, builtin: BuiltIn::PointCoord, mutable: false, storage: StorageQualifier::Input, }, "gl_GlobalInvocationID" | "gl_NumWorkGroups" | "gl_WorkGroupSize" | "gl_WorkGroupID" | "gl_LocalInvocationID" => BuiltInData { inner: TypeInner::Vector { size: VectorSize::Tri, scalar: Scalar::U32, }, builtin: match name { "gl_GlobalInvocationID" => BuiltIn::GlobalInvocationId, "gl_NumWorkGroups" => BuiltIn::NumWorkGroups, "gl_WorkGroupSize" => BuiltIn::WorkGroupSize, "gl_WorkGroupID" => BuiltIn::WorkGroupId, "gl_LocalInvocationID" => BuiltIn::LocalInvocationId, _ => unreachable!(), }, mutable: false, storage: StorageQualifier::Input, }, "gl_FrontFacing" => BuiltInData { inner: TypeInner::Scalar(Scalar::BOOL), builtin: BuiltIn::FrontFacing, mutable: false, storage: StorageQualifier::Input, }, "gl_PointSize" | "gl_FragDepth" => BuiltInData { inner: TypeInner::Scalar(Scalar::F32), builtin: match name { "gl_PointSize" => BuiltIn::PointSize, "gl_FragDepth" => BuiltIn::FragDepth, _ => unreachable!(), }, mutable: true, storage: StorageQualifier::Output, }, "gl_ClipDistance" | "gl_CullDistance" => { let base = ctx.module.types.insert( Type { name: None, inner: TypeInner::Scalar(Scalar::F32), }, meta, ); BuiltInData { inner: TypeInner::Array { base, size: crate::ArraySize::Dynamic, stride: 4, }, builtin: match name { "gl_ClipDistance" => BuiltIn::ClipDistance, "gl_CullDistance" => BuiltIn::CullDistance, _ => unreachable!(), }, mutable: self.meta.stage == ShaderStage::Vertex, storage: StorageQualifier::Output, } } _ => { let builtin = match name { "gl_BaseVertex" => BuiltIn::BaseVertex, "gl_BaseInstance" => BuiltIn::BaseInstance, "gl_PrimitiveID" => BuiltIn::PrimitiveIndex, "gl_BaryCoordEXT" => BuiltIn::Barycentric { perspective: true }, "gl_BaryCoordNoPerspEXT" => BuiltIn::Barycentric { perspective: false }, "gl_InstanceIndex" => BuiltIn::InstanceIndex, "gl_VertexIndex" => BuiltIn::VertexIndex, "gl_SampleID" => BuiltIn::SampleIndex, "gl_LocalInvocationIndex" => BuiltIn::LocalInvocationIndex, "gl_DrawID" => BuiltIn::DrawIndex, _ => return Ok(None), }; BuiltInData { inner: TypeInner::Scalar(Scalar::U32), builtin, mutable: false, storage: StorageQualifier::Input, } } }; self.add_builtin(ctx, name, data, meta) } pub(crate) fn make_variable_invariant( &mut self, ctx: &mut Context, name: &str, meta: Span, ) -> Result<()> { if let Some(var) = self.lookup_variable(ctx, name, meta)? { if let Some(index) = var.entry_arg { if let Binding::BuiltIn(BuiltIn::Position { ref mut invariant }) = self.entry_args[index].binding { *invariant = true; } } } Ok(()) } pub(crate) fn field_selection( &mut self, ctx: &mut Context, pos: ExprPos, expression: Handle, name: &str, meta: Span, ) -> Result> { let (ty, is_pointer) = match *ctx.resolve_type(expression, meta)? { TypeInner::Pointer { base, .. } => (&ctx.module.types[base].inner, true), ref ty => (ty, false), }; match *ty { TypeInner::Struct { ref members, .. } => { let index = members .iter() .position(|m| m.name == Some(name.into())) .ok_or_else(|| Error { kind: ErrorKind::UnknownField(name.into()), meta, })?; let pointer = ctx.add_expression( Expression::AccessIndex { base: expression, index: index as u32, }, meta, )?; Ok(match pos { ExprPos::Rhs if is_pointer => { ctx.add_expression(Expression::Load { pointer }, meta)? } _ => pointer, }) } // swizzles (xyzw, rgba, stpq) TypeInner::Vector { size, .. } => { let check_swizzle_components = |comps: &str| { name.chars() .map(|c| { comps .find(c) .filter(|i| *i < size as usize) .map(|i| SwizzleComponent::from_index(i as u32)) }) .collect::>>() }; let components = check_swizzle_components("xyzw") .or_else(|| check_swizzle_components("rgba")) .or_else(|| check_swizzle_components("stpq")); if let Some(components) = components { if let ExprPos::Lhs = pos { let not_unique = (1..components.len()) .any(|i| components[i..].contains(&components[i - 1])); if not_unique { self.errors.push(Error { kind: ErrorKind::SemanticError( format!( concat!( "swizzle cannot have duplicate components in ", "left-hand-side expression for \"{:?}\"" ), name ) .into(), ), meta, }) } } let mut pattern = [SwizzleComponent::X; 4]; for (pat, component) in pattern.iter_mut().zip(&components) { *pat = *component; } // flatten nested swizzles (vec.zyx.xy.x => vec.z) let mut expression = expression; if let Expression::Swizzle { size: _, vector, pattern: ref src_pattern, } = ctx[expression] { expression = vector; for pat in &mut pattern { *pat = src_pattern[pat.index() as usize]; } } let size = match components.len() { // Swizzles with just one component are accesses and not swizzles 1 => { match pos { // If the position is in the right hand side and the base // vector is a pointer, load it, otherwise the swizzle would // produce a pointer ExprPos::Rhs if is_pointer => { expression = ctx.add_expression( Expression::Load { pointer: expression, }, meta, )?; } _ => {} }; return ctx.add_expression( Expression::AccessIndex { base: expression, index: pattern[0].index(), }, meta, ); } 2 => VectorSize::Bi, 3 => VectorSize::Tri, 4 => VectorSize::Quad, _ => { self.errors.push(Error { kind: ErrorKind::SemanticError( format!("Bad swizzle size for \"{name:?}\"").into(), ), meta, }); VectorSize::Quad } }; if is_pointer { // NOTE: for lhs expression, this extra load ends up as an unused expr, because the // assignment will extract the pointer and use it directly anyway. Unfortunately we // need it for validation to pass, as swizzles cannot operate on pointer values. expression = ctx.add_expression( Expression::Load { pointer: expression, }, meta, )?; } Ok(ctx.add_expression( Expression::Swizzle { size, vector: expression, pattern, }, meta, )?) } else { Err(Error { kind: ErrorKind::SemanticError( format!("Invalid swizzle for vector \"{name}\"").into(), ), meta, }) } } _ => Err(Error { kind: ErrorKind::SemanticError( format!("Can't lookup field on this type \"{name}\"").into(), ), meta, }), } } pub(crate) fn add_global_var( &mut self, ctx: &mut Context, VarDeclaration { qualifiers, mut ty, name, init, meta, }: VarDeclaration, ) -> Result { let storage = qualifiers.storage.0; let (ret, lookup) = match storage { StorageQualifier::Input | StorageQualifier::Output => { let input = storage == StorageQualifier::Input; // TODO: glslang seems to use a counter for variables without // explicit location (even if that causes collisions) let location = qualifiers .uint_layout_qualifier("location", &mut self.errors) .unwrap_or(0); let interpolation = qualifiers.interpolation.take().map(|(i, _)| i).or_else(|| { let kind = ctx.module.types[ty].inner.scalar_kind()?; Some(match kind { ScalarKind::Float => Interpolation::Perspective, _ => Interpolation::Flat, }) }); let sampling = qualifiers.sampling.take().map(|(s, _)| s); let handle = ctx.module.global_variables.append( GlobalVariable { name: name.clone(), space: AddressSpace::Private, binding: None, ty, init, memory_decorations: crate::MemoryDecorations::empty(), }, meta, ); let blend_src = qualifiers .layout_qualifiers .remove(&QualifierKey::Index) .and_then(|(value, _span)| match value { QualifierValue::Uint(index) => Some(index), _ => None, }); let idx = self.entry_args.len(); self.entry_args.push(EntryArg { name: name.clone(), binding: Binding::Location { location, interpolation, sampling, blend_src, per_primitive: false, }, handle, storage, }); let lookup = GlobalLookup { kind: GlobalLookupKind::Variable(handle), entry_arg: Some(idx), mutable: !input, }; (GlobalOrConstant::Global(handle), lookup) } StorageQualifier::Const => { // Check if this is a specialization constant with constant_id let constant_id = qualifiers.uint_layout_qualifier("constant_id", &mut self.errors); if let Some(id) = constant_id { // This is a specialization constant - convert to Override let id: Option = match id.try_into() { Ok(v) => Some(v), Err(_) => { self.errors.push(Error { kind: ErrorKind::SemanticError( format!( "constant_id value {id} is too high (maximum is {})", u16::MAX ) .into(), ), meta, }); None } }; let override_handle = ctx.module.overrides.append( Override { name: name.clone(), id, ty, init, }, meta, ); let lookup = GlobalLookup { kind: GlobalLookupKind::Override(override_handle, ty), entry_arg: None, mutable: false, }; (GlobalOrConstant::Override(override_handle), lookup) } else { // Regular constant let init = init.ok_or_else(|| Error { kind: ErrorKind::SemanticError( "const values must have an initializer".into(), ), meta, })?; let constant = Constant { name: name.clone(), ty, init, }; let handle = ctx.module.constants.append(constant, meta); let lookup = GlobalLookup { kind: GlobalLookupKind::Constant(handle, ty), entry_arg: None, mutable: false, }; (GlobalOrConstant::Constant(handle), lookup) } } StorageQualifier::AddressSpace(mut space) => { match space { AddressSpace::Storage { ref mut access } => { if let Some((allowed_access, _)) = qualifiers.storage_access.take() { *access = allowed_access; } } AddressSpace::Uniform => match ctx.module.types[ty].inner { TypeInner::Image { class, dim, arrayed, } => { if let crate::ImageClass::Storage { mut access, mut format, } = class { if let Some((allowed_access, _)) = qualifiers.storage_access.take() { access = allowed_access; } match qualifiers.layout_qualifiers.remove(&QualifierKey::Format) { Some((QualifierValue::Format(f), _)) => format = f, // TODO: glsl supports images without format qualifier // if they are `writeonly` None => self.errors.push(Error { kind: ErrorKind::SemanticError( "image types require a format layout qualifier".into(), ), meta, }), _ => unreachable!(), } ty = ctx.module.types.insert( Type { name: None, inner: TypeInner::Image { dim, arrayed, class: crate::ImageClass::Storage { format, access }, }, }, meta, ); } space = AddressSpace::Handle } TypeInner::Sampler { .. } => space = AddressSpace::Handle, _ => { if qualifiers.none_layout_qualifier("push_constant", &mut self.errors) { space = AddressSpace::Immediate } } }, AddressSpace::Function => space = AddressSpace::Private, _ => {} }; let binding = match space { AddressSpace::Uniform | AddressSpace::Storage { .. } | AddressSpace::Handle => { let binding = qualifiers.uint_layout_qualifier("binding", &mut self.errors); if binding.is_none() { self.errors.push(Error { kind: ErrorKind::SemanticError( "uniform/buffer blocks require layout(binding=X)".into(), ), meta, }); } let set = qualifiers.uint_layout_qualifier("set", &mut self.errors); binding.map(|binding| ResourceBinding { group: set.unwrap_or(0), binding, }) } _ => None, }; let handle = ctx.module.global_variables.append( GlobalVariable { name: name.clone(), space, binding, ty, init, memory_decorations: crate::MemoryDecorations::empty(), }, meta, ); let lookup = GlobalLookup { kind: GlobalLookupKind::Variable(handle), entry_arg: None, mutable: true, }; (GlobalOrConstant::Global(handle), lookup) } }; if let Some(name) = name { ctx.add_global(&name, lookup)?; self.global_variables.push((name, lookup)); } qualifiers.unused_errors(&mut self.errors); Ok(ret) } pub(crate) fn add_local_var( &mut self, ctx: &mut Context, decl: VarDeclaration, ) -> Result> { let storage = decl.qualifiers.storage; let mutable = match storage.0 { StorageQualifier::AddressSpace(AddressSpace::Function) => true, StorageQualifier::Const => false, _ => { self.errors.push(Error { kind: ErrorKind::SemanticError("Locals cannot have a storage qualifier".into()), meta: storage.1, }); true } }; let handle = ctx.locals.append( LocalVariable { name: decl.name.clone(), ty: decl.ty, init: decl.init, }, decl.meta, ); let expr = ctx.add_expression(Expression::LocalVariable(handle), decl.meta)?; if let Some(name) = decl.name { let maybe_var = ctx.add_local_var(name.clone(), expr, mutable); if maybe_var.is_some() { self.errors.push(Error { kind: ErrorKind::VariableAlreadyDeclared(name), meta: decl.meta, }) } } decl.qualifiers.unused_errors(&mut self.errors); Ok(expr) } } naga-29.0.3/src/front/interpolator.rs000064400000000000000000000053701046102023000156410ustar 00000000000000/*! Interpolation defaults. */ impl crate::Binding { /// Apply the usual default interpolation for `ty` to `binding`. /// /// This function is a utility front ends may use to satisfy the Naga IR's /// requirement, meant to ensure that input languages' policies have been /// applied appropriately, that all I/O `Binding`s from the vertex shader to the /// fragment shader must have non-`None` `interpolation` values. /// /// All the shader languages Naga supports have similar rules: /// perspective-correct, center-sampled interpolation is the default for any /// binding that can vary, and everything else either defaults to flat, or /// requires an explicit flat qualifier/attribute/what-have-you. /// /// If `binding` is not a [`Location`] binding, or if its [`interpolation`] is /// already set, then make no changes. Otherwise, set `binding`'s interpolation /// and sampling to reasonable defaults depending on `ty`, the type of the value /// being interpolated: /// /// - If `ty` is a floating-point scalar, vector, or matrix type, then /// default to [`Perspective`] interpolation and [`Center`] sampling. /// /// - If `ty` is an integral scalar or vector, then default to [`Flat`] /// interpolation, which has no associated sampling. /// /// - For any other types, make no change. Such types are not permitted as /// user-defined IO values, and will probably be flagged by the verifier /// /// When structs appear in input or output types, each member ought to have its /// own [`Binding`], so structs are simply covered by the third case. /// /// [`Binding`]: crate::Binding /// [`Location`]: crate::Binding::Location /// [`interpolation`]: crate::Binding::Location::interpolation /// [`Perspective`]: crate::Interpolation::Perspective /// [`Flat`]: crate::Interpolation::Flat /// [`Center`]: crate::Sampling::Center pub fn apply_default_interpolation(&mut self, ty: &crate::TypeInner) { if let crate::Binding::Location { location: _, interpolation: ref mut interpolation @ None, ref mut sampling, blend_src: _, per_primitive: _, } = *self { match ty.scalar_kind() { Some(crate::ScalarKind::Float) => { *interpolation = Some(crate::Interpolation::Perspective); *sampling = Some(crate::Sampling::Center); } Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => { *interpolation = Some(crate::Interpolation::Flat); *sampling = None; } Some(_) | None => {} } } } } naga-29.0.3/src/front/mod.rs000064400000000000000000000326561046102023000137050ustar 00000000000000/*! Frontend parsers that consume binary and text shaders and load them into [`Module`](super::Module)s. */ mod interpolator; mod type_gen; #[cfg(feature = "spv-in")] pub mod atomic_upgrade; #[cfg(feature = "glsl-in")] pub mod glsl; #[cfg(feature = "spv-in")] pub mod spv; #[cfg(feature = "wgsl-in")] pub mod wgsl; use alloc::{boxed::Box, vec, vec::Vec}; use core::ops; use crate::{ arena::{Arena, Handle, HandleVec, UniqueArena}, proc::{ResolveContext, ResolveError, TypeResolution}, FastHashMap, }; /// A table of types for an `Arena`. /// /// A front end can use a `Typifier` to get types for an arena's expressions /// while it is still contributing expressions to it. At any point, you can call /// [`typifier.grow(expr, arena, ctx)`], where `expr` is a `Handle` /// referring to something in `arena`, and the `Typifier` will resolve the types /// of all the expressions up to and including `expr`. Then you can write /// `typifier[handle]` to get the type of any handle at or before `expr`. /// /// Note that `Typifier` does *not* build an `Arena` as a part of its /// usual operation. Ideally, a module's type arena should only contain types /// actually needed by `Handle`s elsewhere in the module — functions, /// variables, [`Compose`] expressions, other types, and so on — so we don't /// want every little thing that occurs as the type of some intermediate /// expression to show up there. /// /// Instead, `Typifier` accumulates a [`TypeResolution`] for each expression, /// which refers to the `Arena` in the [`ResolveContext`] passed to `grow` /// as needed. [`TypeResolution`] is a lightweight representation for /// intermediate types like this; see its documentation for details. /// /// If you do need to register a `Typifier`'s conclusion in an `Arena` /// (say, for a [`LocalVariable`] whose type you've inferred), you can use /// [`register_type`] to do so. /// /// [`typifier.grow(expr, arena)`]: Typifier::grow /// [`register_type`]: Typifier::register_type /// [`Compose`]: crate::Expression::Compose /// [`LocalVariable`]: crate::LocalVariable #[derive(Debug, Default)] pub struct Typifier { resolutions: HandleVec, } impl Typifier { pub const fn new() -> Self { Typifier { resolutions: HandleVec::new(), } } pub fn reset(&mut self) { self.resolutions.clear() } pub fn get<'a>( &'a self, expr_handle: Handle, types: &'a UniqueArena, ) -> &'a crate::TypeInner { self.resolutions[expr_handle].inner_with(types) } /// Add an expression's type to an `Arena`. /// /// Add the type of `expr_handle` to `types`, and return a `Handle` /// referring to it. /// /// # Note /// /// If you just need a [`TypeInner`] for `expr_handle`'s type, consider /// using `typifier[expression].inner_with(types)` instead. Calling /// [`TypeResolution::inner_with`] often lets us avoid adding anything to /// the arena, which can significantly reduce the number of types that end /// up in the final module. /// /// [`TypeInner`]: crate::TypeInner pub fn register_type( &self, expr_handle: Handle, types: &mut UniqueArena, ) -> Handle { match self[expr_handle].clone() { TypeResolution::Handle(handle) => handle, TypeResolution::Value(inner) => { types.insert(crate::Type { name: None, inner }, crate::Span::UNDEFINED) } } } /// Grow this typifier until it contains a type for `expr_handle`. pub fn grow( &mut self, expr_handle: Handle, expressions: &Arena, ctx: &ResolveContext, ) -> Result<(), ResolveError> { if self.resolutions.len() <= expr_handle.index() { for (eh, expr) in expressions.iter().skip(self.resolutions.len()) { //Note: the closure can't `Err` by construction let resolution = ctx.resolve(expr, |h| Ok(&self.resolutions[h]))?; log::debug!("Resolving {eh:?} = {expr:?} : {resolution:?}"); self.resolutions.insert(eh, resolution); } } Ok(()) } /// Recompute the type resolution for `expr_handle`. /// /// If the type of `expr_handle` hasn't yet been calculated, call /// [`grow`](Self::grow) to ensure it is covered. /// /// In either case, when this returns, `self[expr_handle]` should be an /// updated type resolution for `expr_handle`. pub fn invalidate( &mut self, expr_handle: Handle, expressions: &Arena, ctx: &ResolveContext, ) -> Result<(), ResolveError> { if self.resolutions.len() <= expr_handle.index() { self.grow(expr_handle, expressions, ctx) } else { let expr = &expressions[expr_handle]; //Note: the closure can't `Err` by construction let resolution = ctx.resolve(expr, |h| Ok(&self.resolutions[h]))?; self.resolutions[expr_handle] = resolution; Ok(()) } } } impl ops::Index> for Typifier { type Output = TypeResolution; fn index(&self, handle: Handle) -> &Self::Output { &self.resolutions[handle] } } /// Type representing a lexical scope, associating a name to a single variable /// /// The scope is generic over the variable representation and name representation /// in order to allow larger flexibility on the frontends on how they might /// represent them. type Scope = FastHashMap; /// Structure responsible for managing variable lookups and keeping track of /// lexical scopes /// /// The symbol table is generic over the variable representation and its name /// to allow larger flexibility on the frontends on how they might represent them. /// /// ``` /// use naga::front::SymbolTable; /// /// // Create a new symbol table with `u32`s representing the variable /// let mut symbol_table: SymbolTable<&str, u32> = SymbolTable::default(); /// /// // Add two variables named `var1` and `var2` with 0 and 2 respectively /// symbol_table.add("var1", 0); /// symbol_table.add("var2", 2); /// /// // Check that `var1` exists and is `0` /// assert_eq!(symbol_table.lookup("var1"), Some(&0)); /// /// // Push a new scope and add a variable to it named `var1` shadowing the /// // variable of our previous scope /// symbol_table.push_scope(); /// symbol_table.add("var1", 1); /// /// // Check that `var1` now points to the new value of `1` and `var2` still /// // exists with its value of `2` /// assert_eq!(symbol_table.lookup("var1"), Some(&1)); /// assert_eq!(symbol_table.lookup("var2"), Some(&2)); /// /// // Pop the scope /// symbol_table.pop_scope(); /// /// // Check that `var1` now refers to our initial variable with value `0` /// assert_eq!(symbol_table.lookup("var1"), Some(&0)); /// ``` /// /// Scopes are ordered as a LIFO stack so a variable defined in a later scope /// with the same name as another variable defined in a earlier scope will take /// precedence in the lookup. Scopes can be added with [`push_scope`] and /// removed with [`pop_scope`]. /// /// A root scope is added when the symbol table is created and must always be /// present. Trying to pop it will result in a panic. /// /// Variables can be added with [`add`] and looked up with [`lookup`]. Adding a /// variable will do so in the currently active scope and as mentioned /// previously a lookup will search from the current scope to the root scope. /// /// [`push_scope`]: Self::push_scope /// [`pop_scope`]: Self::push_scope /// [`add`]: Self::add /// [`lookup`]: Self::lookup pub struct SymbolTable { /// Stack of lexical scopes. Not all scopes are active; see [`cursor`]. /// /// [`cursor`]: Self::cursor scopes: Vec>, /// Limit of the [`scopes`] stack (exclusive). By using a separate value for /// the stack length instead of `Vec`'s own internal length, the scopes can /// be reused to cache memory allocations. /// /// [`scopes`]: Self::scopes cursor: usize, lookup_cursor_is_one_behind: bool, } impl SymbolTable { /// Adds a new lexical scope. /// /// All variables declared after this point will be added to this scope /// until another scope is pushed or [`pop_scope`] is called, causing this /// scope to be removed along with all variables added to it. /// /// # PANICS /// - If the current lookup scope doesn't match the current scope /// /// [`pop_scope`]: Self::pop_scope pub fn push_scope(&mut self) { self.check_lookup_scope_matches_current_scope(); // If the cursor is equal to the scope's stack length then we need to // push another empty scope. Otherwise we can reuse the already existing // scope. if self.scopes.len() == self.cursor { self.scopes.push(FastHashMap::default()) } else { self.scopes[self.cursor].clear(); } self.cursor += 1; } /// Removes the current lexical scope and all its variables /// /// # PANICS /// - If the current lexical scope is the root scope /// - If the current lookup scope doesn't match the current scope pub fn pop_scope(&mut self) { // Despite the method title, the variables are only deleted when the // scope is reused. This is because while a clear is inevitable if the // scope needs to be reused, there are cases where the scope might be // popped and not reused, i.e. if another scope with the same nesting // level is never pushed again. assert!(self.cursor != 1, "Tried to pop the root scope"); self.check_lookup_scope_matches_current_scope(); self.cursor -= 1; } /// Reduces the lookup scope by one level. /// /// # PANICS /// - If the current lookup scope doesn't match the current scope pub fn reduce_lookup_scope(&mut self) { self.check_lookup_scope_matches_current_scope(); self.lookup_cursor_is_one_behind = true; } /// Resets the lookup scope to the current scope. /// /// # PANICS /// - If the current lookup scope already matches the current scope pub fn reset_lookup_scope(&mut self) { assert!( self.lookup_cursor_is_one_behind, "current lookup scope already matches the current scope" ); self.lookup_cursor_is_one_behind = false; } fn check_lookup_scope_matches_current_scope(&self) { assert!( !self.lookup_cursor_is_one_behind, "current lookup scope doesn't match the current scope" ); } } impl SymbolTable where Name: core::hash::Hash + Eq, { /// Perform a lookup for a variable named `name`. /// /// As stated in the struct level documentation the lookup will proceed from /// the current scope to the root scope, returning `Some` when a variable is /// found or `None` if there doesn't exist a variable with `name` in any /// scope. pub fn lookup(&self, name: &Q) -> Option<&Var> where Name: core::borrow::Borrow, Q: core::hash::Hash + Eq + ?Sized, { let cursor = self .cursor .saturating_sub(self.lookup_cursor_is_one_behind.into()); // Iterate backwards through the scopes and try to find the variable for scope in self.scopes[..cursor].iter().rev() { if let Some(var) = scope.get(name) { return Some(var); } } None } /// Adds a new variable to the current scope. /// /// Returns the previous variable with the same name in this scope if it /// exists, so that the frontend might handle it in case variable shadowing /// is disallowed. pub fn add(&mut self, name: Name, var: Var) -> Option { self.scopes[self.cursor - 1].insert(name, var) } /// Adds a new variable to the root scope. /// /// This is used in GLSL for builtins which aren't known in advance and only /// when used for the first time, so there must be a way to add those /// declarations to the root unconditionally from the current scope. /// /// Returns the previous variable with the same name in the root scope if it /// exists, so that the frontend might handle it in case variable shadowing /// is disallowed. pub fn add_root(&mut self, name: Name, var: Var) -> Option { self.scopes[0].insert(name, var) } } impl Default for SymbolTable { /// Constructs a new symbol table with a root scope fn default() -> Self { Self { scopes: vec![FastHashMap::default()], cursor: 1, lookup_cursor_is_one_behind: false, } } } use core::fmt; impl fmt::Debug for SymbolTable { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str("SymbolTable ")?; f.debug_list() .entries(self.scopes[..self.cursor].iter()) .finish() } } impl crate::Module { pub fn get_or_insert_default_doc_comments(&mut self) -> &mut Box { self.doc_comments .get_or_insert_with(|| Box::new(crate::DocComments::default())) } } naga-29.0.3/src/front/spv/convert.rs000064400000000000000000000230301046102023000154000ustar 00000000000000use core::convert::TryInto; use super::error::Error; pub(super) const fn map_binary_operator(word: spirv::Op) -> Result { use crate::BinaryOperator; use spirv::Op; match word { // Arithmetic Instructions +, -, *, /, % Op::IAdd | Op::FAdd => Ok(BinaryOperator::Add), Op::ISub | Op::FSub => Ok(BinaryOperator::Subtract), Op::IMul | Op::FMul => Ok(BinaryOperator::Multiply), Op::UDiv | Op::SDiv | Op::FDiv => Ok(BinaryOperator::Divide), Op::SRem => Ok(BinaryOperator::Modulo), // Relational and Logical Instructions Op::IEqual | Op::FOrdEqual | Op::FUnordEqual | Op::LogicalEqual => { Ok(BinaryOperator::Equal) } Op::INotEqual | Op::FOrdNotEqual | Op::FUnordNotEqual | Op::LogicalNotEqual => { Ok(BinaryOperator::NotEqual) } Op::ULessThan | Op::SLessThan | Op::FOrdLessThan | Op::FUnordLessThan => { Ok(BinaryOperator::Less) } Op::ULessThanEqual | Op::SLessThanEqual | Op::FOrdLessThanEqual | Op::FUnordLessThanEqual => Ok(BinaryOperator::LessEqual), Op::UGreaterThan | Op::SGreaterThan | Op::FOrdGreaterThan | Op::FUnordGreaterThan => { Ok(BinaryOperator::Greater) } Op::UGreaterThanEqual | Op::SGreaterThanEqual | Op::FOrdGreaterThanEqual | Op::FUnordGreaterThanEqual => Ok(BinaryOperator::GreaterEqual), Op::BitwiseOr => Ok(BinaryOperator::InclusiveOr), Op::BitwiseXor => Ok(BinaryOperator::ExclusiveOr), Op::BitwiseAnd => Ok(BinaryOperator::And), _ => Err(Error::UnknownBinaryOperator(word)), } } pub(super) const fn map_relational_fun( word: spirv::Op, ) -> Result { use crate::RelationalFunction as Rf; use spirv::Op; match word { Op::All => Ok(Rf::All), Op::Any => Ok(Rf::Any), Op::IsNan => Ok(Rf::IsNan), Op::IsInf => Ok(Rf::IsInf), _ => Err(Error::UnknownRelationalFunction(word)), } } pub(super) const fn map_vector_size(word: spirv::Word) -> Result { match word { 2 => Ok(crate::VectorSize::Bi), 3 => Ok(crate::VectorSize::Tri), 4 => Ok(crate::VectorSize::Quad), _ => Err(Error::InvalidVectorSize(word)), } } pub(super) fn map_image_dim(word: spirv::Word) -> Result { use spirv::Dim as D; match D::from_u32(word) { Some(D::Dim1D) => Ok(crate::ImageDimension::D1), Some(D::Dim2D) => Ok(crate::ImageDimension::D2), Some(D::Dim3D) => Ok(crate::ImageDimension::D3), Some(D::DimCube) => Ok(crate::ImageDimension::Cube), _ => Err(Error::UnsupportedImageDim(word)), } } pub(super) fn map_image_format(word: spirv::Word) -> Result { match spirv::ImageFormat::from_u32(word) { Some(spirv::ImageFormat::R8) => Ok(crate::StorageFormat::R8Unorm), Some(spirv::ImageFormat::R8Snorm) => Ok(crate::StorageFormat::R8Snorm), Some(spirv::ImageFormat::R8ui) => Ok(crate::StorageFormat::R8Uint), Some(spirv::ImageFormat::R8i) => Ok(crate::StorageFormat::R8Sint), Some(spirv::ImageFormat::R16) => Ok(crate::StorageFormat::R16Unorm), Some(spirv::ImageFormat::R16Snorm) => Ok(crate::StorageFormat::R16Snorm), Some(spirv::ImageFormat::R16ui) => Ok(crate::StorageFormat::R16Uint), Some(spirv::ImageFormat::R16i) => Ok(crate::StorageFormat::R16Sint), Some(spirv::ImageFormat::R16f) => Ok(crate::StorageFormat::R16Float), Some(spirv::ImageFormat::Rg8) => Ok(crate::StorageFormat::Rg8Unorm), Some(spirv::ImageFormat::Rg8Snorm) => Ok(crate::StorageFormat::Rg8Snorm), Some(spirv::ImageFormat::Rg8ui) => Ok(crate::StorageFormat::Rg8Uint), Some(spirv::ImageFormat::Rg8i) => Ok(crate::StorageFormat::Rg8Sint), Some(spirv::ImageFormat::R32ui) => Ok(crate::StorageFormat::R32Uint), Some(spirv::ImageFormat::R32i) => Ok(crate::StorageFormat::R32Sint), Some(spirv::ImageFormat::R32f) => Ok(crate::StorageFormat::R32Float), Some(spirv::ImageFormat::Rg16) => Ok(crate::StorageFormat::Rg16Unorm), Some(spirv::ImageFormat::Rg16Snorm) => Ok(crate::StorageFormat::Rg16Snorm), Some(spirv::ImageFormat::Rg16ui) => Ok(crate::StorageFormat::Rg16Uint), Some(spirv::ImageFormat::Rg16i) => Ok(crate::StorageFormat::Rg16Sint), Some(spirv::ImageFormat::Rg16f) => Ok(crate::StorageFormat::Rg16Float), Some(spirv::ImageFormat::Rgba8) => Ok(crate::StorageFormat::Rgba8Unorm), Some(spirv::ImageFormat::Rgba8Snorm) => Ok(crate::StorageFormat::Rgba8Snorm), Some(spirv::ImageFormat::Rgba8ui) => Ok(crate::StorageFormat::Rgba8Uint), Some(spirv::ImageFormat::Rgba8i) => Ok(crate::StorageFormat::Rgba8Sint), Some(spirv::ImageFormat::Rgb10a2ui) => Ok(crate::StorageFormat::Rgb10a2Uint), Some(spirv::ImageFormat::Rgb10A2) => Ok(crate::StorageFormat::Rgb10a2Unorm), Some(spirv::ImageFormat::R11fG11fB10f) => Ok(crate::StorageFormat::Rg11b10Ufloat), Some(spirv::ImageFormat::R64ui) => Ok(crate::StorageFormat::R64Uint), Some(spirv::ImageFormat::Rg32ui) => Ok(crate::StorageFormat::Rg32Uint), Some(spirv::ImageFormat::Rg32i) => Ok(crate::StorageFormat::Rg32Sint), Some(spirv::ImageFormat::Rg32f) => Ok(crate::StorageFormat::Rg32Float), Some(spirv::ImageFormat::Rgba16) => Ok(crate::StorageFormat::Rgba16Unorm), Some(spirv::ImageFormat::Rgba16Snorm) => Ok(crate::StorageFormat::Rgba16Snorm), Some(spirv::ImageFormat::Rgba16ui) => Ok(crate::StorageFormat::Rgba16Uint), Some(spirv::ImageFormat::Rgba16i) => Ok(crate::StorageFormat::Rgba16Sint), Some(spirv::ImageFormat::Rgba16f) => Ok(crate::StorageFormat::Rgba16Float), Some(spirv::ImageFormat::Rgba32ui) => Ok(crate::StorageFormat::Rgba32Uint), Some(spirv::ImageFormat::Rgba32i) => Ok(crate::StorageFormat::Rgba32Sint), Some(spirv::ImageFormat::Rgba32f) => Ok(crate::StorageFormat::Rgba32Float), _ => Err(Error::UnsupportedImageFormat(word)), } } pub(super) fn map_width(word: spirv::Word) -> Result { (word >> 3) // bits to bytes .try_into() .map_err(|_| Error::InvalidTypeWidth(word)) } pub(super) fn map_builtin(word: spirv::Word, invariant: bool) -> Result { use spirv::BuiltIn as Bi; Ok(match spirv::BuiltIn::from_u32(word) { Some(Bi::Position | Bi::FragCoord) => crate::BuiltIn::Position { invariant }, Some(Bi::ViewIndex) => crate::BuiltIn::ViewIndex, // vertex Some(Bi::BaseInstance) => crate::BuiltIn::BaseInstance, Some(Bi::BaseVertex) => crate::BuiltIn::BaseVertex, Some(Bi::ClipDistance) => crate::BuiltIn::ClipDistance, Some(Bi::CullDistance) => crate::BuiltIn::CullDistance, Some(Bi::InstanceIndex) => crate::BuiltIn::InstanceIndex, Some(Bi::PointSize) => crate::BuiltIn::PointSize, Some(Bi::VertexIndex) => crate::BuiltIn::VertexIndex, Some(Bi::DrawIndex) => crate::BuiltIn::DrawIndex, // fragment Some(Bi::FragDepth) => crate::BuiltIn::FragDepth, Some(Bi::PointCoord) => crate::BuiltIn::PointCoord, Some(Bi::FrontFacing) => crate::BuiltIn::FrontFacing, Some(Bi::PrimitiveId) => crate::BuiltIn::PrimitiveIndex, Some(Bi::BaryCoordKHR) => crate::BuiltIn::Barycentric { perspective: true }, Some(Bi::BaryCoordNoPerspKHR) => crate::BuiltIn::Barycentric { perspective: false }, Some(Bi::SampleId) => crate::BuiltIn::SampleIndex, Some(Bi::SampleMask) => crate::BuiltIn::SampleMask, // compute Some(Bi::GlobalInvocationId) => crate::BuiltIn::GlobalInvocationId, Some(Bi::LocalInvocationId) => crate::BuiltIn::LocalInvocationId, Some(Bi::LocalInvocationIndex) => crate::BuiltIn::LocalInvocationIndex, Some(Bi::WorkgroupId) => crate::BuiltIn::WorkGroupId, Some(Bi::WorkgroupSize) => crate::BuiltIn::WorkGroupSize, Some(Bi::NumWorkgroups) => crate::BuiltIn::NumWorkGroups, // subgroup Some(Bi::NumSubgroups) => crate::BuiltIn::NumSubgroups, Some(Bi::SubgroupId) => crate::BuiltIn::SubgroupId, Some(Bi::SubgroupSize) => crate::BuiltIn::SubgroupSize, Some(Bi::SubgroupLocalInvocationId) => crate::BuiltIn::SubgroupInvocationId, _ => return Err(Error::UnsupportedBuiltIn(word)), }) } pub(super) fn map_storage_class(word: spirv::Word) -> Result { use super::ExtendedClass as Ec; use spirv::StorageClass as Sc; Ok(match Sc::from_u32(word) { Some(Sc::Function) => Ec::Global(crate::AddressSpace::Function), Some(Sc::Input) => Ec::Input, Some(Sc::Output) => Ec::Output, Some(Sc::Private) => Ec::Global(crate::AddressSpace::Private), Some(Sc::UniformConstant) => Ec::Global(crate::AddressSpace::Handle), Some(Sc::StorageBuffer) => Ec::Global(crate::AddressSpace::Storage { //Note: this is restricted by decorations later access: crate::StorageAccess::LOAD | crate::StorageAccess::STORE, }), // we expect the `Storage` case to be filtered out before calling this function. Some(Sc::Uniform) => Ec::Global(crate::AddressSpace::Uniform), Some(Sc::Workgroup) => Ec::Global(crate::AddressSpace::WorkGroup), Some(Sc::PushConstant) => Ec::Global(crate::AddressSpace::Immediate), _ => return Err(Error::UnsupportedStorageClass(word)), }) } naga-29.0.3/src/front/spv/error.rs000064400000000000000000000165101046102023000150560ustar 00000000000000use alloc::{ format, string::{String, ToString}, }; use codespan_reporting::diagnostic::Diagnostic; use codespan_reporting::files::SimpleFile; use codespan_reporting::term; use super::ModuleState; #[cfg(feature = "stderr")] use crate::error::ErrorWrite; use crate::{arena::Handle, error::replace_control_chars, front::atomic_upgrade}; #[derive(Clone, Debug, thiserror::Error)] pub enum Error { #[error("invalid header")] InvalidHeader, #[error("invalid word count")] InvalidWordCount, #[error("unknown instruction {0}")] UnknownInstruction(u16), #[error("unknown capability %{0}")] UnknownCapability(spirv::Word), #[error("unsupported instruction {1:?} at {0:?}")] UnsupportedInstruction(ModuleState, spirv::Op), #[error("unsupported capability {0:?}")] UnsupportedCapability(spirv::Capability), #[error("unsupported extension {0}")] UnsupportedExtension(String), #[error("unsupported extension set {0}")] UnsupportedExtSet(String), #[error("unsupported extension instantiation set %{0}")] UnsupportedExtInstSet(spirv::Word), #[error("unsupported extension instantiation %{0}")] UnsupportedExtInst(spirv::Word), #[error("unsupported type {0:?}")] UnsupportedType(Handle), #[error("unsupported execution model %{0}")] UnsupportedExecutionModel(spirv::Word), #[error("unsupported execution mode %{0}")] UnsupportedExecutionMode(spirv::Word), #[error("unsupported storage class %{0}")] UnsupportedStorageClass(spirv::Word), #[error("unsupported image dimension %{0}")] UnsupportedImageDim(spirv::Word), #[error("unsupported image format %{0}")] UnsupportedImageFormat(spirv::Word), #[error("unsupported builtin %{0}")] UnsupportedBuiltIn(spirv::Word), #[error("unsupported control flow %{0}")] UnsupportedControlFlow(spirv::Word), #[error("unsupported binary operator %{0}")] UnsupportedBinaryOperator(spirv::Word), #[error("Naga supports OpTypeRuntimeArray in the StorageBuffer storage class only")] UnsupportedRuntimeArrayStorageClass, #[error( "unsupported matrix stride {} for a {}x{} matrix with scalar width={}", stride, columns, rows, width )] UnsupportedMatrixStride { stride: u32, columns: u8, rows: u8, width: u8, }, #[error("unknown binary operator {0:?}")] UnknownBinaryOperator(spirv::Op), #[error("unknown relational function {0:?}")] UnknownRelationalFunction(spirv::Op), #[error("unsupported group operation %{0}")] UnsupportedGroupOperation(spirv::Word), #[error("invalid parameter {0:?}")] InvalidParameter(spirv::Op), #[error("invalid operand count {1} for {0:?}")] InvalidOperandCount(spirv::Op, u16), #[error("invalid operand")] InvalidOperand, #[error("invalid id %{0}")] InvalidId(spirv::Word), #[error("invalid decoration %{0}")] InvalidDecoration(spirv::Word), #[error("invalid type width %{0}")] InvalidTypeWidth(spirv::Word), #[error("invalid sign %{0}")] InvalidSign(spirv::Word), #[error("invalid inner type %{0}")] InvalidInnerType(spirv::Word), #[error("invalid vector size %{0}")] InvalidVectorSize(spirv::Word), #[error("invalid access type %{0}")] InvalidAccessType(spirv::Word), #[error("invalid access {0:?}")] InvalidAccess(crate::Expression), #[error("invalid access index %{0}")] InvalidAccessIndex(spirv::Word), #[error("invalid index type %{0}")] InvalidIndexType(spirv::Word), #[error("invalid binding %{0}")] InvalidBinding(spirv::Word), #[error("invalid global var {0:?}")] InvalidGlobalVar(crate::Expression), #[error("invalid image/sampler expression {0:?}")] InvalidImageExpression(crate::Expression), #[error("cannot create a OpTypeImage as both a depth and storage image")] InvalidImageDepthStorage, #[error("image read/write without format is not currently supported. See https://github.com/gfx-rs/wgpu/issues/6797")] InvalidStorageImageWithoutFormat, #[error("invalid image base type {0:?}")] InvalidImageBaseType(Handle), #[error("invalid image {0:?}")] InvalidImage(Handle), #[error("invalid as type {0:?}")] InvalidAsType(Handle), #[error("invalid vector type {0:?}")] InvalidVectorType(Handle), #[error("inconsistent comparison sampling {0:?}")] InconsistentComparisonSampling(Handle), #[error("wrong function result type %{0}")] WrongFunctionResultType(spirv::Word), #[error("wrong function argument type %{0}")] WrongFunctionArgumentType(spirv::Word), #[error("missing decoration {0:?}")] MissingDecoration(spirv::Decoration), #[error("bad string")] BadString, #[error("incomplete data")] IncompleteData, #[error("invalid terminator")] InvalidTerminator, #[error("invalid edge classification")] InvalidEdgeClassification, #[error("cycle detected in the CFG during traversal at {0}")] ControlFlowGraphCycle(crate::front::spv::BlockId), #[error("recursive function call %{0}")] FunctionCallCycle(spirv::Word), #[error("invalid array size %{0}")] InvalidArraySize(spirv::Word), #[error("invalid barrier scope %{0}")] InvalidBarrierScope(spirv::Word), #[error("invalid barrier memory semantics %{0}")] InvalidBarrierMemorySemantics(spirv::Word), #[error( "arrays of images / samplers are supported only through bindings for \ now (i.e. you can't create an array of images or samplers that doesn't \ come from a binding)" )] NonBindingArrayOfImageOrSamplers, #[error("naga only supports specialization constant IDs up to 65535 but was given {0}")] SpecIdTooHigh(u32), #[error("atomic upgrade error: {0}")] AtomicUpgradeError(atomic_upgrade::Error), } impl Error { #[cfg(feature = "stderr")] pub fn emit_to_writer(&self, writer: &mut impl ErrorWrite, source: &str) { self.emit_to_writer_with_path(writer, source, "spv"); } #[cfg(feature = "stderr")] pub fn emit_to_writer_with_path(&self, writer: &mut impl ErrorWrite, source: &str, path: &str) { let path = path.to_string(); let files = SimpleFile::new(path, replace_control_chars(source)); let config = term::Config::default(); let diagnostic = Diagnostic::error().with_message(format!("{self:?}")); crate::error::emit_to_writer(writer, &config, &files, &diagnostic) .expect("cannot write error"); } pub fn emit_to_string(&self, source: &str) -> String { self.emit_to_string_with_path(source, "spv") } pub fn emit_to_string_with_path(&self, source: &str, path: &str) -> String { let path = path.to_string(); let files = SimpleFile::new(path, replace_control_chars(source)); let config = term::Config::default(); let diagnostic = Diagnostic::error().with_message(format!("{self:?}")); let mut writer = crate::error::DiagnosticBuffer::new(); writer .emit_to_self(&config, &files, &diagnostic) .expect("cannot write error"); writer.into_string() } } impl From for Error { fn from(source: atomic_upgrade::Error) -> Self { Error::AtomicUpgradeError(source) } } naga-29.0.3/src/front/spv/function.rs000064400000000000000000000734731046102023000155650ustar 00000000000000use alloc::{format, vec, vec::Vec}; use super::{Error, Instruction, LookupExpression, LookupHelper as _}; use crate::proc::Emitter; use crate::{ arena::{Arena, Handle}, front::spv::{BlockContext, BodyIndex}, }; pub type BlockId = u32; impl> super::Frontend { // Registers a function call. It will generate a dummy handle to call, which // gets resolved after all the functions are processed. pub(super) fn add_call( &mut self, from: spirv::Word, to: spirv::Word, ) -> Handle { let dummy_handle = self .dummy_functions .append(crate::Function::default(), Default::default()); self.deferred_function_calls.push(to); self.function_call_graph.add_edge(from, to, ()); dummy_handle } pub(super) fn parse_function(&mut self, module: &mut crate::Module) -> Result<(), Error> { let start = self.data_offset; self.lookup_expression.clear(); self.lookup_load_override.clear(); self.lookup_sampled_image.clear(); let result_type_id = self.next()?; let fun_id = self.next()?; let _fun_control = self.next()?; let fun_type_id = self.next()?; let mut fun = { let ft = self.lookup_function_type.lookup(fun_type_id)?; if ft.return_type_id != result_type_id { return Err(Error::WrongFunctionResultType(result_type_id)); } crate::Function { name: self.future_decor.remove(&fun_id).and_then(|dec| dec.name), arguments: Vec::with_capacity(ft.parameter_type_ids.len()), result: if self.lookup_void_type == Some(result_type_id) { None } else { let lookup_result_ty = self.lookup_type.lookup(result_type_id)?; Some(crate::FunctionResult { ty: lookup_result_ty.handle, binding: None, }) }, local_variables: Arena::new(), expressions: self.make_expression_storage( &module.global_variables, &module.constants, &module.overrides, ), named_expressions: crate::NamedExpressions::default(), body: crate::Block::new(), diagnostic_filter_leaf: None, } }; // read parameters for i in 0..fun.arguments.capacity() { let start = self.data_offset; match self.next_inst()? { Instruction { op: spirv::Op::FunctionParameter, wc: 3, } => { let type_id = self.next()?; let id = self.next()?; let handle = fun.expressions.append( crate::Expression::FunctionArgument(i as u32), self.span_from(start), ); self.lookup_expression.insert( id, LookupExpression { handle, type_id, // Setting this to an invalid id will cause get_expr_handle // to default to the main body making sure no load/stores // are added. block_id: 0, }, ); //Note: we redo the lookup in order to work around `self` borrowing if type_id != self .lookup_function_type .lookup(fun_type_id)? .parameter_type_ids[i] { return Err(Error::WrongFunctionArgumentType(type_id)); } let ty = self.lookup_type.lookup(type_id)?.handle; let decor = self.future_decor.remove(&id).unwrap_or_default(); fun.arguments.push(crate::FunctionArgument { name: decor.name, ty, binding: None, }); } Instruction { op, .. } => return Err(Error::InvalidParameter(op)), } } // Note the index this function's handle will be assigned, for tracing. let function_index = module.functions.len(); // Read body self.function_call_graph.add_node(fun_id); let mut parameters_sampling = vec![super::image::SamplingFlags::empty(); fun.arguments.len()]; let mut block_ctx = BlockContext { phis: Default::default(), blocks: Default::default(), body_for_label: Default::default(), mergers: Default::default(), bodies: Default::default(), module, function_id: fun_id, expressions: &mut fun.expressions, local_arena: &mut fun.local_variables, arguments: &fun.arguments, parameter_sampling: &mut parameters_sampling, }; // Insert the main body whose parent is also himself block_ctx.bodies.push(super::Body::with_parent(0)); // Scan the blocks and add them as nodes loop { let fun_inst = self.next_inst()?; log::debug!("{:?}", fun_inst.op); match fun_inst.op { spirv::Op::Line => { fun_inst.expect(4)?; let _file_id = self.next()?; let _row_id = self.next()?; let _col_id = self.next()?; } spirv::Op::Label => { // Read the label ID fun_inst.expect(2)?; let block_id = self.next()?; self.next_block(block_id, &mut block_ctx)?; } spirv::Op::FunctionEnd => { fun_inst.expect(1)?; break; } spirv::Op::ExtInst => { let _ = self.next()?; let _ = self.next()?; let set_id = self.next()?; if Some(set_id) == self.ext_non_semantic_id { for _ in 0..fun_inst.wc - 4 { self.next()?; } } else { return Err(Error::UnsupportedInstruction(self.state, fun_inst.op)); } } _ => { return Err(Error::UnsupportedInstruction(self.state, fun_inst.op)); } } } if let Some(ref prefix) = self.options.block_ctx_dump_prefix { let dump_suffix = match self.lookup_entry_point.get(&fun_id) { Some(ep) => format!("block_ctx.{:?}-{}.txt", ep.stage, ep.name), None => format!("block_ctx.Fun-{function_index}.txt"), }; cfg_if::cfg_if! { if #[cfg(feature = "fs")] { let prefix: &std::path::Path = prefix.as_ref(); let dest = prefix.join(dump_suffix); let dump = format!("{block_ctx:#?}"); if let Err(e) = std::fs::write(&dest, dump) { log::error!("Unable to dump the block context into {dest:?}: {e}"); } } else { log::error!("Unable to dump the block context into {prefix:?}/{dump_suffix}: file system integration was not enabled with the `fs` feature"); } } } // Emit `Store` statements to properly initialize all the local variables we // created for `phi` expressions. // // Note that get_expr_handle also contributes slightly odd entries to this table, // to get the spill. for phi in block_ctx.phis.iter() { // Get a pointer to the local variable for the phi's value. let phi_pointer: Handle = block_ctx.expressions.append( crate::Expression::LocalVariable(phi.local), crate::Span::default(), ); // At the end of each of `phi`'s predecessor blocks, store the corresponding // source value in the phi's local variable. for &(source, predecessor) in phi.expressions.iter() { let source_lexp = &self.lookup_expression[&source]; let predecessor_body_idx = block_ctx.body_for_label[&predecessor]; // If the expression is a global/argument it will have a 0 block // id so we must use a default value instead of panicking let source_body_idx = block_ctx .body_for_label .get(&source_lexp.block_id) .copied() .unwrap_or(0); // If the Naga `Expression` generated for `source` is in scope, then we // can simply store that in the phi's local variable. // // Otherwise, spill the source value to a local variable in the block that // defines it. (We know this store dominates the predecessor; otherwise, // the phi wouldn't have been able to refer to that source expression in // the first place.) Then, the predecessor block can count on finding the // source's value in that local variable. let value = if super::is_parent(predecessor_body_idx, source_body_idx, &block_ctx) { source_lexp.handle } else { // The source SPIR-V expression is not defined in the phi's // predecessor block, nor is it a globally available expression. So it // must be defined off in some other block that merely dominates the // predecessor. This means that the corresponding Naga `Expression` // may not be in scope in the predecessor block. // // In the block that defines `source`, spill it to a fresh local // variable, to ensure we can still use it at the end of the // predecessor. let ty = self.lookup_type[&source_lexp.type_id].handle; let local = block_ctx.local_arena.append( crate::LocalVariable { name: None, ty, init: None, }, crate::Span::default(), ); let pointer = block_ctx.expressions.append( crate::Expression::LocalVariable(local), crate::Span::default(), ); // Get the spilled value of the source expression. let start = block_ctx.expressions.len(); let expr = block_ctx .expressions .append(crate::Expression::Load { pointer }, crate::Span::default()); let range = block_ctx.expressions.range_from(start); block_ctx .blocks .get_mut(&predecessor) .unwrap() .push(crate::Statement::Emit(range), crate::Span::default()); // At the end of the block that defines it, spill the source // expression's value. block_ctx .blocks .get_mut(&source_lexp.block_id) .unwrap() .push( crate::Statement::Store { pointer, value: source_lexp.handle, }, crate::Span::default(), ); expr }; // At the end of the phi predecessor block, store the source // value in the phi's value. block_ctx.blocks.get_mut(&predecessor).unwrap().push( crate::Statement::Store { pointer: phi_pointer, value, }, crate::Span::default(), ) } } fun.body = block_ctx.lower(); // done let fun_handle = module.functions.append(fun, self.span_from_with_op(start)); self.lookup_function.insert( fun_id, super::LookupFunction { handle: fun_handle, parameters_sampling, }, ); if let Some(ep) = self.lookup_entry_point.remove(&fun_id) { self.deferred_entry_points.push((ep, fun_id)); } Ok(()) } pub(super) fn process_entry_point( &mut self, module: &mut crate::Module, ep: super::EntryPoint, fun_id: u32, ) -> Result<(), Error> { // create a wrapping function let mut function = crate::Function { name: Some(format!("{}_wrap", ep.name)), arguments: Vec::new(), result: None, local_variables: Arena::new(), expressions: Arena::new(), named_expressions: crate::NamedExpressions::default(), body: crate::Block::new(), diagnostic_filter_leaf: None, }; // 1. copy the inputs from arguments to privates for &v_id in ep.variable_ids.iter() { let lvar = self.lookup_variable.lookup(v_id)?; if let super::Variable::Input(ref arg) = lvar.inner { let span = module.global_variables.get_span(lvar.handle); let arg_expr = function.expressions.append( crate::Expression::FunctionArgument(function.arguments.len() as u32), span, ); let load_expr = if arg.ty == module.global_variables[lvar.handle].ty { arg_expr } else { // The only case where the type is different is if we need to treat // unsigned integer as signed. let mut emitter = Emitter::default(); emitter.start(&function.expressions); let handle = function.expressions.append( crate::Expression::As { expr: arg_expr, kind: crate::ScalarKind::Sint, convert: Some(4), }, span, ); function.body.extend(emitter.finish(&function.expressions)); handle }; function.body.push( crate::Statement::Store { pointer: function .expressions .append(crate::Expression::GlobalVariable(lvar.handle), span), value: load_expr, }, span, ); let mut arg = arg.clone(); if ep.stage == crate::ShaderStage::Fragment { if let Some(ref mut binding) = arg.binding { binding.apply_default_interpolation(&module.types[arg.ty].inner); } } function.arguments.push(arg); } } // 2. call the wrapped function let fake_id = !(module.entry_points.len() as u32); // doesn't matter, as long as it's not a collision let dummy_handle = self.add_call(fake_id, fun_id); function.body.push( crate::Statement::Call { function: dummy_handle, arguments: Vec::new(), result: None, }, crate::Span::default(), ); // 3. copy the outputs from privates to the result // // It would be nice to share struct layout code here with `parse_type_struct`, // but that case needs to take into account offset decorations, which makes an // abstraction harder to follow than just writing out what we mean. `Layouter` // and `Alignment` cover the worst parts already. let mut members = Vec::new(); self.layouter.update(module.to_ctx()).unwrap(); let mut next_member_offset = 0; let mut struct_alignment = crate::proc::Alignment::ONE; let mut components = Vec::new(); for &v_id in ep.variable_ids.iter() { let lvar = self.lookup_variable.lookup(v_id)?; if let super::Variable::Output(ref result) = lvar.inner { let span = module.global_variables.get_span(lvar.handle); let expr_handle = function .expressions .append(crate::Expression::GlobalVariable(lvar.handle), span); // Cull problematic builtins of gl_PerVertex. // See the docs for `Frontend::gl_per_vertex_builtin_access`. { let ty = &module.types[result.ty]; if let crate::TypeInner::Struct { members: ref original_members, span, } = ty.inner { let mut new_members = None; for (idx, member) in original_members.iter().enumerate() { if let Some(crate::Binding::BuiltIn(built_in)) = member.binding { if !self.gl_per_vertex_builtin_access.contains(&built_in) { new_members.get_or_insert_with(|| original_members.clone()) [idx] .binding = None; } } } if let Some(new_members) = new_members { module.types.replace( result.ty, crate::Type { name: ty.name.clone(), inner: crate::TypeInner::Struct { members: new_members, span, }, }, ); } } } match module.types[result.ty].inner { crate::TypeInner::Struct { members: ref sub_members, .. } => { for (index, sm) in sub_members.iter().enumerate() { if sm.binding.is_none() { continue; } let mut sm = sm.clone(); if let Some(ref mut binding) = sm.binding { if ep.stage == crate::ShaderStage::Vertex { binding.apply_default_interpolation(&module.types[sm.ty].inner); } } let member_alignment = self.layouter[sm.ty].alignment; next_member_offset = member_alignment.round_up(next_member_offset); sm.offset = next_member_offset; struct_alignment = struct_alignment.max(member_alignment); next_member_offset += self.layouter[sm.ty].size; members.push(sm); components.push(function.expressions.append( crate::Expression::AccessIndex { base: expr_handle, index: index as u32, }, span, )); } } ref inner => { let mut binding = result.binding.clone(); if let Some(ref mut binding) = binding { if ep.stage == crate::ShaderStage::Vertex { binding.apply_default_interpolation(inner); } } let member_alignment = self.layouter[result.ty].alignment; next_member_offset = member_alignment.round_up(next_member_offset); members.push(crate::StructMember { name: None, ty: result.ty, binding, offset: next_member_offset, }); struct_alignment = struct_alignment.max(member_alignment); next_member_offset += self.layouter[result.ty].size; // populate just the globals first, then do `Load` in a // separate step, so that we can get a range. components.push(expr_handle); } } } } for (member_index, member) in members.iter().enumerate() { match member.binding { Some(crate::Binding::BuiltIn(crate::BuiltIn::Position { .. })) if self.options.adjust_coordinate_space => { let mut emitter = Emitter::default(); emitter.start(&function.expressions); let global_expr = components[member_index]; let span = function.expressions.get_span(global_expr); let access_expr = function.expressions.append( crate::Expression::AccessIndex { base: global_expr, index: 1, }, span, ); let load_expr = function.expressions.append( crate::Expression::Load { pointer: access_expr, }, span, ); let neg_expr = function.expressions.append( crate::Expression::Unary { op: crate::UnaryOperator::Negate, expr: load_expr, }, span, ); function.body.extend(emitter.finish(&function.expressions)); function.body.push( crate::Statement::Store { pointer: access_expr, value: neg_expr, }, span, ); } _ => {} } } let mut emitter = Emitter::default(); emitter.start(&function.expressions); for component in components.iter_mut() { let load_expr = crate::Expression::Load { pointer: *component, }; let span = function.expressions.get_span(*component); *component = function.expressions.append(load_expr, span); } match members[..] { [] => {} [ref member] => { function.body.extend(emitter.finish(&function.expressions)); let span = function.expressions.get_span(components[0]); function.body.push( crate::Statement::Return { value: components.first().cloned(), }, span, ); function.result = Some(crate::FunctionResult { ty: member.ty, binding: member.binding.clone(), }); } _ => { let span = crate::Span::total_span( components.iter().map(|h| function.expressions.get_span(*h)), ); let ty = module.types.insert( crate::Type { name: None, inner: crate::TypeInner::Struct { members, span: struct_alignment.round_up(next_member_offset), }, }, span, ); let result_expr = function .expressions .append(crate::Expression::Compose { ty, components }, span); function.body.extend(emitter.finish(&function.expressions)); function.body.push( crate::Statement::Return { value: Some(result_expr), }, span, ); function.result = Some(crate::FunctionResult { ty, binding: None }); } } module.entry_points.push(crate::EntryPoint { name: ep.name, stage: ep.stage, early_depth_test: ep.early_depth_test, workgroup_size: ep.workgroup_size, workgroup_size_overrides: None, function, mesh_info: None, task_payload: None, incoming_ray_payload: None, }); Ok(()) } } impl BlockContext<'_> { pub(super) const fn gctx(&self) -> crate::proc::GlobalCtx<'_> { crate::proc::GlobalCtx { types: &self.module.types, constants: &self.module.constants, overrides: &self.module.overrides, global_expressions: &self.module.global_expressions, } } /// Consumes the `BlockContext` producing a Ir [`Block`](crate::Block) fn lower(mut self) -> crate::Block { fn lower_impl( blocks: &mut crate::FastHashMap, bodies: &[super::Body], body_idx: BodyIndex, ) -> crate::Block { let mut block = crate::Block::new(); for item in bodies[body_idx].data.iter() { match *item { super::BodyFragment::BlockId(id) => block.append(blocks.get_mut(&id).unwrap()), super::BodyFragment::If { condition, accept, reject, } => { let accept = lower_impl(blocks, bodies, accept); let reject = lower_impl(blocks, bodies, reject); block.push( crate::Statement::If { condition, accept, reject, }, crate::Span::default(), ) } super::BodyFragment::Loop { body, continuing, break_if, } => { let body = lower_impl(blocks, bodies, body); let continuing = lower_impl(blocks, bodies, continuing); block.push( crate::Statement::Loop { body, continuing, break_if, }, crate::Span::default(), ) } super::BodyFragment::Switch { selector, ref cases, default, } => { let mut ir_cases: Vec<_> = cases .iter() .map(|&(value, body_idx)| { let body = lower_impl(blocks, bodies, body_idx); // Handle simple cases that would make a fallthrough statement unreachable code let fall_through = body.last().is_none_or(|s| !s.is_terminator()); crate::SwitchCase { value: crate::SwitchValue::I32(value), body, fall_through, } }) .collect(); ir_cases.push(crate::SwitchCase { value: crate::SwitchValue::Default, body: lower_impl(blocks, bodies, default), fall_through: false, }); block.push( crate::Statement::Switch { selector, cases: ir_cases, }, crate::Span::default(), ) } super::BodyFragment::Break => { block.push(crate::Statement::Break, crate::Span::default()) } super::BodyFragment::Continue => { block.push(crate::Statement::Continue, crate::Span::default()) } } } block } lower_impl(&mut self.blocks, &self.bodies, 0) } } naga-29.0.3/src/front/spv/image.rs000064400000000000000000001050431046102023000150070ustar 00000000000000use alloc::vec::Vec; use crate::{ arena::{Handle, UniqueArena}, Scalar, }; use super::{Error, LookupExpression, LookupHelper as _}; #[derive(Clone, Debug)] pub(super) struct LookupSampledImage { image: Handle, sampler: Handle, } bitflags::bitflags! { /// Flags describing sampling method. #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub struct SamplingFlags: u32 { /// Regular sampling. const REGULAR = 0x1; /// Comparison sampling. const COMPARISON = 0x2; } } impl super::BlockContext<'_> { fn get_image_expr_ty( &self, handle: Handle, ) -> Result, Error> { match self.expressions[handle] { crate::Expression::GlobalVariable(handle) => { Ok(self.module.global_variables[handle].ty) } crate::Expression::FunctionArgument(i) => Ok(self.arguments[i as usize].ty), crate::Expression::Access { base, .. } => Ok(self.get_image_expr_ty(base)?), ref other => Err(Error::InvalidImageExpression(other.clone())), } } } /// Options of a sampling operation. #[derive(Debug)] pub struct SamplingOptions { /// Projection sampling: the division by W is expected to happen /// in the texture unit. pub project: bool, /// Depth comparison sampling with a reference value. pub compare: bool, /// Gather sampling: Operates on four samples of one channel. pub gather: bool, } enum ExtraCoordinate { ArrayLayer, Projection, Garbage, } /// Return the texture coordinates separated from the array layer, /// and/or divided by the projection term. /// /// The Proj sampling ops expect an extra coordinate for the W. /// The arrayed (can't be Proj!) images expect an extra coordinate for the layer. fn extract_image_coordinates( image_dim: crate::ImageDimension, extra_coordinate: ExtraCoordinate, base: Handle, coordinate_ty: Handle, ctx: &mut super::BlockContext, ) -> (Handle, Option>) { let (given_size, kind) = match ctx.module.types[coordinate_ty].inner { crate::TypeInner::Scalar(Scalar { kind, .. }) => (None, kind), crate::TypeInner::Vector { size, scalar: Scalar { kind, .. }, } => (Some(size), kind), ref other => unreachable!("Unexpected texture coordinate {:?}", other), }; let required_size = image_dim.required_coordinate_size(); let required_ty = required_size.map(|size| { ctx.module .types .get(&crate::Type { name: None, inner: crate::TypeInner::Vector { size, scalar: Scalar { kind, width: 4 }, }, }) .expect("Required coordinate type should have been set up by `parse_type_image`!") }); let extra_expr = crate::Expression::AccessIndex { base, index: required_size.map_or(1, |size| size as u32), }; let base_span = ctx.expressions.get_span(base); match extra_coordinate { ExtraCoordinate::ArrayLayer => { let extracted = match required_size { None => ctx .expressions .append(crate::Expression::AccessIndex { base, index: 0 }, base_span), Some(size) => { let mut components = Vec::with_capacity(size as usize); for index in 0..size as u32 { let comp = ctx .expressions .append(crate::Expression::AccessIndex { base, index }, base_span); components.push(comp); } ctx.expressions.append( crate::Expression::Compose { ty: required_ty.unwrap(), components, }, base_span, ) } }; let array_index_f32 = ctx.expressions.append(extra_expr, base_span); let array_index = ctx.expressions.append( crate::Expression::As { kind: crate::ScalarKind::Sint, expr: array_index_f32, convert: Some(4), }, base_span, ); (extracted, Some(array_index)) } ExtraCoordinate::Projection => { let projection = ctx.expressions.append(extra_expr, base_span); let divided = match required_size { None => { let temp = ctx .expressions .append(crate::Expression::AccessIndex { base, index: 0 }, base_span); ctx.expressions.append( crate::Expression::Binary { op: crate::BinaryOperator::Divide, left: temp, right: projection, }, base_span, ) } Some(size) => { let mut components = Vec::with_capacity(size as usize); for index in 0..size as u32 { let temp = ctx .expressions .append(crate::Expression::AccessIndex { base, index }, base_span); let comp = ctx.expressions.append( crate::Expression::Binary { op: crate::BinaryOperator::Divide, left: temp, right: projection, }, base_span, ); components.push(comp); } ctx.expressions.append( crate::Expression::Compose { ty: required_ty.unwrap(), components, }, base_span, ) } }; (divided, None) } ExtraCoordinate::Garbage if given_size == required_size => (base, None), ExtraCoordinate::Garbage => { use crate::SwizzleComponent as Sc; let cut_expr = match required_size { None => crate::Expression::AccessIndex { base, index: 0 }, Some(size) => crate::Expression::Swizzle { size, vector: base, pattern: [Sc::X, Sc::Y, Sc::Z, Sc::W], }, }; (ctx.expressions.append(cut_expr, base_span), None) } } } pub(super) fn patch_comparison_type( flags: SamplingFlags, var: &mut crate::GlobalVariable, arena: &mut UniqueArena, ) -> bool { if !flags.contains(SamplingFlags::COMPARISON) { return true; } if flags == SamplingFlags::all() { return false; } log::debug!("Flipping comparison for {var:?}"); let original_ty = &arena[var.ty]; let original_ty_span = arena.get_span(var.ty); let ty_inner = match original_ty.inner { crate::TypeInner::Image { class: crate::ImageClass::Sampled { multi, .. }, dim, arrayed, } => crate::TypeInner::Image { class: crate::ImageClass::Depth { multi }, dim, arrayed, }, crate::TypeInner::Sampler { .. } => crate::TypeInner::Sampler { comparison: true }, ref other => unreachable!("Unexpected type for comparison mutation: {:?}", other), }; let name = original_ty.name.clone(); var.ty = arena.insert( crate::Type { name, inner: ty_inner, }, original_ty_span, ); true } impl> super::Frontend { pub(super) fn parse_image_couple(&mut self) -> Result<(), Error> { let _result_type_id = self.next()?; let result_id = self.next()?; let image_id = self.next()?; let sampler_id = self.next()?; let image_lexp = self.lookup_expression.lookup(image_id)?; let sampler_lexp = self.lookup_expression.lookup(sampler_id)?; self.lookup_sampled_image.insert( result_id, LookupSampledImage { image: image_lexp.handle, sampler: sampler_lexp.handle, }, ); Ok(()) } pub(super) fn parse_image_uncouple(&mut self, block_id: spirv::Word) -> Result<(), Error> { let result_type_id = self.next()?; let result_id = self.next()?; let sampled_image_id = self.next()?; self.lookup_expression.insert( result_id, LookupExpression { handle: self.lookup_sampled_image.lookup(sampled_image_id)?.image, type_id: result_type_id, block_id, }, ); Ok(()) } pub(super) fn parse_image_write( &mut self, words_left: u16, ctx: &mut super::BlockContext, emitter: &mut crate::proc::Emitter, block: &mut crate::Block, body_idx: usize, ) -> Result { let image_id = self.next()?; let coordinate_id = self.next()?; let value_id = self.next()?; let image_ops = if words_left != 0 { self.next()? } else { 0 }; if image_ops != 0 { let other = spirv::ImageOperands::from_bits_truncate(image_ops); log::warn!("Unknown image write ops {other:?}"); for _ in 1..words_left { self.next()?; } } let image_lexp = self.lookup_expression.lookup(image_id)?; let image_ty = ctx.get_image_expr_ty(image_lexp.handle)?; let coord_lexp = self.lookup_expression.lookup(coordinate_id)?; let coord_handle = self.get_expr_handle(coordinate_id, coord_lexp, ctx, emitter, block, body_idx); let coord_type_handle = self.lookup_type.lookup(coord_lexp.type_id)?.handle; let (coordinate, array_index) = match ctx.module.types[image_ty].inner { crate::TypeInner::Image { dim, arrayed, class: _, } => extract_image_coordinates( dim, if arrayed { ExtraCoordinate::ArrayLayer } else { ExtraCoordinate::Garbage }, coord_handle, coord_type_handle, ctx, ), _ => return Err(Error::InvalidImage(image_ty)), }; let value_lexp = self.lookup_expression.lookup(value_id)?; let value = self.get_expr_handle(value_id, value_lexp, ctx, emitter, block, body_idx); let value_type = self.lookup_type.lookup(value_lexp.type_id)?.handle; // In hlsl etc, the write value may not be the vector 4. let expanded_value = match ctx.module.types[value_type].inner { crate::TypeInner::Scalar(_) => Some(crate::Expression::Splat { value, size: crate::VectorSize::Quad, }), crate::TypeInner::Vector { size, .. } => match size { crate::VectorSize::Bi => Some(crate::Expression::Swizzle { size: crate::VectorSize::Quad, vector: value, pattern: [ crate::SwizzleComponent::X, crate::SwizzleComponent::Y, crate::SwizzleComponent::Y, crate::SwizzleComponent::Y, ], }), crate::VectorSize::Tri => Some(crate::Expression::Swizzle { size: crate::VectorSize::Quad, vector: value, pattern: [ crate::SwizzleComponent::X, crate::SwizzleComponent::Y, crate::SwizzleComponent::Z, crate::SwizzleComponent::Z, ], }), crate::VectorSize::Quad => None, }, _ => return Err(Error::InvalidVectorType(value_type)), }; let value_patched = if let Some(s) = expanded_value { ctx.expressions.append(s, crate::Span::default()) } else { value }; Ok(crate::Statement::ImageStore { image: image_lexp.handle, coordinate, array_index, value: value_patched, }) } pub(super) fn parse_image_load( &mut self, mut words_left: u16, ctx: &mut super::BlockContext, emitter: &mut crate::proc::Emitter, block: &mut crate::Block, block_id: spirv::Word, body_idx: usize, ) -> Result<(), Error> { let start = self.data_offset; let result_type_id = self.next()?; let result_id = self.next()?; let image_id = self.next()?; let coordinate_id = self.next()?; let mut image_ops = if words_left != 0 { words_left -= 1; self.next()? } else { 0 }; let mut sample = None; let mut level = None; while image_ops != 0 { let bit = 1 << image_ops.trailing_zeros(); match spirv::ImageOperands::from_bits_truncate(bit) { spirv::ImageOperands::LOD => { let lod_expr = self.next()?; let lod_lexp = self.lookup_expression.lookup(lod_expr)?; let lod_handle = self.get_expr_handle(lod_expr, lod_lexp, ctx, emitter, block, body_idx); level = Some(lod_handle); words_left -= 1; } spirv::ImageOperands::SAMPLE => { let sample_expr = self.next()?; let sample_handle = self.lookup_expression.lookup(sample_expr)?.handle; sample = Some(sample_handle); words_left -= 1; } other => { log::warn!("Unknown image load op {other:?}"); for _ in 0..words_left { self.next()?; } break; } } image_ops ^= bit; } // No need to call get_expr_handle here since only globals/arguments are // allowed as images and they are always in the root scope let image_lexp = self.lookup_expression.lookup(image_id)?; let image_ty = ctx.get_image_expr_ty(image_lexp.handle)?; let coord_lexp = self.lookup_expression.lookup(coordinate_id)?; let coord_handle = self.get_expr_handle(coordinate_id, coord_lexp, ctx, emitter, block, body_idx); let coord_type_handle = self.lookup_type.lookup(coord_lexp.type_id)?.handle; let (coordinate, array_index, is_depth) = match ctx.module.types[image_ty].inner { crate::TypeInner::Image { dim, arrayed, class, } => { let (coord, array_index) = extract_image_coordinates( dim, if arrayed { ExtraCoordinate::ArrayLayer } else { ExtraCoordinate::Garbage }, coord_handle, coord_type_handle, ctx, ); (coord, array_index, class.is_depth()) } _ => return Err(Error::InvalidImage(image_ty)), }; let image_load_expr = crate::Expression::ImageLoad { image: image_lexp.handle, coordinate, array_index, sample, level, }; let image_load_handle = ctx .expressions .append(image_load_expr, self.span_from_with_op(start)); let handle = if is_depth { let result_ty = self.lookup_type.lookup(result_type_id)?; // The return type of `OpImageRead` can be a scalar or vector. match ctx.module.types[result_ty.handle].inner { crate::TypeInner::Vector { size, .. } => { let splat_expr = crate::Expression::Splat { size, value: image_load_handle, }; ctx.expressions .append(splat_expr, self.span_from_with_op(start)) } _ => image_load_handle, } } else { image_load_handle }; self.lookup_expression.insert( result_id, LookupExpression { handle, type_id: result_type_id, block_id, }, ); Ok(()) } #[allow(clippy::too_many_arguments)] pub(super) fn parse_image_sample( &mut self, mut words_left: u16, options: SamplingOptions, ctx: &mut super::BlockContext, emitter: &mut crate::proc::Emitter, block: &mut crate::Block, block_id: spirv::Word, body_idx: usize, ) -> Result<(), Error> { let start = self.data_offset; let result_type_id = self.next()?; let result_id = self.next()?; let sampled_image_id = self.next()?; let coordinate_id = self.next()?; let (component_id, dref_id) = match (options.gather, options.compare) { (true, false) => (Some(self.next()?), None), (_, true) => (None, Some(self.next()?)), (_, _) => (None, None), }; let span = self.span_from_with_op(start); let mut image_ops = if words_left != 0 { words_left -= 1; self.next()? } else { 0 }; let mut level = crate::SampleLevel::Auto; let mut offset = None; while image_ops != 0 { let bit = 1 << image_ops.trailing_zeros(); match spirv::ImageOperands::from_bits_truncate(bit) { spirv::ImageOperands::BIAS => { let bias_expr = self.next()?; let bias_lexp = self.lookup_expression.lookup(bias_expr)?; let bias_handle = self.get_expr_handle(bias_expr, bias_lexp, ctx, emitter, block, body_idx); level = crate::SampleLevel::Bias(bias_handle); words_left -= 1; } spirv::ImageOperands::LOD => { let lod_expr = self.next()?; let lod_lexp = self.lookup_expression.lookup(lod_expr)?; let lod_handle = self.get_expr_handle(lod_expr, lod_lexp, ctx, emitter, block, body_idx); let is_depth_image = { let image_lexp = self.lookup_sampled_image.lookup(sampled_image_id)?; let image_ty = ctx.get_image_expr_ty(image_lexp.image)?; matches!( ctx.module.types[image_ty].inner, crate::TypeInner::Image { class: crate::ImageClass::Depth { .. }, .. } ) }; level = if options.compare { log::debug!("Assuming {lod_handle:?} is zero"); crate::SampleLevel::Zero } else if is_depth_image { log::debug!( "Assuming level {lod_handle:?} converts losslessly to an integer" ); let expr = crate::Expression::As { expr: lod_handle, kind: crate::ScalarKind::Sint, convert: Some(4), }; let s32_lod_handle = ctx.expressions.append(expr, span); crate::SampleLevel::Exact(s32_lod_handle) } else { crate::SampleLevel::Exact(lod_handle) }; words_left -= 1; } spirv::ImageOperands::GRAD => { let grad_x_expr = self.next()?; let grad_x_lexp = self.lookup_expression.lookup(grad_x_expr)?; let grad_x_handle = self.get_expr_handle( grad_x_expr, grad_x_lexp, ctx, emitter, block, body_idx, ); let grad_y_expr = self.next()?; let grad_y_lexp = self.lookup_expression.lookup(grad_y_expr)?; let grad_y_handle = self.get_expr_handle( grad_y_expr, grad_y_lexp, ctx, emitter, block, body_idx, ); level = if options.compare { log::debug!( "Assuming gradients {grad_x_handle:?} and {grad_y_handle:?} are not greater than 1" ); crate::SampleLevel::Zero } else { crate::SampleLevel::Gradient { x: grad_x_handle, y: grad_y_handle, } }; words_left -= 2; } spirv::ImageOperands::CONST_OFFSET => { let offset_expr = self.next()?; let offset_lexp = self.lookup_expression.lookup(offset_expr)?; let offset_handle = self.get_expr_handle( offset_expr, offset_lexp, ctx, emitter, block, body_idx, ); offset = Some(offset_handle); words_left -= 1; } other => { log::warn!("Unknown image sample operand {other:?}"); for _ in 0..words_left { self.next()?; } break; } } image_ops ^= bit; } let si_lexp = self.lookup_sampled_image.lookup(sampled_image_id)?; let coord_lexp = self.lookup_expression.lookup(coordinate_id)?; let coord_handle = self.get_expr_handle(coordinate_id, coord_lexp, ctx, emitter, block, body_idx); let coord_type_handle = self.lookup_type.lookup(coord_lexp.type_id)?.handle; let gather = match (options.gather, component_id) { (true, Some(component_id)) => { let component_lexp = self.lookup_expression.lookup(component_id)?; let component_value = match ctx.expressions[component_lexp.handle] { // VUID-StandaloneSpirv-OpImageGather-04664: // The “Component” operand of OpImageGather, and OpImageSparseGather must be the // of a constant instruction. crate::Expression::Constant(const_handle) => { let constant = &ctx.module.constants[const_handle]; match ctx.module.global_expressions[constant.init] { // SPIR-V specification: "It must be a 32-bit integer type scalar." crate::Expression::Literal(crate::Literal::U32(value)) => value, crate::Expression::Literal(crate::Literal::I32(value)) => value as u32, _ => { log::error!( "Image gather component constant must be a 32-bit integer literal" ); return Err(Error::InvalidOperand); } } } _ => { log::error!("Image gather component must be a constant"); return Err(Error::InvalidOperand); } }; debug_assert_eq!(level, crate::SampleLevel::Auto); level = crate::SampleLevel::Zero; // SPIR-V specification: "Behavior is undefined if its value is not 0, 1, 2 or 3." match component_value { 0 => Some(crate::SwizzleComponent::X), 1 => Some(crate::SwizzleComponent::Y), 2 => Some(crate::SwizzleComponent::Z), 3 => Some(crate::SwizzleComponent::W), other => { log::error!("Invalid gather component operand: {other}"); return Err(Error::InvalidOperand); } } } (true, None) => { debug_assert_eq!(level, crate::SampleLevel::Auto); level = crate::SampleLevel::Zero; Some(crate::SwizzleComponent::X) } (_, _) => None, }; let sampling_bit = if options.compare { SamplingFlags::COMPARISON } else { SamplingFlags::REGULAR }; let image_ty = match ctx.expressions[si_lexp.image] { crate::Expression::GlobalVariable(handle) => { if let Some(flags) = self.handle_sampling.get_mut(&handle) { *flags |= sampling_bit; } ctx.module.global_variables[handle].ty } crate::Expression::FunctionArgument(i) => { ctx.parameter_sampling[i as usize] |= sampling_bit; ctx.arguments[i as usize].ty } crate::Expression::Access { base, .. } => match ctx.expressions[base] { crate::Expression::GlobalVariable(handle) => { if let Some(flags) = self.handle_sampling.get_mut(&handle) { *flags |= sampling_bit; } match ctx.module.types[ctx.module.global_variables[handle].ty].inner { crate::TypeInner::BindingArray { base, .. } => base, _ => return Err(Error::InvalidGlobalVar(ctx.expressions[base].clone())), } } ref other => return Err(Error::InvalidGlobalVar(other.clone())), }, ref other => return Err(Error::InvalidGlobalVar(other.clone())), }; match ctx.expressions[si_lexp.sampler] { crate::Expression::GlobalVariable(handle) => { *self.handle_sampling.get_mut(&handle).unwrap() |= sampling_bit; } crate::Expression::FunctionArgument(i) => { ctx.parameter_sampling[i as usize] |= sampling_bit; } crate::Expression::Access { base, .. } => match ctx.expressions[base] { crate::Expression::GlobalVariable(handle) => { *self.handle_sampling.get_mut(&handle).unwrap() |= sampling_bit; } ref other => return Err(Error::InvalidGlobalVar(other.clone())), }, ref other => return Err(Error::InvalidGlobalVar(other.clone())), } let ((coordinate, array_index), depth_ref, is_depth) = match ctx.module.types[image_ty].inner { crate::TypeInner::Image { dim, arrayed, class, } => ( extract_image_coordinates( dim, if options.project { ExtraCoordinate::Projection } else if arrayed { ExtraCoordinate::ArrayLayer } else { ExtraCoordinate::Garbage }, coord_handle, coord_type_handle, ctx, ), { match dref_id { Some(id) => { let expr_lexp = self.lookup_expression.lookup(id)?; let mut expr = self .get_expr_handle(id, expr_lexp, ctx, emitter, block, body_idx); if options.project { let required_size = dim.required_coordinate_size(); let right = ctx.expressions.append( crate::Expression::AccessIndex { base: coord_handle, index: required_size.map_or(1, |size| size as u32), }, crate::Span::default(), ); expr = ctx.expressions.append( crate::Expression::Binary { op: crate::BinaryOperator::Divide, left: expr, right, }, crate::Span::default(), ) }; Some(expr) } None => None, } }, class.is_depth(), ), _ => return Err(Error::InvalidImage(image_ty)), }; let expr = crate::Expression::ImageSample { image: si_lexp.image, sampler: si_lexp.sampler, gather, coordinate, array_index, offset, level, depth_ref, clamp_to_edge: false, }; let image_sample_handle = ctx.expressions.append(expr, self.span_from_with_op(start)); let handle = if is_depth && depth_ref.is_none() { let splat_expr = crate::Expression::Splat { size: crate::VectorSize::Quad, value: image_sample_handle, }; ctx.expressions .append(splat_expr, self.span_from_with_op(start)) } else { image_sample_handle }; self.lookup_expression.insert( result_id, LookupExpression { handle, type_id: result_type_id, block_id, }, ); Ok(()) } pub(super) fn parse_image_query_size( &mut self, at_level: bool, ctx: &mut super::BlockContext, emitter: &mut crate::proc::Emitter, block: &mut crate::Block, block_id: spirv::Word, body_idx: usize, ) -> Result<(), Error> { let start = self.data_offset; let result_type_id = self.next()?; let result_id = self.next()?; let image_id = self.next()?; let level = if at_level { let level_id = self.next()?; let level_lexp = self.lookup_expression.lookup(level_id)?; Some(self.get_expr_handle(level_id, level_lexp, ctx, emitter, block, body_idx)) } else { None }; // No need to call get_expr_handle here since only globals/arguments are // allowed as images and they are always in the root scope //TODO: handle arrays and cubes let image_lexp = self.lookup_expression.lookup(image_id)?; let expr = crate::Expression::ImageQuery { image: image_lexp.handle, query: crate::ImageQuery::Size { level }, }; let result_type_handle = self.lookup_type.lookup(result_type_id)?.handle; let maybe_scalar_kind = ctx.module.types[result_type_handle].inner.scalar_kind(); let expr = if maybe_scalar_kind == Some(crate::ScalarKind::Sint) { crate::Expression::As { expr: ctx.expressions.append(expr, self.span_from_with_op(start)), kind: crate::ScalarKind::Sint, convert: Some(4), } } else { expr }; self.lookup_expression.insert( result_id, LookupExpression { handle: ctx.expressions.append(expr, self.span_from_with_op(start)), type_id: result_type_id, block_id, }, ); Ok(()) } pub(super) fn parse_image_query_other( &mut self, query: crate::ImageQuery, ctx: &mut super::BlockContext, block_id: spirv::Word, ) -> Result<(), Error> { let start = self.data_offset; let result_type_id = self.next()?; let result_id = self.next()?; let image_id = self.next()?; // No need to call get_expr_handle here since only globals/arguments are // allowed as images and they are always in the root scope let image_lexp = self.lookup_expression.lookup(image_id)?.clone(); let expr = crate::Expression::ImageQuery { image: image_lexp.handle, query, }; let result_type_handle = self.lookup_type.lookup(result_type_id)?.handle; let maybe_scalar_kind = ctx.module.types[result_type_handle].inner.scalar_kind(); let expr = if maybe_scalar_kind == Some(crate::ScalarKind::Sint) { crate::Expression::As { expr: ctx.expressions.append(expr, self.span_from_with_op(start)), kind: crate::ScalarKind::Sint, convert: Some(4), } } else { expr }; self.lookup_expression.insert( result_id, LookupExpression { handle: ctx.expressions.append(expr, self.span_from_with_op(start)), type_id: result_type_id, block_id, }, ); Ok(()) } } naga-29.0.3/src/front/spv/mod.rs000064400000000000000000003505101046102023000145050ustar 00000000000000/*! Frontend for [SPIR-V][spv] (Standard Portable Intermediate Representation). ## ID lookups Our IR links to everything with `Handle`, while SPIR-V uses IDs. In order to keep track of the associations, the parser has many lookup tables. There map `spv::Word` into a specific IR handle, plus potentially a bit of extra info, such as the related SPIR-V type ID. TODO: would be nice to find ways that avoid looking up as much ## Inputs/Outputs We create a private variable for each input/output. The relevant inputs are populated at the start of an entry point. The outputs are saved at the end. The function associated with an entry point is wrapped in another function, such that we can handle any `Return` statements without problems. ## Row-major matrices We don't handle them natively, since the IR only expects column majority. Instead, we detect when such matrix is accessed in the `OpAccessChain`, and we generate a parallel expression that loads the value, but transposed. This value then gets used instead of `OpLoad` result later on. [spv]: https://www.khronos.org/registry/SPIR-V/ */ mod convert; mod error; mod function; mod image; mod next_block; mod null; pub use error::Error; use alloc::{borrow::ToOwned, string::String, vec, vec::Vec}; use core::{convert::TryInto, mem, num::NonZeroU32}; use half::f16; use petgraph::graphmap::GraphMap; use super::atomic_upgrade::Upgrades; use crate::{ arena::{Arena, Handle, UniqueArena}, proc::{Alignment, Layouter}, FastHashMap, FastHashSet, FastIndexMap, }; use convert::*; use function::*; pub const SUPPORTED_CAPABILITIES: &[spirv::Capability] = &[ spirv::Capability::Shader, spirv::Capability::VulkanMemoryModel, spirv::Capability::ClipDistance, spirv::Capability::CullDistance, spirv::Capability::SampleRateShading, spirv::Capability::DerivativeControl, spirv::Capability::Matrix, spirv::Capability::ImageQuery, spirv::Capability::Sampled1D, spirv::Capability::Image1D, spirv::Capability::SampledCubeArray, spirv::Capability::ImageCubeArray, spirv::Capability::StorageImageExtendedFormats, spirv::Capability::Int8, spirv::Capability::Int16, spirv::Capability::Int64, spirv::Capability::Int64Atomics, spirv::Capability::Float16, spirv::Capability::AtomicFloat32AddEXT, spirv::Capability::Float64, spirv::Capability::Geometry, spirv::Capability::MultiView, spirv::Capability::StorageBuffer16BitAccess, spirv::Capability::UniformAndStorageBuffer16BitAccess, spirv::Capability::GroupNonUniform, spirv::Capability::GroupNonUniformVote, spirv::Capability::GroupNonUniformArithmetic, spirv::Capability::GroupNonUniformBallot, spirv::Capability::GroupNonUniformShuffle, spirv::Capability::GroupNonUniformShuffleRelative, spirv::Capability::RuntimeDescriptorArray, spirv::Capability::StorageImageMultisample, spirv::Capability::FragmentBarycentricKHR, // tricky ones spirv::Capability::UniformBufferArrayDynamicIndexing, spirv::Capability::StorageBufferArrayDynamicIndexing, ]; pub const SUPPORTED_EXTENSIONS: &[&str] = &[ "SPV_KHR_storage_buffer_storage_class", "SPV_KHR_vulkan_memory_model", "SPV_KHR_multiview", "SPV_EXT_descriptor_indexing", "SPV_EXT_shader_atomic_float_add", "SPV_KHR_16bit_storage", "SPV_KHR_non_semantic_info", "SPV_KHR_fragment_shader_barycentric", ]; #[derive(Copy, Clone)] pub struct Instruction { op: spirv::Op, wc: u16, } impl Instruction { const fn expect(self, count: u16) -> Result<(), Error> { if self.wc == count { Ok(()) } else { Err(Error::InvalidOperandCount(self.op, self.wc)) } } fn expect_at_least(self, count: u16) -> Result { self.wc .checked_sub(count) .ok_or(Error::InvalidOperandCount(self.op, self.wc)) } } impl crate::TypeInner { fn can_comparison_sample(&self, module: &crate::Module) -> bool { match *self { crate::TypeInner::Image { class: crate::ImageClass::Sampled { kind: crate::ScalarKind::Float, multi: false, }, .. } => true, crate::TypeInner::Sampler { .. } => true, crate::TypeInner::BindingArray { base, .. } => { module.types[base].inner.can_comparison_sample(module) } _ => false, } } } #[derive(Clone, Copy, Debug, PartialEq, PartialOrd)] pub enum ModuleState { Empty, Capability, Extension, ExtInstImport, MemoryModel, EntryPoint, ExecutionMode, Source, Name, ModuleProcessed, Annotation, Type, Function, } trait LookupHelper { type Target; fn lookup(&self, key: spirv::Word) -> Result<&Self::Target, Error>; } impl LookupHelper for FastHashMap { type Target = T; fn lookup(&self, key: spirv::Word) -> Result<&T, Error> { self.get(&key).ok_or(Error::InvalidId(key)) } } impl crate::ImageDimension { const fn required_coordinate_size(&self) -> Option { match *self { crate::ImageDimension::D1 => None, crate::ImageDimension::D2 => Some(crate::VectorSize::Bi), crate::ImageDimension::D3 => Some(crate::VectorSize::Tri), crate::ImageDimension::Cube => Some(crate::VectorSize::Tri), } } } type MemberIndex = u32; bitflags::bitflags! { #[derive(Clone, Copy, Debug, Default)] struct DecorationFlags: u32 { const NON_READABLE = 0x1; const NON_WRITABLE = 0x2; const COHERENT = 0x4; const VOLATILE = 0x8; } } impl DecorationFlags { fn to_storage_access(self) -> crate::StorageAccess { let mut access = crate::StorageAccess::LOAD | crate::StorageAccess::STORE; if self.contains(DecorationFlags::NON_READABLE) { access &= !crate::StorageAccess::LOAD; } if self.contains(DecorationFlags::NON_WRITABLE) { access &= !crate::StorageAccess::STORE; } access } fn to_memory_decorations(self) -> crate::MemoryDecorations { let mut decorations = crate::MemoryDecorations::empty(); if self.contains(DecorationFlags::COHERENT) { decorations |= crate::MemoryDecorations::COHERENT; } if self.contains(DecorationFlags::VOLATILE) { decorations |= crate::MemoryDecorations::VOLATILE; } decorations } } #[derive(Debug, PartialEq)] enum Majority { Column, Row, } #[derive(Debug, Default)] struct Decoration { name: Option, built_in: Option, location: Option, index: Option, desc_set: Option, desc_index: Option, specialization_constant_id: Option, storage_buffer: bool, offset: Option, array_stride: Option, matrix_stride: Option, matrix_major: Option, invariant: bool, interpolation: Option, sampling: Option, flags: DecorationFlags, } impl Decoration { const fn debug_name(&self) -> &str { match self.name { Some(ref name) => name.as_str(), None => "?", } } const fn resource_binding(&self) -> Option { match *self { Decoration { desc_set: Some(group), desc_index: Some(binding), .. } => Some(crate::ResourceBinding { group, binding }), _ => None, } } fn io_binding(&self) -> Result { match *self { Decoration { built_in: Some(built_in), location: None, invariant, .. } => Ok(crate::Binding::BuiltIn(map_builtin(built_in, invariant)?)), Decoration { built_in: None, location: Some(location), index: Some(index), .. } => Ok(crate::Binding::Location { location, interpolation: None, sampling: None, blend_src: Some(index), per_primitive: false, }), Decoration { built_in: None, location: Some(location), interpolation, sampling, .. } => Ok(crate::Binding::Location { location, interpolation, sampling, blend_src: None, per_primitive: false, }), _ => Err(Error::MissingDecoration(spirv::Decoration::Location)), } } } #[derive(Debug)] struct LookupFunctionType { parameter_type_ids: Vec, return_type_id: spirv::Word, } struct LookupFunction { handle: Handle, parameters_sampling: Vec, } #[derive(Debug)] struct EntryPoint { stage: crate::ShaderStage, name: String, early_depth_test: Option, workgroup_size: [u32; 3], variable_ids: Vec, } #[derive(Clone, Debug)] struct LookupType { handle: Handle, base_id: Option, } #[derive(Debug)] enum Constant { Constant(Handle), Override(Handle), } impl Constant { const fn to_expr(&self) -> crate::Expression { match *self { Self::Constant(c) => crate::Expression::Constant(c), Self::Override(o) => crate::Expression::Override(o), } } } #[derive(Debug)] struct LookupConstant { inner: Constant, type_id: spirv::Word, } #[derive(Debug)] enum Variable { Global, Input(crate::FunctionArgument), Output(crate::FunctionResult), } #[derive(Debug)] struct LookupVariable { inner: Variable, handle: Handle, type_id: spirv::Word, } /// Information about SPIR-V result ids, stored in `Frontend::lookup_expression`. #[derive(Clone, Debug)] struct LookupExpression { /// The `Expression` constructed for this result. /// /// Note that, while a SPIR-V result id can be used in any block dominated /// by its definition, a Naga `Expression` is only in scope for the rest of /// its subtree. `Frontend::get_expr_handle` takes care of spilling the result /// to a `LocalVariable` which can then be used anywhere. handle: Handle, /// The SPIR-V type of this result. type_id: spirv::Word, /// The label id of the block that defines this expression. /// /// This is zero for globals, constants, and function parameters, since they /// originate outside any function's block. block_id: spirv::Word, } #[derive(Debug)] struct LookupMember { type_id: spirv::Word, // This is true for either matrices, or arrays of matrices (yikes). row_major: bool, } #[derive(Clone, Debug)] enum LookupLoadOverride { /// For arrays of matrices, we track them but not loading yet. Pending, /// For matrices, vectors, and scalars, we pre-load the data. Loaded(Handle), } #[derive(PartialEq)] enum ExtendedClass { Global(crate::AddressSpace), Input, Output, } #[derive(Clone, Debug)] pub struct Options { /// The IR coordinate space matches all the APIs except SPIR-V, /// so by default we flip the Y coordinate of the `BuiltIn::Position`. /// This flag can be used to avoid this. pub adjust_coordinate_space: bool, /// Only allow shaders with the known set of capabilities. pub strict_capabilities: bool, pub block_ctx_dump_prefix: Option, } impl Default for Options { fn default() -> Self { Options { adjust_coordinate_space: true, strict_capabilities: true, block_ctx_dump_prefix: None, } } } /// An index into the `BlockContext::bodies` table. type BodyIndex = usize; /// An intermediate representation of a Naga [`Statement`]. /// /// `Body` and `BodyFragment` values form a tree: the `BodyIndex` fields of the /// variants are indices of the child `Body` values in [`BlockContext::bodies`]. /// The `lower` function assembles the final `Statement` tree from this `Body` /// tree. See [`BlockContext`] for details. /// /// [`Statement`]: crate::Statement #[derive(Debug)] enum BodyFragment { BlockId(spirv::Word), If { condition: Handle, accept: BodyIndex, reject: BodyIndex, }, Loop { /// The body of the loop. Its [`Body::parent`] is the block containing /// this `Loop` fragment. body: BodyIndex, /// The loop's continuing block. This is a grandchild: its /// [`Body::parent`] is the loop body block, whose index is above. continuing: BodyIndex, /// If the SPIR-V loop's back-edge branch is conditional, this is the /// expression that must be `false` for the back-edge to be taken, with /// `true` being for the "loop merge" (which breaks out of the loop). break_if: Option>, }, Switch { selector: Handle, cases: Vec<(i32, BodyIndex)>, default: BodyIndex, }, Break, Continue, } /// An intermediate representation of a Naga [`Block`]. /// /// This will be assembled into a `Block` once we've added spills for phi nodes /// and out-of-scope expressions. See [`BlockContext`] for details. /// /// [`Block`]: crate::Block #[derive(Debug)] struct Body { /// The index of the direct parent of this body parent: usize, data: Vec, } impl Body { /// Creates a new empty `Body` with the specified `parent` pub const fn with_parent(parent: usize) -> Self { Body { parent, data: Vec::new(), } } } #[derive(Debug)] struct PhiExpression { /// The local variable used for the phi node local: Handle, /// List of (expression, block) expressions: Vec<(spirv::Word, spirv::Word)>, } #[derive(Copy, Clone, Debug, PartialEq, Eq)] enum MergeBlockInformation { LoopMerge, LoopContinue, SelectionMerge, SwitchMerge, } /// Fragments of Naga IR, to be assembled into `Statements` once data flow is /// resolved. /// /// We can't build a Naga `Statement` tree directly from SPIR-V blocks for three /// main reasons: /// /// - We parse a function's SPIR-V blocks in the order they appear in the file. /// Within a function, SPIR-V requires that a block must precede any blocks it /// structurally dominates, but doesn't say much else about the order in which /// they must appear. So while we know we'll see control flow header blocks /// before their child constructs and merge blocks, those children and the /// merge blocks may appear in any order - perhaps even intermingled with /// children of other constructs. /// /// - A SPIR-V expression can be used in any SPIR-V block dominated by its /// definition, whereas Naga expressions are scoped to the rest of their /// subtree. This means that discovering an expression use later in the /// function retroactively requires us to have spilled that expression into a /// local variable back before we left its scope. (The docs for /// [`Frontend::get_expr_handle`] explain this in more detail.) /// /// - We translate SPIR-V OpPhi expressions as Naga local variables in which we /// store the appropriate value before jumping to the OpPhi's block. /// /// All these cases require us to go back and amend previously generated Naga IR /// based on things we discover later. But modifying old blocks in arbitrary /// spots in a `Statement` tree is awkward. /// /// Instead, as we iterate through the function's body, we accumulate /// control-flow-free fragments of Naga IR in the [`blocks`] table, while /// building a skeleton of the Naga `Statement` tree in [`bodies`]. We note any /// spills and temporaries we must introduce in [`phis`]. /// /// Finally, once we've processed the entire function, we add temporaries and /// spills to the fragmentary `Blocks` as directed by `phis`, and assemble them /// into the final Naga `Statement` tree as directed by `bodies`. /// /// [`blocks`]: BlockContext::blocks /// [`bodies`]: BlockContext::bodies /// [`phis`]: BlockContext::phis #[derive(Debug)] struct BlockContext<'function> { /// Phi nodes encountered when parsing the function, used to generate spills /// to local variables. phis: Vec, /// Fragments of control-flow-free Naga IR. /// /// These will be stitched together into a proper [`Statement`] tree according /// to `bodies`, once parsing is complete. /// /// [`Statement`]: crate::Statement blocks: FastHashMap, /// Map from each SPIR-V block's label id to the index of the [`Body`] in /// [`bodies`] the block should append its contents to. /// /// Since each statement in a Naga [`Block`] dominates the next, we are sure /// to encounter their SPIR-V blocks in order. Thus, by having this table /// map a SPIR-V structured control flow construct's merge block to the same /// body index as its header block, when we encounter the merge block, we /// will simply pick up building the [`Body`] where the header left off. /// /// A function's first block is special: it is the only block we encounter /// without having seen its label mentioned in advance. (It's simply the /// first `OpLabel` after the `OpFunction`.) We thus assume that any block /// missing an entry here must be the first block, which always has body /// index zero. /// /// [`bodies`]: BlockContext::bodies /// [`Block`]: crate::Block body_for_label: FastHashMap, /// SPIR-V metadata about merge/continue blocks. mergers: FastHashMap, /// A table of `Body` values, each representing a block in the final IR. /// /// The first element is always the function's top-level block. bodies: Vec, /// The module we're building. module: &'function mut crate::Module, /// Id of the function currently being processed function_id: spirv::Word, /// Expression arena of the function currently being processed expressions: &'function mut Arena, /// Local variables arena of the function currently being processed local_arena: &'function mut Arena, /// Arguments of the function currently being processed arguments: &'function [crate::FunctionArgument], /// Metadata about the usage of function parameters as sampling objects parameter_sampling: &'function mut [image::SamplingFlags], } enum SignAnchor { Result, Operand, } pub struct Frontend { data: I, data_offset: usize, state: ModuleState, layouter: Layouter, temp_bytes: Vec, ext_glsl_id: Option, ext_non_semantic_id: Option, future_decor: FastHashMap, future_member_decor: FastHashMap<(spirv::Word, MemberIndex), Decoration>, lookup_member: FastHashMap<(Handle, MemberIndex), LookupMember>, handle_sampling: FastHashMap, image::SamplingFlags>, /// A record of what is accessed by [`Atomic`] statements we've /// generated, so we can upgrade the types of their operands. /// /// [`Atomic`]: crate::Statement::Atomic upgrade_atomics: Upgrades, lookup_type: FastHashMap, lookup_void_type: Option, lookup_storage_buffer_types: FastHashMap, crate::StorageAccess>, lookup_constant: FastHashMap, lookup_variable: FastHashMap, lookup_expression: FastHashMap, // Load overrides are used to work around row-major matrices lookup_load_override: FastHashMap, lookup_sampled_image: FastHashMap, lookup_function_type: FastHashMap, lookup_function: FastHashMap, lookup_entry_point: FastHashMap, // When parsing functions, each entry point function gets an entry here so that additional // processing for them can be performed after all function parsing. deferred_entry_points: Vec<(EntryPoint, spirv::Word)>, //Note: each `OpFunctionCall` gets a single entry here, indexed by the // dummy `Handle` of the call site. deferred_function_calls: Vec, dummy_functions: Arena, // Graph of all function calls through the module. // It's used to sort the functions (as nodes) topologically, // so that in the IR any called function is already known. function_call_graph: GraphMap< spirv::Word, (), petgraph::Directed, core::hash::BuildHasherDefault, >, options: Options, /// Maps for a switch from a case target to the respective body and associated literals that /// use that target block id. /// /// Used to preserve allocations between instruction parsing. switch_cases: FastIndexMap)>, /// Tracks access to gl_PerVertex's builtins, it is used to cull unused builtins since initializing those can /// affect performance and the mere presence of some of these builtins might cause backends to error since they /// might be unsupported. /// /// The problematic builtins are: PointSize, ClipDistance and CullDistance. /// /// glslang declares those by default even though they are never written to /// (see ) gl_per_vertex_builtin_access: FastHashSet, } impl> Frontend { pub fn new(data: I, options: &Options) -> Self { Frontend { data, data_offset: 0, state: ModuleState::Empty, layouter: Layouter::default(), temp_bytes: Vec::new(), ext_glsl_id: None, ext_non_semantic_id: None, future_decor: FastHashMap::default(), future_member_decor: FastHashMap::default(), handle_sampling: FastHashMap::default(), lookup_member: FastHashMap::default(), upgrade_atomics: Default::default(), lookup_type: FastHashMap::default(), lookup_void_type: None, lookup_storage_buffer_types: FastHashMap::default(), lookup_constant: FastHashMap::default(), lookup_variable: FastHashMap::default(), lookup_expression: FastHashMap::default(), lookup_load_override: FastHashMap::default(), lookup_sampled_image: FastHashMap::default(), lookup_function_type: FastHashMap::default(), lookup_function: FastHashMap::default(), lookup_entry_point: FastHashMap::default(), deferred_entry_points: Vec::default(), deferred_function_calls: Vec::default(), dummy_functions: Arena::new(), function_call_graph: GraphMap::new(), options: options.clone(), switch_cases: FastIndexMap::default(), gl_per_vertex_builtin_access: FastHashSet::default(), } } fn span_from(&self, from: usize) -> crate::Span { crate::Span::from(from..self.data_offset) } fn span_from_with_op(&self, from: usize) -> crate::Span { crate::Span::from((from - 4)..self.data_offset) } fn next(&mut self) -> Result { if let Some(res) = self.data.next() { self.data_offset += 4; Ok(res) } else { Err(Error::IncompleteData) } } fn next_inst(&mut self) -> Result { let word = self.next()?; let (wc, opcode) = ((word >> 16) as u16, (word & 0xffff) as u16); if wc == 0 { return Err(Error::InvalidWordCount); } let op = spirv::Op::from_u32(opcode as u32).ok_or(Error::UnknownInstruction(opcode))?; Ok(Instruction { op, wc }) } fn next_string(&mut self, mut count: u16) -> Result<(String, u16), Error> { self.temp_bytes.clear(); loop { if count == 0 { return Err(Error::BadString); } count -= 1; let chars = self.next()?.to_le_bytes(); let pos = chars.iter().position(|&c| c == 0).unwrap_or(4); self.temp_bytes.extend_from_slice(&chars[..pos]); if pos < 4 { break; } } core::str::from_utf8(&self.temp_bytes) .map(|s| (s.to_owned(), count)) .map_err(|_| Error::BadString) } fn next_decoration( &mut self, inst: Instruction, base_words: u16, dec: &mut Decoration, ) -> Result<(), Error> { let raw = self.next()?; let dec_typed = spirv::Decoration::from_u32(raw).ok_or(Error::InvalidDecoration(raw))?; log::trace!("\t\t{}: {:?}", dec.debug_name(), dec_typed); match dec_typed { spirv::Decoration::BuiltIn => { inst.expect(base_words + 2)?; dec.built_in = Some(self.next()?); } spirv::Decoration::Location => { inst.expect(base_words + 2)?; dec.location = Some(self.next()?); } spirv::Decoration::Index => { inst.expect(base_words + 2)?; dec.index = Some(self.next()?); } spirv::Decoration::DescriptorSet => { inst.expect(base_words + 2)?; dec.desc_set = Some(self.next()?); } spirv::Decoration::Binding => { inst.expect(base_words + 2)?; dec.desc_index = Some(self.next()?); } spirv::Decoration::BufferBlock => { dec.storage_buffer = true; } spirv::Decoration::Offset => { inst.expect(base_words + 2)?; dec.offset = Some(self.next()?); } spirv::Decoration::ArrayStride => { inst.expect(base_words + 2)?; dec.array_stride = NonZeroU32::new(self.next()?); } spirv::Decoration::MatrixStride => { inst.expect(base_words + 2)?; dec.matrix_stride = NonZeroU32::new(self.next()?); } spirv::Decoration::Invariant => { dec.invariant = true; } spirv::Decoration::NoPerspective => { dec.interpolation = Some(crate::Interpolation::Linear); } spirv::Decoration::Flat => { dec.interpolation = Some(crate::Interpolation::Flat); } spirv::Decoration::PerVertexKHR => { dec.interpolation = Some(crate::Interpolation::PerVertex); } spirv::Decoration::Centroid => { dec.sampling = Some(crate::Sampling::Centroid); } spirv::Decoration::Sample => { dec.sampling = Some(crate::Sampling::Sample); } spirv::Decoration::NonReadable => { dec.flags |= DecorationFlags::NON_READABLE; } spirv::Decoration::NonWritable => { dec.flags |= DecorationFlags::NON_WRITABLE; } spirv::Decoration::Coherent => { dec.flags |= DecorationFlags::COHERENT; } spirv::Decoration::Volatile => { dec.flags |= DecorationFlags::VOLATILE; } spirv::Decoration::ColMajor => { dec.matrix_major = Some(Majority::Column); } spirv::Decoration::RowMajor => { dec.matrix_major = Some(Majority::Row); } spirv::Decoration::SpecId => { dec.specialization_constant_id = Some(self.next()?); } other => { let level = match other { // Block decorations show up everywhere and we don't // really care about them, so to prevent log spam // we demote them to debug level. spirv::Decoration::Block => log::Level::Debug, _ => log::Level::Warn, }; log::log!(level, "Unknown decoration {other:?}"); for _ in base_words + 1..inst.wc { let _var = self.next()?; } } } Ok(()) } /// Return the Naga [`Expression`] to use in `body_idx` to refer to the SPIR-V result `id`. /// /// Ideally, we would just have a map from each SPIR-V instruction id to the /// [`Handle`] for the Naga [`Expression`] we generated for it. /// Unfortunately, SPIR-V and Naga IR are different enough that such a /// straightforward relationship isn't possible. /// /// In SPIR-V, an instruction's result id can be used by any instruction /// dominated by that instruction. In Naga, an [`Expression`] is only in /// scope for the remainder of its [`Block`]. In pseudocode: /// /// ```ignore /// loop { /// a = f(); /// g(a); /// break; /// } /// h(a); /// ``` /// /// Suppose the calls to `f`, `g`, and `h` are SPIR-V instructions. In /// SPIR-V, both the `g` and `h` instructions are allowed to refer to `a`, /// because the loop body, including `f`, dominates both of them. /// /// But if `a` is a Naga [`Expression`], its scope ends at the end of the /// block it's evaluated in: the loop body. Thus, while the [`Expression`] /// we generate for `g` can refer to `a`, the one we generate for `h` /// cannot. /// /// Instead, the SPIR-V front end must generate Naga IR like this: /// /// ```ignore /// var temp; // INTRODUCED /// loop { /// a = f(); /// g(a); /// temp = a; // INTRODUCED /// } /// h(temp); // ADJUSTED /// ``` /// /// In other words, where `a` is in scope, [`Expression`]s can refer to it /// directly; but once it is out of scope, we need to spill it to a /// temporary and refer to that instead. /// /// Given a SPIR-V expression `id` and the index `body_idx` of the [body] /// that wants to refer to it: /// /// - If the Naga [`Expression`] we generated for `id` is in scope in /// `body_idx`, then we simply return its `Handle`. /// /// - Otherwise, introduce a new [`LocalVariable`], and add an entry to /// [`BlockContext::phis`] to arrange for `id`'s value to be spilled to /// it. Then emit a fresh [`Load`] of that temporary variable for use in /// `body_idx`'s block, and return its `Handle`. /// /// The SPIR-V domination rule ensures that the introduced [`LocalVariable`] /// will always have been initialized before it is used. /// /// `lookup` must be the [`LookupExpression`] for `id`. /// /// `body_idx` argument must be the index of the [`Body`] that hopes to use /// `id`'s [`Expression`]. /// /// [`Expression`]: crate::Expression /// [`Handle`]: crate::Handle /// [`Block`]: crate::Block /// [body]: BlockContext::bodies /// [`LocalVariable`]: crate::LocalVariable /// [`Load`]: crate::Expression::Load fn get_expr_handle( &self, id: spirv::Word, lookup: &LookupExpression, ctx: &mut BlockContext, emitter: &mut crate::proc::Emitter, block: &mut crate::Block, body_idx: BodyIndex, ) -> Handle { // What `Body` was `id` defined in? let expr_body_idx = ctx .body_for_label .get(&lookup.block_id) .copied() .unwrap_or(0); // Don't need to do a load/store if the expression is in the main body // or if the expression is in the same body as where the query was // requested. The body_idx might actually not be the final one if a loop // or conditional occurs but in those cases we know that the new body // will be a subscope of the body that was passed so we can still reuse // the handle and not issue a load/store. if is_parent(body_idx, expr_body_idx, ctx) { lookup.handle } else { // Add a temporary variable of the same type which will be used to // store the original expression and used in the current block let ty = self.lookup_type[&lookup.type_id].handle; let local = ctx.local_arena.append( crate::LocalVariable { name: None, ty, init: None, }, crate::Span::default(), ); block.extend(emitter.finish(ctx.expressions)); let pointer = ctx.expressions.append( crate::Expression::LocalVariable(local), crate::Span::default(), ); emitter.start(ctx.expressions); let expr = ctx .expressions .append(crate::Expression::Load { pointer }, crate::Span::default()); // Add a slightly odd entry to the phi table, so that while `id`'s // `Expression` is still in scope, the usual phi processing will // spill its value to `local`, where we can find it later. // // This pretends that the block in which `id` is defined is the // predecessor of some other block with a phi in it that cites id as // one of its sources, and uses `local` as its variable. There is no // such phi, but nobody needs to know that. ctx.phis.push(PhiExpression { local, expressions: vec![(id, lookup.block_id)], }); expr } } fn parse_expr_unary_op( &mut self, ctx: &mut BlockContext, emitter: &mut crate::proc::Emitter, block: &mut crate::Block, block_id: spirv::Word, body_idx: usize, op: crate::UnaryOperator, ) -> Result<(), Error> { let start = self.data_offset; let result_type_id = self.next()?; let result_id = self.next()?; let p_id = self.next()?; let p_lexp = self.lookup_expression.lookup(p_id)?; let handle = self.get_expr_handle(p_id, p_lexp, ctx, emitter, block, body_idx); let expr = crate::Expression::Unary { op, expr: handle }; self.lookup_expression.insert( result_id, LookupExpression { handle: ctx.expressions.append(expr, self.span_from_with_op(start)), type_id: result_type_id, block_id, }, ); Ok(()) } fn parse_expr_binary_op( &mut self, ctx: &mut BlockContext, emitter: &mut crate::proc::Emitter, block: &mut crate::Block, block_id: spirv::Word, body_idx: usize, op: crate::BinaryOperator, ) -> Result<(), Error> { let start = self.data_offset; let result_type_id = self.next()?; let result_id = self.next()?; let p1_id = self.next()?; let p2_id = self.next()?; let p1_lexp = self.lookup_expression.lookup(p1_id)?; let left = self.get_expr_handle(p1_id, p1_lexp, ctx, emitter, block, body_idx); let p2_lexp = self.lookup_expression.lookup(p2_id)?; let right = self.get_expr_handle(p2_id, p2_lexp, ctx, emitter, block, body_idx); let expr = crate::Expression::Binary { op, left, right }; self.lookup_expression.insert( result_id, LookupExpression { handle: ctx.expressions.append(expr, self.span_from_with_op(start)), type_id: result_type_id, block_id, }, ); Ok(()) } /// A more complicated version of the unary op, /// where we force the operand to have the same type as the result. fn parse_expr_unary_op_sign_adjusted( &mut self, ctx: &mut BlockContext, emitter: &mut crate::proc::Emitter, block: &mut crate::Block, block_id: spirv::Word, body_idx: usize, op: crate::UnaryOperator, ) -> Result<(), Error> { let start = self.data_offset; let result_type_id = self.next()?; let result_id = self.next()?; let p1_id = self.next()?; let span = self.span_from_with_op(start); let p1_lexp = self.lookup_expression.lookup(p1_id)?; let left = self.get_expr_handle(p1_id, p1_lexp, ctx, emitter, block, body_idx); let result_lookup_ty = self.lookup_type.lookup(result_type_id)?; let kind = ctx.module.types[result_lookup_ty.handle] .inner .scalar_kind() .unwrap(); let expr = crate::Expression::Unary { op, expr: if p1_lexp.type_id == result_type_id { left } else { ctx.expressions.append( crate::Expression::As { expr: left, kind, convert: None, }, span, ) }, }; self.lookup_expression.insert( result_id, LookupExpression { handle: ctx.expressions.append(expr, span), type_id: result_type_id, block_id, }, ); Ok(()) } /// A more complicated version of the binary op, /// where we force the operand to have the same type as the result. /// This is mostly needed for "i++" and "i--" coming from GLSL. #[allow(clippy::too_many_arguments)] fn parse_expr_binary_op_sign_adjusted( &mut self, ctx: &mut BlockContext, emitter: &mut crate::proc::Emitter, block: &mut crate::Block, block_id: spirv::Word, body_idx: usize, op: crate::BinaryOperator, // For arithmetic operations, we need the sign of operands to match the result. // For boolean operations, however, the operands need to match the signs, but // result is always different - a boolean. anchor: SignAnchor, ) -> Result<(), Error> { let start = self.data_offset; let result_type_id = self.next()?; let result_id = self.next()?; let p1_id = self.next()?; let p2_id = self.next()?; let span = self.span_from_with_op(start); let p1_lexp = self.lookup_expression.lookup(p1_id)?; let left = self.get_expr_handle(p1_id, p1_lexp, ctx, emitter, block, body_idx); let p2_lexp = self.lookup_expression.lookup(p2_id)?; let right = self.get_expr_handle(p2_id, p2_lexp, ctx, emitter, block, body_idx); let expected_type_id = match anchor { SignAnchor::Result => result_type_id, SignAnchor::Operand => p1_lexp.type_id, }; let expected_lookup_ty = self.lookup_type.lookup(expected_type_id)?; let kind = ctx.module.types[expected_lookup_ty.handle] .inner .scalar_kind() .unwrap(); let expr = crate::Expression::Binary { op, left: if p1_lexp.type_id == expected_type_id { left } else { ctx.expressions.append( crate::Expression::As { expr: left, kind, convert: None, }, span, ) }, right: if p2_lexp.type_id == expected_type_id { right } else { ctx.expressions.append( crate::Expression::As { expr: right, kind, convert: None, }, span, ) }, }; self.lookup_expression.insert( result_id, LookupExpression { handle: ctx.expressions.append(expr, span), type_id: result_type_id, block_id, }, ); Ok(()) } /// A version of the binary op where one or both of the arguments might need to be casted to a /// specific integer kind (unsigned or signed), used for operations like OpINotEqual or /// OpUGreaterThan. #[allow(clippy::too_many_arguments)] fn parse_expr_int_comparison( &mut self, ctx: &mut BlockContext, emitter: &mut crate::proc::Emitter, block: &mut crate::Block, block_id: spirv::Word, body_idx: usize, op: crate::BinaryOperator, kind: crate::ScalarKind, ) -> Result<(), Error> { let start = self.data_offset; let result_type_id = self.next()?; let result_id = self.next()?; let p1_id = self.next()?; let p2_id = self.next()?; let span = self.span_from_with_op(start); let p1_lexp = self.lookup_expression.lookup(p1_id)?; let left = self.get_expr_handle(p1_id, p1_lexp, ctx, emitter, block, body_idx); let p1_lookup_ty = self.lookup_type.lookup(p1_lexp.type_id)?; let p1_kind = ctx.module.types[p1_lookup_ty.handle] .inner .scalar_kind() .unwrap(); let p2_lexp = self.lookup_expression.lookup(p2_id)?; let right = self.get_expr_handle(p2_id, p2_lexp, ctx, emitter, block, body_idx); let p2_lookup_ty = self.lookup_type.lookup(p2_lexp.type_id)?; let p2_kind = ctx.module.types[p2_lookup_ty.handle] .inner .scalar_kind() .unwrap(); let expr = crate::Expression::Binary { op, left: if p1_kind == kind { left } else { ctx.expressions.append( crate::Expression::As { expr: left, kind, convert: None, }, span, ) }, right: if p2_kind == kind { right } else { ctx.expressions.append( crate::Expression::As { expr: right, kind, convert: None, }, span, ) }, }; self.lookup_expression.insert( result_id, LookupExpression { handle: ctx.expressions.append(expr, span), type_id: result_type_id, block_id, }, ); Ok(()) } fn parse_expr_shift_op( &mut self, ctx: &mut BlockContext, emitter: &mut crate::proc::Emitter, block: &mut crate::Block, block_id: spirv::Word, body_idx: usize, op: crate::BinaryOperator, ) -> Result<(), Error> { let start = self.data_offset; let result_type_id = self.next()?; let result_id = self.next()?; let p1_id = self.next()?; let p2_id = self.next()?; let span = self.span_from_with_op(start); let p1_lexp = self.lookup_expression.lookup(p1_id)?; let left = self.get_expr_handle(p1_id, p1_lexp, ctx, emitter, block, body_idx); let p2_lexp = self.lookup_expression.lookup(p2_id)?; let p2_handle = self.get_expr_handle(p2_id, p2_lexp, ctx, emitter, block, body_idx); // convert the shift to Uint let right = ctx.expressions.append( crate::Expression::As { expr: p2_handle, kind: crate::ScalarKind::Uint, convert: None, }, span, ); let expr = crate::Expression::Binary { op, left, right }; self.lookup_expression.insert( result_id, LookupExpression { handle: ctx.expressions.append(expr, span), type_id: result_type_id, block_id, }, ); Ok(()) } fn parse_expr_derivative( &mut self, ctx: &mut BlockContext, emitter: &mut crate::proc::Emitter, block: &mut crate::Block, block_id: spirv::Word, body_idx: usize, (axis, ctrl): (crate::DerivativeAxis, crate::DerivativeControl), ) -> Result<(), Error> { let start = self.data_offset; let result_type_id = self.next()?; let result_id = self.next()?; let arg_id = self.next()?; let arg_lexp = self.lookup_expression.lookup(arg_id)?; let arg_handle = self.get_expr_handle(arg_id, arg_lexp, ctx, emitter, block, body_idx); let expr = crate::Expression::Derivative { axis, ctrl, expr: arg_handle, }; self.lookup_expression.insert( result_id, LookupExpression { handle: ctx.expressions.append(expr, self.span_from_with_op(start)), type_id: result_type_id, block_id, }, ); Ok(()) } #[allow(clippy::too_many_arguments)] fn insert_composite( &self, root_expr: Handle, root_type_id: spirv::Word, object_expr: Handle, selections: &[spirv::Word], type_arena: &UniqueArena, expressions: &mut Arena, span: crate::Span, ) -> Result, Error> { let selection = match selections.first() { Some(&index) => index, None => return Ok(object_expr), }; let root_span = expressions.get_span(root_expr); let root_lookup = self.lookup_type.lookup(root_type_id)?; let (count, child_type_id) = match type_arena[root_lookup.handle].inner { crate::TypeInner::Struct { ref members, .. } => { let child_member = self .lookup_member .get(&(root_lookup.handle, selection)) .ok_or(Error::InvalidAccessType(root_type_id))?; (members.len(), child_member.type_id) } crate::TypeInner::Array { size, .. } => { let size = match size { crate::ArraySize::Constant(size) => size.get(), crate::ArraySize::Pending(_) => { unreachable!(); } // A runtime sized array is not a composite type crate::ArraySize::Dynamic => { return Err(Error::InvalidAccessType(root_type_id)) } }; let child_type_id = root_lookup .base_id .ok_or(Error::InvalidAccessType(root_type_id))?; (size as usize, child_type_id) } crate::TypeInner::Vector { size, .. } | crate::TypeInner::Matrix { columns: size, .. } => { let child_type_id = root_lookup .base_id .ok_or(Error::InvalidAccessType(root_type_id))?; (size as usize, child_type_id) } _ => return Err(Error::InvalidAccessType(root_type_id)), }; let mut components = Vec::with_capacity(count); for index in 0..count as u32 { let expr = expressions.append( crate::Expression::AccessIndex { base: root_expr, index, }, if index == selection { span } else { root_span }, ); components.push(expr); } components[selection as usize] = self.insert_composite( components[selection as usize], child_type_id, object_expr, &selections[1..], type_arena, expressions, span, )?; Ok(expressions.append( crate::Expression::Compose { ty: root_lookup.handle, components, }, span, )) } /// Return the Naga [`Expression`] for `pointer_id`, and its referent [`Type`]. /// /// Return a [`Handle`] for a Naga [`Expression`] that holds the value of /// the SPIR-V instruction `pointer_id`, along with the [`Type`] to which it /// is a pointer. /// /// This may entail spilling `pointer_id`'s value to a temporary: /// see [`get_expr_handle`]'s documentation. /// /// [`Expression`]: crate::Expression /// [`Type`]: crate::Type /// [`Handle`]: crate::Handle /// [`get_expr_handle`]: Frontend::get_expr_handle fn get_exp_and_base_ty_handles( &self, pointer_id: spirv::Word, ctx: &mut BlockContext, emitter: &mut crate::proc::Emitter, block: &mut crate::Block, body_idx: usize, ) -> Result<(Handle, Handle), Error> { log::trace!("\t\t\tlooking up pointer expr {pointer_id:?}"); let p_lexp_handle; let p_lexp_ty_id; { let lexp = self.lookup_expression.lookup(pointer_id)?; p_lexp_handle = self.get_expr_handle(pointer_id, lexp, ctx, emitter, block, body_idx); p_lexp_ty_id = lexp.type_id; }; log::trace!("\t\t\tlooking up pointer type {pointer_id:?}"); let p_ty = self.lookup_type.lookup(p_lexp_ty_id)?; let p_ty_base_id = p_ty.base_id.ok_or(Error::InvalidAccessType(p_lexp_ty_id))?; log::trace!("\t\t\tlooking up pointer base type {p_ty_base_id:?} of {p_ty:?}"); let p_base_ty = self.lookup_type.lookup(p_ty_base_id)?; Ok((p_lexp_handle, p_base_ty.handle)) } #[allow(clippy::too_many_arguments)] fn parse_atomic_expr_with_value( &mut self, inst: Instruction, emitter: &mut crate::proc::Emitter, ctx: &mut BlockContext, block: &mut crate::Block, block_id: spirv::Word, body_idx: usize, atomic_function: crate::AtomicFunction, ) -> Result<(), Error> { inst.expect(7)?; let start = self.data_offset; let result_type_id = self.next()?; let result_id = self.next()?; let pointer_id = self.next()?; let _scope_id = self.next()?; let _memory_semantics_id = self.next()?; let value_id = self.next()?; let span = self.span_from_with_op(start); let (p_lexp_handle, p_base_ty_handle) = self.get_exp_and_base_ty_handles(pointer_id, ctx, emitter, block, body_idx)?; log::trace!("\t\t\tlooking up value expr {value_id:?}"); let v_lexp_handle = self.lookup_expression.lookup(value_id)?.handle; block.extend(emitter.finish(ctx.expressions)); // Create an expression for our result let r_lexp_handle = { let expr = crate::Expression::AtomicResult { ty: p_base_ty_handle, comparison: false, }; let handle = ctx.expressions.append(expr, span); self.lookup_expression.insert( result_id, LookupExpression { handle, type_id: result_type_id, block_id, }, ); handle }; emitter.start(ctx.expressions); // Create a statement for the op itself let stmt = crate::Statement::Atomic { pointer: p_lexp_handle, fun: atomic_function, value: v_lexp_handle, result: Some(r_lexp_handle), }; block.push(stmt, span); // Store any associated global variables so we can upgrade their types later self.record_atomic_access(ctx, p_lexp_handle)?; Ok(()) } fn make_expression_storage( &mut self, globals: &Arena, constants: &Arena, overrides: &Arena, ) -> Arena { let mut expressions = Arena::new(); assert!(self.lookup_expression.is_empty()); // register global variables for (&id, var) in self.lookup_variable.iter() { let span = globals.get_span(var.handle); let handle = expressions.append(crate::Expression::GlobalVariable(var.handle), span); self.lookup_expression.insert( id, LookupExpression { type_id: var.type_id, handle, // Setting this to an invalid id will cause get_expr_handle // to default to the main body making sure no load/stores // are added. block_id: 0, }, ); } // register constants for (&id, con) in self.lookup_constant.iter() { let (expr, span) = match con.inner { Constant::Constant(c) => (crate::Expression::Constant(c), constants.get_span(c)), Constant::Override(o) => (crate::Expression::Override(o), overrides.get_span(o)), }; let handle = expressions.append(expr, span); self.lookup_expression.insert( id, LookupExpression { type_id: con.type_id, handle, // Setting this to an invalid id will cause get_expr_handle // to default to the main body making sure no load/stores // are added. block_id: 0, }, ); } // done expressions } fn switch(&mut self, state: ModuleState, op: spirv::Op) -> Result<(), Error> { if state < self.state { Err(Error::UnsupportedInstruction(self.state, op)) } else { self.state = state; Ok(()) } } /// Walk the statement tree and patch it in the following cases: /// 1. Function call targets are replaced by `deferred_function_calls` map fn patch_statements( &mut self, statements: &mut crate::Block, expressions: &mut Arena, fun_parameter_sampling: &mut [image::SamplingFlags], ) -> Result<(), Error> { use crate::Statement as S; let mut i = 0usize; while i < statements.len() { match statements[i] { S::Emit(_) => {} S::Block(ref mut block) => { self.patch_statements(block, expressions, fun_parameter_sampling)?; } S::If { condition: _, ref mut accept, ref mut reject, } => { self.patch_statements(reject, expressions, fun_parameter_sampling)?; self.patch_statements(accept, expressions, fun_parameter_sampling)?; } S::Switch { selector: _, ref mut cases, } => { for case in cases.iter_mut() { self.patch_statements(&mut case.body, expressions, fun_parameter_sampling)?; } } S::Loop { ref mut body, ref mut continuing, break_if: _, } => { self.patch_statements(body, expressions, fun_parameter_sampling)?; self.patch_statements(continuing, expressions, fun_parameter_sampling)?; } S::Break | S::Continue | S::Return { .. } | S::Kill | S::ControlBarrier(_) | S::MemoryBarrier(_) | S::Store { .. } | S::ImageStore { .. } | S::Atomic { .. } | S::ImageAtomic { .. } | S::RayQuery { .. } | S::SubgroupBallot { .. } | S::SubgroupCollectiveOperation { .. } | S::SubgroupGather { .. } | S::RayPipelineFunction(..) => {} S::Call { function: ref mut callee, ref arguments, .. } => { let fun_id = self.deferred_function_calls[callee.index()]; let fun_lookup = self.lookup_function.lookup(fun_id)?; *callee = fun_lookup.handle; // Patch sampling flags for (arg_index, arg) in arguments.iter().enumerate() { let flags = match fun_lookup.parameters_sampling.get(arg_index) { Some(&flags) if !flags.is_empty() => flags, _ => continue, }; match expressions[*arg] { crate::Expression::GlobalVariable(handle) => { if let Some(sampling) = self.handle_sampling.get_mut(&handle) { *sampling |= flags } } crate::Expression::FunctionArgument(i) => { fun_parameter_sampling[i as usize] |= flags; } ref other => return Err(Error::InvalidGlobalVar(other.clone())), } } } S::WorkGroupUniformLoad { .. } => unreachable!(), S::CooperativeStore { .. } => unreachable!(), } i += 1; } Ok(()) } fn patch_function( &mut self, handle: Option>, fun: &mut crate::Function, ) -> Result<(), Error> { // Note: this search is a bit unfortunate let (fun_id, mut parameters_sampling) = match handle { Some(h) => { let (&fun_id, lookup) = self .lookup_function .iter_mut() .find(|&(_, ref lookup)| lookup.handle == h) .unwrap(); (fun_id, mem::take(&mut lookup.parameters_sampling)) } None => (0, Vec::new()), }; for (_, expr) in fun.expressions.iter_mut() { if let crate::Expression::CallResult(ref mut function) = *expr { let fun_id = self.deferred_function_calls[function.index()]; *function = self.lookup_function.lookup(fun_id)?.handle; } } self.patch_statements( &mut fun.body, &mut fun.expressions, &mut parameters_sampling, )?; if let Some(lookup) = self.lookup_function.get_mut(&fun_id) { lookup.parameters_sampling = parameters_sampling; } Ok(()) } pub fn parse(mut self) -> Result { let mut module = { if self.next()? != spirv::MAGIC_NUMBER { return Err(Error::InvalidHeader); } let version_raw = self.next()?; let generator = self.next()?; let _bound = self.next()?; let _schema = self.next()?; log::debug!("Generated by {generator} version {version_raw:x}"); crate::Module::default() }; self.layouter.clear(); self.dummy_functions = Arena::new(); self.lookup_function.clear(); self.function_call_graph.clear(); loop { use spirv::Op; let inst = match self.next_inst() { Ok(inst) => inst, Err(Error::IncompleteData) => break, Err(other) => return Err(other), }; log::debug!("\t{:?} [{}]", inst.op, inst.wc); match inst.op { Op::Capability => self.parse_capability(inst), Op::Extension => self.parse_extension(inst), Op::ExtInstImport => self.parse_ext_inst_import(inst), Op::MemoryModel => self.parse_memory_model(inst), Op::EntryPoint => self.parse_entry_point(inst), Op::ExecutionMode => self.parse_execution_mode(inst), Op::String => self.parse_string(inst), Op::Source => self.parse_source(inst), Op::SourceExtension => self.parse_source_extension(inst), Op::Name => self.parse_name(inst), Op::MemberName => self.parse_member_name(inst), Op::ModuleProcessed => self.parse_module_processed(inst), Op::Decorate => self.parse_decorate(inst), Op::MemberDecorate => self.parse_member_decorate(inst), Op::TypeVoid => self.parse_type_void(inst), Op::TypeBool => self.parse_type_bool(inst, &mut module), Op::TypeInt => self.parse_type_int(inst, &mut module), Op::TypeFloat => self.parse_type_float(inst, &mut module), Op::TypeVector => self.parse_type_vector(inst, &mut module), Op::TypeMatrix => self.parse_type_matrix(inst, &mut module), Op::TypeFunction => self.parse_type_function(inst), Op::TypePointer => self.parse_type_pointer(inst, &mut module), Op::TypeArray => self.parse_type_array(inst, &mut module), Op::TypeRuntimeArray => self.parse_type_runtime_array(inst, &mut module), Op::TypeStruct => self.parse_type_struct(inst, &mut module), Op::TypeImage => self.parse_type_image(inst, &mut module), Op::TypeSampledImage => self.parse_type_sampled_image(inst), Op::TypeSampler => self.parse_type_sampler(inst, &mut module), Op::Constant | Op::SpecConstant => self.parse_constant(inst, &mut module), Op::ConstantComposite | Op::SpecConstantComposite => { self.parse_composite_constant(inst, &mut module) } Op::ConstantNull | Op::Undef => self.parse_null_constant(inst, &mut module), Op::ConstantTrue | Op::SpecConstantTrue => { self.parse_bool_constant(inst, true, &mut module) } Op::ConstantFalse | Op::SpecConstantFalse => { self.parse_bool_constant(inst, false, &mut module) } Op::Variable => self.parse_global_variable(inst, &mut module), Op::Function => { self.switch(ModuleState::Function, inst.op)?; inst.expect(5)?; self.parse_function(&mut module) } Op::ExtInst => { // Ignore the result type and result id let _ = self.next()?; let _ = self.next()?; let set_id = self.next()?; if Some(set_id) == self.ext_non_semantic_id { // We've already skipped the instruction byte, result type, result id, and instruction set id for _ in 0..inst.wc - 4 { self.next()?; } Ok(()) } else { return Err(Error::UnsupportedInstruction(self.state, inst.op)); } } _ => Err(Error::UnsupportedInstruction(self.state, inst.op)), //TODO }?; } if !self.upgrade_atomics.is_empty() { log::debug!("Upgrading atomic pointers..."); module.upgrade_atomics(&self.upgrade_atomics)?; } // Do entry point specific processing after all functions are parsed so that we can // cull unused problematic builtins of gl_PerVertex. for (ep, fun_id) in mem::take(&mut self.deferred_entry_points) { self.process_entry_point(&mut module, ep, fun_id)?; } log::debug!("Patching..."); { let mut nodes = petgraph::algo::toposort(&self.function_call_graph, None) .map_err(|cycle| Error::FunctionCallCycle(cycle.node_id()))?; nodes.reverse(); // we need dominated first let mut functions = mem::take(&mut module.functions); for fun_id in nodes { if fun_id > !(functions.len() as u32) { // skip all the fake IDs registered for the entry points continue; } let lookup = self.lookup_function.get_mut(&fun_id).unwrap(); // take out the function from the old array let fun = mem::take(&mut functions[lookup.handle]); // add it to the newly formed arena, and adjust the lookup lookup.handle = module .functions .append(fun, functions.get_span(lookup.handle)); } } // patch all the functions for (handle, fun) in module.functions.iter_mut() { self.patch_function(Some(handle), fun)?; } for ep in module.entry_points.iter_mut() { self.patch_function(None, &mut ep.function)?; } // Check all the images and samplers to have consistent comparison property. for (handle, flags) in self.handle_sampling.drain() { if !image::patch_comparison_type( flags, module.global_variables.get_mut(handle), &mut module.types, ) { return Err(Error::InconsistentComparisonSampling(handle)); } } if !self.future_decor.is_empty() { log::debug!("Unused item decorations: {:?}", self.future_decor); self.future_decor.clear(); } if !self.future_member_decor.is_empty() { log::debug!("Unused member decorations: {:?}", self.future_member_decor); self.future_member_decor.clear(); } Ok(module) } fn parse_capability(&mut self, inst: Instruction) -> Result<(), Error> { self.switch(ModuleState::Capability, inst.op)?; inst.expect(2)?; let capability = self.next()?; let cap = spirv::Capability::from_u32(capability).ok_or(Error::UnknownCapability(capability))?; if !SUPPORTED_CAPABILITIES.contains(&cap) { if self.options.strict_capabilities { return Err(Error::UnsupportedCapability(cap)); } else { log::warn!("Unknown capability {cap:?}"); } } Ok(()) } fn parse_extension(&mut self, inst: Instruction) -> Result<(), Error> { self.switch(ModuleState::Extension, inst.op)?; inst.expect_at_least(2)?; let (name, left) = self.next_string(inst.wc - 1)?; if left != 0 { return Err(Error::InvalidOperand); } if !SUPPORTED_EXTENSIONS.contains(&name.as_str()) { return Err(Error::UnsupportedExtension(name)); } Ok(()) } fn parse_ext_inst_import(&mut self, inst: Instruction) -> Result<(), Error> { self.switch(ModuleState::Extension, inst.op)?; inst.expect_at_least(3)?; let result_id = self.next()?; let (name, left) = self.next_string(inst.wc - 2)?; if left != 0 { return Err(Error::InvalidOperand); } if &name == "GLSL.std.450" { self.ext_glsl_id = Some(result_id); } else if &name == "NonSemantic.Shader.DebugInfo.100" { // We completely ignore this extension. All related instructions are // non-semantic and only for debug purposes, and the spec says they // are ignorable. Many compilers (dxc, slang, etc) will emit these // instructions depending on configuration. self.ext_non_semantic_id = Some(result_id); } else { return Err(Error::UnsupportedExtSet(name)); } Ok(()) } fn parse_memory_model(&mut self, inst: Instruction) -> Result<(), Error> { self.switch(ModuleState::MemoryModel, inst.op)?; inst.expect(3)?; let _addressing_model = self.next()?; let _memory_model = self.next()?; Ok(()) } fn parse_entry_point(&mut self, inst: Instruction) -> Result<(), Error> { self.switch(ModuleState::EntryPoint, inst.op)?; inst.expect_at_least(4)?; let exec_model = self.next()?; let exec_model = spirv::ExecutionModel::from_u32(exec_model) .ok_or(Error::UnsupportedExecutionModel(exec_model))?; let function_id = self.next()?; let (name, left) = self.next_string(inst.wc - 3)?; let ep = EntryPoint { stage: match exec_model { spirv::ExecutionModel::Vertex => crate::ShaderStage::Vertex, spirv::ExecutionModel::Fragment => crate::ShaderStage::Fragment, spirv::ExecutionModel::GLCompute => crate::ShaderStage::Compute, spirv::ExecutionModel::TaskEXT => crate::ShaderStage::Task, spirv::ExecutionModel::MeshEXT => crate::ShaderStage::Mesh, _ => return Err(Error::UnsupportedExecutionModel(exec_model as u32)), }, name, early_depth_test: None, workgroup_size: [0; 3], variable_ids: self.data.by_ref().take(left as usize).collect(), }; self.lookup_entry_point.insert(function_id, ep); Ok(()) } fn parse_execution_mode(&mut self, inst: Instruction) -> Result<(), Error> { use spirv::ExecutionMode; self.switch(ModuleState::ExecutionMode, inst.op)?; inst.expect_at_least(3)?; let ep_id = self.next()?; let mode_id = self.next()?; let args: Vec = self.data.by_ref().take(inst.wc as usize - 3).collect(); let ep = self .lookup_entry_point .get_mut(&ep_id) .ok_or(Error::InvalidId(ep_id))?; let mode = ExecutionMode::from_u32(mode_id).ok_or(Error::UnsupportedExecutionMode(mode_id))?; match mode { ExecutionMode::EarlyFragmentTests => { ep.early_depth_test = Some(crate::EarlyDepthTest::Force); } ExecutionMode::DepthUnchanged => { if let &mut Some(ref mut early_depth_test) = &mut ep.early_depth_test { if let &mut crate::EarlyDepthTest::Allow { ref mut conservative, } = early_depth_test { *conservative = crate::ConservativeDepth::Unchanged; } } else { ep.early_depth_test = Some(crate::EarlyDepthTest::Allow { conservative: crate::ConservativeDepth::Unchanged, }); } } ExecutionMode::DepthGreater => { if let &mut Some(ref mut early_depth_test) = &mut ep.early_depth_test { if let &mut crate::EarlyDepthTest::Allow { ref mut conservative, } = early_depth_test { *conservative = crate::ConservativeDepth::GreaterEqual; } } else { ep.early_depth_test = Some(crate::EarlyDepthTest::Allow { conservative: crate::ConservativeDepth::GreaterEqual, }); } } ExecutionMode::DepthLess => { if let &mut Some(ref mut early_depth_test) = &mut ep.early_depth_test { if let &mut crate::EarlyDepthTest::Allow { ref mut conservative, } = early_depth_test { *conservative = crate::ConservativeDepth::LessEqual; } } else { ep.early_depth_test = Some(crate::EarlyDepthTest::Allow { conservative: crate::ConservativeDepth::LessEqual, }); } } ExecutionMode::DepthReplacing => { // Ignored because it can be deduced from the IR. } ExecutionMode::OriginUpperLeft => { // Ignored because the other option (OriginLowerLeft) is not valid in Vulkan mode. } ExecutionMode::LocalSize => { ep.workgroup_size = [args[0], args[1], args[2]]; } _ => { return Err(Error::UnsupportedExecutionMode(mode_id)); } } Ok(()) } fn parse_string(&mut self, inst: Instruction) -> Result<(), Error> { self.switch(ModuleState::Source, inst.op)?; inst.expect_at_least(3)?; let _id = self.next()?; let (_name, _) = self.next_string(inst.wc - 2)?; Ok(()) } fn parse_source(&mut self, inst: Instruction) -> Result<(), Error> { self.switch(ModuleState::Source, inst.op)?; for _ in 1..inst.wc { let _ = self.next()?; } Ok(()) } fn parse_source_extension(&mut self, inst: Instruction) -> Result<(), Error> { self.switch(ModuleState::Source, inst.op)?; inst.expect_at_least(2)?; let (_name, _) = self.next_string(inst.wc - 1)?; Ok(()) } fn parse_name(&mut self, inst: Instruction) -> Result<(), Error> { self.switch(ModuleState::Name, inst.op)?; inst.expect_at_least(3)?; let id = self.next()?; let (name, left) = self.next_string(inst.wc - 2)?; if left != 0 { return Err(Error::InvalidOperand); } self.future_decor.entry(id).or_default().name = Some(name); Ok(()) } fn parse_member_name(&mut self, inst: Instruction) -> Result<(), Error> { self.switch(ModuleState::Name, inst.op)?; inst.expect_at_least(4)?; let id = self.next()?; let member = self.next()?; let (name, left) = self.next_string(inst.wc - 3)?; if left != 0 { return Err(Error::InvalidOperand); } self.future_member_decor .entry((id, member)) .or_default() .name = Some(name); Ok(()) } fn parse_module_processed(&mut self, inst: Instruction) -> Result<(), Error> { self.switch(ModuleState::Name, inst.op)?; inst.expect_at_least(2)?; let (_info, left) = self.next_string(inst.wc - 1)?; //Note: string is ignored if left != 0 { return Err(Error::InvalidOperand); } Ok(()) } fn parse_decorate(&mut self, inst: Instruction) -> Result<(), Error> { self.switch(ModuleState::Annotation, inst.op)?; inst.expect_at_least(3)?; let id = self.next()?; let mut dec = self.future_decor.remove(&id).unwrap_or_default(); self.next_decoration(inst, 2, &mut dec)?; self.future_decor.insert(id, dec); Ok(()) } fn parse_member_decorate(&mut self, inst: Instruction) -> Result<(), Error> { self.switch(ModuleState::Annotation, inst.op)?; inst.expect_at_least(4)?; let id = self.next()?; let member = self.next()?; let mut dec = self .future_member_decor .remove(&(id, member)) .unwrap_or_default(); self.next_decoration(inst, 3, &mut dec)?; self.future_member_decor.insert((id, member), dec); Ok(()) } fn parse_type_void(&mut self, inst: Instruction) -> Result<(), Error> { self.switch(ModuleState::Type, inst.op)?; inst.expect(2)?; let id = self.next()?; self.lookup_void_type = Some(id); Ok(()) } fn parse_type_bool( &mut self, inst: Instruction, module: &mut crate::Module, ) -> Result<(), Error> { let start = self.data_offset; self.switch(ModuleState::Type, inst.op)?; inst.expect(2)?; let id = self.next()?; let inner = crate::TypeInner::Scalar(crate::Scalar::BOOL); self.lookup_type.insert( id, LookupType { handle: module.types.insert( crate::Type { name: self.future_decor.remove(&id).and_then(|dec| dec.name), inner, }, self.span_from_with_op(start), ), base_id: None, }, ); Ok(()) } fn parse_type_int( &mut self, inst: Instruction, module: &mut crate::Module, ) -> Result<(), Error> { let start = self.data_offset; self.switch(ModuleState::Type, inst.op)?; inst.expect(4)?; let id = self.next()?; let width = self.next()?; let sign = self.next()?; let inner = crate::TypeInner::Scalar(crate::Scalar { kind: match sign { 0 => crate::ScalarKind::Uint, 1 => crate::ScalarKind::Sint, _ => return Err(Error::InvalidSign(sign)), }, width: map_width(width)?, }); self.lookup_type.insert( id, LookupType { handle: module.types.insert( crate::Type { name: self.future_decor.remove(&id).and_then(|dec| dec.name), inner, }, self.span_from_with_op(start), ), base_id: None, }, ); Ok(()) } fn parse_type_float( &mut self, inst: Instruction, module: &mut crate::Module, ) -> Result<(), Error> { let start = self.data_offset; self.switch(ModuleState::Type, inst.op)?; inst.expect(3)?; let id = self.next()?; let width = self.next()?; let inner = crate::TypeInner::Scalar(crate::Scalar::float(map_width(width)?)); self.lookup_type.insert( id, LookupType { handle: module.types.insert( crate::Type { name: self.future_decor.remove(&id).and_then(|dec| dec.name), inner, }, self.span_from_with_op(start), ), base_id: None, }, ); Ok(()) } fn parse_type_vector( &mut self, inst: Instruction, module: &mut crate::Module, ) -> Result<(), Error> { let start = self.data_offset; self.switch(ModuleState::Type, inst.op)?; inst.expect(4)?; let id = self.next()?; let type_id = self.next()?; let type_lookup = self.lookup_type.lookup(type_id)?; let scalar = match module.types[type_lookup.handle].inner { crate::TypeInner::Scalar(scalar) => scalar, _ => return Err(Error::InvalidInnerType(type_id)), }; let component_count = self.next()?; let inner = crate::TypeInner::Vector { size: map_vector_size(component_count)?, scalar, }; self.lookup_type.insert( id, LookupType { handle: module.types.insert( crate::Type { name: self.future_decor.remove(&id).and_then(|dec| dec.name), inner, }, self.span_from_with_op(start), ), base_id: Some(type_id), }, ); Ok(()) } fn parse_type_matrix( &mut self, inst: Instruction, module: &mut crate::Module, ) -> Result<(), Error> { let start = self.data_offset; self.switch(ModuleState::Type, inst.op)?; inst.expect(4)?; let id = self.next()?; let vector_type_id = self.next()?; let num_columns = self.next()?; let decor = self.future_decor.remove(&id); let vector_type_lookup = self.lookup_type.lookup(vector_type_id)?; let inner = match module.types[vector_type_lookup.handle].inner { crate::TypeInner::Vector { size, scalar } => crate::TypeInner::Matrix { columns: map_vector_size(num_columns)?, rows: size, scalar, }, _ => return Err(Error::InvalidInnerType(vector_type_id)), }; self.lookup_type.insert( id, LookupType { handle: module.types.insert( crate::Type { name: decor.and_then(|dec| dec.name), inner, }, self.span_from_with_op(start), ), base_id: Some(vector_type_id), }, ); Ok(()) } fn parse_type_function(&mut self, inst: Instruction) -> Result<(), Error> { self.switch(ModuleState::Type, inst.op)?; inst.expect_at_least(3)?; let id = self.next()?; let return_type_id = self.next()?; let parameter_type_ids = self.data.by_ref().take(inst.wc as usize - 3).collect(); self.lookup_function_type.insert( id, LookupFunctionType { parameter_type_ids, return_type_id, }, ); Ok(()) } fn parse_type_pointer( &mut self, inst: Instruction, module: &mut crate::Module, ) -> Result<(), Error> { let start = self.data_offset; self.switch(ModuleState::Type, inst.op)?; inst.expect(4)?; let id = self.next()?; let storage_class = self.next()?; let type_id = self.next()?; let decor = self.future_decor.remove(&id); let base_lookup_ty = self.lookup_type.lookup(type_id)?; let base_inner = &module.types[base_lookup_ty.handle].inner; let space = if let Some(space) = base_inner.pointer_space() { space } else if self .lookup_storage_buffer_types .contains_key(&base_lookup_ty.handle) { crate::AddressSpace::Storage { access: crate::StorageAccess::default(), } } else { match map_storage_class(storage_class)? { ExtendedClass::Global(space) => space, ExtendedClass::Input | ExtendedClass::Output => crate::AddressSpace::Private, } }; // We don't support pointers to runtime-sized arrays in the `Uniform` // storage class with the `BufferBlock` decoration. Runtime-sized arrays // should be in the StorageBuffer class. if let crate::TypeInner::Array { size: crate::ArraySize::Dynamic, .. } = *base_inner { match space { crate::AddressSpace::Storage { .. } => {} _ => { return Err(Error::UnsupportedRuntimeArrayStorageClass); } } } // Don't bother with pointer stuff for `Handle` types. let lookup_ty = if space == crate::AddressSpace::Handle { base_lookup_ty.clone() } else { LookupType { handle: module.types.insert( crate::Type { name: decor.and_then(|dec| dec.name), inner: crate::TypeInner::Pointer { base: base_lookup_ty.handle, space, }, }, self.span_from_with_op(start), ), base_id: Some(type_id), } }; self.lookup_type.insert(id, lookup_ty); Ok(()) } fn parse_type_array( &mut self, inst: Instruction, module: &mut crate::Module, ) -> Result<(), Error> { let start = self.data_offset; self.switch(ModuleState::Type, inst.op)?; inst.expect(4)?; let id = self.next()?; let type_id = self.next()?; let length_id = self.next()?; let length_const = self.lookup_constant.lookup(length_id)?; let size = resolve_constant(module.to_ctx(), &length_const.inner) .and_then(NonZeroU32::new) .ok_or(Error::InvalidArraySize(length_id))?; let decor = self.future_decor.remove(&id).unwrap_or_default(); let base = self.lookup_type.lookup(type_id)?.handle; self.layouter.update(module.to_ctx()).unwrap(); // HACK if the underlying type is an image or a sampler, let's assume // that we're dealing with a binding-array // // Note that it's not a strictly correct assumption, but rather a trade // off caused by an impedance mismatch between SPIR-V's and Naga's type // systems - Naga distinguishes between arrays and binding-arrays via // types (i.e. both kinds of arrays are just different types), while // SPIR-V distinguishes between them through usage - e.g. given: // // ``` // %image = OpTypeImage %float 2D 2 0 0 2 Rgba16f // %uint_256 = OpConstant %uint 256 // %image_array = OpTypeArray %image %uint_256 // ``` // // ``` // %image = OpTypeImage %float 2D 2 0 0 2 Rgba16f // %uint_256 = OpConstant %uint 256 // %image_array = OpTypeArray %image %uint_256 // %image_array_ptr = OpTypePointer UniformConstant %image_array // ``` // // ... in the first case, `%image_array` should technically correspond // to `TypeInner::Array`, while in the second case it should say // `TypeInner::BindingArray` (kinda, depending on whether `%image_array` // is ever used as a freestanding type or rather always through the // pointer-indirection). // // Anyway, at the moment we don't support other kinds of image / sampler // arrays than those binding-based, so this assumption is pretty safe // for now. let inner = if let crate::TypeInner::Image { .. } | crate::TypeInner::Sampler { .. } = module.types[base].inner { crate::TypeInner::BindingArray { base, size: crate::ArraySize::Constant(size), } } else { crate::TypeInner::Array { base, size: crate::ArraySize::Constant(size), stride: match decor.array_stride { Some(stride) => stride.get(), None => self.layouter[base].to_stride(), }, } }; self.lookup_type.insert( id, LookupType { handle: module.types.insert( crate::Type { name: decor.name, inner, }, self.span_from_with_op(start), ), base_id: Some(type_id), }, ); Ok(()) } fn parse_type_runtime_array( &mut self, inst: Instruction, module: &mut crate::Module, ) -> Result<(), Error> { let start = self.data_offset; self.switch(ModuleState::Type, inst.op)?; inst.expect(3)?; let id = self.next()?; let type_id = self.next()?; let decor = self.future_decor.remove(&id).unwrap_or_default(); let base = self.lookup_type.lookup(type_id)?.handle; self.layouter.update(module.to_ctx()).unwrap(); // HACK same case as in `parse_type_array()` let inner = if let crate::TypeInner::Image { .. } | crate::TypeInner::Sampler { .. } = module.types[base].inner { crate::TypeInner::BindingArray { base: self.lookup_type.lookup(type_id)?.handle, size: crate::ArraySize::Dynamic, } } else { crate::TypeInner::Array { base: self.lookup_type.lookup(type_id)?.handle, size: crate::ArraySize::Dynamic, stride: match decor.array_stride { Some(stride) => stride.get(), None => self.layouter[base].to_stride(), }, } }; self.lookup_type.insert( id, LookupType { handle: module.types.insert( crate::Type { name: decor.name, inner, }, self.span_from_with_op(start), ), base_id: Some(type_id), }, ); Ok(()) } fn parse_type_struct( &mut self, inst: Instruction, module: &mut crate::Module, ) -> Result<(), Error> { let start = self.data_offset; self.switch(ModuleState::Type, inst.op)?; inst.expect_at_least(2)?; let id = self.next()?; let parent_decor = self.future_decor.remove(&id); let is_storage_buffer = parent_decor .as_ref() .is_some_and(|decor| decor.storage_buffer); self.layouter.update(module.to_ctx()).unwrap(); let mut members = Vec::::with_capacity(inst.wc as usize - 2); let mut member_lookups = Vec::with_capacity(members.capacity()); let mut storage_access = crate::StorageAccess::empty(); let mut span = 0; let mut alignment = Alignment::ONE; for i in 0..u32::from(inst.wc) - 2 { let type_id = self.next()?; let ty = self.lookup_type.lookup(type_id)?.handle; let decor = self .future_member_decor .remove(&(id, i)) .unwrap_or_default(); storage_access |= decor.flags.to_storage_access(); member_lookups.push(LookupMember { type_id, row_major: decor.matrix_major == Some(Majority::Row), }); let member_alignment = self.layouter[ty].alignment; span = member_alignment.round_up(span); alignment = member_alignment.max(alignment); let binding = decor.io_binding().ok(); if let Some(offset) = decor.offset { span = offset; } let offset = span; span += self.layouter[ty].size; let inner = &module.types[ty].inner; if let crate::TypeInner::Matrix { columns, rows, scalar, } = *inner { if let Some(stride) = decor.matrix_stride { let expected_stride = Alignment::from(rows) * scalar.width as u32; if stride.get() != expected_stride { return Err(Error::UnsupportedMatrixStride { stride: stride.get(), columns: columns as u8, rows: rows as u8, width: scalar.width, }); } } } members.push(crate::StructMember { name: decor.name, ty, binding, offset, }); } span = alignment.round_up(span); let inner = crate::TypeInner::Struct { span, members }; let ty_handle = module.types.insert( crate::Type { name: parent_decor.and_then(|dec| dec.name), inner, }, self.span_from_with_op(start), ); if is_storage_buffer { self.lookup_storage_buffer_types .insert(ty_handle, storage_access); } for (i, member_lookup) in member_lookups.into_iter().enumerate() { self.lookup_member .insert((ty_handle, i as u32), member_lookup); } self.lookup_type.insert( id, LookupType { handle: ty_handle, base_id: None, }, ); Ok(()) } fn parse_type_image( &mut self, inst: Instruction, module: &mut crate::Module, ) -> Result<(), Error> { let start = self.data_offset; self.switch(ModuleState::Type, inst.op)?; inst.expect(9)?; let id = self.next()?; let sample_type_id = self.next()?; let dim = self.next()?; let is_depth = self.next()?; let is_array = self.next()? != 0; let is_msaa = self.next()? != 0; let is_sampled = self.next()?; let format = self.next()?; let dim = map_image_dim(dim)?; let decor = self.future_decor.remove(&id).unwrap_or_default(); // ensure there is a type for texture coordinate without extra components module.types.insert( crate::Type { name: None, inner: { let scalar = crate::Scalar::F32; match dim.required_coordinate_size() { None => crate::TypeInner::Scalar(scalar), Some(size) => crate::TypeInner::Vector { size, scalar }, } }, }, Default::default(), ); let base_handle = self.lookup_type.lookup(sample_type_id)?.handle; let kind = module.types[base_handle] .inner .scalar_kind() .ok_or(Error::InvalidImageBaseType(base_handle))?; let inner = crate::TypeInner::Image { class: if is_depth == 1 { if is_sampled == 2 { return Err(Error::InvalidImageDepthStorage); } crate::ImageClass::Depth { multi: is_msaa } } // If we have an unknown format and storage texture, this is // StorageRead/WriteWithoutFormat. We don't currently support // this. else if is_sampled == 2 && format == 0 { return Err(Error::InvalidStorageImageWithoutFormat); } // If we have explicit class information (is_sampled = 2 = Storage), use it. // // If we have unknown class information (is_sampled = 0 = Unknown), infer the // class from the presence of an explicit format. else if format != 0 && (is_sampled == 0 || is_sampled == 2) { crate::ImageClass::Storage { format: map_image_format(format)?, access: crate::StorageAccess::default(), } } // We will hit this case either when sampled is 1, or if we have unknown // sampling information or when sampled is 0 and we have no explicit format. else { crate::ImageClass::Sampled { kind, multi: is_msaa, } }, dim, arrayed: is_array, }; let handle = module.types.insert( crate::Type { name: decor.name, inner, }, self.span_from_with_op(start), ); self.lookup_type.insert( id, LookupType { handle, base_id: Some(sample_type_id), }, ); Ok(()) } fn parse_type_sampled_image(&mut self, inst: Instruction) -> Result<(), Error> { self.switch(ModuleState::Type, inst.op)?; inst.expect(3)?; let id = self.next()?; let image_id = self.next()?; self.lookup_type.insert( id, LookupType { handle: self.lookup_type.lookup(image_id)?.handle, base_id: Some(image_id), }, ); Ok(()) } fn parse_type_sampler( &mut self, inst: Instruction, module: &mut crate::Module, ) -> Result<(), Error> { let start = self.data_offset; self.switch(ModuleState::Type, inst.op)?; inst.expect(2)?; let id = self.next()?; let decor = self.future_decor.remove(&id).unwrap_or_default(); let handle = module.types.insert( crate::Type { name: decor.name, inner: crate::TypeInner::Sampler { comparison: false }, }, self.span_from_with_op(start), ); self.lookup_type.insert( id, LookupType { handle, base_id: None, }, ); Ok(()) } fn parse_constant( &mut self, inst: Instruction, module: &mut crate::Module, ) -> Result<(), Error> { let start = self.data_offset; self.switch(ModuleState::Type, inst.op)?; inst.expect_at_least(4)?; let type_id = self.next()?; let id = self.next()?; let type_lookup = self.lookup_type.lookup(type_id)?; let ty = type_lookup.handle; let literal = match module.types[ty].inner { crate::TypeInner::Scalar(crate::Scalar { kind: crate::ScalarKind::Uint, width, }) => { let low = self.next()?; match width { 4 => crate::Literal::U32(low), 8 => { inst.expect(5)?; let high = self.next()?; crate::Literal::U64((u64::from(high) << 32) | u64::from(low)) } _ => return Err(Error::InvalidTypeWidth(width as u32)), } } crate::TypeInner::Scalar(crate::Scalar { kind: crate::ScalarKind::Sint, width, }) => { let low = self.next()?; match width { 4 => crate::Literal::I32(low as i32), 8 => { inst.expect(5)?; let high = self.next()?; crate::Literal::I64(((u64::from(high) << 32) | u64::from(low)) as i64) } _ => return Err(Error::InvalidTypeWidth(width as u32)), } } crate::TypeInner::Scalar(crate::Scalar { kind: crate::ScalarKind::Float, width, }) => { let low = self.next()?; match width { // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#Literal // If a numeric type’s bit width is less than 32-bits, the value appears in the low-order bits of the word. 2 => crate::Literal::F16(f16::from_bits(low as u16)), 4 => crate::Literal::F32(f32::from_bits(low)), 8 => { inst.expect(5)?; let high = self.next()?; crate::Literal::F64(f64::from_bits( (u64::from(high) << 32) | u64::from(low), )) } _ => return Err(Error::InvalidTypeWidth(width as u32)), } } _ => return Err(Error::UnsupportedType(type_lookup.handle)), }; let span = self.span_from_with_op(start); let init = module .global_expressions .append(crate::Expression::Literal(literal), span); self.insert_parsed_constant(module, id, type_id, ty, init, span) } fn parse_composite_constant( &mut self, inst: Instruction, module: &mut crate::Module, ) -> Result<(), Error> { let start = self.data_offset; self.switch(ModuleState::Type, inst.op)?; inst.expect_at_least(3)?; let type_id = self.next()?; let id = self.next()?; let type_lookup = self.lookup_type.lookup(type_id)?; let ty = type_lookup.handle; let mut components = Vec::with_capacity(inst.wc as usize - 3); for _ in 0..components.capacity() { let start = self.data_offset; let component_id = self.next()?; let span = self.span_from_with_op(start); let constant = self.lookup_constant.lookup(component_id)?; let expr = module .global_expressions .append(constant.inner.to_expr(), span); components.push(expr); } let span = self.span_from_with_op(start); let init = module .global_expressions .append(crate::Expression::Compose { ty, components }, span); self.insert_parsed_constant(module, id, type_id, ty, init, span) } fn parse_null_constant( &mut self, inst: Instruction, module: &mut crate::Module, ) -> Result<(), Error> { let start = self.data_offset; self.switch(ModuleState::Type, inst.op)?; inst.expect(3)?; let type_id = self.next()?; let id = self.next()?; let span = self.span_from_with_op(start); let type_lookup = self.lookup_type.lookup(type_id)?; let ty = type_lookup.handle; let init = module .global_expressions .append(crate::Expression::ZeroValue(ty), span); self.insert_parsed_constant(module, id, type_id, ty, init, span) } fn parse_bool_constant( &mut self, inst: Instruction, value: bool, module: &mut crate::Module, ) -> Result<(), Error> { let start = self.data_offset; self.switch(ModuleState::Type, inst.op)?; inst.expect(3)?; let type_id = self.next()?; let id = self.next()?; let span = self.span_from_with_op(start); let type_lookup = self.lookup_type.lookup(type_id)?; let ty = type_lookup.handle; let init = module.global_expressions.append( crate::Expression::Literal(crate::Literal::Bool(value)), span, ); self.insert_parsed_constant(module, id, type_id, ty, init, span) } fn insert_parsed_constant( &mut self, module: &mut crate::Module, id: u32, type_id: u32, ty: Handle, init: Handle, span: crate::Span, ) -> Result<(), Error> { let decor = self.future_decor.remove(&id).unwrap_or_default(); let inner = if let Some(id) = decor.specialization_constant_id { let o = crate::Override { name: decor.name, id: Some(id.try_into().map_err(|_| Error::SpecIdTooHigh(id))?), ty, init: Some(init), }; Constant::Override(module.overrides.append(o, span)) } else { let c = crate::Constant { name: decor.name, ty, init, }; Constant::Constant(module.constants.append(c, span)) }; self.lookup_constant .insert(id, LookupConstant { inner, type_id }); Ok(()) } fn parse_global_variable( &mut self, inst: Instruction, module: &mut crate::Module, ) -> Result<(), Error> { let start = self.data_offset; self.switch(ModuleState::Type, inst.op)?; inst.expect_at_least(4)?; let type_id = self.next()?; let id = self.next()?; let storage_class = self.next()?; let init = if inst.wc > 4 { inst.expect(5)?; let start = self.data_offset; let init_id = self.next()?; let span = self.span_from_with_op(start); let lconst = self.lookup_constant.lookup(init_id)?; let expr = module .global_expressions .append(lconst.inner.to_expr(), span); Some(expr) } else { None }; let span = self.span_from_with_op(start); let dec = self.future_decor.remove(&id).unwrap_or_default(); let original_ty = self.lookup_type.lookup(type_id)?.handle; let mut ty = original_ty; if let crate::TypeInner::Pointer { base, space: _ } = module.types[original_ty].inner { ty = base; } if let crate::TypeInner::BindingArray { .. } = module.types[original_ty].inner { // Inside `parse_type_array()` we guess that an array of images or // samplers must be a binding array, and here we validate that guess if dec.desc_set.is_none() || dec.desc_index.is_none() { return Err(Error::NonBindingArrayOfImageOrSamplers); } } if let crate::TypeInner::Image { dim, arrayed, class: crate::ImageClass::Storage { format, access: _ }, } = module.types[ty].inner { // Storage image types in IR have to contain the access, but not in the SPIR-V. // The same image type in SPIR-V can be used (and has to be used) for multiple images. // So we copy the type out and apply the variable access decorations. let access = dec.flags.to_storage_access(); ty = module.types.insert( crate::Type { name: None, inner: crate::TypeInner::Image { dim, arrayed, class: crate::ImageClass::Storage { format, access }, }, }, Default::default(), ); } let ext_class = match self.lookup_storage_buffer_types.get(&ty) { Some(&access) => ExtendedClass::Global(crate::AddressSpace::Storage { access }), None => map_storage_class(storage_class)?, }; let (inner, var) = match ext_class { ExtendedClass::Global(mut space) => { if let crate::AddressSpace::Storage { ref mut access } = space { *access &= dec.flags.to_storage_access(); } let var = crate::GlobalVariable { binding: dec.resource_binding(), name: dec.name, space, ty, init, memory_decorations: dec.flags.to_memory_decorations(), }; (Variable::Global, var) } ExtendedClass::Input => { let binding = dec.io_binding()?; let mut unsigned_ty = ty; if let crate::Binding::BuiltIn(built_in) = binding { let needs_inner_uint = match built_in { crate::BuiltIn::BaseInstance | crate::BuiltIn::BaseVertex | crate::BuiltIn::InstanceIndex | crate::BuiltIn::SampleIndex | crate::BuiltIn::VertexIndex | crate::BuiltIn::PrimitiveIndex | crate::BuiltIn::LocalInvocationIndex => { Some(crate::TypeInner::Scalar(crate::Scalar::U32)) } crate::BuiltIn::GlobalInvocationId | crate::BuiltIn::LocalInvocationId | crate::BuiltIn::WorkGroupId | crate::BuiltIn::WorkGroupSize => Some(crate::TypeInner::Vector { size: crate::VectorSize::Tri, scalar: crate::Scalar::U32, }), crate::BuiltIn::Barycentric { perspective: false } => { Some(crate::TypeInner::Vector { size: crate::VectorSize::Tri, scalar: crate::Scalar::F32, }) } _ => None, }; if let (Some(inner), Some(crate::ScalarKind::Sint)) = (needs_inner_uint, module.types[ty].inner.scalar_kind()) { unsigned_ty = module .types .insert(crate::Type { name: None, inner }, Default::default()); } } let var = crate::GlobalVariable { name: dec.name.clone(), space: crate::AddressSpace::Private, binding: None, ty, init: None, memory_decorations: crate::MemoryDecorations::empty(), }; let inner = Variable::Input(crate::FunctionArgument { name: dec.name, ty: unsigned_ty, binding: Some(binding), }); (inner, var) } ExtendedClass::Output => { // For output interface blocks, this would be a structure. let binding = dec.io_binding().ok(); let init = match binding { Some(crate::Binding::BuiltIn(built_in)) => { match null::generate_default_built_in( Some(built_in), ty, &mut module.global_expressions, span, ) { Ok(handle) => Some(handle), Err(e) => { log::warn!("Failed to initialize output built-in: {e}"); None } } } Some(crate::Binding::Location { .. }) => None, None => match module.types[ty].inner { crate::TypeInner::Struct { ref members, .. } => { let mut components = Vec::with_capacity(members.len()); for member in members.iter() { let built_in = match member.binding { Some(crate::Binding::BuiltIn(built_in)) => Some(built_in), _ => None, }; let handle = null::generate_default_built_in( built_in, member.ty, &mut module.global_expressions, span, )?; components.push(handle); } Some( module .global_expressions .append(crate::Expression::Compose { ty, components }, span), ) } _ => None, }, }; let var = crate::GlobalVariable { name: dec.name, space: crate::AddressSpace::Private, binding: None, ty, init, memory_decorations: crate::MemoryDecorations::empty(), }; let inner = Variable::Output(crate::FunctionResult { ty, binding }); (inner, var) } }; let handle = module.global_variables.append(var, span); if module.types[ty].inner.can_comparison_sample(module) { log::debug!("\t\ttracking {handle:?} for sampling properties"); self.handle_sampling .insert(handle, image::SamplingFlags::empty()); } self.lookup_variable.insert( id, LookupVariable { inner, handle, type_id, }, ); Ok(()) } /// Record an atomic access to some component of a global variable. /// /// Given `handle`, an expression referring to a scalar that has had an /// atomic operation applied to it, descend into the expression, noting /// which global variable it ultimately refers to, and which struct fields /// of that global's value it accesses. /// /// Return the handle of the type of the expression. /// /// If the expression doesn't actually refer to something in a global /// variable, we can't upgrade its type in a way that Naga validation would /// pass, so reject the input instead. fn record_atomic_access( &mut self, ctx: &BlockContext, handle: Handle, ) -> Result, Error> { log::debug!("\t\tlocating global variable in {handle:?}"); match ctx.expressions[handle] { crate::Expression::Access { base, index } => { log::debug!("\t\t access {handle:?} {index:?}"); let ty = self.record_atomic_access(ctx, base)?; let crate::TypeInner::Array { base, .. } = ctx.module.types[ty].inner else { unreachable!("Atomic operations on Access expressions only work for arrays"); }; Ok(base) } crate::Expression::AccessIndex { base, index } => { log::debug!("\t\t access index {handle:?} {index:?}"); let ty = self.record_atomic_access(ctx, base)?; match ctx.module.types[ty].inner { crate::TypeInner::Struct { ref members, .. } => { let index = index as usize; self.upgrade_atomics.insert_field(ty, index); Ok(members[index].ty) } crate::TypeInner::Array { base, .. } => { Ok(base) } _ => unreachable!("Atomic operations on AccessIndex expressions only work for structs and arrays"), } } crate::Expression::GlobalVariable(h) => { log::debug!("\t\t found {h:?}"); self.upgrade_atomics.insert_global(h); Ok(ctx.module.global_variables[h].ty) } _ => Err(Error::AtomicUpgradeError( crate::front::atomic_upgrade::Error::GlobalVariableMissing, )), } } } fn resolve_constant(gctx: crate::proc::GlobalCtx, constant: &Constant) -> Option { let constant = match *constant { Constant::Constant(constant) => constant, Constant::Override(_) => return None, }; match gctx.global_expressions[gctx.constants[constant].init] { crate::Expression::Literal(crate::Literal::U32(id)) => Some(id), crate::Expression::Literal(crate::Literal::I32(id)) => Some(id as u32), _ => None, } } pub fn parse_u8_slice(data: &[u8], options: &Options) -> Result { if !data.len().is_multiple_of(4) { return Err(Error::IncompleteData); } let words = data .chunks(4) .map(|c| u32::from_le_bytes(c.try_into().unwrap())); Frontend::new(words, options).parse() } /// Helper function to check if `child` is in the scope of `parent` fn is_parent(mut child: usize, parent: usize, block_ctx: &BlockContext) -> bool { loop { if child == parent { // The child is in the scope parent break true; } else if child == 0 { // Searched finished at the root the child isn't in the parent's body break false; } child = block_ctx.bodies[child].parent; } } #[cfg(test)] mod test { use alloc::vec; #[test] fn parse() { let bin = vec![ // Magic number. Version number: 1.0. 0x03, 0x02, 0x23, 0x07, 0x00, 0x00, 0x01, 0x00, // Generator number: 0. Bound: 0. 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Reserved word: 0. 0x00, 0x00, 0x00, 0x00, // OpMemoryModel. Logical. 0x0e, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, // GLSL450. 0x01, 0x00, 0x00, 0x00, ]; let _ = super::parse_u8_slice(&bin, &Default::default()).unwrap(); } } naga-29.0.3/src/front/spv/next_block.rs000064400000000000000000004301671046102023000160650ustar 00000000000000//! Implementation of [`Frontend::next_block()`]. //! //! This method is split out into its own module purely because it is so long. use alloc::{format, vec, vec::Vec}; use crate::front::spv::{ convert::{map_binary_operator, map_relational_fun}, image, resolve_constant, BlockContext, Body, BodyFragment, Constant, Error, Frontend, LookupExpression, LookupHelper as _, LookupLoadOverride, MergeBlockInformation, PhiExpression, SignAnchor, }; use crate::Handle; impl> Frontend { /// Add the next SPIR-V block's contents to `block_ctx`. /// /// Except for the function's entry block, `block_id` should be the label of /// a block we've seen mentioned before, with an entry in /// `block_ctx.body_for_label` to tell us which `Body` it contributes to. pub(in crate::front::spv) fn next_block( &mut self, block_id: spirv::Word, ctx: &mut BlockContext, ) -> Result<(), Error> { // Extend `body` with the correct form for a branch to `target`. fn merger(body: &mut Body, target: &MergeBlockInformation) { body.data.push(match *target { MergeBlockInformation::LoopContinue => BodyFragment::Continue, MergeBlockInformation::LoopMerge | MergeBlockInformation::SwitchMerge => { BodyFragment::Break } // Finishing a selection merge means just falling off the end of // the `accept` or `reject` block of the `If` statement. MergeBlockInformation::SelectionMerge => return, }) } let mut emitter = crate::proc::Emitter::default(); emitter.start(ctx.expressions); // Find the `Body` to which this block contributes. // // If this is some SPIR-V structured control flow construct's merge // block, then `body_idx` will refer to the same `Body` as the header, // so that we simply pick up accumulating the `Body` where the header // left off. Each of the statements in a block dominates the next, so // we're sure to encounter their SPIR-V blocks in order, ensuring that // the `Body` will be assembled in the proper order. // // Note that, unlike every other kind of SPIR-V block, we don't know the // function's first block's label in advance. Thus, we assume that if // this block has no entry in `ctx.body_for_label`, it must be the // function's first block. This always has body index zero. let mut body_idx = *ctx.body_for_label.entry(block_id).or_default(); // The Naga IR block this call builds. This will end up as // `ctx.blocks[&block_id]`, and `ctx.bodies[body_idx]` will refer to it // via a `BodyFragment::BlockId`. let mut block = crate::Block::new(); // Stores the merge block as defined by a `OpSelectionMerge` otherwise is `None` // // This is used in `OpSwitch` to promote the `MergeBlockInformation` from // `SelectionMerge` to `SwitchMerge` to allow `Break`s this isn't desirable for // `LoopMerge`s because otherwise `Continue`s wouldn't be allowed let mut selection_merge_block = None; macro_rules! get_expr_handle { ($id:expr, $lexp:expr) => { self.get_expr_handle($id, $lexp, ctx, &mut emitter, &mut block, body_idx) }; } macro_rules! parse_expr_op { ($op:expr, BINARY) => { self.parse_expr_binary_op(ctx, &mut emitter, &mut block, block_id, body_idx, $op) }; ($op:expr, SHIFT) => { self.parse_expr_shift_op(ctx, &mut emitter, &mut block, block_id, body_idx, $op) }; ($op:expr, UNARY) => { self.parse_expr_unary_op(ctx, &mut emitter, &mut block, block_id, body_idx, $op) }; ($axis:expr, $ctrl:expr, DERIVATIVE) => { self.parse_expr_derivative( ctx, &mut emitter, &mut block, block_id, body_idx, ($axis, $ctrl), ) }; } let terminator = loop { use spirv::Op; let start = self.data_offset; let inst = self.next_inst()?; let span = crate::Span::from(start..(start + 4 * (inst.wc as usize))); log::debug!("\t\t{:?} [{}]", inst.op, inst.wc); match inst.op { Op::Line => { inst.expect(4)?; let _file_id = self.next()?; let _row_id = self.next()?; let _col_id = self.next()?; } Op::NoLine => inst.expect(1)?, Op::Undef => { inst.expect(3)?; let type_id = self.next()?; let id = self.next()?; let type_lookup = self.lookup_type.lookup(type_id)?; let ty = type_lookup.handle; self.lookup_expression.insert( id, LookupExpression { handle: ctx .expressions .append(crate::Expression::ZeroValue(ty), span), type_id, block_id, }, ); } Op::Variable => { inst.expect_at_least(4)?; block.extend(emitter.finish(ctx.expressions)); let result_type_id = self.next()?; let result_id = self.next()?; let _storage_class = self.next()?; let init = if inst.wc > 4 { inst.expect(5)?; let init_id = self.next()?; let lconst = self.lookup_constant.lookup(init_id)?; Some(ctx.expressions.append(lconst.inner.to_expr(), span)) } else { None }; let name = self .future_decor .remove(&result_id) .and_then(|decor| decor.name); if let Some(ref name) = name { log::debug!("\t\t\tid={result_id} name={name}"); } let lookup_ty = self.lookup_type.lookup(result_type_id)?; let var_handle = ctx.local_arena.append( crate::LocalVariable { name, ty: match ctx.module.types[lookup_ty.handle].inner { crate::TypeInner::Pointer { base, .. } => base, _ => lookup_ty.handle, }, init, }, span, ); self.lookup_expression.insert( result_id, LookupExpression { handle: ctx .expressions .append(crate::Expression::LocalVariable(var_handle), span), type_id: result_type_id, block_id, }, ); emitter.start(ctx.expressions); } Op::Phi => { inst.expect_at_least(3)?; block.extend(emitter.finish(ctx.expressions)); let result_type_id = self.next()?; let result_id = self.next()?; let name = format!("phi_{result_id}"); let local = ctx.local_arena.append( crate::LocalVariable { name: Some(name), ty: self.lookup_type.lookup(result_type_id)?.handle, init: None, }, self.span_from(start), ); let pointer = ctx .expressions .append(crate::Expression::LocalVariable(local), span); let in_count = (inst.wc - 3) / 2; let mut phi = PhiExpression { local, expressions: Vec::with_capacity(in_count as usize), }; for _ in 0..in_count { let expr = self.next()?; let block = self.next()?; phi.expressions.push((expr, block)); } ctx.phis.push(phi); emitter.start(ctx.expressions); // Associate the lookup with an actual value, which is emitted // into the current block. self.lookup_expression.insert( result_id, LookupExpression { handle: ctx .expressions .append(crate::Expression::Load { pointer }, span), type_id: result_type_id, block_id, }, ); } Op::AccessChain | Op::InBoundsAccessChain => { struct AccessExpression { base_handle: Handle, type_id: spirv::Word, load_override: Option, } inst.expect_at_least(4)?; let result_type_id = self.next()?; let result_id = self.next()?; let base_id = self.next()?; log::trace!("\t\t\tlooking up expr {base_id:?}"); let mut acex = { let lexp = self.lookup_expression.lookup(base_id)?; let lty = self.lookup_type.lookup(lexp.type_id)?; // HACK `OpAccessChain` and `OpInBoundsAccessChain` // require for the result type to be a pointer, but if // we're given a pointer to an image / sampler, it will // be *already* dereferenced, since we do that early // during `parse_type_pointer()`. // // This can happen only through `BindingArray`, since // that's the only case where one can obtain a pointer // to an image / sampler, and so let's match on that: let dereference = match ctx.module.types[lty.handle].inner { crate::TypeInner::BindingArray { .. } => false, _ => true, }; let type_id = if dereference { lty.base_id.ok_or(Error::InvalidAccessType(lexp.type_id))? } else { lexp.type_id }; AccessExpression { base_handle: get_expr_handle!(base_id, lexp), type_id, load_override: self.lookup_load_override.get(&base_id).cloned(), } }; for _ in 4..inst.wc { let access_id = self.next()?; log::trace!("\t\t\tlooking up index expr {access_id:?}"); let index_expr = self.lookup_expression.lookup(access_id)?.clone(); let index_expr_handle = get_expr_handle!(access_id, &index_expr); let index_expr_data = &ctx.expressions[index_expr.handle]; let index_maybe = match *index_expr_data { crate::Expression::Constant(const_handle) => Some( ctx.gctx() .get_const_val(ctx.module.constants[const_handle].init) .map_err(|_| { Error::InvalidAccess(crate::Expression::Constant( const_handle, )) })?, ), _ => None, }; log::trace!("\t\t\tlooking up type {:?}", acex.type_id); let type_lookup = self.lookup_type.lookup(acex.type_id)?; let ty = &ctx.module.types[type_lookup.handle]; acex = match ty.inner { // can only index a struct with a constant crate::TypeInner::Struct { ref members, .. } => { let index = index_maybe .ok_or_else(|| Error::InvalidAccess(index_expr_data.clone()))?; let lookup_member = self .lookup_member .get(&(type_lookup.handle, index)) .ok_or(Error::InvalidAccessType(acex.type_id))?; let base_handle = ctx.expressions.append( crate::Expression::AccessIndex { base: acex.base_handle, index, }, span, ); if let Some(crate::Binding::BuiltIn(built_in)) = members[index as usize].binding { self.gl_per_vertex_builtin_access.insert(built_in); } AccessExpression { base_handle, type_id: lookup_member.type_id, load_override: if lookup_member.row_major { debug_assert!(acex.load_override.is_none()); let sub_type_lookup = self.lookup_type.lookup(lookup_member.type_id)?; Some(match ctx.module.types[sub_type_lookup.handle].inner { // load it transposed, to match column major expectations crate::TypeInner::Matrix { .. } => { let loaded = ctx.expressions.append( crate::Expression::Load { pointer: base_handle, }, span, ); let transposed = ctx.expressions.append( crate::Expression::Math { fun: crate::MathFunction::Transpose, arg: loaded, arg1: None, arg2: None, arg3: None, }, span, ); LookupLoadOverride::Loaded(transposed) } _ => LookupLoadOverride::Pending, }) } else { None }, } } crate::TypeInner::Matrix { .. } => { let load_override = match acex.load_override { // We are indexing inside a row-major matrix Some(LookupLoadOverride::Loaded(load_expr)) => { let index = index_maybe.ok_or_else(|| { Error::InvalidAccess(index_expr_data.clone()) })?; let sub_handle = ctx.expressions.append( crate::Expression::AccessIndex { base: load_expr, index, }, span, ); Some(LookupLoadOverride::Loaded(sub_handle)) } _ => None, }; let sub_expr = match index_maybe { Some(index) => crate::Expression::AccessIndex { base: acex.base_handle, index, }, None => crate::Expression::Access { base: acex.base_handle, index: index_expr_handle, }, }; AccessExpression { base_handle: ctx.expressions.append(sub_expr, span), type_id: type_lookup .base_id .ok_or(Error::InvalidAccessType(acex.type_id))?, load_override, } } // This must be a vector or an array. _ => { let base_handle = ctx.expressions.append( crate::Expression::Access { base: acex.base_handle, index: index_expr_handle, }, span, ); let load_override = match acex.load_override { // If there is a load override in place, then we always end up // with a side-loaded value here. Some(lookup_load_override) => { let sub_expr = match lookup_load_override { // We must be indexing into the array of row-major matrices. // Let's load the result of indexing and transpose it. LookupLoadOverride::Pending => { let loaded = ctx.expressions.append( crate::Expression::Load { pointer: base_handle, }, span, ); ctx.expressions.append( crate::Expression::Math { fun: crate::MathFunction::Transpose, arg: loaded, arg1: None, arg2: None, arg3: None, }, span, ) } // We are indexing inside a row-major matrix. LookupLoadOverride::Loaded(load_expr) => { ctx.expressions.append( crate::Expression::Access { base: load_expr, index: index_expr_handle, }, span, ) } }; Some(LookupLoadOverride::Loaded(sub_expr)) } None => None, }; AccessExpression { base_handle, type_id: type_lookup .base_id .ok_or(Error::InvalidAccessType(acex.type_id))?, load_override, } } }; } if let Some(load_expr) = acex.load_override { self.lookup_load_override.insert(result_id, load_expr); } let lookup_expression = LookupExpression { handle: acex.base_handle, type_id: result_type_id, block_id, }; self.lookup_expression.insert(result_id, lookup_expression); } Op::VectorExtractDynamic => { inst.expect(5)?; let result_type_id = self.next()?; let id = self.next()?; let composite_id = self.next()?; let index_id = self.next()?; let root_lexp = self.lookup_expression.lookup(composite_id)?; let root_handle = get_expr_handle!(composite_id, root_lexp); let root_type_lookup = self.lookup_type.lookup(root_lexp.type_id)?; let index_lexp = self.lookup_expression.lookup(index_id)?; let index_handle = get_expr_handle!(index_id, index_lexp); let index_type = self.lookup_type.lookup(index_lexp.type_id)?.handle; let num_components = match ctx.module.types[root_type_lookup.handle].inner { crate::TypeInner::Vector { size, .. } => size as u32, _ => return Err(Error::InvalidVectorType(root_type_lookup.handle)), }; let mut make_index = |ctx: &mut BlockContext, index: u32| { make_index_literal( ctx, index, &mut block, &mut emitter, index_type, index_lexp.type_id, span, ) }; let index_expr = make_index(ctx, 0)?; let mut handle = ctx.expressions.append( crate::Expression::Access { base: root_handle, index: index_expr, }, span, ); for index in 1..num_components { let index_expr = make_index(ctx, index)?; let access_expr = ctx.expressions.append( crate::Expression::Access { base: root_handle, index: index_expr, }, span, ); let cond = ctx.expressions.append( crate::Expression::Binary { op: crate::BinaryOperator::Equal, left: index_expr, right: index_handle, }, span, ); handle = ctx.expressions.append( crate::Expression::Select { condition: cond, accept: access_expr, reject: handle, }, span, ); } self.lookup_expression.insert( id, LookupExpression { handle, type_id: result_type_id, block_id, }, ); } Op::VectorInsertDynamic => { inst.expect(6)?; let result_type_id = self.next()?; let id = self.next()?; let composite_id = self.next()?; let object_id = self.next()?; let index_id = self.next()?; let object_lexp = self.lookup_expression.lookup(object_id)?; let object_handle = get_expr_handle!(object_id, object_lexp); let root_lexp = self.lookup_expression.lookup(composite_id)?; let root_handle = get_expr_handle!(composite_id, root_lexp); let root_type_lookup = self.lookup_type.lookup(root_lexp.type_id)?; let index_lexp = self.lookup_expression.lookup(index_id)?; let index_handle = get_expr_handle!(index_id, index_lexp); let index_type = self.lookup_type.lookup(index_lexp.type_id)?.handle; let num_components = match ctx.module.types[root_type_lookup.handle].inner { crate::TypeInner::Vector { size, .. } => size as u32, _ => return Err(Error::InvalidVectorType(root_type_lookup.handle)), }; let mut components = Vec::with_capacity(num_components as usize); for index in 0..num_components { let index_expr = make_index_literal( ctx, index, &mut block, &mut emitter, index_type, index_lexp.type_id, span, )?; let access_expr = ctx.expressions.append( crate::Expression::Access { base: root_handle, index: index_expr, }, span, ); let cond = ctx.expressions.append( crate::Expression::Binary { op: crate::BinaryOperator::Equal, left: index_expr, right: index_handle, }, span, ); let handle = ctx.expressions.append( crate::Expression::Select { condition: cond, accept: object_handle, reject: access_expr, }, span, ); components.push(handle); } let handle = ctx.expressions.append( crate::Expression::Compose { ty: root_type_lookup.handle, components, }, span, ); self.lookup_expression.insert( id, LookupExpression { handle, type_id: result_type_id, block_id, }, ); } Op::CompositeExtract => { inst.expect_at_least(4)?; let result_type_id = self.next()?; let result_id = self.next()?; let base_id = self.next()?; log::trace!("\t\t\tlooking up expr {base_id:?}"); let mut lexp = self.lookup_expression.lookup(base_id)?.clone(); lexp.handle = get_expr_handle!(base_id, &lexp); for _ in 4..inst.wc { let index = self.next()?; log::trace!("\t\t\tlooking up type {:?}", lexp.type_id); let type_lookup = self.lookup_type.lookup(lexp.type_id)?; let type_id = match ctx.module.types[type_lookup.handle].inner { crate::TypeInner::Struct { .. } => { self.lookup_member .get(&(type_lookup.handle, index)) .ok_or(Error::InvalidAccessType(lexp.type_id))? .type_id } crate::TypeInner::Array { .. } | crate::TypeInner::Vector { .. } | crate::TypeInner::Matrix { .. } => type_lookup .base_id .ok_or(Error::InvalidAccessType(lexp.type_id))?, ref other => { log::warn!("composite type {other:?}"); return Err(Error::UnsupportedType(type_lookup.handle)); } }; lexp = LookupExpression { handle: ctx.expressions.append( crate::Expression::AccessIndex { base: lexp.handle, index, }, span, ), type_id, block_id, }; } self.lookup_expression.insert( result_id, LookupExpression { handle: lexp.handle, type_id: result_type_id, block_id, }, ); } Op::CompositeInsert => { inst.expect_at_least(5)?; let result_type_id = self.next()?; let id = self.next()?; let object_id = self.next()?; let composite_id = self.next()?; let mut selections = Vec::with_capacity(inst.wc as usize - 5); for _ in 5..inst.wc { selections.push(self.next()?); } let object_lexp = self.lookup_expression.lookup(object_id)?.clone(); let object_handle = get_expr_handle!(object_id, &object_lexp); let root_lexp = self.lookup_expression.lookup(composite_id)?.clone(); let root_handle = get_expr_handle!(composite_id, &root_lexp); let handle = self.insert_composite( root_handle, result_type_id, object_handle, &selections, &ctx.module.types, ctx.expressions, span, )?; self.lookup_expression.insert( id, LookupExpression { handle, type_id: result_type_id, block_id, }, ); } Op::CompositeConstruct => { inst.expect_at_least(3)?; let result_type_id = self.next()?; let id = self.next()?; let mut components = Vec::with_capacity(inst.wc as usize - 2); for _ in 3..inst.wc { let comp_id = self.next()?; log::trace!("\t\t\tlooking up expr {comp_id:?}"); let lexp = self.lookup_expression.lookup(comp_id)?; let handle = get_expr_handle!(comp_id, lexp); components.push(handle); } let ty = self.lookup_type.lookup(result_type_id)?.handle; let first = components[0]; let expr = match ctx.module.types[ty].inner { // this is an optimization to detect the splat crate::TypeInner::Vector { size, .. } if components.len() == size as usize && components[1..].iter().all(|&c| c == first) => { crate::Expression::Splat { size, value: first } } _ => crate::Expression::Compose { ty, components }, }; self.lookup_expression.insert( id, LookupExpression { handle: ctx.expressions.append(expr, span), type_id: result_type_id, block_id, }, ); } Op::Load => { inst.expect_at_least(4)?; let result_type_id = self.next()?; let result_id = self.next()?; let pointer_id = self.next()?; if inst.wc != 4 { inst.expect(5)?; let _memory_access = self.next()?; } let base_lexp = self.lookup_expression.lookup(pointer_id)?; let base_handle = get_expr_handle!(pointer_id, base_lexp); let type_lookup = self.lookup_type.lookup(base_lexp.type_id)?; let handle = match ctx.module.types[type_lookup.handle].inner { crate::TypeInner::Image { .. } | crate::TypeInner::Sampler { .. } => { base_handle } _ => match self.lookup_load_override.get(&pointer_id) { Some(&LookupLoadOverride::Loaded(handle)) => handle, //Note: we aren't handling `LookupLoadOverride::Pending` properly here _ => ctx.expressions.append( crate::Expression::Load { pointer: base_handle, }, span, ), }, }; self.lookup_expression.insert( result_id, LookupExpression { handle, type_id: result_type_id, block_id, }, ); } Op::Store => { inst.expect_at_least(3)?; let pointer_id = self.next()?; let value_id = self.next()?; if inst.wc != 3 { inst.expect(4)?; let _memory_access = self.next()?; } let base_expr = self.lookup_expression.lookup(pointer_id)?; let base_handle = get_expr_handle!(pointer_id, base_expr); let value_expr = self.lookup_expression.lookup(value_id)?; let value_handle = get_expr_handle!(value_id, value_expr); block.extend(emitter.finish(ctx.expressions)); block.push( crate::Statement::Store { pointer: base_handle, value: value_handle, }, span, ); emitter.start(ctx.expressions); } // Arithmetic Instructions +, -, *, /, % Op::SNegate | Op::FNegate => { inst.expect(4)?; self.parse_expr_unary_op_sign_adjusted( ctx, &mut emitter, &mut block, block_id, body_idx, crate::UnaryOperator::Negate, )?; } Op::IAdd | Op::ISub | Op::IMul | Op::BitwiseOr | Op::BitwiseXor | Op::BitwiseAnd | Op::SDiv | Op::SRem => { inst.expect(5)?; let operator = map_binary_operator(inst.op)?; self.parse_expr_binary_op_sign_adjusted( ctx, &mut emitter, &mut block, block_id, body_idx, operator, SignAnchor::Result, )?; } Op::IEqual | Op::INotEqual => { inst.expect(5)?; let operator = map_binary_operator(inst.op)?; self.parse_expr_binary_op_sign_adjusted( ctx, &mut emitter, &mut block, block_id, body_idx, operator, SignAnchor::Operand, )?; } Op::FAdd => { inst.expect(5)?; parse_expr_op!(crate::BinaryOperator::Add, BINARY)?; } Op::FSub => { inst.expect(5)?; parse_expr_op!(crate::BinaryOperator::Subtract, BINARY)?; } Op::FMul => { inst.expect(5)?; parse_expr_op!(crate::BinaryOperator::Multiply, BINARY)?; } Op::UDiv | Op::FDiv => { inst.expect(5)?; parse_expr_op!(crate::BinaryOperator::Divide, BINARY)?; } Op::UMod | Op::FRem => { inst.expect(5)?; parse_expr_op!(crate::BinaryOperator::Modulo, BINARY)?; } Op::SMod => { inst.expect(5)?; // x - y * int(floor(float(x) / float(y))) let start = self.data_offset; let result_type_id = self.next()?; let result_id = self.next()?; let p1_id = self.next()?; let p2_id = self.next()?; let span = self.span_from_with_op(start); let p1_lexp = self.lookup_expression.lookup(p1_id)?; let left = self.get_expr_handle( p1_id, p1_lexp, ctx, &mut emitter, &mut block, body_idx, ); let p2_lexp = self.lookup_expression.lookup(p2_id)?; let right = self.get_expr_handle( p2_id, p2_lexp, ctx, &mut emitter, &mut block, body_idx, ); let result_ty = self.lookup_type.lookup(result_type_id)?; let inner = &ctx.module.types[result_ty.handle].inner; let kind = inner.scalar_kind().unwrap(); let size = inner.size(ctx.gctx()) as u8; let left_cast = ctx.expressions.append( crate::Expression::As { expr: left, kind: crate::ScalarKind::Float, convert: Some(size), }, span, ); let right_cast = ctx.expressions.append( crate::Expression::As { expr: right, kind: crate::ScalarKind::Float, convert: Some(size), }, span, ); let div = ctx.expressions.append( crate::Expression::Binary { op: crate::BinaryOperator::Divide, left: left_cast, right: right_cast, }, span, ); let floor = ctx.expressions.append( crate::Expression::Math { fun: crate::MathFunction::Floor, arg: div, arg1: None, arg2: None, arg3: None, }, span, ); let cast = ctx.expressions.append( crate::Expression::As { expr: floor, kind, convert: Some(size), }, span, ); let mult = ctx.expressions.append( crate::Expression::Binary { op: crate::BinaryOperator::Multiply, left: cast, right, }, span, ); let sub = ctx.expressions.append( crate::Expression::Binary { op: crate::BinaryOperator::Subtract, left, right: mult, }, span, ); self.lookup_expression.insert( result_id, LookupExpression { handle: sub, type_id: result_type_id, block_id, }, ); } Op::FMod => { inst.expect(5)?; // x - y * floor(x / y) let start = self.data_offset; let span = self.span_from_with_op(start); let result_type_id = self.next()?; let result_id = self.next()?; let p1_id = self.next()?; let p2_id = self.next()?; let p1_lexp = self.lookup_expression.lookup(p1_id)?; let left = self.get_expr_handle( p1_id, p1_lexp, ctx, &mut emitter, &mut block, body_idx, ); let p2_lexp = self.lookup_expression.lookup(p2_id)?; let right = self.get_expr_handle( p2_id, p2_lexp, ctx, &mut emitter, &mut block, body_idx, ); let div = ctx.expressions.append( crate::Expression::Binary { op: crate::BinaryOperator::Divide, left, right, }, span, ); let floor = ctx.expressions.append( crate::Expression::Math { fun: crate::MathFunction::Floor, arg: div, arg1: None, arg2: None, arg3: None, }, span, ); let mult = ctx.expressions.append( crate::Expression::Binary { op: crate::BinaryOperator::Multiply, left: floor, right, }, span, ); let sub = ctx.expressions.append( crate::Expression::Binary { op: crate::BinaryOperator::Subtract, left, right: mult, }, span, ); self.lookup_expression.insert( result_id, LookupExpression { handle: sub, type_id: result_type_id, block_id, }, ); } Op::VectorTimesScalar | Op::VectorTimesMatrix | Op::MatrixTimesScalar | Op::MatrixTimesVector | Op::MatrixTimesMatrix => { inst.expect(5)?; parse_expr_op!(crate::BinaryOperator::Multiply, BINARY)?; } Op::Transpose => { inst.expect(4)?; let result_type_id = self.next()?; let result_id = self.next()?; let matrix_id = self.next()?; let matrix_lexp = self.lookup_expression.lookup(matrix_id)?; let matrix_handle = get_expr_handle!(matrix_id, matrix_lexp); let expr = crate::Expression::Math { fun: crate::MathFunction::Transpose, arg: matrix_handle, arg1: None, arg2: None, arg3: None, }; self.lookup_expression.insert( result_id, LookupExpression { handle: ctx.expressions.append(expr, span), type_id: result_type_id, block_id, }, ); } Op::Dot => { inst.expect(5)?; let result_type_id = self.next()?; let result_id = self.next()?; let left_id = self.next()?; let right_id = self.next()?; let left_lexp = self.lookup_expression.lookup(left_id)?; let left_handle = get_expr_handle!(left_id, left_lexp); let right_lexp = self.lookup_expression.lookup(right_id)?; let right_handle = get_expr_handle!(right_id, right_lexp); let expr = crate::Expression::Math { fun: crate::MathFunction::Dot, arg: left_handle, arg1: Some(right_handle), arg2: None, arg3: None, }; self.lookup_expression.insert( result_id, LookupExpression { handle: ctx.expressions.append(expr, span), type_id: result_type_id, block_id, }, ); } Op::BitFieldInsert => { inst.expect(7)?; let start = self.data_offset; let span = self.span_from_with_op(start); let result_type_id = self.next()?; let result_id = self.next()?; let base_id = self.next()?; let insert_id = self.next()?; let offset_id = self.next()?; let count_id = self.next()?; let base_lexp = self.lookup_expression.lookup(base_id)?; let base_handle = get_expr_handle!(base_id, base_lexp); let insert_lexp = self.lookup_expression.lookup(insert_id)?; let insert_handle = get_expr_handle!(insert_id, insert_lexp); let offset_lexp = self.lookup_expression.lookup(offset_id)?; let offset_handle = get_expr_handle!(offset_id, offset_lexp); let offset_lookup_ty = self.lookup_type.lookup(offset_lexp.type_id)?; let count_lexp = self.lookup_expression.lookup(count_id)?; let count_handle = get_expr_handle!(count_id, count_lexp); let count_lookup_ty = self.lookup_type.lookup(count_lexp.type_id)?; let offset_kind = ctx.module.types[offset_lookup_ty.handle] .inner .scalar_kind() .unwrap(); let count_kind = ctx.module.types[count_lookup_ty.handle] .inner .scalar_kind() .unwrap(); let offset_cast_handle = if offset_kind != crate::ScalarKind::Uint { ctx.expressions.append( crate::Expression::As { expr: offset_handle, kind: crate::ScalarKind::Uint, convert: None, }, span, ) } else { offset_handle }; let count_cast_handle = if count_kind != crate::ScalarKind::Uint { ctx.expressions.append( crate::Expression::As { expr: count_handle, kind: crate::ScalarKind::Uint, convert: None, }, span, ) } else { count_handle }; let expr = crate::Expression::Math { fun: crate::MathFunction::InsertBits, arg: base_handle, arg1: Some(insert_handle), arg2: Some(offset_cast_handle), arg3: Some(count_cast_handle), }; self.lookup_expression.insert( result_id, LookupExpression { handle: ctx.expressions.append(expr, span), type_id: result_type_id, block_id, }, ); } Op::BitFieldSExtract | Op::BitFieldUExtract => { inst.expect(6)?; let result_type_id = self.next()?; let result_id = self.next()?; let base_id = self.next()?; let offset_id = self.next()?; let count_id = self.next()?; let base_lexp = self.lookup_expression.lookup(base_id)?; let base_handle = get_expr_handle!(base_id, base_lexp); let offset_lexp = self.lookup_expression.lookup(offset_id)?; let offset_handle = get_expr_handle!(offset_id, offset_lexp); let offset_lookup_ty = self.lookup_type.lookup(offset_lexp.type_id)?; let count_lexp = self.lookup_expression.lookup(count_id)?; let count_handle = get_expr_handle!(count_id, count_lexp); let count_lookup_ty = self.lookup_type.lookup(count_lexp.type_id)?; let offset_kind = ctx.module.types[offset_lookup_ty.handle] .inner .scalar_kind() .unwrap(); let count_kind = ctx.module.types[count_lookup_ty.handle] .inner .scalar_kind() .unwrap(); let offset_cast_handle = if offset_kind != crate::ScalarKind::Uint { ctx.expressions.append( crate::Expression::As { expr: offset_handle, kind: crate::ScalarKind::Uint, convert: None, }, span, ) } else { offset_handle }; let count_cast_handle = if count_kind != crate::ScalarKind::Uint { ctx.expressions.append( crate::Expression::As { expr: count_handle, kind: crate::ScalarKind::Uint, convert: None, }, span, ) } else { count_handle }; let expr = crate::Expression::Math { fun: crate::MathFunction::ExtractBits, arg: base_handle, arg1: Some(offset_cast_handle), arg2: Some(count_cast_handle), arg3: None, }; self.lookup_expression.insert( result_id, LookupExpression { handle: ctx.expressions.append(expr, span), type_id: result_type_id, block_id, }, ); } Op::BitReverse | Op::BitCount => { inst.expect(4)?; let result_type_id = self.next()?; let result_id = self.next()?; let base_id = self.next()?; let base_lexp = self.lookup_expression.lookup(base_id)?; let base_handle = get_expr_handle!(base_id, base_lexp); let expr = crate::Expression::Math { fun: match inst.op { Op::BitReverse => crate::MathFunction::ReverseBits, Op::BitCount => crate::MathFunction::CountOneBits, _ => unreachable!(), }, arg: base_handle, arg1: None, arg2: None, arg3: None, }; self.lookup_expression.insert( result_id, LookupExpression { handle: ctx.expressions.append(expr, span), type_id: result_type_id, block_id, }, ); } Op::OuterProduct => { inst.expect(5)?; let result_type_id = self.next()?; let result_id = self.next()?; let left_id = self.next()?; let right_id = self.next()?; let left_lexp = self.lookup_expression.lookup(left_id)?; let left_handle = get_expr_handle!(left_id, left_lexp); let right_lexp = self.lookup_expression.lookup(right_id)?; let right_handle = get_expr_handle!(right_id, right_lexp); let expr = crate::Expression::Math { fun: crate::MathFunction::Outer, arg: left_handle, arg1: Some(right_handle), arg2: None, arg3: None, }; self.lookup_expression.insert( result_id, LookupExpression { handle: ctx.expressions.append(expr, span), type_id: result_type_id, block_id, }, ); } // Bitwise instructions Op::Not => { inst.expect(4)?; self.parse_expr_unary_op_sign_adjusted( ctx, &mut emitter, &mut block, block_id, body_idx, crate::UnaryOperator::BitwiseNot, )?; } Op::ShiftRightLogical => { inst.expect(5)?; //TODO: convert input and result to unsigned parse_expr_op!(crate::BinaryOperator::ShiftRight, SHIFT)?; } Op::ShiftRightArithmetic => { inst.expect(5)?; //TODO: convert input and result to signed parse_expr_op!(crate::BinaryOperator::ShiftRight, SHIFT)?; } Op::ShiftLeftLogical => { inst.expect(5)?; parse_expr_op!(crate::BinaryOperator::ShiftLeft, SHIFT)?; } // Sampling Op::Image => { inst.expect(4)?; self.parse_image_uncouple(block_id)?; } Op::SampledImage => { inst.expect(5)?; self.parse_image_couple()?; } Op::ImageWrite => { let extra = inst.expect_at_least(4)?; let stmt = self.parse_image_write(extra, ctx, &mut emitter, &mut block, body_idx)?; block.extend(emitter.finish(ctx.expressions)); block.push(stmt, span); emitter.start(ctx.expressions); } Op::ImageFetch | Op::ImageRead => { let extra = inst.expect_at_least(5)?; self.parse_image_load( extra, ctx, &mut emitter, &mut block, block_id, body_idx, )?; } Op::ImageSampleImplicitLod | Op::ImageSampleExplicitLod => { let extra = inst.expect_at_least(5)?; let options = image::SamplingOptions { compare: false, project: false, gather: false, }; self.parse_image_sample( extra, options, ctx, &mut emitter, &mut block, block_id, body_idx, )?; } Op::ImageSampleProjImplicitLod | Op::ImageSampleProjExplicitLod => { let extra = inst.expect_at_least(5)?; let options = image::SamplingOptions { compare: false, project: true, gather: false, }; self.parse_image_sample( extra, options, ctx, &mut emitter, &mut block, block_id, body_idx, )?; } Op::ImageSampleDrefImplicitLod | Op::ImageSampleDrefExplicitLod => { let extra = inst.expect_at_least(6)?; let options = image::SamplingOptions { compare: true, project: false, gather: false, }; self.parse_image_sample( extra, options, ctx, &mut emitter, &mut block, block_id, body_idx, )?; } Op::ImageSampleProjDrefImplicitLod | Op::ImageSampleProjDrefExplicitLod => { let extra = inst.expect_at_least(6)?; let options = image::SamplingOptions { compare: true, project: true, gather: false, }; self.parse_image_sample( extra, options, ctx, &mut emitter, &mut block, block_id, body_idx, )?; } Op::ImageGather => { let extra = inst.expect_at_least(6)?; let options = image::SamplingOptions { compare: false, project: false, gather: true, }; self.parse_image_sample( extra, options, ctx, &mut emitter, &mut block, block_id, body_idx, )?; } Op::ImageDrefGather => { let extra = inst.expect_at_least(6)?; let options = image::SamplingOptions { compare: true, project: false, gather: true, }; self.parse_image_sample( extra, options, ctx, &mut emitter, &mut block, block_id, body_idx, )?; } Op::ImageQuerySize => { inst.expect(4)?; self.parse_image_query_size( false, ctx, &mut emitter, &mut block, block_id, body_idx, )?; } Op::ImageQuerySizeLod => { inst.expect(5)?; self.parse_image_query_size( true, ctx, &mut emitter, &mut block, block_id, body_idx, )?; } Op::ImageQueryLevels => { inst.expect(4)?; self.parse_image_query_other(crate::ImageQuery::NumLevels, ctx, block_id)?; } Op::ImageQuerySamples => { inst.expect(4)?; self.parse_image_query_other(crate::ImageQuery::NumSamples, ctx, block_id)?; } // other ops Op::Select => { inst.expect(6)?; let result_type_id = self.next()?; let result_id = self.next()?; let condition = self.next()?; let o1_id = self.next()?; let o2_id = self.next()?; let cond_lexp = self.lookup_expression.lookup(condition)?; let cond_handle = get_expr_handle!(condition, cond_lexp); let o1_lexp = self.lookup_expression.lookup(o1_id)?; let o1_handle = get_expr_handle!(o1_id, o1_lexp); let o2_lexp = self.lookup_expression.lookup(o2_id)?; let o2_handle = get_expr_handle!(o2_id, o2_lexp); let expr = crate::Expression::Select { condition: cond_handle, accept: o1_handle, reject: o2_handle, }; self.lookup_expression.insert( result_id, LookupExpression { handle: ctx.expressions.append(expr, span), type_id: result_type_id, block_id, }, ); } Op::VectorShuffle => { inst.expect_at_least(5)?; let result_type_id = self.next()?; let result_id = self.next()?; let v1_id = self.next()?; let v2_id = self.next()?; let v1_lexp = self.lookup_expression.lookup(v1_id)?; let v1_lty = self.lookup_type.lookup(v1_lexp.type_id)?; let v1_handle = get_expr_handle!(v1_id, v1_lexp); let n1 = match ctx.module.types[v1_lty.handle].inner { crate::TypeInner::Vector { size, .. } => size as u32, _ => return Err(Error::InvalidInnerType(v1_lexp.type_id)), }; let v2_lexp = self.lookup_expression.lookup(v2_id)?; let v2_lty = self.lookup_type.lookup(v2_lexp.type_id)?; let v2_handle = get_expr_handle!(v2_id, v2_lexp); let n2 = match ctx.module.types[v2_lty.handle].inner { crate::TypeInner::Vector { size, .. } => size as u32, _ => return Err(Error::InvalidInnerType(v2_lexp.type_id)), }; self.temp_bytes.clear(); let mut max_component = 0; for _ in 5..inst.wc as usize { let mut index = self.next()?; if index == u32::MAX { // treat Undefined as X index = 0; } max_component = max_component.max(index); self.temp_bytes.push(index as u8); } // Check for swizzle first. let expr = if max_component < n1 { use crate::SwizzleComponent as Sc; let size = match self.temp_bytes.len() { 2 => crate::VectorSize::Bi, 3 => crate::VectorSize::Tri, _ => crate::VectorSize::Quad, }; let mut pattern = [Sc::X; 4]; for (pat, index) in pattern.iter_mut().zip(self.temp_bytes.drain(..)) { *pat = match index { 0 => Sc::X, 1 => Sc::Y, 2 => Sc::Z, _ => Sc::W, }; } crate::Expression::Swizzle { size, vector: v1_handle, pattern, } } else { // Fall back to access + compose let mut components = Vec::with_capacity(self.temp_bytes.len()); for index in self.temp_bytes.drain(..).map(|i| i as u32) { let expr = if index < n1 { crate::Expression::AccessIndex { base: v1_handle, index, } } else if index < n1 + n2 { crate::Expression::AccessIndex { base: v2_handle, index: index - n1, } } else { return Err(Error::InvalidAccessIndex(index)); }; components.push(ctx.expressions.append(expr, span)); } crate::Expression::Compose { ty: self.lookup_type.lookup(result_type_id)?.handle, components, } }; self.lookup_expression.insert( result_id, LookupExpression { handle: ctx.expressions.append(expr, span), type_id: result_type_id, block_id, }, ); } Op::Bitcast | Op::ConvertSToF | Op::ConvertUToF | Op::ConvertFToU | Op::ConvertFToS | Op::FConvert | Op::UConvert | Op::SConvert => { inst.expect(4)?; let result_type_id = self.next()?; let result_id = self.next()?; let value_id = self.next()?; let value_lexp = self.lookup_expression.lookup(value_id)?; let ty_lookup = self.lookup_type.lookup(result_type_id)?; let scalar = match ctx.module.types[ty_lookup.handle].inner { crate::TypeInner::Scalar(scalar) | crate::TypeInner::Vector { scalar, .. } | crate::TypeInner::Matrix { scalar, .. } => scalar, _ => return Err(Error::InvalidAsType(ty_lookup.handle)), }; let expr = crate::Expression::As { expr: get_expr_handle!(value_id, value_lexp), kind: scalar.kind, convert: if scalar.kind == crate::ScalarKind::Bool { Some(crate::BOOL_WIDTH) } else if inst.op == Op::Bitcast { None } else { Some(scalar.width) }, }; self.lookup_expression.insert( result_id, LookupExpression { handle: ctx.expressions.append(expr, span), type_id: result_type_id, block_id, }, ); } Op::FunctionCall => { inst.expect_at_least(4)?; let result_type_id = self.next()?; let result_id = self.next()?; let func_id = self.next()?; let mut arguments = Vec::with_capacity(inst.wc as usize - 4); for _ in 0..arguments.capacity() { let arg_id = self.next()?; let lexp = self.lookup_expression.lookup(arg_id)?; arguments.push(get_expr_handle!(arg_id, lexp)); } block.extend(emitter.finish(ctx.expressions)); // We just need an unique handle here, nothing more. let function = self.add_call(ctx.function_id, func_id); let result = if self.lookup_void_type == Some(result_type_id) { None } else { let expr_handle = ctx .expressions .append(crate::Expression::CallResult(function), span); self.lookup_expression.insert( result_id, LookupExpression { handle: expr_handle, type_id: result_type_id, block_id, }, ); Some(expr_handle) }; block.push( crate::Statement::Call { function, arguments, result, }, span, ); emitter.start(ctx.expressions); } Op::ExtInst => { use crate::MathFunction as Mf; use spirv::GlslStd450Op as Glo; let base_wc = 5; inst.expect_at_least(base_wc)?; let result_type_id = self.next()?; let result_id = self.next()?; let set_id = self.next()?; if Some(set_id) == self.ext_non_semantic_id { for _ in 0..inst.wc - 4 { self.next()?; } continue; } else if Some(set_id) != self.ext_glsl_id { return Err(Error::UnsupportedExtInstSet(set_id)); } let inst_id = self.next()?; let gl_op = Glo::from_u32(inst_id).ok_or(Error::UnsupportedExtInst(inst_id))?; let fun = match gl_op { Glo::Round => Mf::Round, Glo::RoundEven => Mf::Round, Glo::Trunc => Mf::Trunc, Glo::FAbs | Glo::SAbs => Mf::Abs, Glo::FSign | Glo::SSign => Mf::Sign, Glo::Floor => Mf::Floor, Glo::Ceil => Mf::Ceil, Glo::Fract => Mf::Fract, Glo::Sin => Mf::Sin, Glo::Cos => Mf::Cos, Glo::Tan => Mf::Tan, Glo::Asin => Mf::Asin, Glo::Acos => Mf::Acos, Glo::Atan => Mf::Atan, Glo::Sinh => Mf::Sinh, Glo::Cosh => Mf::Cosh, Glo::Tanh => Mf::Tanh, Glo::Atan2 => Mf::Atan2, Glo::Asinh => Mf::Asinh, Glo::Acosh => Mf::Acosh, Glo::Atanh => Mf::Atanh, Glo::Radians => Mf::Radians, Glo::Degrees => Mf::Degrees, Glo::Pow => Mf::Pow, Glo::Exp => Mf::Exp, Glo::Log => Mf::Log, Glo::Exp2 => Mf::Exp2, Glo::Log2 => Mf::Log2, Glo::Sqrt => Mf::Sqrt, Glo::InverseSqrt => Mf::InverseSqrt, Glo::MatrixInverse => Mf::Inverse, Glo::Determinant => Mf::Determinant, Glo::ModfStruct => Mf::Modf, Glo::FMin | Glo::UMin | Glo::SMin | Glo::NMin => Mf::Min, Glo::FMax | Glo::UMax | Glo::SMax | Glo::NMax => Mf::Max, Glo::FClamp | Glo::UClamp | Glo::SClamp | Glo::NClamp => Mf::Clamp, Glo::FMix => Mf::Mix, Glo::Step => Mf::Step, Glo::SmoothStep => Mf::SmoothStep, Glo::Fma => Mf::Fma, Glo::FrexpStruct => Mf::Frexp, Glo::Ldexp => Mf::Ldexp, Glo::Length => Mf::Length, Glo::Distance => Mf::Distance, Glo::Cross => Mf::Cross, Glo::Normalize => Mf::Normalize, Glo::FaceForward => Mf::FaceForward, Glo::Reflect => Mf::Reflect, Glo::Refract => Mf::Refract, Glo::PackUnorm4x8 => Mf::Pack4x8unorm, Glo::PackSnorm4x8 => Mf::Pack4x8snorm, Glo::PackHalf2x16 => Mf::Pack2x16float, Glo::PackUnorm2x16 => Mf::Pack2x16unorm, Glo::PackSnorm2x16 => Mf::Pack2x16snorm, Glo::UnpackUnorm4x8 => Mf::Unpack4x8unorm, Glo::UnpackSnorm4x8 => Mf::Unpack4x8snorm, Glo::UnpackHalf2x16 => Mf::Unpack2x16float, Glo::UnpackUnorm2x16 => Mf::Unpack2x16unorm, Glo::UnpackSnorm2x16 => Mf::Unpack2x16snorm, Glo::FindILsb => Mf::FirstTrailingBit, Glo::FindUMsb | Glo::FindSMsb => Mf::FirstLeadingBit, // TODO: https://github.com/gfx-rs/naga/issues/2526 Glo::Modf | Glo::Frexp => return Err(Error::UnsupportedExtInst(inst_id)), Glo::IMix | Glo::PackDouble2x32 | Glo::UnpackDouble2x32 | Glo::InterpolateAtCentroid | Glo::InterpolateAtSample | Glo::InterpolateAtOffset => { return Err(Error::UnsupportedExtInst(inst_id)) } }; let arg_count = fun.argument_count(); inst.expect(base_wc + arg_count as u16)?; let arg = { let arg_id = self.next()?; let lexp = self.lookup_expression.lookup(arg_id)?; get_expr_handle!(arg_id, lexp) }; let arg1 = if arg_count > 1 { let arg_id = self.next()?; let lexp = self.lookup_expression.lookup(arg_id)?; Some(get_expr_handle!(arg_id, lexp)) } else { None }; let arg2 = if arg_count > 2 { let arg_id = self.next()?; let lexp = self.lookup_expression.lookup(arg_id)?; Some(get_expr_handle!(arg_id, lexp)) } else { None }; let arg3 = if arg_count > 3 { let arg_id = self.next()?; let lexp = self.lookup_expression.lookup(arg_id)?; Some(get_expr_handle!(arg_id, lexp)) } else { None }; let expr = crate::Expression::Math { fun, arg, arg1, arg2, arg3, }; self.lookup_expression.insert( result_id, LookupExpression { handle: ctx.expressions.append(expr, span), type_id: result_type_id, block_id, }, ); } // Relational and Logical Instructions Op::LogicalNot => { inst.expect(4)?; parse_expr_op!(crate::UnaryOperator::LogicalNot, UNARY)?; } Op::LogicalOr => { inst.expect(5)?; parse_expr_op!(crate::BinaryOperator::LogicalOr, BINARY)?; } Op::LogicalAnd => { inst.expect(5)?; parse_expr_op!(crate::BinaryOperator::LogicalAnd, BINARY)?; } Op::SGreaterThan | Op::SGreaterThanEqual | Op::SLessThan | Op::SLessThanEqual => { inst.expect(5)?; self.parse_expr_int_comparison( ctx, &mut emitter, &mut block, block_id, body_idx, map_binary_operator(inst.op)?, crate::ScalarKind::Sint, )?; } Op::UGreaterThan | Op::UGreaterThanEqual | Op::ULessThan | Op::ULessThanEqual => { inst.expect(5)?; self.parse_expr_int_comparison( ctx, &mut emitter, &mut block, block_id, body_idx, map_binary_operator(inst.op)?, crate::ScalarKind::Uint, )?; } Op::FOrdEqual | Op::FUnordEqual | Op::FOrdNotEqual | Op::FUnordNotEqual | Op::FOrdLessThan | Op::FUnordLessThan | Op::FOrdGreaterThan | Op::FUnordGreaterThan | Op::FOrdLessThanEqual | Op::FUnordLessThanEqual | Op::FOrdGreaterThanEqual | Op::FUnordGreaterThanEqual | Op::LogicalEqual | Op::LogicalNotEqual => { inst.expect(5)?; let operator = map_binary_operator(inst.op)?; parse_expr_op!(operator, BINARY)?; } Op::Any | Op::All | Op::IsNan | Op::IsInf | Op::IsFinite | Op::IsNormal => { inst.expect(4)?; let result_type_id = self.next()?; let result_id = self.next()?; let arg_id = self.next()?; let arg_lexp = self.lookup_expression.lookup(arg_id)?; let arg_handle = get_expr_handle!(arg_id, arg_lexp); let expr = crate::Expression::Relational { fun: map_relational_fun(inst.op)?, argument: arg_handle, }; self.lookup_expression.insert( result_id, LookupExpression { handle: ctx.expressions.append(expr, span), type_id: result_type_id, block_id, }, ); } Op::Kill => { inst.expect(1)?; break Some(crate::Statement::Kill); } Op::Unreachable => { inst.expect(1)?; break None; } Op::Return => { inst.expect(1)?; break Some(crate::Statement::Return { value: None }); } Op::ReturnValue => { inst.expect(2)?; let value_id = self.next()?; let value_lexp = self.lookup_expression.lookup(value_id)?; let value_handle = get_expr_handle!(value_id, value_lexp); break Some(crate::Statement::Return { value: Some(value_handle), }); } Op::Branch => { inst.expect(2)?; let target_id = self.next()?; // If this is a branch to a merge or continue block, then // that ends the current body. // // Why can we count on finding an entry here when it's // needed? SPIR-V requires dominators to appear before // blocks they dominate, so we will have visited a // structured control construct's header block before // anything that could exit it. if let Some(info) = ctx.mergers.get(&target_id) { block.extend(emitter.finish(ctx.expressions)); ctx.blocks.insert(block_id, block); let body = &mut ctx.bodies[body_idx]; body.data.push(BodyFragment::BlockId(block_id)); merger(body, info); return Ok(()); } // If `target_id` has no entry in `ctx.body_for_label`, then // this must be the only branch to it: // // - We've already established that it's not anybody's merge // block. // // - It can't be a switch case. Only switch header blocks // and other switch cases can branch to a switch case. // Switch header blocks must dominate all their cases, so // they must appear in the file before them, and when we // see `Op::Switch` we populate `ctx.body_for_label` for // every switch case. // // Thus, `target_id` must be a simple extension of the // current block, which we dominate, so we know we'll // encounter it later in the file. ctx.body_for_label.entry(target_id).or_insert(body_idx); break None; } Op::BranchConditional => { inst.expect_at_least(4)?; let condition = { let condition_id = self.next()?; let lexp = self.lookup_expression.lookup(condition_id)?; get_expr_handle!(condition_id, lexp) }; // HACK(eddyb) Naga doesn't seem to have this helper, // so it's declared on the fly here for convenience. #[derive(Copy, Clone)] struct BranchTarget { label_id: spirv::Word, merge_info: Option, } let branch_target = |label_id| BranchTarget { label_id, merge_info: ctx.mergers.get(&label_id).copied(), }; let true_target = branch_target(self.next()?); let false_target = branch_target(self.next()?); // Consume branch weights for _ in 4..inst.wc { let _ = self.next()?; } // Handle `OpBranchConditional`s used at the end of a loop // body's "continuing" section as a "conditional backedge", // i.e. a `do`-`while` condition, or `break if` in WGSL. // HACK(eddyb) this has to go to the parent *twice*, because // `OpLoopMerge` left the "continuing" section nested in the // loop body in terms of `parent`, but not `BodyFragment`. let parent_body_idx = ctx.bodies[body_idx].parent; let parent_parent_body_idx = ctx.bodies[parent_body_idx].parent; match ctx.bodies[parent_parent_body_idx].data[..] { // The `OpLoopMerge`'s `continuing` block and the loop's // backedge block may not be the same, but they'll both // belong to the same body. [.., BodyFragment::Loop { body: loop_body_idx, continuing: loop_continuing_idx, break_if: ref mut break_if_slot @ None, }] if body_idx == loop_continuing_idx => { // Try both orderings of break-vs-backedge, because // SPIR-V is symmetrical here, unlike WGSL `break if`. let break_if_cond = [true, false].into_iter().find_map(|true_breaks| { let (break_candidate, backedge_candidate) = if true_breaks { (true_target, false_target) } else { (false_target, true_target) }; if break_candidate.merge_info != Some(MergeBlockInformation::LoopMerge) { return None; } // HACK(eddyb) since Naga doesn't explicitly track // backedges, this is checking for the outcome of // `OpLoopMerge` below (even if it looks weird). let backedge_candidate_is_backedge = backedge_candidate.merge_info.is_none() && ctx.body_for_label.get(&backedge_candidate.label_id) == Some(&loop_body_idx); if !backedge_candidate_is_backedge { return None; } Some(if true_breaks { condition } else { ctx.expressions.append( crate::Expression::Unary { op: crate::UnaryOperator::LogicalNot, expr: condition, }, span, ) }) }); if let Some(break_if_cond) = break_if_cond { *break_if_slot = Some(break_if_cond); // This `OpBranchConditional` ends the "continuing" // section of the loop body as normal, with the // `break if` condition having been stashed above. break None; } } _ => {} } block.extend(emitter.finish(ctx.expressions)); ctx.blocks.insert(block_id, block); let body = &mut ctx.bodies[body_idx]; body.data.push(BodyFragment::BlockId(block_id)); let same_target = true_target.label_id == false_target.label_id; // Start a body block for the `accept` branch. let accept = ctx.bodies.len(); let mut accept_block = Body::with_parent(body_idx); // If the `OpBranchConditional` target is somebody else's // merge or continue block, then put a `Break` or `Continue` // statement in this new body block. if let Some(info) = true_target.merge_info { merger( match same_target { true => &mut ctx.bodies[body_idx], false => &mut accept_block, }, &info, ) } else { // Note the body index for the block we're branching to. let prev = ctx.body_for_label.insert( true_target.label_id, match same_target { true => body_idx, false => accept, }, ); debug_assert!(prev.is_none()); } if same_target { return Ok(()); } ctx.bodies.push(accept_block); // Handle the `reject` branch just like the `accept` block. let reject = ctx.bodies.len(); let mut reject_block = Body::with_parent(body_idx); if let Some(info) = false_target.merge_info { merger(&mut reject_block, &info) } else { let prev = ctx.body_for_label.insert(false_target.label_id, reject); debug_assert!(prev.is_none()); } ctx.bodies.push(reject_block); let body = &mut ctx.bodies[body_idx]; body.data.push(BodyFragment::If { condition, accept, reject, }); return Ok(()); } Op::Switch => { inst.expect_at_least(3)?; let selector = self.next()?; let default_id = self.next()?; // If the previous instruction was a `OpSelectionMerge` then we must // promote the `MergeBlockInformation` to a `SwitchMerge` if let Some(merge) = selection_merge_block { ctx.mergers .insert(merge, MergeBlockInformation::SwitchMerge); } let default = ctx.bodies.len(); ctx.bodies.push(Body::with_parent(body_idx)); ctx.body_for_label.entry(default_id).or_insert(default); let selector_lexp = &self.lookup_expression[&selector]; let selector_lty = self.lookup_type.lookup(selector_lexp.type_id)?; let selector_handle = get_expr_handle!(selector, selector_lexp); let selector = match ctx.module.types[selector_lty.handle].inner { crate::TypeInner::Scalar(crate::Scalar { kind: crate::ScalarKind::Uint, width: _, }) => { // IR expects a signed integer, so do a bitcast ctx.expressions.append( crate::Expression::As { kind: crate::ScalarKind::Sint, expr: selector_handle, convert: None, }, span, ) } crate::TypeInner::Scalar(crate::Scalar { kind: crate::ScalarKind::Sint, width: _, }) => selector_handle, ref other => unimplemented!("Unexpected selector {:?}", other), }; // Clear past switch cases to prevent them from entering this one self.switch_cases.clear(); for _ in 0..(inst.wc - 3) / 2 { let literal = self.next()?; let target = self.next()?; let case_body_idx = ctx.bodies.len(); // Check if any previous case already used this target block id, if so // group them together to reorder them later so that no weird // fallthrough cases happen. if let Some(&mut (_, ref mut literals)) = self.switch_cases.get_mut(&target) { literals.push(literal as i32); continue; } let mut body = Body::with_parent(body_idx); if let Some(info) = ctx.mergers.get(&target) { merger(&mut body, info); } ctx.bodies.push(body); ctx.body_for_label.entry(target).or_insert(case_body_idx); // Register this target block id as already having been processed and // the respective body index assigned and the first case value self.switch_cases .insert(target, (case_body_idx, vec![literal as i32])); } // Loop through the collected target blocks creating a new case for each // literal pointing to it, only one case will have the true body and all the // others will be empty fallthrough so that they all execute the same body // without duplicating code. // // Since `switch_cases` is an indexmap the order of insertion is preserved // this is needed because spir-v defines fallthrough order in the switch // instruction. let mut cases = Vec::with_capacity((inst.wc as usize - 3) / 2); for &(case_body_idx, ref literals) in self.switch_cases.values() { let value = literals[0]; for &literal in literals.iter().skip(1) { let empty_body_idx = ctx.bodies.len(); let body = Body::with_parent(body_idx); ctx.bodies.push(body); cases.push((literal, empty_body_idx)); } cases.push((value, case_body_idx)); } block.extend(emitter.finish(ctx.expressions)); let body = &mut ctx.bodies[body_idx]; ctx.blocks.insert(block_id, block); // Make sure the vector has space for at least two more allocations body.data.reserve(2); body.data.push(BodyFragment::BlockId(block_id)); body.data.push(BodyFragment::Switch { selector, cases, default, }); return Ok(()); } Op::SelectionMerge => { inst.expect(3)?; let merge_block_id = self.next()?; // TODO: Selection Control Mask let _selection_control = self.next()?; // Indicate that the merge block is a continuation of the // current `Body`. ctx.body_for_label.entry(merge_block_id).or_insert(body_idx); // Let subsequent branches to the merge block know that // they've reached the end of the selection construct. ctx.mergers .insert(merge_block_id, MergeBlockInformation::SelectionMerge); selection_merge_block = Some(merge_block_id); } Op::LoopMerge => { inst.expect_at_least(4)?; let merge_block_id = self.next()?; let continuing = self.next()?; // TODO: Loop Control Parameters for _ in 0..inst.wc - 3 { self.next()?; } // Indicate that the merge block is a continuation of the // current `Body`. ctx.body_for_label.entry(merge_block_id).or_insert(body_idx); // Let subsequent branches to the merge block know that // they're `Break` statements. ctx.mergers .insert(merge_block_id, MergeBlockInformation::LoopMerge); let loop_body_idx = ctx.bodies.len(); ctx.bodies.push(Body::with_parent(body_idx)); let continue_idx = ctx.bodies.len(); // The continue block inherits the scope of the loop body ctx.bodies.push(Body::with_parent(loop_body_idx)); ctx.body_for_label.entry(continuing).or_insert(continue_idx); // Let subsequent branches to the continue block know that // they're `Continue` statements. ctx.mergers .insert(continuing, MergeBlockInformation::LoopContinue); // The loop header always belongs to the loop body ctx.body_for_label.insert(block_id, loop_body_idx); let parent_body = &mut ctx.bodies[body_idx]; parent_body.data.push(BodyFragment::Loop { body: loop_body_idx, continuing: continue_idx, break_if: None, }); body_idx = loop_body_idx; } Op::DPdxCoarse => { parse_expr_op!( crate::DerivativeAxis::X, crate::DerivativeControl::Coarse, DERIVATIVE )?; } Op::DPdyCoarse => { parse_expr_op!( crate::DerivativeAxis::Y, crate::DerivativeControl::Coarse, DERIVATIVE )?; } Op::FwidthCoarse => { parse_expr_op!( crate::DerivativeAxis::Width, crate::DerivativeControl::Coarse, DERIVATIVE )?; } Op::DPdxFine => { parse_expr_op!( crate::DerivativeAxis::X, crate::DerivativeControl::Fine, DERIVATIVE )?; } Op::DPdyFine => { parse_expr_op!( crate::DerivativeAxis::Y, crate::DerivativeControl::Fine, DERIVATIVE )?; } Op::FwidthFine => { parse_expr_op!( crate::DerivativeAxis::Width, crate::DerivativeControl::Fine, DERIVATIVE )?; } Op::DPdx => { parse_expr_op!( crate::DerivativeAxis::X, crate::DerivativeControl::None, DERIVATIVE )?; } Op::DPdy => { parse_expr_op!( crate::DerivativeAxis::Y, crate::DerivativeControl::None, DERIVATIVE )?; } Op::Fwidth => { parse_expr_op!( crate::DerivativeAxis::Width, crate::DerivativeControl::None, DERIVATIVE )?; } Op::ArrayLength => { inst.expect(5)?; let result_type_id = self.next()?; let result_id = self.next()?; let structure_id = self.next()?; let member_index = self.next()?; // We're assuming that the validation pass, if it's run, will catch if the // wrong types or parameters are supplied here. let structure_ptr = self.lookup_expression.lookup(structure_id)?; let structure_handle = get_expr_handle!(structure_id, structure_ptr); let member_ptr = ctx.expressions.append( crate::Expression::AccessIndex { base: structure_handle, index: member_index, }, span, ); let length = ctx .expressions .append(crate::Expression::ArrayLength(member_ptr), span); self.lookup_expression.insert( result_id, LookupExpression { handle: length, type_id: result_type_id, block_id, }, ); } Op::CopyMemory => { inst.expect_at_least(3)?; let target_id = self.next()?; let source_id = self.next()?; let _memory_access = if inst.wc != 3 { inst.expect(4)?; spirv::MemoryAccess::from_bits(self.next()?) .ok_or(Error::InvalidParameter(Op::CopyMemory))? } else { spirv::MemoryAccess::NONE }; // TODO: check if the source and target types are the same? let target = self.lookup_expression.lookup(target_id)?; let target_handle = get_expr_handle!(target_id, target); let source = self.lookup_expression.lookup(source_id)?; let source_handle = get_expr_handle!(source_id, source); // This operation is practically the same as loading and then storing, I think. let value_expr = ctx.expressions.append( crate::Expression::Load { pointer: source_handle, }, span, ); block.extend(emitter.finish(ctx.expressions)); block.push( crate::Statement::Store { pointer: target_handle, value: value_expr, }, span, ); emitter.start(ctx.expressions); } Op::ControlBarrier => { inst.expect(4)?; let exec_scope_id = self.next()?; let _mem_scope_raw = self.next()?; let semantics_id = self.next()?; let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?; let semantics_const = self.lookup_constant.lookup(semantics_id)?; let exec_scope = resolve_constant(ctx.gctx(), &exec_scope_const.inner) .ok_or(Error::InvalidBarrierScope(exec_scope_id))?; let semantics = resolve_constant(ctx.gctx(), &semantics_const.inner) .ok_or(Error::InvalidBarrierMemorySemantics(semantics_id))?; if exec_scope == spirv::Scope::Workgroup as u32 || exec_scope == spirv::Scope::Subgroup as u32 { let mut flags = crate::Barrier::empty(); flags.set( crate::Barrier::STORAGE, semantics & spirv::MemorySemantics::UNIFORM_MEMORY.bits() != 0, ); flags.set( crate::Barrier::WORK_GROUP, semantics & (spirv::MemorySemantics::WORKGROUP_MEMORY).bits() != 0, ); flags.set( crate::Barrier::SUB_GROUP, semantics & spirv::MemorySemantics::SUBGROUP_MEMORY.bits() != 0, ); flags.set( crate::Barrier::TEXTURE, semantics & spirv::MemorySemantics::IMAGE_MEMORY.bits() != 0, ); block.extend(emitter.finish(ctx.expressions)); block.push(crate::Statement::ControlBarrier(flags), span); emitter.start(ctx.expressions); } else { log::warn!("Unsupported barrier execution scope: {exec_scope}"); } } Op::MemoryBarrier => { inst.expect(3)?; let mem_scope_id = self.next()?; let semantics_id = self.next()?; let mem_scope_const = self.lookup_constant.lookup(mem_scope_id)?; let semantics_const = self.lookup_constant.lookup(semantics_id)?; let mem_scope = resolve_constant(ctx.gctx(), &mem_scope_const.inner) .ok_or(Error::InvalidBarrierScope(mem_scope_id))?; let semantics = resolve_constant(ctx.gctx(), &semantics_const.inner) .ok_or(Error::InvalidBarrierMemorySemantics(semantics_id))?; let mut flags = if mem_scope == spirv::Scope::Device as u32 { crate::Barrier::STORAGE } else if mem_scope == spirv::Scope::Workgroup as u32 { crate::Barrier::WORK_GROUP } else if mem_scope == spirv::Scope::Subgroup as u32 { crate::Barrier::SUB_GROUP } else { crate::Barrier::empty() }; flags.set( crate::Barrier::STORAGE, semantics & spirv::MemorySemantics::UNIFORM_MEMORY.bits() != 0, ); flags.set( crate::Barrier::WORK_GROUP, semantics & (spirv::MemorySemantics::WORKGROUP_MEMORY).bits() != 0, ); flags.set( crate::Barrier::SUB_GROUP, semantics & spirv::MemorySemantics::SUBGROUP_MEMORY.bits() != 0, ); flags.set( crate::Barrier::TEXTURE, semantics & spirv::MemorySemantics::IMAGE_MEMORY.bits() != 0, ); block.extend(emitter.finish(ctx.expressions)); block.push(crate::Statement::MemoryBarrier(flags), span); emitter.start(ctx.expressions); } Op::CopyObject => { inst.expect(4)?; let result_type_id = self.next()?; let result_id = self.next()?; let operand_id = self.next()?; let lookup = self.lookup_expression.lookup(operand_id)?; let handle = get_expr_handle!(operand_id, lookup); self.lookup_expression.insert( result_id, LookupExpression { handle, type_id: result_type_id, block_id, }, ); } Op::GroupNonUniformBallot => { inst.expect(5)?; block.extend(emitter.finish(ctx.expressions)); let result_type_id = self.next()?; let result_id = self.next()?; let exec_scope_id = self.next()?; let predicate_id = self.next()?; let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?; let _exec_scope = resolve_constant(ctx.gctx(), &exec_scope_const.inner) .filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32) .ok_or(Error::InvalidBarrierScope(exec_scope_id))?; let predicate = if self .lookup_constant .lookup(predicate_id) .ok() .filter(|predicate_const| match predicate_const.inner { Constant::Constant(constant) => matches!( ctx.gctx().global_expressions[ctx.gctx().constants[constant].init], crate::Expression::Literal(crate::Literal::Bool(true)), ), Constant::Override(_) => false, }) .is_some() { None } else { let predicate_lookup = self.lookup_expression.lookup(predicate_id)?; let predicate_handle = get_expr_handle!(predicate_id, predicate_lookup); Some(predicate_handle) }; let result_handle = ctx .expressions .append(crate::Expression::SubgroupBallotResult, span); self.lookup_expression.insert( result_id, LookupExpression { handle: result_handle, type_id: result_type_id, block_id, }, ); block.push( crate::Statement::SubgroupBallot { result: result_handle, predicate, }, span, ); emitter.start(ctx.expressions); } Op::GroupNonUniformAll | Op::GroupNonUniformAny | Op::GroupNonUniformIAdd | Op::GroupNonUniformFAdd | Op::GroupNonUniformIMul | Op::GroupNonUniformFMul | Op::GroupNonUniformSMax | Op::GroupNonUniformUMax | Op::GroupNonUniformFMax | Op::GroupNonUniformSMin | Op::GroupNonUniformUMin | Op::GroupNonUniformFMin | Op::GroupNonUniformBitwiseAnd | Op::GroupNonUniformBitwiseOr | Op::GroupNonUniformBitwiseXor | Op::GroupNonUniformLogicalAnd | Op::GroupNonUniformLogicalOr | Op::GroupNonUniformLogicalXor => { block.extend(emitter.finish(ctx.expressions)); inst.expect( if matches!(inst.op, Op::GroupNonUniformAll | Op::GroupNonUniformAny) { 5 } else { 6 }, )?; let result_type_id = self.next()?; let result_id = self.next()?; let exec_scope_id = self.next()?; let collective_op_id = match inst.op { Op::GroupNonUniformAll | Op::GroupNonUniformAny => { crate::CollectiveOperation::Reduce } _ => { let group_op_id = self.next()?; match spirv::GroupOperation::from_u32(group_op_id) { Some(spirv::GroupOperation::Reduce) => { crate::CollectiveOperation::Reduce } Some(spirv::GroupOperation::InclusiveScan) => { crate::CollectiveOperation::InclusiveScan } Some(spirv::GroupOperation::ExclusiveScan) => { crate::CollectiveOperation::ExclusiveScan } _ => return Err(Error::UnsupportedGroupOperation(group_op_id)), } } }; let argument_id = self.next()?; let argument_lookup = self.lookup_expression.lookup(argument_id)?; let argument_handle = get_expr_handle!(argument_id, argument_lookup); let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?; let _exec_scope = resolve_constant(ctx.gctx(), &exec_scope_const.inner) .filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32) .ok_or(Error::InvalidBarrierScope(exec_scope_id))?; let op_id = match inst.op { Op::GroupNonUniformAll => crate::SubgroupOperation::All, Op::GroupNonUniformAny => crate::SubgroupOperation::Any, Op::GroupNonUniformIAdd | Op::GroupNonUniformFAdd => { crate::SubgroupOperation::Add } Op::GroupNonUniformIMul | Op::GroupNonUniformFMul => { crate::SubgroupOperation::Mul } Op::GroupNonUniformSMax | Op::GroupNonUniformUMax | Op::GroupNonUniformFMax => crate::SubgroupOperation::Max, Op::GroupNonUniformSMin | Op::GroupNonUniformUMin | Op::GroupNonUniformFMin => crate::SubgroupOperation::Min, Op::GroupNonUniformBitwiseAnd | Op::GroupNonUniformLogicalAnd => { crate::SubgroupOperation::And } Op::GroupNonUniformBitwiseOr | Op::GroupNonUniformLogicalOr => { crate::SubgroupOperation::Or } Op::GroupNonUniformBitwiseXor | Op::GroupNonUniformLogicalXor => { crate::SubgroupOperation::Xor } _ => unreachable!(), }; let result_type = self.lookup_type.lookup(result_type_id)?; let result_handle = ctx.expressions.append( crate::Expression::SubgroupOperationResult { ty: result_type.handle, }, span, ); self.lookup_expression.insert( result_id, LookupExpression { handle: result_handle, type_id: result_type_id, block_id, }, ); block.push( crate::Statement::SubgroupCollectiveOperation { result: result_handle, op: op_id, collective_op: collective_op_id, argument: argument_handle, }, span, ); emitter.start(ctx.expressions); } Op::GroupNonUniformBroadcastFirst | Op::GroupNonUniformBroadcast | Op::GroupNonUniformShuffle | Op::GroupNonUniformShuffleDown | Op::GroupNonUniformShuffleUp | Op::GroupNonUniformShuffleXor | Op::GroupNonUniformQuadBroadcast => { inst.expect(if matches!(inst.op, Op::GroupNonUniformBroadcastFirst) { 5 } else { 6 })?; block.extend(emitter.finish(ctx.expressions)); let result_type_id = self.next()?; let result_id = self.next()?; let exec_scope_id = self.next()?; let argument_id = self.next()?; let argument_lookup = self.lookup_expression.lookup(argument_id)?; let argument_handle = get_expr_handle!(argument_id, argument_lookup); let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?; let _exec_scope = resolve_constant(ctx.gctx(), &exec_scope_const.inner) .filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32) .ok_or(Error::InvalidBarrierScope(exec_scope_id))?; let mode = if matches!(inst.op, Op::GroupNonUniformBroadcastFirst) { crate::GatherMode::BroadcastFirst } else { let index_id = self.next()?; let index_lookup = self.lookup_expression.lookup(index_id)?; let index_handle = get_expr_handle!(index_id, index_lookup); match inst.op { Op::GroupNonUniformBroadcast => { crate::GatherMode::Broadcast(index_handle) } Op::GroupNonUniformShuffle => crate::GatherMode::Shuffle(index_handle), Op::GroupNonUniformShuffleDown => { crate::GatherMode::ShuffleDown(index_handle) } Op::GroupNonUniformShuffleUp => { crate::GatherMode::ShuffleUp(index_handle) } Op::GroupNonUniformShuffleXor => { crate::GatherMode::ShuffleXor(index_handle) } Op::GroupNonUniformQuadBroadcast => { crate::GatherMode::QuadBroadcast(index_handle) } _ => unreachable!(), } }; let result_type = self.lookup_type.lookup(result_type_id)?; let result_handle = ctx.expressions.append( crate::Expression::SubgroupOperationResult { ty: result_type.handle, }, span, ); self.lookup_expression.insert( result_id, LookupExpression { handle: result_handle, type_id: result_type_id, block_id, }, ); block.push( crate::Statement::SubgroupGather { result: result_handle, mode, argument: argument_handle, }, span, ); emitter.start(ctx.expressions); } Op::GroupNonUniformQuadSwap => { inst.expect(6)?; block.extend(emitter.finish(ctx.expressions)); let result_type_id = self.next()?; let result_id = self.next()?; let exec_scope_id = self.next()?; let argument_id = self.next()?; let direction_id = self.next()?; let argument_lookup = self.lookup_expression.lookup(argument_id)?; let argument_handle = get_expr_handle!(argument_id, argument_lookup); let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?; let _exec_scope = resolve_constant(ctx.gctx(), &exec_scope_const.inner) .filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32) .ok_or(Error::InvalidBarrierScope(exec_scope_id))?; let direction_const = self.lookup_constant.lookup(direction_id)?; let direction_const = resolve_constant(ctx.gctx(), &direction_const.inner) .ok_or(Error::InvalidOperand)?; let direction = match direction_const { 0 => crate::Direction::X, 1 => crate::Direction::Y, 2 => crate::Direction::Diagonal, _ => unreachable!(), }; let result_type = self.lookup_type.lookup(result_type_id)?; let result_handle = ctx.expressions.append( crate::Expression::SubgroupOperationResult { ty: result_type.handle, }, span, ); self.lookup_expression.insert( result_id, LookupExpression { handle: result_handle, type_id: result_type_id, block_id, }, ); block.push( crate::Statement::SubgroupGather { mode: crate::GatherMode::QuadSwap(direction), result: result_handle, argument: argument_handle, }, span, ); emitter.start(ctx.expressions); } Op::AtomicLoad => { inst.expect(6)?; let start = self.data_offset; let result_type_id = self.next()?; let result_id = self.next()?; let pointer_id = self.next()?; let _scope_id = self.next()?; let _memory_semantics_id = self.next()?; let span = self.span_from_with_op(start); log::trace!("\t\t\tlooking up expr {pointer_id:?}"); let p_lexp_handle = get_expr_handle!(pointer_id, self.lookup_expression.lookup(pointer_id)?); // Create an expression for our result let expr = crate::Expression::Load { pointer: p_lexp_handle, }; let handle = ctx.expressions.append(expr, span); self.lookup_expression.insert( result_id, LookupExpression { handle, type_id: result_type_id, block_id, }, ); // Store any associated global variables so we can upgrade their types later self.record_atomic_access(ctx, p_lexp_handle)?; } Op::AtomicStore => { inst.expect(5)?; let start = self.data_offset; let pointer_id = self.next()?; let _scope_id = self.next()?; let _memory_semantics_id = self.next()?; let value_id = self.next()?; let span = self.span_from_with_op(start); log::trace!("\t\t\tlooking up pointer expr {pointer_id:?}"); let p_lexp_handle = get_expr_handle!(pointer_id, self.lookup_expression.lookup(pointer_id)?); log::trace!("\t\t\tlooking up value expr {pointer_id:?}"); let v_lexp_handle = get_expr_handle!(value_id, self.lookup_expression.lookup(value_id)?); block.extend(emitter.finish(ctx.expressions)); // Create a statement for the op itself let stmt = crate::Statement::Store { pointer: p_lexp_handle, value: v_lexp_handle, }; block.push(stmt, span); emitter.start(ctx.expressions); // Store any associated global variables so we can upgrade their types later self.record_atomic_access(ctx, p_lexp_handle)?; } Op::AtomicIIncrement | Op::AtomicIDecrement => { inst.expect(6)?; let start = self.data_offset; let result_type_id = self.next()?; let result_id = self.next()?; let pointer_id = self.next()?; let _scope_id = self.next()?; let _memory_semantics_id = self.next()?; let span = self.span_from_with_op(start); let (p_exp_h, p_base_ty_h) = self.get_exp_and_base_ty_handles( pointer_id, ctx, &mut emitter, &mut block, body_idx, )?; block.extend(emitter.finish(ctx.expressions)); // Create an expression for our result let r_lexp_handle = { let expr = crate::Expression::AtomicResult { ty: p_base_ty_h, comparison: false, }; let handle = ctx.expressions.append(expr, span); self.lookup_expression.insert( result_id, LookupExpression { handle, type_id: result_type_id, block_id, }, ); handle }; emitter.start(ctx.expressions); // Create a literal "1" to use as our value let one_lexp_handle = make_index_literal( ctx, 1, &mut block, &mut emitter, p_base_ty_h, result_type_id, span, )?; // Create a statement for the op itself let stmt = crate::Statement::Atomic { pointer: p_exp_h, fun: match inst.op { Op::AtomicIIncrement => crate::AtomicFunction::Add, _ => crate::AtomicFunction::Subtract, }, value: one_lexp_handle, result: Some(r_lexp_handle), }; block.push(stmt, span); // Store any associated global variables so we can upgrade their types later self.record_atomic_access(ctx, p_exp_h)?; } Op::AtomicCompareExchange => { inst.expect(9)?; let start = self.data_offset; let span = self.span_from_with_op(start); let result_type_id = self.next()?; let result_id = self.next()?; let pointer_id = self.next()?; let _memory_scope_id = self.next()?; let _equal_memory_semantics_id = self.next()?; let _unequal_memory_semantics_id = self.next()?; let value_id = self.next()?; let comparator_id = self.next()?; let (p_exp_h, p_base_ty_h) = self.get_exp_and_base_ty_handles( pointer_id, ctx, &mut emitter, &mut block, body_idx, )?; log::trace!("\t\t\tlooking up value expr {value_id:?}"); let v_lexp_handle = get_expr_handle!(value_id, self.lookup_expression.lookup(value_id)?); log::trace!("\t\t\tlooking up comparator expr {value_id:?}"); let c_lexp_handle = get_expr_handle!( comparator_id, self.lookup_expression.lookup(comparator_id)? ); // We know from the SPIR-V spec that the result type must be an integer // scalar, and we'll need the type itself to get a handle to the atomic // result struct. let crate::TypeInner::Scalar(scalar) = ctx.module.types[p_base_ty_h].inner else { return Err( crate::front::atomic_upgrade::Error::CompareExchangeNonScalarBaseType .into(), ); }; // Get a handle to the atomic result struct type. let atomic_result_struct_ty_h = ctx.module.generate_predeclared_type( crate::PredeclaredType::AtomicCompareExchangeWeakResult(scalar), ); block.extend(emitter.finish(ctx.expressions)); // Create an expression for our atomic result let atomic_lexp_handle = { let expr = crate::Expression::AtomicResult { ty: atomic_result_struct_ty_h, comparison: true, }; ctx.expressions.append(expr, span) }; // Create an dot accessor to extract the value from the // result struct __atomic_compare_exchange_result and use that // as the expression for the result_id { let expr = crate::Expression::AccessIndex { base: atomic_lexp_handle, index: 0, }; let handle = ctx.expressions.append(expr, span); // Use this dot accessor as the result id's expression let _ = self.lookup_expression.insert( result_id, LookupExpression { handle, type_id: result_type_id, block_id, }, ); } emitter.start(ctx.expressions); // Create a statement for the op itself let stmt = crate::Statement::Atomic { pointer: p_exp_h, fun: crate::AtomicFunction::Exchange { compare: Some(c_lexp_handle), }, value: v_lexp_handle, result: Some(atomic_lexp_handle), }; block.push(stmt, span); // Store any associated global variables so we can upgrade their types later self.record_atomic_access(ctx, p_exp_h)?; } Op::AtomicExchange | Op::AtomicIAdd | Op::AtomicISub | Op::AtomicSMin | Op::AtomicUMin | Op::AtomicSMax | Op::AtomicUMax | Op::AtomicAnd | Op::AtomicOr | Op::AtomicXor | Op::AtomicFAddEXT => self.parse_atomic_expr_with_value( inst, &mut emitter, ctx, &mut block, block_id, body_idx, match inst.op { Op::AtomicExchange => crate::AtomicFunction::Exchange { compare: None }, Op::AtomicIAdd | Op::AtomicFAddEXT => crate::AtomicFunction::Add, Op::AtomicISub => crate::AtomicFunction::Subtract, Op::AtomicSMin => crate::AtomicFunction::Min, Op::AtomicUMin => crate::AtomicFunction::Min, Op::AtomicSMax => crate::AtomicFunction::Max, Op::AtomicUMax => crate::AtomicFunction::Max, Op::AtomicAnd => crate::AtomicFunction::And, Op::AtomicOr => crate::AtomicFunction::InclusiveOr, Op::AtomicXor => crate::AtomicFunction::ExclusiveOr, _ => unreachable!(), }, )?, _ => { return Err(Error::UnsupportedInstruction(self.state, inst.op)); } } }; block.extend(emitter.finish(ctx.expressions)); if let Some(stmt) = terminator { block.push(stmt, crate::Span::default()); } // Save this block fragment in `block_ctx.blocks`, and mark it to be // incorporated into the current body at `Statement` assembly time. ctx.blocks.insert(block_id, block); let body = &mut ctx.bodies[body_idx]; body.data.push(BodyFragment::BlockId(block_id)); Ok(()) } } fn make_index_literal( ctx: &mut BlockContext, index: u32, block: &mut crate::Block, emitter: &mut crate::proc::Emitter, index_type: Handle, index_type_id: spirv::Word, span: crate::Span, ) -> Result, Error> { block.extend(emitter.finish(ctx.expressions)); let literal = match ctx.module.types[index_type].inner.scalar_kind() { Some(crate::ScalarKind::Uint) => crate::Literal::U32(index), Some(crate::ScalarKind::Sint) => crate::Literal::I32(index as i32), _ => return Err(Error::InvalidIndexType(index_type_id)), }; let expr = ctx .expressions .append(crate::Expression::Literal(literal), span); emitter.start(ctx.expressions); Ok(expr) } naga-29.0.3/src/front/spv/null.rs000064400000000000000000000024741046102023000147030ustar 00000000000000use alloc::vec; use super::Error; use crate::arena::{Arena, Handle}; /// Create a default value for an output built-in. pub fn generate_default_built_in( built_in: Option, ty: Handle, global_expressions: &mut Arena, span: crate::Span, ) -> Result, Error> { let expr = match built_in { Some(crate::BuiltIn::Position { .. }) => { let zero = global_expressions .append(crate::Expression::Literal(crate::Literal::F32(0.0)), span); let one = global_expressions .append(crate::Expression::Literal(crate::Literal::F32(1.0)), span); crate::Expression::Compose { ty, components: vec![zero, zero, zero, one], } } Some(crate::BuiltIn::PointSize) => crate::Expression::Literal(crate::Literal::F32(1.0)), Some(crate::BuiltIn::FragDepth) => crate::Expression::Literal(crate::Literal::F32(0.0)), Some(crate::BuiltIn::SampleMask) => { crate::Expression::Literal(crate::Literal::U32(u32::MAX)) } // Note: `crate::BuiltIn::ClipDistance` is intentionally left for the default path _ => crate::Expression::ZeroValue(ty), }; Ok(global_expressions.append(expr, span)) } naga-29.0.3/src/front/type_gen.rs000064400000000000000000000567371046102023000147460ustar 00000000000000/*! Type generators. */ use alloc::{string::ToString, vec}; use crate::{arena::Handle, span::Span}; impl crate::Module { /// Populate this module's [`SpecialTypes::ray_desc`] type. /// /// [`SpecialTypes::ray_desc`] is the type of the [`descriptor`] operand of /// an [`Initialize`] [`RayQuery`] statement. In WGSL, it is a struct type /// referred to as `RayDesc`. /// /// Backends consume values of this type to drive platform APIs, so if you /// change any its fields, you must update the backends to match. Look for /// backend code dealing with [`RayQueryFunction::Initialize`]. /// /// [`SpecialTypes::ray_desc`]: crate::SpecialTypes::ray_desc /// [`descriptor`]: crate::RayQueryFunction::Initialize::descriptor /// [`Initialize`]: crate::RayQueryFunction::Initialize /// [`RayQuery`]: crate::Statement::RayQuery /// [`RayQueryFunction::Initialize`]: crate::RayQueryFunction::Initialize pub fn generate_ray_desc_type(&mut self) -> Handle { if let Some(handle) = self.special_types.ray_desc { return handle; } let ty_flag = self.types.insert( crate::Type { name: None, inner: crate::TypeInner::Scalar(crate::Scalar::U32), }, Span::UNDEFINED, ); let ty_scalar = self.types.insert( crate::Type { name: None, inner: crate::TypeInner::Scalar(crate::Scalar::F32), }, Span::UNDEFINED, ); let ty_vector = self.types.insert( crate::Type { name: None, inner: crate::TypeInner::Vector { size: crate::VectorSize::Tri, scalar: crate::Scalar::F32, }, }, Span::UNDEFINED, ); let handle = self.types.insert( crate::Type { name: Some("RayDesc".to_string()), inner: crate::TypeInner::Struct { members: vec![ crate::StructMember { name: Some("flags".to_string()), ty: ty_flag, binding: None, offset: 0, }, crate::StructMember { name: Some("cull_mask".to_string()), ty: ty_flag, binding: None, offset: 4, }, crate::StructMember { name: Some("tmin".to_string()), ty: ty_scalar, binding: None, offset: 8, }, crate::StructMember { name: Some("tmax".to_string()), ty: ty_scalar, binding: None, offset: 12, }, crate::StructMember { name: Some("origin".to_string()), ty: ty_vector, binding: None, offset: 16, }, crate::StructMember { name: Some("dir".to_string()), ty: ty_vector, binding: None, offset: 32, }, ], span: 48, }, }, Span::UNDEFINED, ); self.special_types.ray_desc = Some(handle); handle } /// Make sure the types for the vertex return are in the module's type pub fn generate_vertex_return_type(&mut self) -> Handle { if let Some(handle) = self.special_types.ray_vertex_return { return handle; } let ty_vec3f = self.types.insert( crate::Type { name: None, inner: crate::TypeInner::Vector { size: crate::VectorSize::Tri, scalar: crate::Scalar::F32, }, }, Span::UNDEFINED, ); let array = self.types.insert( crate::Type { name: None, inner: crate::TypeInner::Array { base: ty_vec3f, size: crate::ArraySize::Constant(core::num::NonZeroU32::new(3).unwrap()), stride: 16, }, }, Span::UNDEFINED, ); self.special_types.ray_vertex_return = Some(array); array } /// Populate this module's [`SpecialTypes::ray_intersection`] type. /// /// [`SpecialTypes::ray_intersection`] is the type of a /// `RayQueryGetIntersection` expression. In WGSL, it is a struct type /// referred to as `RayIntersection`. /// /// Backends construct values of this type based on platform APIs, so if you /// change any its fields, you must update the backends to match. Look for /// the backend's handling for [`Expression::RayQueryGetIntersection`]. /// /// [`SpecialTypes::ray_intersection`]: crate::SpecialTypes::ray_intersection /// [`Expression::RayQueryGetIntersection`]: crate::Expression::RayQueryGetIntersection pub fn generate_ray_intersection_type(&mut self) -> Handle { if let Some(handle) = self.special_types.ray_intersection { return handle; } let ty_flag = self.types.insert( crate::Type { name: None, inner: crate::TypeInner::Scalar(crate::Scalar::U32), }, Span::UNDEFINED, ); let ty_scalar = self.types.insert( crate::Type { name: None, inner: crate::TypeInner::Scalar(crate::Scalar::F32), }, Span::UNDEFINED, ); let ty_barycentrics = self.types.insert( crate::Type { name: None, inner: crate::TypeInner::Vector { size: crate::VectorSize::Bi, scalar: crate::Scalar::F32, }, }, Span::UNDEFINED, ); let ty_bool = self.types.insert( crate::Type { name: None, inner: crate::TypeInner::Scalar(crate::Scalar::BOOL), }, Span::UNDEFINED, ); let ty_transform = self.types.insert( crate::Type { name: None, inner: crate::TypeInner::Matrix { columns: crate::VectorSize::Quad, rows: crate::VectorSize::Tri, scalar: crate::Scalar::F32, }, }, Span::UNDEFINED, ); let handle = self.types.insert( crate::Type { name: Some("RayIntersection".to_string()), inner: crate::TypeInner::Struct { members: vec![ crate::StructMember { name: Some("kind".to_string()), ty: ty_flag, binding: None, offset: 0, }, crate::StructMember { name: Some("t".to_string()), ty: ty_scalar, binding: None, offset: 4, }, crate::StructMember { name: Some("instance_custom_data".to_string()), ty: ty_flag, binding: None, offset: 8, }, crate::StructMember { name: Some("instance_index".to_string()), ty: ty_flag, binding: None, offset: 12, }, crate::StructMember { name: Some("sbt_record_offset".to_string()), ty: ty_flag, binding: None, offset: 16, }, crate::StructMember { name: Some("geometry_index".to_string()), ty: ty_flag, binding: None, offset: 20, }, crate::StructMember { name: Some("primitive_index".to_string()), ty: ty_flag, binding: None, offset: 24, }, crate::StructMember { name: Some("barycentrics".to_string()), ty: ty_barycentrics, binding: None, offset: 28, }, crate::StructMember { name: Some("front_face".to_string()), ty: ty_bool, binding: None, offset: 36, }, crate::StructMember { name: Some("object_to_world".to_string()), ty: ty_transform, binding: None, offset: 48, }, crate::StructMember { name: Some("world_to_object".to_string()), ty: ty_transform, binding: None, offset: 112, }, ], span: 176, }, }, Span::UNDEFINED, ); self.special_types.ray_intersection = Some(handle); handle } /// Generate [`SpecialTypes::external_texture_params`] and /// [`SpecialTypes::external_texture_transfer_function`]. /// /// Other than the WGSL backend, every backend that supports external /// textures does so by lowering them to a set of ordinary textures and /// some parameters saying how to sample from them. These types are used /// for said parameters. Note that they are not used by the IR, but /// generated purely as a convenience for the backends. /// /// [`SpecialTypes::external_texture_params`]: crate::ir::SpecialTypes::external_texture_params /// [`SpecialTypes::external_texture_transfer_function`]: crate::ir::SpecialTypes::external_texture_transfer_function pub fn generate_external_texture_types(&mut self) { if self.special_types.external_texture_params.is_some() { return; } let ty_f32 = self.types.insert( crate::Type { name: None, inner: crate::TypeInner::Scalar(crate::Scalar::F32), }, Span::UNDEFINED, ); let ty_u32 = self.types.insert( crate::Type { name: None, inner: crate::TypeInner::Scalar(crate::Scalar::U32), }, Span::UNDEFINED, ); let ty_vec2u = self.types.insert( crate::Type { name: None, inner: crate::TypeInner::Vector { size: crate::VectorSize::Bi, scalar: crate::Scalar::U32, }, }, Span::UNDEFINED, ); let ty_mat3x2f = self.types.insert( crate::Type { name: None, inner: crate::TypeInner::Matrix { columns: crate::VectorSize::Tri, rows: crate::VectorSize::Bi, scalar: crate::Scalar::F32, }, }, Span::UNDEFINED, ); let ty_mat3x3f = self.types.insert( crate::Type { name: None, inner: crate::TypeInner::Matrix { columns: crate::VectorSize::Tri, rows: crate::VectorSize::Tri, scalar: crate::Scalar::F32, }, }, Span::UNDEFINED, ); let ty_mat4x4f = self.types.insert( crate::Type { name: None, inner: crate::TypeInner::Matrix { columns: crate::VectorSize::Quad, rows: crate::VectorSize::Quad, scalar: crate::Scalar::F32, }, }, Span::UNDEFINED, ); let transfer_fn_handle = self.types.insert( crate::Type { name: Some("NagaExternalTextureTransferFn".to_string()), inner: crate::TypeInner::Struct { members: vec![ crate::StructMember { name: Some("a".to_string()), ty: ty_f32, binding: None, offset: 0, }, crate::StructMember { name: Some("b".to_string()), ty: ty_f32, binding: None, offset: 4, }, crate::StructMember { name: Some("g".to_string()), ty: ty_f32, binding: None, offset: 8, }, crate::StructMember { name: Some("k".to_string()), ty: ty_f32, binding: None, offset: 12, }, ], span: 16, }, }, Span::UNDEFINED, ); self.special_types.external_texture_transfer_function = Some(transfer_fn_handle); let params_handle = self.types.insert( crate::Type { name: Some("NagaExternalTextureParams".to_string()), inner: crate::TypeInner::Struct { members: vec![ crate::StructMember { name: Some("yuv_conversion_matrix".to_string()), ty: ty_mat4x4f, binding: None, offset: 0, }, crate::StructMember { name: Some("gamut_conversion_matrix".to_string()), ty: ty_mat3x3f, binding: None, offset: 64, }, crate::StructMember { name: Some("src_tf".to_string()), ty: transfer_fn_handle, binding: None, offset: 112, }, crate::StructMember { name: Some("dst_tf".to_string()), ty: transfer_fn_handle, binding: None, offset: 128, }, crate::StructMember { name: Some("sample_transform".to_string()), ty: ty_mat3x2f, binding: None, offset: 144, }, crate::StructMember { name: Some("load_transform".to_string()), ty: ty_mat3x2f, binding: None, offset: 168, }, crate::StructMember { name: Some("size".to_string()), ty: ty_vec2u, binding: None, offset: 192, }, crate::StructMember { name: Some("num_planes".to_string()), ty: ty_u32, binding: None, offset: 200, }, ], span: 208, }, }, Span::UNDEFINED, ); self.special_types.external_texture_params = Some(params_handle); } /// Populate this module's [`SpecialTypes::predeclared_types`] type and return the handle. /// /// [`SpecialTypes::predeclared_types`]: crate::SpecialTypes::predeclared_types pub fn generate_predeclared_type( &mut self, special_type: crate::PredeclaredType, ) -> Handle { if let Some(value) = self.special_types.predeclared_types.get(&special_type) { return *value; } let name = special_type.struct_name(); let ty = match special_type { crate::PredeclaredType::AtomicCompareExchangeWeakResult(scalar) => { let bool_ty = self.types.insert( crate::Type { name: None, inner: crate::TypeInner::Scalar(crate::Scalar::BOOL), }, Span::UNDEFINED, ); let scalar_ty = self.types.insert( crate::Type { name: None, inner: crate::TypeInner::Scalar(scalar), }, Span::UNDEFINED, ); crate::Type { name: Some(name), inner: crate::TypeInner::Struct { members: vec![ crate::StructMember { name: Some("old_value".to_string()), ty: scalar_ty, binding: None, offset: 0, }, crate::StructMember { name: Some("exchanged".to_string()), ty: bool_ty, binding: None, offset: scalar.width as u32, }, ], span: scalar.width as u32 * 2, }, } } crate::PredeclaredType::ModfResult { size, scalar } => { let float_ty = self.types.insert( crate::Type { name: None, inner: crate::TypeInner::Scalar(scalar), }, Span::UNDEFINED, ); let (member_ty, second_offset) = if let Some(size) = size { let vec_ty = self.types.insert( crate::Type { name: None, inner: crate::TypeInner::Vector { size, scalar }, }, Span::UNDEFINED, ); (vec_ty, size as u32 * scalar.width as u32) } else { (float_ty, scalar.width as u32) }; crate::Type { name: Some(name), inner: crate::TypeInner::Struct { members: vec![ crate::StructMember { name: Some("fract".to_string()), ty: member_ty, binding: None, offset: 0, }, crate::StructMember { name: Some("whole".to_string()), ty: member_ty, binding: None, offset: second_offset, }, ], span: second_offset * 2, }, } } crate::PredeclaredType::FrexpResult { size, scalar } => { let float_ty = self.types.insert( crate::Type { name: None, inner: crate::TypeInner::Scalar(scalar), }, Span::UNDEFINED, ); let int_ty = self.types.insert( crate::Type { name: None, inner: crate::TypeInner::Scalar(crate::Scalar { kind: crate::ScalarKind::Sint, width: scalar.width, }), }, Span::UNDEFINED, ); let (fract_member_ty, exp_member_ty, second_offset) = if let Some(size) = size { let vec_float_ty = self.types.insert( crate::Type { name: None, inner: crate::TypeInner::Vector { size, scalar }, }, Span::UNDEFINED, ); let vec_int_ty = self.types.insert( crate::Type { name: None, inner: crate::TypeInner::Vector { size, scalar: crate::Scalar { kind: crate::ScalarKind::Sint, width: scalar.width, }, }, }, Span::UNDEFINED, ); (vec_float_ty, vec_int_ty, size as u32 * scalar.width as u32) } else { (float_ty, int_ty, scalar.width as u32) }; crate::Type { name: Some(name), inner: crate::TypeInner::Struct { members: vec![ crate::StructMember { name: Some("fract".to_string()), ty: fract_member_ty, binding: None, offset: 0, }, crate::StructMember { name: Some("exp".to_string()), ty: exp_member_ty, binding: None, offset: second_offset, }, ], span: second_offset * 2, }, } } }; let handle = self.types.insert(ty, Span::UNDEFINED); self.special_types .predeclared_types .insert(special_type, handle); handle } } naga-29.0.3/src/front/wgsl/error.rs000064400000000000000000001752751046102023000152400ustar 00000000000000//! Formatting WGSL front end error messages. use crate::common::wgsl::TryToWgsl; use crate::diagnostic_filter::ConflictingDiagnosticRuleError; use crate::error::replace_control_chars; use crate::proc::{Alignment, ConstantEvaluatorError, ResolveError}; use crate::{Scalar, SourceLocation, Span}; use super::parse::directive::enable_extension::{EnableExtension, UnimplementedEnableExtension}; use super::parse::directive::language_extension::{ LanguageExtension, UnimplementedLanguageExtension, }; use super::parse::lexer::Token; use codespan_reporting::diagnostic::{Diagnostic, Label}; use codespan_reporting::files::SimpleFile; use codespan_reporting::term; use thiserror::Error; use alloc::{ borrow::Cow, boxed::Box, format, string::{String, ToString}, vec, vec::Vec, }; use core::ops::Range; #[derive(Clone, Debug)] pub struct ParseError { message: String, // The first span should be the primary span, and the other ones should be complementary. labels: Vec<(Span, Cow<'static, str>)>, notes: Vec, } impl ParseError { pub fn labels(&self) -> impl ExactSizeIterator + '_ { self.labels .iter() .map(|&(span, ref msg)| (span, msg.as_ref())) } pub fn message(&self) -> &str { &self.message } fn diagnostic(&self) -> Diagnostic<()> { let diagnostic = Diagnostic::error() .with_message(self.message.to_string()) .with_labels( self.labels .iter() .filter_map(|label| label.0.to_range().map(|range| (label, range))) .map(|(label, range)| { Label::primary((), range).with_message(label.1.to_string()) }) .collect(), ) .with_notes( self.notes .iter() .map(|note| format!("note: {note}")) .collect(), ); diagnostic } /// Emits a summary of the error to standard error stream. #[cfg(feature = "stderr")] pub fn emit_to_stderr(&self, source: &str) { self.emit_to_stderr_with_path(source, "wgsl") } /// Emits a summary of the error to standard error stream. #[cfg(feature = "stderr")] pub fn emit_to_stderr_with_path

(&self, source: &str, path: P) where P: AsRef, { let path = path.as_ref().display().to_string(); let files = SimpleFile::new(path, replace_control_chars(source)); let config = term::Config::default(); cfg_if::cfg_if! { if #[cfg(feature = "termcolor")] { let writer = term::termcolor::StandardStream::stderr(term::termcolor::ColorChoice::Auto); term::emit_to_write_style(&mut writer.lock(), &config, &files, &self.diagnostic()) .expect("cannot write error"); } else { let writer = std::io::stderr(); term::emit_to_io_write(&mut writer.lock(), &config, &files, &self.diagnostic()) .expect("cannot write error"); } } } /// Emits a summary of the error to a string. pub fn emit_to_string(&self, source: &str) -> String { self.emit_to_string_with_path(source, "wgsl") } /// Emits a summary of the error to a string. pub fn emit_to_string_with_path

(&self, source: &str, path: P) -> String where P: AsRef, { let path = path.as_ref().display().to_string(); let files = SimpleFile::new(path, replace_control_chars(source)); let config = term::Config::default(); let mut writer = crate::error::DiagnosticBuffer::new(); writer .emit_to_self(&config, &files, &self.diagnostic()) .expect("cannot write error"); writer.into_string() } /// Returns a [`SourceLocation`] for the first label in the error message. pub fn location(&self, source: &str) -> Option { self.labels.first().map(|label| label.0.location(source)) } } impl core::fmt::Display for ParseError { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { write!(f, "{}", self.message) } } impl core::error::Error for ParseError {} #[derive(Copy, Clone, Debug, PartialEq)] pub enum ExpectedToken<'a> { Token(Token<'a>), Identifier, AfterIdentListComma, AfterIdentListArg, /// LHS expression (identifier component_or_swizzle_specifier?, (`lhs_expression`) component_or_swizzle_specifier?, &`lhs_expression`, *`lhs_expression`) LhsExpression, /// Expected: constant, parenthesized expression, identifier PrimaryExpression, /// Expected: assignment, increment/decrement expression Assignment, /// Expected: 'case', 'default', '}' SwitchItem, /// Expected: ',', ')' WorkgroupSizeSeparator, /// Expected: 'struct', 'let', 'var', 'type', ';', 'fn', eof GlobalItem, /// Access of `var`, `let`, `const`. Variable, /// Access of a function Function, /// The `diagnostic` identifier of the `@diagnostic(…)` attribute. DiagnosticAttribute, /// statement Statement, /// for loop init statement (variable_or_value_statement, variable_updating_statement, func_call_statement) ForInit, /// for loop update statement (variable_updating_statement, func_call_statement) ForUpdate, } #[derive(Clone, Copy, Debug, Error, PartialEq)] pub enum NumberError { #[error("invalid numeric literal format")] Invalid, #[error("numeric literal not representable by target type")] NotRepresentable, } #[derive(Copy, Clone, Debug, PartialEq)] pub enum InvalidAssignmentType { Other, Swizzle, ImmutableBinding(Span), } #[derive(Clone, Debug)] pub(crate) enum Error<'a> { Unexpected(Span, ExpectedToken<'a>), UnexpectedComponents(Span), UnexpectedOperationInConstContext(Span), BadNumber(Span, NumberError), BadMatrixScalarKind(Span, Scalar), BadAccessor(Span), BadTexture(Span), BadTypeCast { span: Span, from_type: String, to_type: String, }, NotStorageTexture(Span), BadTextureSampleType { span: Span, scalar: Scalar, }, BadIncrDecrReferenceType(Span), InvalidResolve(ResolveError), /// A break if appeared outside of a continuing block InvalidBreakIf(Span), InvalidGatherComponent(Span), InvalidConstructorComponentType(Span, i32), InvalidIdentifierUnderscore(Span), ReservedIdentifierPrefix(Span), UnknownAddressSpace(Span), InvalidLocalVariableAddressSpace(Span), UnknownRayFlag(Span), RepeatedAttribute(Span), UnknownAttribute(Span), UnknownBuiltin(Span), UnknownAccess(Span), UnknownIdent(Span, &'a str), UnknownScalarType(Span), UnknownStorageFormat(Span), UnknownConservativeDepth(Span), UnknownEnableExtension(Span, &'a str), UnknownLanguageExtension(Span, &'a str), UnknownDiagnosticRuleName(Span), SizeAttributeTooLow(Span, u32), AlignAttributeTooLow(Span, Alignment), NonPowerOfTwoAlignAttribute(Span), InconsistentBinding(Span), TypeNotConstructible(Span), TypeNotInferable(Span), InitializationTypeMismatch { name: Span, expected: String, got: String, }, DeclMissingTypeAndInit(Span), MissingAttribute(&'static str, Span), InvalidAddrOfOperand(Span), InvalidAtomicPointer(Span), InvalidAtomicOperandType(Span), InvalidRayQueryPointer(Span), NotPointer(Span), NotReference(&'static str, Span), InvalidAssignment { span: Span, ty: InvalidAssignmentType, }, ReservedKeyword(Span), /// Redefinition of an identifier (used for both module-scope and local redefinitions). Redefinition { /// Span of the identifier in the previous definition. previous: Span, /// Span of the identifier in the new definition. current: Span, }, /// A declaration refers to itself directly. RecursiveDeclaration { /// The location of the name of the declaration. ident: Span, /// The point at which it is used. usage: Span, }, /// A declaration refers to itself indirectly, through one or more other /// definitions. CyclicDeclaration { /// The location of the name of some declaration in the cycle. ident: Span, /// The edges of the cycle of references. /// /// Each `(decl, reference)` pair indicates that the declaration whose /// name is `decl` has an identifier at `reference` whose definition is /// the next declaration in the cycle. The last pair's `reference` is /// the same identifier as `ident`, above. path: Box<[(Span, Span)]>, }, InvalidSwitchSelector { span: Span, }, InvalidSwitchCase { span: Span, }, SwitchCaseTypeMismatch { span: Span, }, CalledEntryPoint(Span), CalledLocalDecl(Span), WrongArgumentCount { span: Span, expected: Range, found: u32, }, /// No overload of this function accepts this many arguments. TooManyArguments { /// The name of the function being called. function: String, /// The function name in the call expression. call_span: Span, /// The first argument that is unacceptable. arg_span: Span, /// Maximum number of arguments accepted by any overload of /// this function. max_arguments: u32, }, /// A value passed to a builtin function has a type that is not /// accepted by any overload of the function. WrongArgumentType { /// The name of the function being called. function: String, /// The function name in the call expression. call_span: Span, /// The first argument whose type is unacceptable. arg_span: Span, /// The index of the first argument whose type is unacceptable. arg_index: u32, /// That argument's actual type. arg_ty: String, /// The set of argument types that would have been accepted for /// this argument, given the prior arguments. allowed: Vec, }, /// A value passed to a builtin function has a type that is not /// accepted, given the earlier arguments' types. InconsistentArgumentType { /// The name of the function being called. function: String, /// The function name in the call expression. call_span: Span, /// The first unacceptable argument. arg_span: Span, /// The index of the first unacceptable argument. arg_index: u32, /// The actual type of the first unacceptable argument. arg_ty: String, /// The prior argument whose type made the `arg_span` argument /// unacceptable. inconsistent_span: Span, /// The index of the `inconsistent_span` argument. inconsistent_index: u32, /// The type of the `inconsistent_span` argument. inconsistent_ty: String, /// The types that would have been accepted instead of the /// first unacceptable argument. allowed: Vec, }, FunctionReturnsVoid(Span), FunctionMustUseUnused(Span), FunctionMustUseReturnsVoid(Span, Span), InvalidWorkGroupUniformLoad(Span), Internal(&'static str), ExpectedConstExprConcreteIntegerScalar(Span), ExpectedNonNegative(Span), ExpectedPositiveArrayLength(Span), MissingWorkgroupSize(Span), ConstantEvaluatorError(Box, Span), AutoConversion(Box), AutoConversionLeafScalar(Box), ConcretizationFailed(Box), ExceededLimitForNestedBraces { span: Span, limit: u8, }, PipelineConstantIDValue(Span), NotBool(Span), ConstAssertFailed(Span), DirectiveAfterFirstGlobalDecl { directive_span: Span, }, EnableExtensionNotYetImplemented { kind: UnimplementedEnableExtension, span: Span, }, EnableExtensionNotEnabled { kind: EnableExtension, span: Span, }, EnableExtensionNotSupported { kind: EnableExtension, span: Span, }, LanguageExtensionNotYetImplemented { kind: UnimplementedLanguageExtension, span: Span, }, DiagnosticInvalidSeverity { severity_control_name_span: Span, }, DiagnosticDuplicateTriggeringRule(ConflictingDiagnosticRuleError), DiagnosticAttributeNotYetImplementedAtParseSite { site_name_plural: &'static str, spans: Vec, }, DiagnosticAttributeNotSupported { on_what: DiagnosticAttributeNotSupportedPosition, spans: Vec, }, SelectUnexpectedArgumentType { arg_span: Span, arg_type: String, }, SelectRejectAndAcceptHaveNoCommonType { reject_span: Span, reject_type: String, accept_span: Span, accept_type: String, }, ExpectedGlobalVariable { name_span: Span, }, StructMemberTooLarge { member_name_span: Span, }, TypeTooLarge { span: Span, }, UnderspecifiedCooperativeMatrix, InvalidCooperativeLoadType(Span), UnsupportedCooperativeScalar(Span), UnexpectedIdentForEnumerant(Span), UnexpectedExprForEnumerant(Span), UnusedArgsForTemplate(Vec), UnexpectedTemplate(Span), MissingTemplateArg { span: Span, description: &'static str, }, UnexpectedExprForTypeExpression(Span), MissingIncomingPayload(Span), } impl From for Error<'_> { fn from(value: ConflictingDiagnosticRuleError) -> Self { Self::DiagnosticDuplicateTriggeringRule(value) } } /// Used for diagnostic refinement in [`Error::DiagnosticAttributeNotSupported`]. #[derive(Clone, Copy, Debug)] pub(crate) enum DiagnosticAttributeNotSupportedPosition { SemicolonInModulePosition, Other { display_plural: &'static str }, } impl From<&'static str> for DiagnosticAttributeNotSupportedPosition { fn from(display_plural: &'static str) -> Self { Self::Other { display_plural } } } #[derive(Clone, Debug)] pub(crate) struct AutoConversionError { pub dest_span: Span, pub dest_type: String, pub source_span: Span, pub source_type: String, } #[derive(Clone, Debug)] pub(crate) struct AutoConversionLeafScalarError { pub dest_span: Span, pub dest_scalar: String, pub source_span: Span, pub source_type: String, } #[derive(Clone, Debug)] pub(crate) struct ConcretizationFailedError { pub expr_span: Span, pub expr_type: String, pub concretization_preferences: Vec<(String, ConstantEvaluatorError)>, } impl<'a> Error<'a> { #[cold] #[inline(never)] pub(crate) fn as_parse_error(&self, source: &'a str) -> ParseError { match *self { Error::Unexpected(unexpected_span, expected) => { let expected_str = match expected { ExpectedToken::Token(token) => match token { Token::Separator(c) => format!("`{c}`"), Token::Paren(c) => format!("`{c}`"), Token::Attribute => "@".to_string(), Token::Number(_) => "number".to_string(), Token::Word(s) => s.to_string(), Token::Operation(c) => format!("operation (`{c}`)"), Token::LogicalOperation(c) => format!("logical operation (`{c}`)"), Token::ShiftOperation(c) => format!("bitshift (`{c}{c}`)"), Token::AssignmentOperation(c) if c == '<' || c == '>' => { format!("bitshift (`{c}{c}=`)") } Token::AssignmentOperation(c) => format!("operation (`{c}=`)"), Token::IncrementOperation => "increment operation".to_string(), Token::DecrementOperation => "decrement operation".to_string(), Token::Arrow => "->".to_string(), Token::TemplateArgsStart => "template args start".to_string(), Token::TemplateArgsEnd => "template args end".to_string(), Token::Unknown(c) => format!("unknown (`{c}`)"), Token::Trivia => "trivia".to_string(), Token::DocComment(s) => format!("doc comment ('{s}')"), Token::ModuleDocComment(s) => format!("module doc comment ('{s}')"), Token::End => "end".to_string(), }, ExpectedToken::Identifier => "identifier".to_string(), ExpectedToken::LhsExpression => "LHS expression (identifier component_or_swizzle_specifier?, (`lhs_expression`) component_or_swizzle_specifier?, &`lhs_expression`, *`lhs_expression`)".to_string(), ExpectedToken::PrimaryExpression => "expression".to_string(), ExpectedToken::Assignment => "assignment or increment/decrement".to_string(), ExpectedToken::SwitchItem => concat!( "switch item (`case` or `default`) or a closing curly bracket ", "to signify the end of the switch statement (`}`)" ) .to_string(), ExpectedToken::WorkgroupSizeSeparator => { "workgroup size separator (`,`) or a closing parenthesis".to_string() } ExpectedToken::GlobalItem => concat!( "global item (`struct`, `const`, `var`, `alias`, ", "`fn`, `diagnostic`, `enable`, `requires`, `;`) ", "or the end of the file" ) .to_string(), ExpectedToken::Variable => "variable access".to_string(), ExpectedToken::Function => "function name".to_string(), ExpectedToken::AfterIdentListArg => { "next argument, trailing comma, or end of list (`,` or `;`)".to_string() } ExpectedToken::AfterIdentListComma => { "next argument or end of list (`;`)".to_string() } ExpectedToken::DiagnosticAttribute => { "the `diagnostic` attribute identifier".to_string() } ExpectedToken::Statement => "statement".to_string(), ExpectedToken::ForInit => "for loop initializer statement (`var`/`let`/`const` declaration, assignment, `i++`/`i--` statement, function call)".to_string(), ExpectedToken::ForUpdate => "for loop update statement (assignment, `i++`/`i--` statement, function call)".to_string(), }; ParseError { message: format!( "expected {}, found {:?}", expected_str, &source[unexpected_span], ), labels: vec![(unexpected_span, format!("expected {expected_str}").into())], notes: vec![], } } Error::UnexpectedComponents(bad_span) => ParseError { message: "unexpected components".to_string(), labels: vec![(bad_span, "unexpected components".into())], notes: vec![], }, Error::UnexpectedOperationInConstContext(span) => ParseError { message: "this operation is not supported in a const context".to_string(), labels: vec![(span, "operation not supported here".into())], notes: vec![], }, Error::BadNumber(bad_span, ref err) => ParseError { message: format!("{}: `{}`", err, &source[bad_span],), labels: vec![(bad_span, err.to_string().into())], notes: vec![], }, Error::BadMatrixScalarKind(span, scalar) => ParseError { message: format!( "matrix scalar type must be floating-point, but found `{}`", scalar.to_wgsl_for_diagnostics() ), labels: vec![(span, "must be floating-point (e.g. `f32`)".into())], notes: vec![], }, Error::BadAccessor(accessor_span) => ParseError { message: format!("invalid field accessor `{}`", &source[accessor_span],), labels: vec![(accessor_span, "invalid accessor".into())], notes: vec![], }, Error::UnknownIdent(ident_span, ident) => ParseError { message: format!("no definition in scope for identifier: `{ident}`"), labels: vec![(ident_span, "unknown identifier".into())], notes: vec![], }, Error::UnknownScalarType(bad_span) => ParseError { message: format!("unknown scalar type: `{}`", &source[bad_span]), labels: vec![(bad_span, "unknown scalar type".into())], notes: vec!["Valid scalar types are f32, f64, i32, u32, bool".into()], }, Error::NotStorageTexture(bad_span) => ParseError { message: "textureStore can only be applied to storage textures".to_string(), labels: vec![(bad_span, "not a storage texture".into())], notes: vec![], }, Error::BadTextureSampleType { span, scalar } => ParseError { message: format!( "texture sample type must be one of f32, i32 or u32, but found {}", scalar.to_wgsl_for_diagnostics() ), labels: vec![(span, "must be one of f32, i32 or u32".into())], notes: vec![], }, Error::BadIncrDecrReferenceType(span) => ParseError { message: concat!( "increment/decrement operation requires ", "reference type to be one of i32 or u32" ) .to_string(), labels: vec![(span, "must be a reference type of i32 or u32".into())], notes: vec![], }, Error::BadTexture(bad_span) => ParseError { message: format!( "expected an image, but found `{}` which is not an image", &source[bad_span] ), labels: vec![(bad_span, "not an image".into())], notes: vec![], }, Error::BadTypeCast { span, ref from_type, ref to_type, } => { let msg = format!("cannot cast a {from_type} to a {to_type}"); ParseError { message: msg.clone(), labels: vec![(span, msg.into())], notes: vec![], } } Error::InvalidResolve(ref resolve_error) => ParseError { message: resolve_error.to_string(), labels: vec![], notes: vec![], }, Error::InvalidBreakIf(bad_span) => ParseError { message: "A break if is only allowed in a continuing block".to_string(), labels: vec![(bad_span, "not in a continuing block".into())], notes: vec![], }, Error::InvalidGatherComponent(bad_span) => ParseError { message: format!( "textureGather component `{}` doesn't exist, must be 0, 1, 2, or 3", &source[bad_span] ), labels: vec![(bad_span, "invalid component".into())], notes: vec![], }, Error::InvalidConstructorComponentType(bad_span, component) => ParseError { message: format!("invalid type for constructor component at index [{component}]"), labels: vec![(bad_span, "invalid component type".into())], notes: vec![], }, Error::InvalidIdentifierUnderscore(bad_span) => ParseError { message: "Identifier can't be `_`".to_string(), labels: vec![(bad_span, "invalid identifier".into())], notes: vec![ "Use phony assignment instead (`_ =` notice the absence of `let` or `var`)" .to_string(), ], }, Error::ReservedIdentifierPrefix(bad_span) => ParseError { message: format!( "Identifier starts with a reserved prefix: `{}`", &source[bad_span] ), labels: vec![(bad_span, "invalid identifier".into())], notes: vec![], }, Error::UnknownAddressSpace(bad_span) => ParseError { message: format!("unknown address space: `{}`", &source[bad_span]), labels: vec![(bad_span, "unknown address space".into())], notes: vec![], }, Error::InvalidLocalVariableAddressSpace(bad_span) => ParseError { message: format!("invalid address space for local variable: `{}`", &source[bad_span]), labels: vec![(bad_span, "local variables can only use 'function' address space".into())], notes: vec![], }, Error::UnknownRayFlag(bad_span) => ParseError { message: format!("unknown ray flag: `{}`", &source[bad_span]), labels: vec![(bad_span, "unknown ray flag".into())], notes: vec![], }, Error::RepeatedAttribute(bad_span) => ParseError { message: format!("repeated attribute: `{}`", &source[bad_span]), labels: vec![(bad_span, "repeated attribute".into())], notes: vec![], }, Error::UnknownAttribute(bad_span) => ParseError { message: format!("unknown attribute: `{}`", &source[bad_span]), labels: vec![(bad_span, "unknown attribute".into())], notes: vec![], }, Error::UnknownBuiltin(bad_span) => ParseError { message: format!("unknown builtin: `{}`", &source[bad_span]), labels: vec![(bad_span, "unknown builtin".into())], notes: vec![], }, Error::UnknownAccess(bad_span) => ParseError { message: format!("unknown access: `{}`", &source[bad_span]), labels: vec![(bad_span, "unknown access".into())], notes: vec![], }, Error::UnknownStorageFormat(bad_span) => ParseError { message: format!("unknown storage format: `{}`", &source[bad_span]), labels: vec![(bad_span, "unknown storage format".into())], notes: vec![], }, Error::UnknownConservativeDepth(bad_span) => ParseError { message: format!("unknown conservative depth: `{}`", &source[bad_span]), labels: vec![(bad_span, "unknown conservative depth".into())], notes: vec![], }, Error::UnknownEnableExtension(span, word) => ParseError { message: format!("unknown enable-extension `{word}`"), labels: vec![(span, "".into())], notes: vec![ "See available extensions at ." .into(), ], }, Error::UnknownLanguageExtension(span, name) => ParseError { message: format!("unknown language extension `{name}`"), labels: vec![(span, "".into())], notes: vec![concat!( "See available extensions at ", "." ) .into()], }, Error::UnknownDiagnosticRuleName(span) => ParseError { message: format!("unknown `diagnostic(…)` rule name `{}`", &source[span]), labels: vec![(span, "not a valid diagnostic rule name".into())], notes: vec![concat!( "See available trigger rules at ", "." ) .into()], }, Error::SizeAttributeTooLow(bad_span, min_size) => ParseError { message: format!("struct member size must be at least {min_size}"), labels: vec![(bad_span, format!("must be at least {min_size}").into())], notes: vec![], }, Error::AlignAttributeTooLow(bad_span, min_align) => ParseError { message: format!("struct member alignment must be at least {min_align}"), labels: vec![(bad_span, format!("must be at least {min_align}").into())], notes: vec![], }, Error::NonPowerOfTwoAlignAttribute(bad_span) => ParseError { message: "struct member alignment must be a power of 2".to_string(), labels: vec![(bad_span, "must be a power of 2".into())], notes: vec![], }, Error::InconsistentBinding(span) => ParseError { message: "input/output binding is not consistent".to_string(), labels: vec![(span, "input/output binding is not consistent".into())], notes: vec![], }, Error::TypeNotConstructible(span) => ParseError { message: format!("type `{}` is not constructible", &source[span]), labels: vec![(span, "type is not constructible".into())], notes: vec![], }, Error::TypeNotInferable(span) => ParseError { message: "type can't be inferred".to_string(), labels: vec![(span, "type can't be inferred".into())], notes: vec![], }, Error::InitializationTypeMismatch { name, ref expected, ref got, } => ParseError { message: format!( "the type of `{}` is expected to be `{}`, but got `{}`", &source[name], expected, got, ), labels: vec![(name, format!("definition of `{}`", &source[name]).into())], notes: vec![], }, Error::DeclMissingTypeAndInit(name_span) => ParseError { message: format!( "declaration of `{}` needs a type specifier or initializer", &source[name_span] ), labels: vec![(name_span, "needs a type specifier or initializer".into())], notes: vec![], }, Error::MissingAttribute(name, name_span) => ParseError { message: format!( "variable `{}` needs a '{}' attribute", &source[name_span], name ), labels: vec![( name_span, format!("definition of `{}`", &source[name_span]).into(), )], notes: vec![], }, Error::InvalidAddrOfOperand(span) => ParseError { message: "cannot take the address of a vector component".to_string(), labels: vec![(span, "invalid operand for address-of".into())], notes: vec![], }, Error::InvalidAtomicPointer(span) => ParseError { message: "atomic operation is done on a pointer to a non-atomic".to_string(), labels: vec![(span, "atomic pointer is invalid".into())], notes: vec![], }, Error::InvalidAtomicOperandType(span) => ParseError { message: "atomic operand type is inconsistent with the operation".to_string(), labels: vec![(span, "atomic operand type is invalid".into())], notes: vec![], }, Error::InvalidRayQueryPointer(span) => ParseError { message: "ray query operation is done on a pointer to a non-ray-query".to_string(), labels: vec![(span, "ray query pointer is invalid".into())], notes: vec![], }, Error::NotPointer(span) => ParseError { message: "the operand of the `*` operator must be a pointer".to_string(), labels: vec![(span, "expression is not a pointer".into())], notes: vec![], }, Error::NotReference(what, span) => ParseError { message: format!("{what} must be a reference"), labels: vec![(span, "expression is not a reference".into())], notes: vec![], }, Error::InvalidAssignment { span, ty } => { let (extra_label, notes) = match ty { InvalidAssignmentType::Swizzle => ( None, vec![ "WGSL does not support assignments to swizzles".into(), "consider assigning each component individually".into(), ], ), InvalidAssignmentType::ImmutableBinding(binding_span) => ( Some((binding_span, "this is an immutable binding".into())), vec![format!( "consider declaring `{}` with `var` instead of `let`", &source[binding_span] )], ), InvalidAssignmentType::Other => (None, vec![]), }; ParseError { message: "invalid left-hand side of assignment".into(), labels: core::iter::once((span, "cannot assign to this expression".into())) .chain(extra_label) .collect(), notes, } } Error::ReservedKeyword(name_span) => ParseError { message: format!("name `{}` is a reserved keyword", &source[name_span]), labels: vec![( name_span, format!("definition of `{}`", &source[name_span]).into(), )], notes: vec![], }, Error::Redefinition { previous, current } => ParseError { message: format!("redefinition of `{}`", &source[current]), labels: vec![ ( current, format!("redefinition of `{}`", &source[current]).into(), ), ( previous, format!("previous definition of `{}`", &source[previous]).into(), ), ], notes: vec![], }, Error::RecursiveDeclaration { ident, usage } => ParseError { message: format!("declaration of `{}` is recursive", &source[ident]), labels: vec![(ident, "".into()), (usage, "uses itself here".into())], notes: vec![], }, Error::CyclicDeclaration { ident, ref path } => ParseError { message: format!("declaration of `{}` is cyclic", &source[ident]), labels: path .iter() .enumerate() .flat_map(|(i, &(ident, usage))| { [ (ident, "".into()), ( usage, if i == path.len() - 1 { "ending the cycle".into() } else { format!("uses `{}`", &source[ident]).into() }, ), ] }) .collect(), notes: vec![], }, Error::InvalidSwitchSelector { span } => ParseError { message: "invalid `switch` selector".to_string(), labels: vec![( span, "`switch` selector must be a scalar integer" .into(), )], notes: vec![], }, Error::InvalidSwitchCase { span } => ParseError { message: "invalid `switch` case selector value".to_string(), labels: vec![( span, "`switch` case selector must be a scalar integer const expression" .into(), )], notes: vec![], }, Error::SwitchCaseTypeMismatch { span } => ParseError { message: "invalid `switch` case selector value".to_string(), labels: vec![( span, "`switch` case selector must have the same type as the `switch` selector expression" .into(), )], notes: vec![], }, Error::CalledEntryPoint(span) => ParseError { message: "entry point cannot be called".to_string(), labels: vec![(span, "entry point cannot be called".into())], notes: vec![], }, Error::CalledLocalDecl(span) => ParseError { message: "local declaration cannot be called".to_string(), labels: vec![(span, "local declaration cannot be called".into())], notes: vec![], }, Error::WrongArgumentCount { span, ref expected, found, } => ParseError { message: format!( "wrong number of arguments: expected {}, found {}", if expected.len() < 2 { format!("{}", expected.start) } else { format!("{}..{}", expected.start, expected.end) }, found ), labels: vec![(span, "wrong number of arguments".into())], notes: vec![], }, Error::TooManyArguments { ref function, call_span, arg_span, max_arguments, } => ParseError { message: format!("too many arguments passed to `{function}`"), labels: vec![ (call_span, "".into()), (arg_span, format!("unexpected argument #{}", max_arguments + 1).into()) ], notes: vec![ format!("The `{function}` function accepts at most {max_arguments} argument(s)") ], }, Error::WrongArgumentType { ref function, call_span, arg_span, arg_index, ref arg_ty, ref allowed, } => { let message = format!( "wrong type passed as argument #{} to `{function}`", arg_index + 1, ); let labels = vec![ (call_span, "".into()), (arg_span, format!("argument #{} has type `{arg_ty}`", arg_index + 1).into()) ]; let mut notes = vec![]; notes.push(format!("`{function}` accepts the following types for argument #{}:", arg_index + 1)); notes.extend(allowed.iter().map(|ty| format!("allowed type: {ty}"))); ParseError { message, labels, notes } }, Error::InconsistentArgumentType { ref function, call_span, arg_span, arg_index, ref arg_ty, inconsistent_span, inconsistent_index, ref inconsistent_ty, ref allowed } => { let message = format!( "inconsistent type passed as argument #{} to `{function}`", arg_index + 1, ); let labels = vec![ (call_span, "".into()), (arg_span, format!("argument #{} has type {arg_ty}", arg_index + 1).into()), (inconsistent_span, format!( "this argument has type {inconsistent_ty}, which constrains subsequent arguments" ).into()), ]; let mut notes = vec![ format!("Because argument #{} has type {inconsistent_ty}, only the following types", inconsistent_index + 1), format!("(or types that automatically convert to them) are accepted for argument #{}:", arg_index + 1), ]; notes.extend(allowed.iter().map(|ty| format!("allowed type: {ty}"))); ParseError { message, labels, notes } } Error::FunctionReturnsVoid(span) => ParseError { message: "function does not return any value".to_string(), labels: vec![(span, "".into())], notes: vec![ "perhaps you meant to call the function in a separate statement?".into(), ], }, Error::FunctionMustUseUnused(call) => ParseError { message: "unused return value from function annotated with @must_use".into(), labels: vec![(call, "".into())], notes: vec![ format!( "function '{}' is declared with `@must_use` attribute", &source[call], ), "use a phony assignment or declare a value using the function call as the initializer".into(), ], }, Error::FunctionMustUseReturnsVoid(attr, signature) => ParseError { message: "function annotated with @must_use but does not return any value".into(), labels: vec![ (attr, "".into()), (signature, "".into()), ], notes: vec![ "declare a return type or remove the attribute".into(), ], }, Error::InvalidWorkGroupUniformLoad(span) => ParseError { message: "incorrect type passed to workgroupUniformLoad".into(), labels: vec![(span, "".into())], notes: vec!["passed type must be a workgroup pointer".into()], }, Error::Internal(message) => ParseError { message: "internal WGSL front end error".to_string(), labels: vec![], notes: vec![message.into()], }, Error::ExpectedConstExprConcreteIntegerScalar(span) => ParseError { message: concat!( "must be a const-expression that ", "resolves to a concrete integer scalar (`u32` or `i32`)" ) .to_string(), labels: vec![(span, "must resolve to `u32` or `i32`".into())], notes: vec![], }, Error::ExpectedNonNegative(span) => ParseError { message: "must be non-negative (>= 0)".to_string(), labels: vec![(span, "must be non-negative".into())], notes: vec![], }, Error::ExpectedPositiveArrayLength(span) => ParseError { message: "array element count must be positive (> 0)".to_string(), labels: vec![(span, "must be positive".into())], notes: vec![], }, Error::ConstantEvaluatorError(ref e, span) => ParseError { message: e.to_string(), labels: vec![(span, "see msg".into())], notes: vec![], }, Error::MissingWorkgroupSize(span) => ParseError { message: "workgroup size is missing on compute shader entry point".to_string(), labels: vec![( span, "must be paired with a `@workgroup_size` attribute".into(), )], notes: vec![], }, Error::AutoConversion(ref error) => { // destructuring ensures all fields are handled let AutoConversionError { dest_span, ref dest_type, source_span, ref source_type, } = **error; ParseError { message: format!( "automatic conversions cannot convert `{source_type}` to `{dest_type}`" ), labels: vec![ ( dest_span, format!("a value of type {dest_type} is required here").into(), ), ( source_span, format!("this expression has type {source_type}").into(), ), ], notes: vec![], } } Error::AutoConversionLeafScalar(ref error) => { let AutoConversionLeafScalarError { dest_span, ref dest_scalar, source_span, ref source_type, } = **error; ParseError { message: format!( "automatic conversions cannot convert elements of `{source_type}` to `{dest_scalar}`" ), labels: vec![ ( dest_span, format!( "a value with elements of type {dest_scalar} is required here" ) .into(), ), ( source_span, format!("this expression has type {source_type}").into(), ), ], notes: vec![], } } Error::ConcretizationFailed(ref error) => { let ConcretizationFailedError { expr_span, ref expr_type, ref concretization_preferences, } = **error; ParseError { message: "failed to convert expression to a concrete type".to_string(), labels: vec![( expr_span, format!("this expression has type {expr_type}").into(), )], notes: concretization_preferences .iter() .map(|&(ref scalar, ref err)| format!("the expression couldn't be converted to have {scalar} scalar type: {err}") ) .collect(), } } Error::ExceededLimitForNestedBraces { span, limit } => ParseError { message: "brace nesting limit reached".into(), labels: vec![(span, "limit reached at this brace".into())], notes: vec![format!("nesting limit is currently set to {limit}")], }, Error::PipelineConstantIDValue(span) => ParseError { message: "pipeline constant ID must be between 0 and 65535 inclusive".to_string(), labels: vec![(span, "must be between 0 and 65535 inclusive".into())], notes: vec![], }, Error::NotBool(span) => ParseError { message: "must be a const-expression that resolves to a `bool`".to_string(), labels: vec![(span, "must resolve to `bool`".into())], notes: vec![], }, Error::ConstAssertFailed(span) => ParseError { message: "`const_assert` failure".to_string(), labels: vec![(span, "evaluates to `false`".into())], notes: vec![], }, Error::DirectiveAfterFirstGlobalDecl { directive_span } => ParseError { message: "expected global declaration, but found a global directive".into(), labels: vec![( directive_span, "written after first global declaration".into(), )], notes: vec![concat!( "global directives are only allowed before global declarations; ", "maybe hoist this closer to the top of the shader module?" ) .into()], }, Error::EnableExtensionNotYetImplemented { kind, span } => ParseError { message: format!( "the `{}` enable-extension is not yet supported", EnableExtension::Unimplemented(kind).to_ident() ), labels: vec![( span, concat!( "this enable-extension specifies standard functionality ", "which is not yet implemented in Naga" ) .into(), )], notes: vec![format!( concat!( "Let Naga maintainers know that you ran into this at ", ", ", "so they can prioritize it!" ), kind.tracking_issue_num() )], }, Error::EnableExtensionNotEnabled { kind, span } => ParseError { message: format!("the `{}` enable extension is not enabled", kind.to_ident()), labels: vec![( span, format!( concat!( "the `{}` \"Enable Extension\" is needed for this functionality, ", "but it is not currently enabled." ), kind.to_ident() ) .into(), )], notes: if let EnableExtension::Unimplemented(kind) = kind { vec![format!( concat!( "This \"Enable Extension\" is not yet implemented. ", "Let Naga maintainers know that you ran into this at ", ", ", "so they can prioritize it!" ), kind.tracking_issue_num() )] } else { vec![ format!( "You can enable this extension by adding `enable {};` at the top of the shader, before any other items.", kind.to_ident() ), ] }, }, Error::EnableExtensionNotSupported { kind, span } => ParseError { message: format!( "the `{}` extension is not supported in the current environment", kind.to_ident() ), labels: vec![( span, "unsupported enable-extension".into(), )], notes: vec![], }, Error::LanguageExtensionNotYetImplemented { kind, span } => ParseError { message: format!( "the `{}` language extension is not yet supported", LanguageExtension::Unimplemented(kind).to_ident() ), labels: vec![(span, "".into())], notes: vec![format!( concat!( "Let Naga maintainers know that you ran into this at ", ", ", "so they can prioritize it!" ), kind.tracking_issue_num() )], }, Error::DiagnosticInvalidSeverity { severity_control_name_span, } => ParseError { message: "invalid `diagnostic(…)` severity".into(), labels: vec![( severity_control_name_span, "not a valid severity level".into(), )], notes: vec![concat!( "See available severities at ", "." ) .into()], }, Error::DiagnosticDuplicateTriggeringRule(ConflictingDiagnosticRuleError { triggering_rule_spans, }) => { let [first_span, second_span] = triggering_rule_spans; ParseError { message: "found conflicting `diagnostic(…)` rule(s)".into(), labels: vec![ (first_span, "first rule".into()), (second_span, "second rule".into()), ], notes: vec![ concat!( "Multiple `diagnostic(…)` rules with the same rule name ", "conflict unless they are directives and the severity is the same.", ) .into(), "You should delete the rule you don't want.".into(), ], } } Error::DiagnosticAttributeNotYetImplementedAtParseSite { site_name_plural, ref spans, } => ParseError { message: "`@diagnostic(…)` attribute(s) not yet implemented".into(), labels: { let mut spans = spans.iter().cloned(); let first = spans .next() .map(|span| { ( span, format!("can't use this on {site_name_plural} (yet)").into(), ) }) .expect("internal error: diag. attr. rejection on empty map"); core::iter::once(first) .chain(spans.map(|span| (span, "".into()))) .collect() }, notes: vec![format!(concat!( "Let Naga maintainers know that you ran into this at ", ", ", "so they can prioritize it!" ))], }, Error::DiagnosticAttributeNotSupported { on_what, ref spans } => { // In this case the user may have intended to create a global diagnostic filter directive, // so display a note to them suggesting the correct syntax. let intended_diagnostic_directive = match on_what { DiagnosticAttributeNotSupportedPosition::SemicolonInModulePosition => true, DiagnosticAttributeNotSupportedPosition::Other { .. } => false, }; let on_what_plural = match on_what { DiagnosticAttributeNotSupportedPosition::SemicolonInModulePosition => { "semicolons" } DiagnosticAttributeNotSupportedPosition::Other { display_plural } => { display_plural } }; ParseError { message: format!( "`@diagnostic(…)` attribute(s) on {on_what_plural} are not supported", ), labels: spans .iter() .cloned() .map(|span| (span, "".into())) .collect(), notes: vec![ concat!( "`@diagnostic(…)` attributes are only permitted on `fn`s, ", "some statements, and `switch`/`loop` bodies." ) .into(), { if intended_diagnostic_directive { concat!( "If you meant to declare a diagnostic filter that ", "applies to the entire module, move this line to ", "the top of the file and remove the `@` symbol." ) .into() } else { concat!( "These attributes are well-formed, ", "you likely just need to move them." ) .into() } }, ], } } Error::SelectUnexpectedArgumentType { arg_span, ref arg_type } => ParseError { message: "unexpected argument type for `select` call".into(), labels: vec![(arg_span, format!("this value of type {arg_type}").into())], notes: vec!["expected a scalar or a `vecN` of scalars".into()], }, Error::SelectRejectAndAcceptHaveNoCommonType { reject_span, ref reject_type, accept_span, ref accept_type, } => ParseError { message: "type mismatch for reject and accept values in `select` call".into(), labels: vec![ (reject_span, format!("reject value of type {reject_type}").into()), (accept_span, format!("accept value of type {accept_type}").into()), ], notes: vec![], }, Error::ExpectedGlobalVariable { name_span } => ParseError { message: "expected global variable".to_string(), labels: vec![(name_span, "variable used here".into())], notes: vec![], }, Error::StructMemberTooLarge { member_name_span } => ParseError { message: "struct member is too large".into(), labels: vec![(member_name_span, "this member exceeds the maximum size".into())], notes: vec![format!( "the maximum size is {} bytes", crate::valid::MAX_TYPE_SIZE )], }, Error::TypeTooLarge { span } => ParseError { message: "type is too large".into(), labels: vec![(span, "this type exceeds the maximum size".into())], notes: vec![format!( "the maximum size is {} bytes", crate::valid::MAX_TYPE_SIZE )], }, Error::UnderspecifiedCooperativeMatrix => ParseError { message: "cooperative matrix constructor is underspecified".into(), labels: vec![], notes: vec![format!("must be F32")], }, Error::InvalidCooperativeLoadType(span) => ParseError { message: "cooperative load should have a generic type for coop_mat".into(), labels: vec![(span, "type needs the coop_mat<...>".into())], notes: vec![format!("must be a valid cooperative type")], }, Error::UnsupportedCooperativeScalar(span) => ParseError { message: "cooperative scalar type is not supported".into(), labels: vec![(span, "type needs the scalar type specified".into())], notes: vec![format!("must be F32")], }, Error::UnexpectedIdentForEnumerant(ident_span) => ParseError { message: format!( "identifier `{}` resolves to a declaration", &source[ident_span] ), labels: vec![(ident_span, "needs to resolve to a predeclared enumerant".into())], notes: vec![], }, Error::UnexpectedExprForEnumerant(expr_span) => ParseError { message: "unexpected expression".to_string(), labels: vec![(expr_span, "needs to be an identifier resolving to a predeclared enumerant".into())], notes: vec![], }, Error::UnusedArgsForTemplate(ref expr_spans) => ParseError { message: "unused expressions for template".to_string(), labels: expr_spans.iter().cloned().map(|span| -> (_, _){ (span, "unused".into()) }).collect(), notes: vec![], }, Error::UnexpectedTemplate(span) => ParseError { message: "unexpected template".to_string(), labels: vec![(span, "expected identifier".into())], notes: vec![], }, Error::MissingTemplateArg { span, description: arg, } => ParseError { message: format!( "`{}` needs a template argument specified: {arg}", &source[span] ), labels: vec![(span, "is missing a template argument".into())], notes: vec![], }, Error::UnexpectedExprForTypeExpression(expr_span) => ParseError { message: "unexpected expression".to_string(), labels: vec![(expr_span, "needs to be an identifier resolving to a type declaration (alias or struct) or predeclared type(-generator)".into())], notes: vec![], }, Error::MissingIncomingPayload(span) => ParseError { message: "incoming payload is missing on a `closest_hit`, `any_hit` or `miss` shader entry point".to_string(), labels: vec![( span, "must be paired with a `@incoming_payload` attribute".into(), )], notes: vec![], }, } } } naga-29.0.3/src/front/wgsl/index.rs000064400000000000000000000177271046102023000152130ustar 00000000000000use alloc::{boxed::Box, vec, vec::Vec}; use super::{Error, Result}; use crate::front::wgsl::parse::ast; use crate::{FastHashMap, Handle, Span}; /// A `GlobalDecl` list in which each definition occurs before all its uses. pub struct Index<'a> { dependency_order: Vec>>, } impl<'a> Index<'a> { /// Generate an `Index` for the given translation unit. /// /// Perform a topological sort on `tu`'s global declarations, placing /// referents before the definitions that refer to them. /// /// Return an error if the graph of references between declarations contains /// any cycles. pub fn generate(tu: &ast::TranslationUnit<'a>) -> Result<'a, Self> { // Produce a map from global definitions' names to their `Handle`s. // While doing so, reject conflicting definitions. let mut globals = FastHashMap::with_capacity_and_hasher(tu.decls.len(), Default::default()); for (handle, decl) in tu.decls.iter() { if let Some(ident) = decl_ident(decl) { let name = ident.name; if let Some(old) = globals.insert(name, handle) { return Err(Box::new(Error::Redefinition { previous: decl_ident(&tu.decls[old]) .expect("decl should have ident for redefinition") .span, current: ident.span, })); } } } let len = tu.decls.len(); let solver = DependencySolver { globals: &globals, module: tu, visited: vec![false; len], temp_visited: vec![false; len], path: Vec::new(), out: Vec::with_capacity(len), }; let dependency_order = solver.solve()?; Ok(Self { dependency_order }) } /// Iterate over `GlobalDecl`s, visiting each definition before all its uses. /// /// Produce handles for all of the `GlobalDecl`s of the `TranslationUnit` /// passed to `Index::generate`, ordered so that a given declaration is /// produced before any other declaration that uses it. pub fn visit_ordered(&self) -> impl Iterator>> + '_ { self.dependency_order.iter().copied() } } /// An edge from a reference to its referent in the current depth-first /// traversal. /// /// This is like `ast::Dependency`, except that we've determined which /// `GlobalDecl` it refers to. struct ResolvedDependency<'a> { /// The referent of some identifier used in the current declaration. decl: Handle>, /// Where that use occurs within the current declaration. usage: Span, } /// Local state for ordering a `TranslationUnit`'s module-scope declarations. /// /// Values of this type are used temporarily by `Index::generate` /// to perform a depth-first sort on the declarations. /// Technically, what we want is a topological sort, but a depth-first sort /// has one key benefit - it's much more efficient in storing /// the path of each node for error generation. struct DependencySolver<'source, 'temp> { /// A map from module-scope definitions' names to their handles. globals: &'temp FastHashMap<&'source str, Handle>>, /// The translation unit whose declarations we're ordering. module: &'temp ast::TranslationUnit<'source>, /// For each handle, whether we have pushed it onto `out` yet. visited: Vec, /// For each handle, whether it is an predecessor in the current depth-first /// traversal. This is used to detect cycles in the reference graph. temp_visited: Vec, /// The current path in our depth-first traversal. Used for generating /// error messages for non-trivial reference cycles. path: Vec>, /// The list of declaration handles, with declarations before uses. out: Vec>>, } impl<'a> DependencySolver<'a, '_> { /// Produce the sorted list of declaration handles, and check for cycles. fn solve(mut self) -> Result<'a, Vec>>> { for (id, _) in self.module.decls.iter() { if self.visited[id.index()] { continue; } self.dfs(id)?; } Ok(self.out) } /// Ensure that all declarations used by `id` have been added to the /// ordering, and then append `id` itself. fn dfs(&mut self, id: Handle>) -> Result<'a, ()> { let decl = &self.module.decls[id]; let id_usize = id.index(); self.temp_visited[id_usize] = true; for dep in decl.dependencies.iter() { if let Some(&dep_id) = self.globals.get(dep.ident) { self.path.push(ResolvedDependency { decl: dep_id, usage: dep.usage, }); let dep_id_usize = dep_id.index(); if self.temp_visited[dep_id_usize] { // Found a cycle. return if dep_id == id { // A declaration refers to itself directly. Err(Box::new(Error::RecursiveDeclaration { ident: decl_ident(decl).expect("decl should have ident").span, usage: dep.usage, })) } else { // A declaration refers to itself indirectly, through // one or more other definitions. Report the entire path // of references. let start_at = self .path .iter() .rev() .enumerate() .find_map(|(i, dep)| (dep.decl == dep_id).then_some(i)) .unwrap_or(0); Err(Box::new(Error::CyclicDeclaration { ident: decl_ident(&self.module.decls[dep_id]) .expect("decl should have ident") .span, path: self.path[start_at..] .iter() .map(|curr_dep| { let curr_id = curr_dep.decl; let curr_decl = &self.module.decls[curr_id]; ( decl_ident(curr_decl).expect("decl should have ident").span, curr_dep.usage, ) }) .collect(), })) }; } else if !self.visited[dep_id_usize] { self.dfs(dep_id)?; } // Remove this edge from the current path. self.path.pop(); } // Ignore unresolved identifiers; they may be predeclared objects. } // Remove this node from the current path. self.temp_visited[id_usize] = false; // Now everything this declaration uses has been visited, and is already // present in `out`. That means we can append this one to the // ordering, and mark it as visited. self.out.push(id); self.visited[id_usize] = true; Ok(()) } } const fn decl_ident<'a>(decl: &ast::GlobalDecl<'a>) -> Option> { match decl.kind { ast::GlobalDeclKind::Fn(ref f) => Some(f.name), ast::GlobalDeclKind::Var(ref v) => Some(v.name), ast::GlobalDeclKind::Const(ref c) => Some(c.name), ast::GlobalDeclKind::Override(ref o) => Some(o.name), ast::GlobalDeclKind::Struct(ref s) => Some(s.name), ast::GlobalDeclKind::Type(ref t) => Some(t.name), ast::GlobalDeclKind::ConstAssert(_) => None, } } naga-29.0.3/src/front/wgsl/lower/construction.rs000064400000000000000000000577051046102023000177660ustar 00000000000000use alloc::{boxed::Box, vec, vec::Vec}; use core::num::NonZeroU32; use crate::common::wgsl::{TryToWgsl, TypeContext}; use crate::front::wgsl::lower::{ExpressionContext, Lowerer}; use crate::front::wgsl::parse::ast; use crate::front::wgsl::{Error, Result}; use crate::{Handle, Span}; /// A [`constructor built-in function`]. /// /// WGSL has two types of such functions: /// /// - Those that fully specify the type being constructed, like /// `vec3(x,y,z)`, which obviously constructs a `vec3`. /// /// - Those that leave the component type of the composite being constructed /// implicit, to be inferred from the argument types, like `vec3(x,y,z)`, /// which constructs a `vec3` where `T` is the type of `x`, `y`, and `z`. /// /// This enum represents both cases. The `PartialFoo` variants /// represent the second case, where the component type is implicit. /// /// [`constructor built-in function`]: https://gpuweb.github.io/gpuweb/wgsl/#constructor-builtin-function pub enum Constructor { /// A vector construction whose component type is inferred from the /// argument: `vec3(1.0)`. PartialVector { size: crate::VectorSize }, /// A matrix construction whose component type is inferred from the /// argument: `mat2x2(1,2,3,4)`. PartialMatrix { columns: crate::VectorSize, rows: crate::VectorSize, }, /// An array whose component type and size are inferred from the arguments: /// `array(3,4,5)`. PartialArray, /// A known Naga type. /// /// When we match on this type, we need to see the `TypeInner` here, but at /// the point that we build this value we'll still need mutable access to /// the module later. To avoid borrowing from the module, the type parameter /// `T` is `Handle` initially. Then we use `borrow_inner` to produce a /// version holding a tuple `(Handle, &TypeInner)`. Type(T), } impl Constructor> { /// Return an equivalent `Constructor` value that includes borrowed /// `TypeInner` values alongside any type handles. /// /// The returned form is more convenient to match on, since the patterns /// can actually see what the handle refers to. fn borrow_inner( self, module: &crate::Module, ) -> Constructor<(Handle, &crate::TypeInner)> { match self { Constructor::PartialVector { size } => Constructor::PartialVector { size }, Constructor::PartialMatrix { columns, rows } => { Constructor::PartialMatrix { columns, rows } } Constructor::PartialArray => Constructor::PartialArray, Constructor::Type(handle) => Constructor::Type((handle, &module.types[handle].inner)), } } } enum Components<'a> { None, One { component: Handle, span: Span, ty_inner: &'a crate::TypeInner, }, Many { components: Vec>, spans: Vec, }, } impl Components<'_> { fn into_components_vec(self) -> Vec> { match self { Self::None => vec![], Self::One { component, .. } => vec![component], Self::Many { components, .. } => components, } } } impl<'source> Lowerer<'source, '_> { /// Generate Naga IR for a type constructor expression. /// /// The `constructor` value represents the head of the constructor /// expression, which is at least a hint of which type is being built; if /// it's one of the `Partial` variants, we need to consider the argument /// types as well. /// /// This is used for [`Call`] expressions, once we've determined that /// the "callable" (in WGSL spec terms) is actually a type. /// /// [`Call`]: ast::Expression::Call pub fn construct( &mut self, span: Span, constructor: Constructor>, ty_span: Span, components: &[Handle>], ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result<'source, Handle> { use crate::proc::TypeResolution as Tr; let components = match *components { [] => Components::None, [component] => { let span = ctx.ast_expressions.get_span(component); let component = self.expression_for_abstract(component, ctx)?; let ty_inner = super::resolve_inner!(ctx, component); Components::One { component, span, ty_inner, } } ref ast_components @ [_, _, ..] => { let components = ast_components .iter() .map(|&expr| self.expression_for_abstract(expr, ctx)) .collect::>()?; let spans = ast_components .iter() .map(|&expr| ctx.ast_expressions.get_span(expr)) .collect(); for &component in &components { ctx.grow_types(component)?; } Components::Many { components, spans } } }; // Even though we computed `constructor` above, wait until now to borrow // a reference to the `TypeInner`, so that the component-handling code // above can have mutable access to the type arena. let constructor = constructor.borrow_inner(ctx.module); let expr; match (components, constructor) { // Zero-value constructor with explicit type. (Components::None, Constructor::Type((result_ty, inner))) if inner.is_constructible(&ctx.module.types) => { expr = crate::Expression::ZeroValue(result_ty); } // Zero-value constructor, vector with type inference (Components::None, Constructor::PartialVector { size }) => { // vec2(), vec3(), vec4() return vectors of abstractInts; the same // is not true of the similar constructors for matrices or arrays. // See https://www.w3.org/TR/WGSL/#vec2-builtin et seq. let result_ty = ctx.module.types.insert( crate::Type { name: None, inner: crate::TypeInner::Vector { size, scalar: crate::Scalar::ABSTRACT_INT, }, }, span, ); expr = crate::Expression::ZeroValue(result_ty); } // Zero-value constructor, matrix or array with type inference (Components::None, Constructor::PartialMatrix { .. } | Constructor::PartialArray) => { // We have no arguments from which to infer the result type, so // partial constructors aren't acceptable here. return Err(Box::new(Error::TypeNotInferable(ty_span))); } // Scalar constructor & conversion (scalar -> scalar) ( Components::One { component, ty_inner: &crate::TypeInner::Scalar { .. }, .. }, Constructor::Type((_, &crate::TypeInner::Scalar(scalar))), ) => { expr = crate::Expression::As { expr: component, kind: scalar.kind, convert: Some(scalar.width), }; } // Vector conversion (vector -> vector) ( Components::One { component, ty_inner: &crate::TypeInner::Vector { size: src_size, .. }, .. }, Constructor::Type(( _, &crate::TypeInner::Vector { size: dst_size, scalar: dst_scalar, }, )), ) if dst_size == src_size => { expr = crate::Expression::As { expr: component, kind: dst_scalar.kind, convert: Some(dst_scalar.width), }; } // Vector conversion (vector -> vector) - partial ( Components::One { component, ty_inner: &crate::TypeInner::Vector { size: src_size, .. }, .. }, Constructor::PartialVector { size: dst_size }, ) if dst_size == src_size => { // This is a trivial conversion: the sizes match, and a Partial // constructor doesn't specify a scalar type, so nothing can // possibly happen. return Ok(component); } // Matrix conversion (matrix -> matrix) ( Components::One { component, ty_inner: &crate::TypeInner::Matrix { columns: src_columns, rows: src_rows, .. }, .. }, Constructor::Type(( _, &crate::TypeInner::Matrix { columns: dst_columns, rows: dst_rows, scalar: dst_scalar, }, )), ) if dst_columns == src_columns && dst_rows == src_rows => { expr = crate::Expression::As { expr: component, kind: dst_scalar.kind, convert: Some(dst_scalar.width), }; } // Matrix conversion (matrix -> matrix) - partial ( Components::One { component, ty_inner: &crate::TypeInner::Matrix { columns: src_columns, rows: src_rows, .. }, .. }, Constructor::PartialMatrix { columns: dst_columns, rows: dst_rows, }, ) if dst_columns == src_columns && dst_rows == src_rows => { // This is a trivial conversion: the sizes match, and a Partial // constructor doesn't specify a scalar type, so nothing can // possibly happen. return Ok(component); } // Vector constructor (splat) - infer type ( Components::One { component, ty_inner: &crate::TypeInner::Scalar { .. }, .. }, Constructor::PartialVector { size }, ) => { expr = crate::Expression::Splat { size, value: component, }; } // Vector constructor (splat) ( Components::One { mut component, ty_inner: &crate::TypeInner::Scalar(component_scalar), span, }, Constructor::Type(( type_handle, &crate::TypeInner::Vector { size, scalar: vec_scalar, }, )), ) => { // Splat only allows automatic conversions of the component's scalar. if !component_scalar.automatically_converts_to(vec_scalar) { let component_ty = &ctx.typifier()[component]; let arg_ty = ctx.type_resolution_to_string(component_ty); return Err(Box::new(Error::WrongArgumentType { function: ctx.type_to_string(type_handle), call_span: ty_span, arg_span: span, arg_index: 0, arg_ty, allowed: vec![vec_scalar.to_wgsl_for_diagnostics()], })); } ctx.convert_slice_to_common_leaf_scalar( core::slice::from_mut(&mut component), vec_scalar, )?; expr = crate::Expression::Splat { size, value: component, }; } // Vector constructor (by elements), partial ( Components::Many { mut components, spans, }, Constructor::PartialVector { size }, ) => { let consensus_scalar = ctx .automatic_conversion_consensus(None, &components) .map_err(|index| { Error::InvalidConstructorComponentType(spans[index], index as i32) })?; ctx.convert_slice_to_common_leaf_scalar(&mut components, consensus_scalar)?; let inner = consensus_scalar.to_inner_vector(size); let ty = ctx.ensure_type_exists(inner); expr = crate::Expression::Compose { ty, components }; } // Vector constructor (by elements), full type given ( Components::Many { mut components, .. }, Constructor::Type((ty, &crate::TypeInner::Vector { scalar, .. })), ) => { ctx.try_automatic_conversions_for_vector(&mut components, scalar, ty_span)?; expr = crate::Expression::Compose { ty, components }; } // Matrix constructor (by elements), partial ( Components::Many { mut components, spans, }, Constructor::PartialMatrix { columns, rows }, ) if components.len() == columns as usize * rows as usize => { let consensus_scalar = ctx .automatic_conversion_consensus( Some(crate::Scalar::ABSTRACT_FLOAT), &components, ) .map_err(|index| { Error::InvalidConstructorComponentType(spans[index], index as i32) })?; ctx.convert_slice_to_common_leaf_scalar(&mut components, consensus_scalar)?; let vec_ty = ctx.ensure_type_exists(consensus_scalar.to_inner_vector(rows)); let components = components .chunks(rows as usize) .map(|vec_components| { ctx.append_expression( crate::Expression::Compose { ty: vec_ty, components: Vec::from(vec_components), }, Default::default(), ) }) .collect::>>()?; let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix { columns, rows, scalar: consensus_scalar, }); expr = crate::Expression::Compose { ty, components }; } // Matrix constructor (by elements), type given ( Components::Many { mut components, .. }, Constructor::Type(( _, &crate::TypeInner::Matrix { columns, rows, scalar, }, )), ) if components.len() == columns as usize * rows as usize => { let element = Tr::Value(crate::TypeInner::Scalar(scalar)); ctx.try_automatic_conversions_slice(&mut components, &element, ty_span)?; let vec_ty = ctx.ensure_type_exists(scalar.to_inner_vector(rows)); let components = components .chunks(rows as usize) .map(|vec_components| { ctx.append_expression( crate::Expression::Compose { ty: vec_ty, components: Vec::from(vec_components), }, Default::default(), ) }) .collect::>>()?; let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix { columns, rows, scalar, }); expr = crate::Expression::Compose { ty, components }; } // Matrix constructor (by columns), partial ( Components::Many { mut components, spans, }, Constructor::PartialMatrix { columns, rows }, ) if components.len() == columns as usize => { let consensus_scalar = ctx .automatic_conversion_consensus( Some(crate::Scalar::ABSTRACT_FLOAT), &components, ) .map_err(|index| { Error::InvalidConstructorComponentType(spans[index], index as i32) })?; let component_ty = crate::TypeInner::Vector { size: rows, scalar: consensus_scalar, }; ctx.try_automatic_conversions_slice( &mut components, &Tr::Value(component_ty), ty_span, )?; let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix { columns, rows, scalar: consensus_scalar, }); expr = crate::Expression::Compose { ty, components }; } // Matrix constructor (by columns), type given ( Components::Many { mut components, .. }, Constructor::Type(( ty, &crate::TypeInner::Matrix { columns, rows, scalar, }, )), ) if components.len() == columns as usize => { let component_ty = crate::TypeInner::Vector { size: rows, scalar }; ctx.try_automatic_conversions_slice( &mut components, &Tr::Value(component_ty), ty_span, )?; expr = crate::Expression::Compose { ty, components }; } // Array constructor - infer type (components, Constructor::PartialArray) => { let mut components = components.into_components_vec(); if let Ok(consensus_scalar) = ctx.automatic_conversion_consensus(None, &components) { // Note that this will *not* necessarily convert all the // components to the same type! The `automatic_conversion_consensus` // method only considers the parameters' leaf scalar // types; the parameters themselves could be any mix of // vectors, matrices, and scalars. // // But *if* it is possible for this array construction // expression to be well-typed at all, then all the // parameters must have the same type constructors (vec, // matrix, scalar) applied to their leaf scalars, so // reconciling their scalars is always the right thing to // do. And if this array construction is not well-typed, // these conversions will not make it so, and we can let // validation catch the error. ctx.convert_slice_to_common_leaf_scalar(&mut components, consensus_scalar)?; } else { // There's no consensus scalar. Emit the `Compose` // expression anyway, and let validation catch the problem. } let base = ctx.register_type(components[0])?; let inner = crate::TypeInner::Array { base, size: crate::ArraySize::Constant( NonZeroU32::new(u32::try_from(components.len()).unwrap()).unwrap(), ), stride: { ctx.layouter.update(ctx.module.to_ctx()).unwrap(); ctx.layouter[base].to_stride() }, }; let ty = ctx.ensure_type_exists(inner); expr = crate::Expression::Compose { ty, components }; } // Array constructor, explicit type. ( components, Constructor::Type((ty, inner @ &crate::TypeInner::Array { base, .. })), ) if inner.is_constructible(&ctx.module.types) => { let mut components = components.into_components_vec(); ctx.try_automatic_conversions_slice(&mut components, &Tr::Handle(base), ty_span)?; expr = crate::Expression::Compose { ty, components }; } // Struct constructor ( components, Constructor::Type((ty, inner @ &crate::TypeInner::Struct { ref members, .. })), ) if inner.is_constructible(&ctx.module.types) => { let mut components = components.into_components_vec(); let struct_ty_span = ctx.module.types.get_span(ty); // Make a vector of the members' type handles in advance, to // avoid borrowing `members` from `ctx` while we generate // new code. let members: Vec> = members.iter().map(|m| m.ty).collect(); for (component, &ty) in components.iter_mut().zip(&members) { *component = ctx.try_automatic_conversions(*component, &Tr::Handle(ty), struct_ty_span)?; } expr = crate::Expression::Compose { ty, components }; } // ERRORS // Bad conversion (type cast) ( Components::One { span, component, .. }, Constructor::Type(( ty, &(crate::TypeInner::Scalar { .. } | crate::TypeInner::Vector { .. } | crate::TypeInner::Matrix { .. }), )), ) => { let component_ty = &ctx.typifier()[component]; let from_type = ctx.type_resolution_to_string(component_ty); return Err(Box::new(Error::BadTypeCast { span, from_type, to_type: ctx.type_to_string(ty), })); } // Too many parameters for scalar constructor ( Components::Many { spans, .. }, Constructor::Type((_, &crate::TypeInner::Scalar { .. })), ) => { let span = spans[1].until(spans.last().unwrap()); return Err(Box::new(Error::UnexpectedComponents(span))); } // Other types can't be constructed _ => return Err(Box::new(Error::TypeNotConstructible(ty_span))), } let expr = ctx.append_expression(expr, span)?; Ok(expr) } } naga-29.0.3/src/front/wgsl/lower/conversion.rs000064400000000000000000000451131046102023000174070ustar 00000000000000//! WGSL's automatic conversions for abstract types. use alloc::{boxed::Box, string::String, vec::Vec}; use crate::common::wgsl::{TryToWgsl, TypeContext}; use crate::front::wgsl::error::{ AutoConversionError, AutoConversionLeafScalarError, ConcretizationFailedError, }; use crate::front::wgsl::Result; use crate::{Handle, Span}; impl<'source> super::ExpressionContext<'source, '_, '_> { /// Try to use WGSL's automatic conversions to convert `expr` to `goal_ty`. /// /// If no conversions are necessary, return `expr` unchanged. /// /// If automatic conversions cannot convert `expr` to `goal_ty`, return an /// [`AutoConversion`] error. /// /// Although the Load Rule is one of the automatic conversions, this /// function assumes it has already been applied if appropriate, as /// indicated by the fact that the Rust type of `expr` is not `Typed<_>`. /// /// [`AutoConversion`]: super::Error::AutoConversion pub fn try_automatic_conversions( &mut self, expr: Handle, goal_ty: &crate::proc::TypeResolution, goal_span: Span, ) -> Result<'source, Handle> { let expr_span = self.get_expression_span(expr); // Keep the TypeResolution so we can get type names for // structs in error messages. let expr_resolution = super::resolve!(self, expr); let types = &self.module.types; let expr_inner = expr_resolution.inner_with(types); let goal_inner = goal_ty.inner_with(types); // We can only convert abstract types, so if `expr` is not abstract do not even // attempt conversion. This allows the validator to catch type errors correctly // rather than them being misreported as type conversion errors. // If the type is an array (of an array, etc) then we must check whether the // type of the innermost array's base type is abstract. if !expr_inner.is_abstract(types) { return Ok(expr); } // If `expr` already has the requested type, we're done. if self.module.compare_types(expr_resolution, goal_ty) { return Ok(expr); } let (_expr_scalar, goal_scalar) = match expr_inner.automatically_converts_to(goal_inner, types) { Some(scalars) => scalars, None => { let source_type = self.type_resolution_to_string(expr_resolution); let dest_type = self.type_resolution_to_string(goal_ty); return Err(Box::new(super::Error::AutoConversion(Box::new( AutoConversionError { dest_span: goal_span, dest_type, source_span: expr_span, source_type, }, )))); } }; self.convert_leaf_scalar(expr, expr_span, goal_scalar) } /// Try to convert `expr`'s leaf scalar to `goal_scalar` using automatic conversions. /// /// If no conversions are necessary, return `expr` unchanged. /// /// If automatic conversions cannot convert `expr` to `goal_scalar`, return /// an [`AutoConversionLeafScalar`] error. /// /// Although the Load Rule is one of the automatic conversions, this /// function assumes it has already been applied if appropriate, as /// indicated by the fact that the Rust type of `expr` is not `Typed<_>`. /// /// [`AutoConversionLeafScalar`]: super::Error::AutoConversionLeafScalar pub fn try_automatic_conversion_for_leaf_scalar( &mut self, expr: Handle, goal_scalar: crate::Scalar, goal_span: Span, ) -> Result<'source, Handle> { let expr_span = self.get_expression_span(expr); let expr_resolution = super::resolve!(self, expr); let types = &self.module.types; let expr_inner = expr_resolution.inner_with(types); let make_error = || { let source_type = self.type_resolution_to_string(expr_resolution); super::Error::AutoConversionLeafScalar(Box::new(AutoConversionLeafScalarError { dest_span: goal_span, dest_scalar: goal_scalar.to_wgsl_for_diagnostics(), source_span: expr_span, source_type, })) }; let expr_scalar = match expr_inner.automatically_convertible_scalar(&self.module.types) { Some(scalar) => scalar, None => return Err(Box::new(make_error())), }; if expr_scalar == goal_scalar { return Ok(expr); } if !expr_scalar.automatically_converts_to(goal_scalar) { return Err(Box::new(make_error())); } assert!(expr_scalar.is_abstract()); self.convert_leaf_scalar(expr, expr_span, goal_scalar) } fn convert_leaf_scalar( &mut self, expr: Handle, expr_span: Span, goal_scalar: crate::Scalar, ) -> Result<'source, Handle> { let expr_inner = super::resolve_inner!(self, expr); if let crate::TypeInner::Array { .. } = *expr_inner { self.as_const_evaluator() .cast_array(expr, goal_scalar, expr_span) .map_err(|err| { Box::new(super::Error::ConstantEvaluatorError(err.into(), expr_span)) }) } else { let cast = crate::Expression::As { expr, kind: goal_scalar.kind, convert: Some(goal_scalar.width), }; self.append_expression(cast, expr_span) } } /// Try to convert `exprs` to `goal_ty` using WGSL's automatic conversions. pub fn try_automatic_conversions_slice( &mut self, exprs: &mut [Handle], goal_ty: &crate::proc::TypeResolution, goal_span: Span, ) -> Result<'source, ()> { for expr in exprs.iter_mut() { *expr = self.try_automatic_conversions(*expr, goal_ty, goal_span)?; } Ok(()) } /// Apply WGSL's automatic conversions to a vector constructor's arguments. /// /// When calling a vector constructor like `vec3(...)`, the parameters /// can be a mix of scalars and vectors, with the latter being spread out to /// contribute each of their components as a component of the new value. /// When the element type is explicit, as with `` in the example above, /// WGSL's automatic conversions should convert abstract scalar and vector /// parameters to the constructor's required scalar type. pub fn try_automatic_conversions_for_vector( &mut self, exprs: &mut [Handle], goal_scalar: crate::Scalar, goal_span: Span, ) -> Result<'source, ()> { use crate::proc::TypeResolution as Tr; use crate::TypeInner as Ti; let goal_scalar_res = Tr::Value(Ti::Scalar(goal_scalar)); for (i, expr) in exprs.iter_mut().enumerate() { // Keep the TypeResolution so we can get full type names // in error messages. let expr_resolution = super::resolve!(self, *expr); let types = &self.module.types; let expr_inner = expr_resolution.inner_with(types); match *expr_inner { Ti::Scalar(_) => { *expr = self.try_automatic_conversions(*expr, &goal_scalar_res, goal_span)?; } Ti::Vector { size, scalar: _ } => { let goal_vector_res = Tr::Value(Ti::Vector { size, scalar: goal_scalar, }); *expr = self.try_automatic_conversions(*expr, &goal_vector_res, goal_span)?; } _ => { let span = self.get_expression_span(*expr); return Err(Box::new(super::Error::InvalidConstructorComponentType( span, i as i32, ))); } } } Ok(()) } /// Convert `expr` to the leaf scalar type `scalar`. pub fn convert_to_leaf_scalar( &mut self, expr: &mut Handle, goal: crate::Scalar, ) -> Result<'source, ()> { let inner = super::resolve_inner!(self, *expr); // Do nothing if `inner` doesn't even have leaf scalars; // it's a type error that validation will catch. if inner.scalar() != Some(goal) { let cast = crate::Expression::As { expr: *expr, kind: goal.kind, convert: Some(goal.width), }; let expr_span = self.get_expression_span(*expr); *expr = self.append_expression(cast, expr_span)?; } Ok(()) } /// Convert all expressions in `exprs` to a common scalar type. /// /// Note that the caller is responsible for making sure these /// conversions are actually justified. This function simply /// generates `As` expressions, regardless of whether they are /// permitted WGSL automatic conversions. Callers intending to /// implement automatic conversions need to determine for /// themselves whether the casts we we generate are justified, /// perhaps by calling `TypeInner::automatically_converts_to` or /// `Scalar::automatic_conversion_combine`. pub fn convert_slice_to_common_leaf_scalar( &mut self, exprs: &mut [Handle], goal: crate::Scalar, ) -> Result<'source, ()> { for expr in exprs.iter_mut() { self.convert_to_leaf_scalar(expr, goal)?; } Ok(()) } /// Return an expression for the concretized value of `expr`. /// /// If `expr` is already concrete, return it unchanged. pub fn concretize( &mut self, expr: Handle, ) -> Result<'source, Handle> { let inner = super::resolve_inner!(self, expr); if let Some(scalar) = inner.automatically_convertible_scalar(&self.module.types) { use crate::ScalarKind as Sk; let concretization_preferences = match scalar.kind { // already concrete Sk::Sint | Sk::Uint | Sk::Float | Sk::Bool => return Ok(expr), Sk::AbstractInt => { [crate::Scalar::I32, crate::Scalar::U32, crate::Scalar::F32].as_slice() } Sk::AbstractFloat => [crate::Scalar::F32].as_slice(), }; let expr_span = self.get_expression_span(expr); let mut errors = Vec::new(); for concrete_scalar in concretization_preferences { match self .as_const_evaluator() .cast_array(expr, *concrete_scalar, expr_span) { Ok(expr) => return Ok(expr), Err(err) => { errors.push((concrete_scalar.to_wgsl_for_diagnostics(), err)); } } } if !errors.is_empty() { // A `TypeResolution` includes the type's full name, if // it has one. Also, avoid holding the borrow of `inner` // across the call to `cast_array`. let expr_type = &self.typifier()[expr]; return Err(Box::new(super::Error::ConcretizationFailed(Box::new( ConcretizationFailedError { expr_span, expr_type: self.type_resolution_to_string(expr_type), concretization_preferences: errors, }, )))); } } Ok(expr) } /// Find the consensus scalar of `components` under WGSL's automatic /// conversions. /// /// If `components` can all be converted to any common scalar via /// WGSL's automatic conversions, return the best such scalar. /// /// The `components` slice must not be empty. All elements' types must /// have been resolved. /// /// If `components` are definitely not acceptable as arguments to such /// constructors, return `Err(i)`, where `i` is the index in /// `components` of some problematic argument. /// /// If `base` is `Some(scalar)`, the consensus scalar must also be /// compatible with that `scalar`. This is used to restrict matrix /// initializers to floating-point types. /// /// This function doesn't fully type-check the arguments - it only /// considers their leaf scalar types. This means it may return `Ok` /// even when the Naga validator will reject the resulting /// construction expression later. pub fn automatic_conversion_consensus<'handle, I>( &self, base: Option, components: I, ) -> core::result::Result where I: IntoIterator>, I::IntoIter: Clone, // for debugging { let types = &self.module.types; let components_iter = components.into_iter(); log::debug!( "wgsl automatic_conversion_consensus: {}", components_iter .clone() .map(|&expr| { let res = &self.typifier()[expr]; self.type_resolution_to_string(res) }) .collect::>() .join(", ") ); let mut components_iter = components_iter .map(|&c| self.typifier()[c].inner_with(types).scalar()) .enumerate(); let base = base .or_else(|| components_iter.next().unwrap().1) .ok_or(0usize)?; let best = components_iter.try_fold(base, |best, (i, scalar)| { scalar .and_then(|scalar| best.automatic_conversion_combine(scalar)) .ok_or(i) })?; log::debug!(" consensus: {}", best.to_wgsl_for_diagnostics()); Ok(best) } } impl crate::TypeInner { fn automatically_convertible_scalar( &self, types: &crate::UniqueArena, ) -> Option { use crate::TypeInner as Ti; match *self { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } | Ti::Matrix { scalar, .. } => { Some(scalar) } Ti::CooperativeMatrix { .. } => None, Ti::Array { base, .. } => types[base].inner.automatically_convertible_scalar(types), Ti::Atomic(_) | Ti::Pointer { .. } | Ti::ValuePointer { .. } | Ti::Struct { .. } | Ti::Image { .. } | Ti::Sampler { .. } | Ti::AccelerationStructure { .. } | Ti::RayQuery { .. } | Ti::BindingArray { .. } => None, } } /// Return the leaf scalar type of `pointer`. /// /// `pointer` must be a `TypeInner` representing a pointer type. pub fn pointer_automatically_convertible_scalar( &self, types: &crate::UniqueArena, ) -> Option { use crate::TypeInner as Ti; match *self { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } | Ti::Matrix { scalar, .. } => { Some(scalar) } Ti::CooperativeMatrix { .. } => None, Ti::Atomic(_) => None, Ti::Pointer { base, .. } | Ti::Array { base, .. } => { types[base].inner.automatically_convertible_scalar(types) } Ti::ValuePointer { scalar, .. } => Some(scalar), Ti::Struct { .. } | Ti::Image { .. } | Ti::Sampler { .. } | Ti::AccelerationStructure { .. } | Ti::RayQuery { .. } | Ti::BindingArray { .. } => None, } } } impl crate::Scalar { /// Find the common type of `self` and `other` under WGSL's /// automatic conversions. /// /// If there are any scalars to which WGSL's automatic conversions /// will convert both `self` and `other`, return the best such /// scalar. Otherwise, return `None`. pub const fn automatic_conversion_combine(self, other: Self) -> Option { use crate::ScalarKind as Sk; match (self.kind, other.kind) { // When the kinds match... (Sk::AbstractFloat, Sk::AbstractFloat) | (Sk::AbstractInt, Sk::AbstractInt) | (Sk::Sint, Sk::Sint) | (Sk::Uint, Sk::Uint) | (Sk::Float, Sk::Float) | (Sk::Bool, Sk::Bool) => { if self.width == other.width { // ... either no conversion is necessary ... Some(self) } else { // ... or no conversion is possible. // We never convert concrete to concrete, and // abstract types should have only one size. None } } // AbstractInt converts to AbstractFloat. (Sk::AbstractFloat, Sk::AbstractInt) => Some(self), (Sk::AbstractInt, Sk::AbstractFloat) => Some(other), // AbstractFloat converts to Float. (Sk::AbstractFloat, Sk::Float) => Some(other), (Sk::Float, Sk::AbstractFloat) => Some(self), // AbstractInt converts to concrete integer or float. (Sk::AbstractInt, Sk::Uint | Sk::Sint | Sk::Float) => Some(other), (Sk::Uint | Sk::Sint | Sk::Float, Sk::AbstractInt) => Some(self), // AbstractFloat can't be reconciled with concrete integer types. (Sk::AbstractFloat, Sk::Uint | Sk::Sint) | (Sk::Uint | Sk::Sint, Sk::AbstractFloat) => { None } // Nothing can be reconciled with `bool`. (Sk::Bool, _) | (_, Sk::Bool) => None, // Different concrete types cannot be reconciled. (Sk::Sint | Sk::Uint | Sk::Float, Sk::Sint | Sk::Uint | Sk::Float) => None, } } /// Return `true` if automatic conversions will covert `self` to `goal`. pub fn automatically_converts_to(self, goal: Self) -> bool { self.automatic_conversion_combine(goal) == Some(goal) } pub(in crate::front::wgsl) const fn concretize(self) -> Self { use crate::ScalarKind as Sk; match self.kind { Sk::Sint | Sk::Uint | Sk::Float | Sk::Bool => self, Sk::AbstractInt => Self::I32, Sk::AbstractFloat => Self::F32, } } } naga-29.0.3/src/front/wgsl/lower/mod.rs000064400000000000000000006112211046102023000160000ustar 00000000000000use alloc::{ borrow::ToOwned, boxed::Box, format, string::{String, ToString}, vec::Vec, }; use core::num::NonZeroU32; use crate::front::wgsl::error::{Error, ExpectedToken, InvalidAssignmentType}; use crate::front::wgsl::index::Index; use crate::front::wgsl::parse::directive::enable_extension::EnableExtensions; use crate::front::wgsl::parse::number::Number; use crate::front::wgsl::parse::{ast, conv}; use crate::front::wgsl::Result; use crate::front::Typifier; use crate::{ common::wgsl::{TryToWgsl, TypeContext}, compact::KeepUnused, }; use crate::{common::ForDebugWithTypes, proc::LayoutErrorInner}; use crate::{ir, proc}; use crate::{Arena, FastHashMap, FastIndexMap, Handle, Span}; use construction::Constructor; use template_list::TemplateListIter; mod construction; mod conversion; mod template_list; /// Resolves the inner type of a given expression. /// /// Expects a &mut [`ExpressionContext`] and a [`Handle`]. /// /// Returns a &[`ir::TypeInner`]. /// /// Ideally, we would simply have a function that takes a `&mut ExpressionContext` /// and returns a `&TypeResolution`. Unfortunately, this leads the borrow checker /// to conclude that the mutable borrow lasts for as long as we are using the /// `&TypeResolution`, so we can't use the `ExpressionContext` for anything else - /// like, say, resolving another operand's type. Using a macro that expands to /// two separate calls, only the first of which needs a `&mut`, /// lets the borrow checker see that the mutable borrow is over. macro_rules! resolve_inner { ($ctx:ident, $expr:expr) => {{ $ctx.grow_types($expr)?; $ctx.typifier()[$expr].inner_with(&$ctx.module.types) }}; } pub(super) use resolve_inner; /// Resolves the inner types of two given expressions. /// /// Expects a &mut [`ExpressionContext`] and two [`Handle`]s. /// /// Returns a tuple containing two &[`ir::TypeInner`]. /// /// See the documentation of [`resolve_inner!`] for why this macro is necessary. macro_rules! resolve_inner_binary { ($ctx:ident, $left:expr, $right:expr) => {{ $ctx.grow_types($left)?; $ctx.grow_types($right)?; ( $ctx.typifier()[$left].inner_with(&$ctx.module.types), $ctx.typifier()[$right].inner_with(&$ctx.module.types), ) }}; } /// Resolves the type of a given expression. /// /// Expects a &mut [`ExpressionContext`] and a [`Handle`]. /// /// Returns a &[`TypeResolution`]. /// /// See the documentation of [`resolve_inner!`] for why this macro is necessary. /// /// [`TypeResolution`]: proc::TypeResolution macro_rules! resolve { ($ctx:ident, $expr:expr) => {{ let expr = $expr; $ctx.grow_types(expr)?; &$ctx.typifier()[expr] }}; } pub(super) use resolve; /// State for constructing a `ir::Module`. pub struct GlobalContext<'source, 'temp, 'out> { enable_extensions: EnableExtensions, /// The `TranslationUnit`'s expressions arena. ast_expressions: &'temp Arena>, // Naga IR values. /// The map from the names of module-scope declarations to the Naga IR /// `Handle`s we have built for them, owned by `Lowerer::lower`. globals: &'temp mut FastHashMap<&'source str, LoweredGlobalDecl>, /// The module we're constructing. module: &'out mut ir::Module, const_typifier: &'temp mut Typifier, layouter: &'temp mut proc::Layouter, global_expression_kind_tracker: &'temp mut proc::ExpressionKindTracker, } impl<'source> GlobalContext<'source, '_, '_> { const fn as_const(&mut self) -> ExpressionContext<'source, '_, '_> { ExpressionContext { enable_extensions: self.enable_extensions, ast_expressions: self.ast_expressions, globals: self.globals, module: self.module, const_typifier: self.const_typifier, layouter: self.layouter, expr_type: ExpressionContextType::Constant(None), global_expression_kind_tracker: self.global_expression_kind_tracker, } } const fn as_override(&mut self) -> ExpressionContext<'source, '_, '_> { ExpressionContext { enable_extensions: self.enable_extensions, ast_expressions: self.ast_expressions, globals: self.globals, module: self.module, const_typifier: self.const_typifier, layouter: self.layouter, expr_type: ExpressionContextType::Override, global_expression_kind_tracker: self.global_expression_kind_tracker, } } fn ensure_type_exists( &mut self, name: Option, inner: ir::TypeInner, ) -> Handle { self.module .types .insert(ir::Type { inner, name }, Span::UNDEFINED) } } /// State for lowering a statement within a function. pub struct StatementContext<'source, 'temp, 'out> { enable_extensions: EnableExtensions, // WGSL AST values. /// A reference to [`TranslationUnit::expressions`] for the translation unit /// we're lowering. /// /// [`TranslationUnit::expressions`]: ast::TranslationUnit::expressions ast_expressions: &'temp Arena>, // Naga IR values. /// The map from the names of module-scope declarations to the Naga IR /// `Handle`s we have built for them, owned by `Lowerer::lower`. globals: &'temp mut FastHashMap<&'source str, LoweredGlobalDecl>, /// A map from each `ast::Local` handle to the Naga expression /// we've built for it: /// /// - WGSL function arguments become Naga [`FunctionArgument`] expressions. /// /// - WGSL `var` declarations become Naga [`LocalVariable`] expressions. /// /// - WGSL `let` declararations become arbitrary Naga expressions. /// /// This always borrows the `local_table` local variable in /// [`Lowerer::function`]. /// /// [`LocalVariable`]: ir::Expression::LocalVariable /// [`FunctionArgument`]: ir::Expression::FunctionArgument local_table: &'temp mut FastHashMap, Declared>>>, const_typifier: &'temp mut Typifier, typifier: &'temp mut Typifier, layouter: &'temp mut proc::Layouter, function: &'out mut ir::Function, /// Stores the names of expressions that are assigned in `let` statement /// Also stores the spans of the names, for use in errors. named_expressions: &'out mut FastIndexMap, (String, Span)>, module: &'out mut ir::Module, /// Which `Expression`s in `self.naga_expressions` are const expressions, in /// the WGSL sense. /// /// According to the WGSL spec, a const expression must not refer to any /// `let` declarations, even if those declarations' initializers are /// themselves const expressions. So this tracker is not simply concerned /// with the form of the expressions; it is also tracking whether WGSL says /// we should consider them to be const. See the use of `force_non_const` in /// the code for lowering `let` bindings. local_expression_kind_tracker: &'temp mut proc::ExpressionKindTracker, global_expression_kind_tracker: &'temp mut proc::ExpressionKindTracker, } impl<'a, 'temp> StatementContext<'a, 'temp, '_> { const fn as_const<'t>( &'t mut self, block: &'t mut ir::Block, emitter: &'t mut proc::Emitter, ) -> ExpressionContext<'a, 't, 't> where 'temp: 't, { ExpressionContext { enable_extensions: self.enable_extensions, globals: self.globals, ast_expressions: self.ast_expressions, const_typifier: self.const_typifier, layouter: self.layouter, global_expression_kind_tracker: self.global_expression_kind_tracker, module: self.module, expr_type: ExpressionContextType::Constant(Some(LocalExpressionContext { local_table: self.local_table, function: self.function, block, emitter, typifier: self.typifier, local_expression_kind_tracker: self.local_expression_kind_tracker, })), } } const fn as_expression<'t>( &'t mut self, block: &'t mut ir::Block, emitter: &'t mut proc::Emitter, ) -> ExpressionContext<'a, 't, 't> where 'temp: 't, { ExpressionContext { enable_extensions: self.enable_extensions, globals: self.globals, ast_expressions: self.ast_expressions, const_typifier: self.const_typifier, layouter: self.layouter, global_expression_kind_tracker: self.global_expression_kind_tracker, module: self.module, expr_type: ExpressionContextType::Runtime(LocalExpressionContext { local_table: self.local_table, function: self.function, block, emitter, typifier: self.typifier, local_expression_kind_tracker: self.local_expression_kind_tracker, }), } } #[allow(dead_code)] const fn as_global(&mut self) -> GlobalContext<'a, '_, '_> { GlobalContext { enable_extensions: self.enable_extensions, ast_expressions: self.ast_expressions, globals: self.globals, module: self.module, const_typifier: self.const_typifier, layouter: self.layouter, global_expression_kind_tracker: self.global_expression_kind_tracker, } } fn invalid_assignment_type(&self, expr: Handle) -> InvalidAssignmentType { if let Some(&(_, span)) = self.named_expressions.get(&expr) { InvalidAssignmentType::ImmutableBinding(span) } else { match self.function.expressions[expr] { ir::Expression::Swizzle { .. } => InvalidAssignmentType::Swizzle, ir::Expression::Access { base, .. } => self.invalid_assignment_type(base), ir::Expression::AccessIndex { base, .. } => self.invalid_assignment_type(base), _ => InvalidAssignmentType::Other, } } } } pub struct LocalExpressionContext<'temp, 'out> { /// A map from [`ast::Local`] handles to the Naga expressions we've built for them. /// /// This is always [`StatementContext::local_table`] for the /// enclosing statement; see that documentation for details. local_table: &'temp FastHashMap, Declared>>>, function: &'out mut ir::Function, block: &'temp mut ir::Block, emitter: &'temp mut proc::Emitter, typifier: &'temp mut Typifier, /// Which `Expression`s in `self.naga_expressions` are const expressions, in /// the WGSL sense. /// /// See [`StatementContext::local_expression_kind_tracker`] for details. local_expression_kind_tracker: &'temp mut proc::ExpressionKindTracker, } /// The type of Naga IR expression we are lowering an [`ast::Expression`] to. pub enum ExpressionContextType<'temp, 'out> { /// We are lowering to an arbitrary runtime expression, to be /// included in a function's body. /// /// The given [`LocalExpressionContext`] holds information about local /// variables, arguments, and other definitions available only to runtime /// expressions, not constant or override expressions. Runtime(LocalExpressionContext<'temp, 'out>), /// We are lowering to a constant expression, to be included in the module's /// constant expression arena. /// /// Everything global constant expressions are allowed to refer to is /// available in the [`ExpressionContext`], but local constant expressions can /// also refer to other Constant(Option>), /// We are lowering to an override expression, to be included in the module's /// constant expression arena. /// /// Everything override expressions are allowed to refer to is /// available in the [`ExpressionContext`], so this variant /// carries no further information. Override, } /// State for lowering an [`ast::Expression`] to Naga IR. /// /// [`ExpressionContext`]s come in two kinds, distinguished by /// the value of the [`expr_type`] field: /// /// - A [`Runtime`] context contributes [`naga::Expression`]s to a [`naga::Function`]'s /// runtime expression arena. /// /// - A [`Constant`] context contributes [`naga::Expression`]s to a [`naga::Module`]'s /// constant expression arena. /// /// [`ExpressionContext`]s are constructed in restricted ways: /// /// - To get a [`Runtime`] [`ExpressionContext`], call /// [`StatementContext::as_expression`]. /// /// - To get a [`Constant`] [`ExpressionContext`], call /// [`GlobalContext::as_const`]. /// /// - You can demote a [`Runtime`] context to a [`Constant`] context /// by calling [`as_const`], but there's no way to go in the other /// direction, producing a runtime context from a constant one. This /// is because runtime expressions can refer to constant /// expressions, via [`Expression::Constant`], but constant /// expressions can't refer to a function's expressions. /// /// Not to be confused with `wgsl::parse::ExpressionContext`, which is /// for parsing the `ast::Expression` in the first place. /// /// [`expr_type`]: ExpressionContext::expr_type /// [`Runtime`]: ExpressionContextType::Runtime /// [`naga::Expression`]: ir::Expression /// [`naga::Function`]: ir::Function /// [`Constant`]: ExpressionContextType::Constant /// [`naga::Module`]: ir::Module /// [`as_const`]: ExpressionContext::as_const /// [`Expression::Constant`]: ir::Expression::Constant pub struct ExpressionContext<'source, 'temp, 'out> { enable_extensions: EnableExtensions, // WGSL AST values. ast_expressions: &'temp Arena>, // Naga IR values. /// The map from the names of module-scope declarations to the Naga IR /// `Handle`s we have built for them, owned by `Lowerer::lower`. globals: &'temp mut FastHashMap<&'source str, LoweredGlobalDecl>, /// The IR [`Module`] we're constructing. /// /// [`Module`]: ir::Module module: &'out mut ir::Module, /// Type judgments for [`module::global_expressions`]. /// /// [`module::global_expressions`]: ir::Module::global_expressions const_typifier: &'temp mut Typifier, layouter: &'temp mut proc::Layouter, global_expression_kind_tracker: &'temp mut proc::ExpressionKindTracker, /// Whether we are lowering a constant expression or a general /// runtime expression, and the data needed in each case. expr_type: ExpressionContextType<'temp, 'out>, } impl TypeContext for ExpressionContext<'_, '_, '_> { fn lookup_type(&self, handle: Handle) -> &ir::Type { &self.module.types[handle] } fn type_name(&self, handle: Handle) -> &str { self.module.types[handle] .name .as_deref() .unwrap_or("{anonymous type}") } fn write_override( &self, handle: Handle, out: &mut W, ) -> core::fmt::Result { match self.module.overrides[handle].name { Some(ref name) => out.write_str(name), None => write!(out, "{{anonymous override {handle:?}}}"), } } fn write_unnamed_struct( &self, _: &ir::TypeInner, _: &mut W, ) -> core::fmt::Result { unreachable!("the WGSL front end should always know the type name"); } } impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { const fn is_runtime(&self) -> bool { match self.expr_type { ExpressionContextType::Runtime(_) => true, ExpressionContextType::Constant(_) | ExpressionContextType::Override => false, } } #[allow(dead_code)] const fn as_const(&mut self) -> ExpressionContext<'source, '_, '_> { ExpressionContext { enable_extensions: self.enable_extensions, globals: self.globals, ast_expressions: self.ast_expressions, const_typifier: self.const_typifier, layouter: self.layouter, module: self.module, expr_type: ExpressionContextType::Constant(match self.expr_type { ExpressionContextType::Runtime(ref mut local_expression_context) | ExpressionContextType::Constant(Some(ref mut local_expression_context)) => { Some(LocalExpressionContext { local_table: local_expression_context.local_table, function: local_expression_context.function, block: local_expression_context.block, emitter: local_expression_context.emitter, typifier: local_expression_context.typifier, local_expression_kind_tracker: local_expression_context .local_expression_kind_tracker, }) } ExpressionContextType::Constant(None) | ExpressionContextType::Override => None, }), global_expression_kind_tracker: self.global_expression_kind_tracker, } } const fn as_global(&mut self) -> GlobalContext<'source, '_, '_> { GlobalContext { enable_extensions: self.enable_extensions, ast_expressions: self.ast_expressions, globals: self.globals, module: self.module, const_typifier: self.const_typifier, layouter: self.layouter, global_expression_kind_tracker: self.global_expression_kind_tracker, } } const fn as_const_evaluator(&mut self) -> proc::ConstantEvaluator<'_> { match self.expr_type { ExpressionContextType::Runtime(ref mut rctx) => { proc::ConstantEvaluator::for_wgsl_function( self.module, &mut rctx.function.expressions, rctx.local_expression_kind_tracker, self.layouter, rctx.emitter, rctx.block, false, ) } ExpressionContextType::Constant(Some(ref mut rctx)) => { proc::ConstantEvaluator::for_wgsl_function( self.module, &mut rctx.function.expressions, rctx.local_expression_kind_tracker, self.layouter, rctx.emitter, rctx.block, true, ) } ExpressionContextType::Constant(None) => proc::ConstantEvaluator::for_wgsl_module( self.module, self.global_expression_kind_tracker, self.layouter, false, ), ExpressionContextType::Override => proc::ConstantEvaluator::for_wgsl_module( self.module, self.global_expression_kind_tracker, self.layouter, true, ), } } /// Return a wrapper around `value` suitable for formatting. /// /// Return a wrapper around `value` that implements /// [`core::fmt::Display`] in a form suitable for use in /// diagnostic messages. const fn as_diagnostic_display( &self, value: T, ) -> crate::common::DiagnosticDisplay<(T, proc::GlobalCtx<'_>)> { let ctx = self.module.to_ctx(); crate::common::DiagnosticDisplay((value, ctx)) } fn append_expression( &mut self, expr: ir::Expression, span: Span, ) -> Result<'source, Handle> { let mut eval = self.as_const_evaluator(); eval.try_eval_and_append(expr, span) .map_err(|e| Box::new(Error::ConstantEvaluatorError(e.into(), span))) } fn get_const_val>( &self, handle: Handle, ) -> core::result::Result { match self.expr_type { ExpressionContextType::Runtime(ref ctx) => { if !ctx.local_expression_kind_tracker.is_const(handle) { return Err(proc::ConstValueError::NonConst); } self.module .to_ctx() .get_const_val_from(handle, &ctx.function.expressions) } ExpressionContextType::Constant(Some(ref ctx)) => { assert!(ctx.local_expression_kind_tracker.is_const(handle)); self.module .to_ctx() .get_const_val_from(handle, &ctx.function.expressions) } ExpressionContextType::Constant(None) => self.module.to_ctx().get_const_val(handle), ExpressionContextType::Override => Err(proc::ConstValueError::NonConst), } } /// Return `true` if `handle` is a constant expression. fn is_const(&self, handle: Handle) -> bool { use ExpressionContextType as Ect; match self.expr_type { Ect::Runtime(ref ctx) | Ect::Constant(Some(ref ctx)) => { ctx.local_expression_kind_tracker.is_const(handle) } Ect::Constant(None) | Ect::Override => { self.global_expression_kind_tracker.is_const(handle) } } } fn get_expression_span(&self, handle: Handle) -> Span { match self.expr_type { ExpressionContextType::Runtime(ref ctx) | ExpressionContextType::Constant(Some(ref ctx)) => { ctx.function.expressions.get_span(handle) } ExpressionContextType::Constant(None) | ExpressionContextType::Override => { self.module.global_expressions.get_span(handle) } } } const fn typifier(&self) -> &Typifier { match self.expr_type { ExpressionContextType::Runtime(ref ctx) | ExpressionContextType::Constant(Some(ref ctx)) => ctx.typifier, ExpressionContextType::Constant(None) | ExpressionContextType::Override => { self.const_typifier } } } fn get(&self, handle: Handle) -> &crate::Expression { match self.expr_type { ExpressionContextType::Runtime(ref ctx) | ExpressionContextType::Constant(Some(ref ctx)) => &ctx.function.expressions[handle], ExpressionContextType::Constant(None) | ExpressionContextType::Override => { &self.module.global_expressions[handle] } } } fn local( &mut self, local: &Handle, span: Span, ) -> Result<'source, Typed>> { match self.expr_type { ExpressionContextType::Runtime(ref ctx) => Ok(ctx.local_table[local].runtime()), ExpressionContextType::Constant(Some(ref ctx)) => ctx.local_table[local] .const_time() .ok_or(Box::new(Error::UnexpectedOperationInConstContext(span))), _ => Err(Box::new(Error::UnexpectedOperationInConstContext(span))), } } fn runtime_expression_ctx( &mut self, span: Span, ) -> Result<'source, &mut LocalExpressionContext<'temp, 'out>> { match self.expr_type { ExpressionContextType::Runtime(ref mut ctx) => Ok(ctx), ExpressionContextType::Constant(_) | ExpressionContextType::Override => { Err(Box::new(Error::UnexpectedOperationInConstContext(span))) } } } fn with_nested_runtime_expression_ctx<'a, F, T>( &mut self, span: Span, f: F, ) -> Result<'source, (T, crate::Block)> where for<'t> F: FnOnce(&mut ExpressionContext<'source, 't, 't>) -> Result<'source, T>, { let mut block = crate::Block::new(); let rctx = match self.expr_type { ExpressionContextType::Runtime(ref mut rctx) => Ok(rctx), ExpressionContextType::Constant(_) | ExpressionContextType::Override => { Err(Error::UnexpectedOperationInConstContext(span)) } }?; rctx.block .extend(rctx.emitter.finish(&rctx.function.expressions)); rctx.emitter.start(&rctx.function.expressions); let nested_rctx = LocalExpressionContext { local_table: rctx.local_table, function: rctx.function, block: &mut block, emitter: rctx.emitter, typifier: rctx.typifier, local_expression_kind_tracker: rctx.local_expression_kind_tracker, }; let mut nested_ctx = ExpressionContext { enable_extensions: self.enable_extensions, expr_type: ExpressionContextType::Runtime(nested_rctx), ast_expressions: self.ast_expressions, globals: self.globals, module: self.module, const_typifier: self.const_typifier, layouter: self.layouter, global_expression_kind_tracker: self.global_expression_kind_tracker, }; let ret = f(&mut nested_ctx)?; block.extend(rctx.emitter.finish(&rctx.function.expressions)); rctx.emitter.start(&rctx.function.expressions); Ok((ret, block)) } fn gather_component( &mut self, expr: Handle, component_span: Span, gather_span: Span, ) -> Result<'source, ir::SwizzleComponent> { match self.expr_type { ExpressionContextType::Runtime(ref rctx) => { if !rctx.local_expression_kind_tracker.is_const(expr) { return Err(Box::new(Error::ExpectedConstExprConcreteIntegerScalar( component_span, ))); } let index = self .module .to_ctx() .get_const_val_from::(expr, &rctx.function.expressions) .map_err(|err| match err { proc::ConstValueError::NonConst | proc::ConstValueError::InvalidType => { Error::ExpectedConstExprConcreteIntegerScalar(component_span) } proc::ConstValueError::Negative => { Error::ExpectedNonNegative(component_span) } })?; ir::SwizzleComponent::XYZW .get(index as usize) .copied() .ok_or(Box::new(Error::InvalidGatherComponent(component_span))) } // This means a `gather` operation appeared in a constant expression. // This error refers to the `gather` itself, not its "component" argument. ExpressionContextType::Constant(_) | ExpressionContextType::Override => Err(Box::new( Error::UnexpectedOperationInConstContext(gather_span), )), } } /// Determine the type of `handle`, and add it to the module's arena. /// /// If you just need a `TypeInner` for `handle`'s type, use the /// [`resolve_inner!`] macro instead. This function /// should only be used when the type of `handle` needs to appear /// in the module's final `Arena`, for example, if you're /// creating a [`LocalVariable`] whose type is inferred from its /// initializer. /// /// [`LocalVariable`]: ir::LocalVariable fn register_type( &mut self, handle: Handle, ) -> Result<'source, Handle> { self.grow_types(handle)?; // This is equivalent to calling ExpressionContext::typifier(), // except that this lets the borrow checker see that it's okay // to also borrow self.module.types mutably below. let typifier = match self.expr_type { ExpressionContextType::Runtime(ref ctx) | ExpressionContextType::Constant(Some(ref ctx)) => ctx.typifier, ExpressionContextType::Constant(None) | ExpressionContextType::Override => { &*self.const_typifier } }; Ok(typifier.register_type(handle, &mut self.module.types)) } /// Resolve the types of all expressions up through `handle`. /// /// Ensure that [`self.typifier`] has a [`TypeResolution`] for /// every expression in `self.function.expressions`. /// /// This does not add types to any arena. The [`Typifier`] /// documentation explains the steps we take to avoid filling /// arenas with intermediate types. /// /// This function takes `&mut self`, so it can't conveniently /// return a shared reference to the resulting `TypeResolution`: /// the shared reference would extend the mutable borrow, and you /// wouldn't be able to use `self` for anything else. Instead, you /// should use [`register_type`] or one of [`resolve!`], /// [`resolve_inner!`] or [`resolve_inner_binary!`]. /// /// [`self.typifier`]: ExpressionContext::typifier /// [`TypeResolution`]: proc::TypeResolution /// [`register_type`]: Self::register_type /// [`Typifier`]: Typifier fn grow_types(&mut self, handle: Handle) -> Result<'source, &mut Self> { let empty_arena = Arena::new(); let resolve_ctx; let typifier; let expressions; match self.expr_type { ExpressionContextType::Runtime(ref mut ctx) | ExpressionContextType::Constant(Some(ref mut ctx)) => { resolve_ctx = proc::ResolveContext::with_locals( self.module, &ctx.function.local_variables, &ctx.function.arguments, ); typifier = &mut *ctx.typifier; expressions = &ctx.function.expressions; } ExpressionContextType::Constant(None) | ExpressionContextType::Override => { resolve_ctx = proc::ResolveContext::with_locals(self.module, &empty_arena, &[]); typifier = self.const_typifier; expressions = &self.module.global_expressions; } }; typifier .grow(handle, expressions, &resolve_ctx) .map_err(Error::InvalidResolve)?; Ok(self) } fn image_data( &mut self, image: Handle, span: Span, ) -> Result<'source, (ir::ImageClass, bool)> { match *resolve_inner!(self, image) { ir::TypeInner::Image { class, arrayed, .. } => Ok((class, arrayed)), _ => Err(Box::new(Error::BadTexture(span))), } } fn prepare_args<'b>( &mut self, args: &'b [Handle>], min_args: u32, span: Span, ) -> ArgumentContext<'b, 'source> { ArgumentContext { args: args.iter(), min_args, args_used: 0, total_args: args.len() as u32, span, } } /// Insert splats, if needed by the non-'*' operations. /// /// See the "Binary arithmetic expressions with mixed scalar and vector operands" /// table in the WebGPU Shading Language specification for relevant operators. /// /// Multiply is not handled here as backends are expected to handle vec*scalar /// operations, so inserting splats into the IR increases size needlessly. fn binary_op_splat( &mut self, op: ir::BinaryOperator, left: &mut Handle, right: &mut Handle, ) -> Result<'source, ()> { if matches!( op, ir::BinaryOperator::Add | ir::BinaryOperator::Subtract | ir::BinaryOperator::Divide | ir::BinaryOperator::Modulo ) { match resolve_inner_binary!(self, *left, *right) { (&ir::TypeInner::Vector { size, .. }, &ir::TypeInner::Scalar { .. }) => { *right = self.append_expression( ir::Expression::Splat { size, value: *right, }, self.get_expression_span(*right), )?; } (&ir::TypeInner::Scalar { .. }, &ir::TypeInner::Vector { size, .. }) => { *left = self.append_expression( ir::Expression::Splat { size, value: *left }, self.get_expression_span(*left), )?; } _ => {} } } Ok(()) } /// Add a single expression to the expression table that is not covered by `self.emitter`. /// /// This is useful for `CallResult` and `AtomicResult` expressions, which should not be covered by /// `Emit` statements. fn interrupt_emitter( &mut self, expression: ir::Expression, span: Span, ) -> Result<'source, Handle> { match self.expr_type { ExpressionContextType::Runtime(ref mut rctx) | ExpressionContextType::Constant(Some(ref mut rctx)) => { rctx.block .extend(rctx.emitter.finish(&rctx.function.expressions)); } ExpressionContextType::Constant(None) | ExpressionContextType::Override => {} } let result = self.append_expression(expression, span); match self.expr_type { ExpressionContextType::Runtime(ref mut rctx) | ExpressionContextType::Constant(Some(ref mut rctx)) => { rctx.emitter.start(&rctx.function.expressions); } ExpressionContextType::Constant(None) | ExpressionContextType::Override => {} } result } /// Apply the WGSL Load Rule to `expr`. /// /// If `expr` is has type `ref`, perform a load to produce a value of type /// `T`. Otherwise, return `expr` unchanged. fn apply_load_rule( &mut self, expr: Typed>, ) -> Result<'source, Handle> { match expr { Typed::Reference(pointer) => { let load = ir::Expression::Load { pointer }; let span = self.get_expression_span(pointer); self.append_expression(load, span) } Typed::Plain(handle) => Ok(handle), } } fn ensure_type_exists(&mut self, inner: ir::TypeInner) -> Handle { self.as_global().ensure_type_exists(None, inner) } /// Check that `expr` is an identifier resolving to a predeclared enumerant. /// /// The identifier must not have any template parameters. /// /// Return the name of the identifier, together with its span. /// /// Actually, this only checks that the identifier refers to some /// predeclared object, not necessarily an enumerant. This should be good /// enough, since the caller is going to compare the name against some list /// of permitted enumerants anyway. fn enumerant( &self, expr: Handle>, ) -> Result<'source, (&'source str, Span)> { let span = self.ast_expressions.get_span(expr); let expr = &self.ast_expressions[expr]; let ast::Expression::Ident(ref ident) = *expr else { return Err(Box::new(Error::UnexpectedExprForEnumerant(span))); }; let ast::TemplateElaboratedIdent { ident: ast::IdentExpr::Unresolved(name), ref template_list, .. } = *ident else { return Err(Box::new(Error::UnexpectedIdentForEnumerant(span))); }; if self.globals.get(name).is_some() { return Err(Box::new(Error::UnexpectedIdentForEnumerant(span))); } if !template_list.is_empty() { return Err(Box::new(Error::UnexpectedTemplate(span))); } Ok((name, span)) } fn var_address_space( &self, template_list: &[Handle>], ) -> Result<'source, ir::AddressSpace> { let mut tl = TemplateListIter::new(Span::UNDEFINED, template_list); let mut address_space = tl.maybe_address_space(self)?; if let Some(ref mut address_space) = address_space { tl.maybe_access_mode(address_space, self)?; } tl.finish(self)?; Ok(address_space.unwrap_or(ir::AddressSpace::Handle)) } } struct ArgumentContext<'ctx, 'source> { args: core::slice::Iter<'ctx, Handle>>, min_args: u32, args_used: u32, total_args: u32, span: Span, } impl<'source> ArgumentContext<'_, 'source> { pub fn finish(self) -> Result<'source, ()> { if self.args.len() == 0 { Ok(()) } else { Err(Box::new(Error::WrongArgumentCount { found: self.total_args, expected: self.min_args..self.args_used + 1, span: self.span, })) } } pub fn next(&mut self) -> Result<'source, Handle>> { match self.args.next().copied() { Some(arg) => { self.args_used += 1; Ok(arg) } None => Err(Box::new(Error::WrongArgumentCount { found: self.total_args, expected: self.min_args..self.args_used + 1, span: self.span, })), } } } #[derive(Debug, Copy, Clone)] enum Declared { /// Value declared as const Const(T), /// Value declared as non-const Runtime(T), } impl Declared { fn runtime(self) -> T { match self { Declared::Const(t) | Declared::Runtime(t) => t, } } fn const_time(self) -> Option { match self { Declared::Const(t) => Some(t), Declared::Runtime(_) => None, } } } /// WGSL type annotations on expressions, types, values, etc. /// /// Naga and WGSL types are very close, but Naga lacks WGSL's `ref` types, which /// we need to know to apply the Load Rule. This enum carries some WGSL or Naga /// datum along with enough information to determine its corresponding WGSL /// type. /// /// The `T` type parameter can be any expression-like thing: /// /// - `Typed>` can represent a full WGSL type. For example, /// given some Naga `Pointer` type `ptr`, a WGSL reference type is a /// `Typed::Reference(ptr)` whereas a WGSL pointer type is a /// `Typed::Plain(ptr)`. /// /// - `Typed` or `Typed>` can /// represent references similarly. /// /// Use the `map` and `try_map` methods to convert from one expression /// representation to another. /// /// [`Expression`]: ir::Expression #[derive(Debug, Copy, Clone)] enum Typed { /// A WGSL reference. Reference(T), /// A WGSL plain type. Plain(T), } impl Typed { fn map(self, mut f: impl FnMut(T) -> U) -> Typed { match self { Self::Reference(v) => Typed::Reference(f(v)), Self::Plain(v) => Typed::Plain(f(v)), } } fn try_map( self, mut f: impl FnMut(T) -> core::result::Result, ) -> core::result::Result, E> { Ok(match self { Self::Reference(expr) => Typed::Reference(f(expr)?), Self::Plain(expr) => Typed::Plain(f(expr)?), }) } fn ref_or(self, error: E) -> core::result::Result { match self { Self::Reference(v) => Ok(v), Self::Plain(_) => Err(error), } } } /// A single vector component or swizzle. /// /// This represents the things that can appear after the `.` in a vector access /// expression: either a single component name, or a series of them, /// representing a swizzle. enum Components { Single(u32), Swizzle { size: ir::VectorSize, pattern: [ir::SwizzleComponent; 4], }, } impl Components { const fn letter_component(letter: char) -> Option { use ir::SwizzleComponent as Sc; match letter { 'x' | 'r' => Some(Sc::X), 'y' | 'g' => Some(Sc::Y), 'z' | 'b' => Some(Sc::Z), 'w' | 'a' => Some(Sc::W), _ => None, } } fn single_component(name: &str, name_span: Span) -> Result<'_, u32> { let ch = name.chars().next().ok_or(Error::BadAccessor(name_span))?; match Self::letter_component(ch) { Some(sc) => Ok(sc as u32), None => Err(Box::new(Error::BadAccessor(name_span))), } } /// Construct a `Components` value from a 'member' name, like `"wzy"` or `"x"`. /// /// Use `name_span` for reporting errors in parsing the component string. fn new(name: &str, name_span: Span) -> Result<'_, Self> { let size = match name.len() { 1 => return Ok(Components::Single(Self::single_component(name, name_span)?)), 2 => ir::VectorSize::Bi, 3 => ir::VectorSize::Tri, 4 => ir::VectorSize::Quad, _ => return Err(Box::new(Error::BadAccessor(name_span))), }; let mut pattern = [ir::SwizzleComponent::X; 4]; for (comp, ch) in pattern.iter_mut().zip(name.chars()) { *comp = Self::letter_component(ch).ok_or(Error::BadAccessor(name_span))?; } if name.chars().all(|c| matches!(c, 'x' | 'y' | 'z' | 'w')) || name.chars().all(|c| matches!(c, 'r' | 'g' | 'b' | 'a')) { Ok(Components::Swizzle { size, pattern }) } else { Err(Box::new(Error::BadAccessor(name_span))) } } } /// An `ast::GlobalDecl` for which we have built the Naga IR equivalent. enum LoweredGlobalDecl { Function { handle: Handle, must_use: bool, }, Var(Handle), Const(Handle), Override(Handle), Type(Handle), EntryPoint(usize), } enum Texture { Gather, GatherCompare, Sample, SampleBias, SampleCompare, SampleCompareLevel, SampleGrad, SampleLevel, SampleBaseClampToEdge, } impl Texture { pub fn map(word: &str) -> Option { Some(match word { "textureGather" => Self::Gather, "textureGatherCompare" => Self::GatherCompare, "textureSample" => Self::Sample, "textureSampleBias" => Self::SampleBias, "textureSampleCompare" => Self::SampleCompare, "textureSampleCompareLevel" => Self::SampleCompareLevel, "textureSampleGrad" => Self::SampleGrad, "textureSampleLevel" => Self::SampleLevel, "textureSampleBaseClampToEdge" => Self::SampleBaseClampToEdge, _ => return None, }) } pub const fn min_argument_count(&self) -> u32 { match *self { Self::Gather => 3, Self::GatherCompare => 4, Self::Sample => 3, Self::SampleBias => 5, Self::SampleCompare => 5, Self::SampleCompareLevel => 5, Self::SampleGrad => 6, Self::SampleLevel => 5, Self::SampleBaseClampToEdge => 3, } } } enum SubgroupGather { BroadcastFirst, Broadcast, Shuffle, ShuffleDown, ShuffleUp, ShuffleXor, QuadBroadcast, } impl SubgroupGather { pub fn map(word: &str) -> Option { Some(match word { "subgroupBroadcastFirst" => Self::BroadcastFirst, "subgroupBroadcast" => Self::Broadcast, "subgroupShuffle" => Self::Shuffle, "subgroupShuffleDown" => Self::ShuffleDown, "subgroupShuffleUp" => Self::ShuffleUp, "subgroupShuffleXor" => Self::ShuffleXor, "quadBroadcast" => Self::QuadBroadcast, _ => return None, }) } } /// Whether a declaration accepts abstract types, or concretizes. enum AbstractRule { /// This declaration concretizes its initialization expression. Concretize, /// This declaration can accept initializers with abstract types. Allow, } /// Whether `@must_use` applies to a call expression. #[derive(Debug, Copy, Clone)] enum MustUse { Yes, No, } impl From for MustUse { fn from(value: bool) -> Self { if value { MustUse::Yes } else { MustUse::No } } } pub struct Lowerer<'source, 'temp> { index: &'temp Index<'source>, } impl<'source, 'temp> Lowerer<'source, 'temp> { pub const fn new(index: &'temp Index<'source>) -> Self { Self { index } } pub fn lower(&mut self, tu: ast::TranslationUnit<'source>) -> Result<'source, ir::Module> { let mut module = ir::Module { diagnostic_filters: tu.diagnostic_filters, diagnostic_filter_leaf: tu.diagnostic_filter_leaf, ..Default::default() }; let mut ctx = GlobalContext { enable_extensions: tu.enable_extensions, ast_expressions: &tu.expressions, globals: &mut FastHashMap::default(), module: &mut module, const_typifier: &mut Typifier::new(), layouter: &mut proc::Layouter::default(), global_expression_kind_tracker: &mut proc::ExpressionKindTracker::new(), }; if !tu.doc_comments.is_empty() { ctx.module.get_or_insert_default_doc_comments().module = tu.doc_comments.iter().map(|s| s.to_string()).collect(); } for decl_handle in self.index.visit_ordered() { let span = tu.decls.get_span(decl_handle); let decl = &tu.decls[decl_handle]; match decl.kind { ast::GlobalDeclKind::Fn(ref f) => { let lowered_decl = self.function(f, span, &mut ctx)?; if !f.doc_comments.is_empty() { match lowered_decl { LoweredGlobalDecl::Function { handle, .. } => { ctx.module .get_or_insert_default_doc_comments() .functions .insert( handle, f.doc_comments.iter().map(|s| s.to_string()).collect(), ); } LoweredGlobalDecl::EntryPoint(index) => { ctx.module .get_or_insert_default_doc_comments() .entry_points .insert( index, f.doc_comments.iter().map(|s| s.to_string()).collect(), ); } _ => {} } } ctx.globals.insert(f.name.name, lowered_decl); } ast::GlobalDeclKind::Var(ref v) => { let explicit_ty = v.ty.as_ref() .map(|ast| self.resolve_ast_type(ast, &mut ctx.as_const())) .transpose()?; let (ty, initializer) = self.type_and_init( v.name, v.init, explicit_ty, AbstractRule::Concretize, &mut ctx.as_override(), )?; let binding = if let Some(ref binding) = v.binding { Some(ir::ResourceBinding { group: self.const_u32(binding.group, &mut ctx.as_const())?.0, binding: self.const_u32(binding.binding, &mut ctx.as_const())?.0, }) } else { None }; let space = ctx.as_const().var_address_space(&v.template_list)?; let handle = ctx.module.global_variables.append( ir::GlobalVariable { name: Some(v.name.name.to_string()), space, binding, ty, init: initializer, memory_decorations: v.memory_decorations, }, span, ); if !v.doc_comments.is_empty() { ctx.module .get_or_insert_default_doc_comments() .global_variables .insert( handle, v.doc_comments.iter().map(|s| s.to_string()).collect(), ); } ctx.globals .insert(v.name.name, LoweredGlobalDecl::Var(handle)); } ast::GlobalDeclKind::Const(ref c) => { let mut ectx = ctx.as_const(); let explicit_ty = c.ty.as_ref() .map(|ast| self.resolve_ast_type(ast, &mut ectx)) .transpose()?; let (ty, init) = self.type_and_init( c.name, Some(c.init), explicit_ty, AbstractRule::Allow, &mut ectx, )?; let init = init.expect("Global const must have init"); let handle = ctx.module.constants.append( ir::Constant { name: Some(c.name.name.to_string()), ty, init, }, span, ); ctx.globals .insert(c.name.name, LoweredGlobalDecl::Const(handle)); if !c.doc_comments.is_empty() { ctx.module .get_or_insert_default_doc_comments() .constants .insert( handle, c.doc_comments.iter().map(|s| s.to_string()).collect(), ); } } ast::GlobalDeclKind::Override(ref o) => { let explicit_ty = o.ty.as_ref() .map(|ast| self.resolve_ast_type(ast, &mut ctx.as_const())) .transpose()?; let mut ectx = ctx.as_override(); let (ty, init) = self.type_and_init( o.name, o.init, explicit_ty, AbstractRule::Concretize, &mut ectx, )?; let id = o.id.map(|id| self.const_u32(id, &mut ctx.as_const())) .transpose()?; let id = if let Some((id, id_span)) = id { Some( u16::try_from(id) .map_err(|_| Error::PipelineConstantIDValue(id_span))?, ) } else { None }; let handle = ctx.module.overrides.append( ir::Override { name: Some(o.name.name.to_string()), id, ty, init, }, span, ); ctx.globals .insert(o.name.name, LoweredGlobalDecl::Override(handle)); } ast::GlobalDeclKind::Struct(ref s) => { let handle = self.r#struct(s, span, &mut ctx)?; ctx.globals .insert(s.name.name, LoweredGlobalDecl::Type(handle)); if !s.doc_comments.is_empty() { ctx.module .get_or_insert_default_doc_comments() .types .insert( handle, s.doc_comments.iter().map(|s| s.to_string()).collect(), ); } } ast::GlobalDeclKind::Type(ref alias) => { let ty = self.resolve_named_ast_type( &alias.ty, alias.name.name.to_string(), &mut ctx.as_const(), )?; ctx.globals .insert(alias.name.name, LoweredGlobalDecl::Type(ty)); } ast::GlobalDeclKind::ConstAssert(condition) => { let condition = self.expression(condition, &mut ctx.as_const())?; let span = ctx.module.global_expressions.get_span(condition); match ctx .module .to_ctx() .get_const_val_from(condition, &ctx.module.global_expressions) { Ok(true) => Ok(()), Ok(false) => Err(Error::ConstAssertFailed(span)), Err(proc::ConstValueError::NonConst | proc::ConstValueError::Negative) => { unreachable!() } Err(proc::ConstValueError::InvalidType) => Err(Error::NotBool(span)), }?; } } } // Constant evaluation may leave abstract-typed literals and // compositions in expression arenas, so we need to compact the module // to remove unused expressions and types. crate::compact::compact(&mut module, KeepUnused::Yes); Ok(module) } /// Obtain (inferred) type and initializer after automatic conversion fn type_and_init( &mut self, name: ast::Ident<'source>, init: Option>>, explicit_ty: Option>, abstract_rule: AbstractRule, ectx: &mut ExpressionContext<'source, '_, '_>, ) -> Result<'source, (Handle, Option>)> { let ty; let initializer; match (init, explicit_ty) { (Some(init), Some(explicit_ty)) => { let init = self.expression_for_abstract(init, ectx)?; let ty_res = proc::TypeResolution::Handle(explicit_ty); let init = ectx .try_automatic_conversions(init, &ty_res, name.span) .map_err(|error| match *error { Error::AutoConversion(e) => Box::new(Error::InitializationTypeMismatch { name: name.span, expected: e.dest_type, got: e.source_type, }), _ => error, })?; let init_ty = ectx.register_type(init)?; if !ectx.module.compare_types( &proc::TypeResolution::Handle(explicit_ty), &proc::TypeResolution::Handle(init_ty), ) { return Err(Box::new(Error::InitializationTypeMismatch { name: name.span, expected: ectx.type_to_string(explicit_ty), got: ectx.type_to_string(init_ty), })); } ty = explicit_ty; initializer = Some(init); } (Some(init), None) => { let mut init = self.expression_for_abstract(init, ectx)?; if let AbstractRule::Concretize = abstract_rule { init = ectx.concretize(init)?; } ty = ectx.register_type(init)?; initializer = Some(init); } (None, Some(explicit_ty)) => { ty = explicit_ty; initializer = None; } (None, None) => return Err(Box::new(Error::DeclMissingTypeAndInit(name.span))), } Ok((ty, initializer)) } fn function( &mut self, f: &ast::Function<'source>, span: Span, ctx: &mut GlobalContext<'source, '_, '_>, ) -> Result<'source, LoweredGlobalDecl> { let mut local_table = FastHashMap::default(); let mut expressions = Arena::new(); let mut named_expressions = FastIndexMap::default(); let mut local_expression_kind_tracker = proc::ExpressionKindTracker::new(); let arguments = f .arguments .iter() .enumerate() .map(|(i, arg)| -> Result<'_, _> { let ty = self.resolve_ast_type(&arg.ty, &mut ctx.as_const())?; let expr = expressions.append(ir::Expression::FunctionArgument(i as u32), arg.name.span); local_table.insert(arg.handle, Declared::Runtime(Typed::Plain(expr))); named_expressions.insert(expr, (arg.name.name.to_string(), arg.name.span)); local_expression_kind_tracker.insert(expr, proc::ExpressionKind::Runtime); Ok(ir::FunctionArgument { name: Some(arg.name.name.to_string()), ty, binding: self.binding(&arg.binding, ty, ctx)?, }) }) .collect::>>()?; let result = f .result .as_ref() .map(|res| -> Result<'_, _> { let ty = self.resolve_ast_type(&res.ty, &mut ctx.as_const())?; Ok(ir::FunctionResult { ty, binding: self.binding(&res.binding, ty, ctx)?, }) }) .transpose()?; let mut function = ir::Function { name: Some(f.name.name.to_string()), arguments, result, local_variables: Arena::new(), expressions, named_expressions: crate::NamedExpressions::default(), body: ir::Block::default(), diagnostic_filter_leaf: f.diagnostic_filter_leaf, }; let mut typifier = Typifier::default(); let mut stmt_ctx = StatementContext { enable_extensions: ctx.enable_extensions, local_table: &mut local_table, globals: ctx.globals, ast_expressions: ctx.ast_expressions, const_typifier: ctx.const_typifier, typifier: &mut typifier, layouter: ctx.layouter, function: &mut function, named_expressions: &mut named_expressions, module: ctx.module, local_expression_kind_tracker: &mut local_expression_kind_tracker, global_expression_kind_tracker: ctx.global_expression_kind_tracker, }; let mut body = self.block(&f.body, false, &mut stmt_ctx)?; proc::ensure_block_returns(&mut body); function.body = body; function.named_expressions = named_expressions .into_iter() .map(|(key, (name, _))| (key, name)) .collect(); if let Some(ref entry) = f.entry_point { let (workgroup_size, workgroup_size_overrides) = if let Some(workgroup_size) = entry.workgroup_size { // TODO: replace with try_map once stabilized let mut workgroup_size_out = [1; 3]; let mut workgroup_size_overrides_out = [None; 3]; for (i, size) in workgroup_size.into_iter().enumerate() { if let Some(size_expr) = size { match self.const_u32(size_expr, &mut ctx.as_const()) { Ok(value) => { workgroup_size_out[i] = value.0; } Err(err) => { if let Error::ConstantEvaluatorError(ref ty, _) = *err { match **ty { proc::ConstantEvaluatorError::OverrideExpr => { workgroup_size_overrides_out[i] = Some(self.workgroup_size_override( size_expr, &mut ctx.as_override(), )?); } _ => { return Err(err); } } } else { return Err(err); } } } } } if workgroup_size_overrides_out.iter().all(|x| x.is_none()) { (workgroup_size_out, None) } else { (workgroup_size_out, Some(workgroup_size_overrides_out)) } } else { ([0; 3], None) }; let mesh_info = if let Some((var_name, var_span)) = entry.mesh_output_variable { let var = match ctx.globals.get(var_name) { Some(&LoweredGlobalDecl::Var(handle)) => handle, Some(_) => { return Err(Box::new(Error::ExpectedGlobalVariable { name_span: var_span, })) } None => return Err(Box::new(Error::UnknownIdent(var_span, var_name))), }; let mut info = ctx.module.analyze_mesh_shader_info(var); if let Some(h) = info.1[0] { info.0.max_vertices_override = Some( ctx.module .global_expressions .append(crate::Expression::Override(h), Span::UNDEFINED), ); } if let Some(h) = info.1[1] { info.0.max_primitives_override = Some( ctx.module .global_expressions .append(crate::Expression::Override(h), Span::UNDEFINED), ); } Some(info.0) } else { None }; let task_payload = if let Some((var_name, var_span)) = entry.task_payload { Some(match ctx.globals.get(var_name) { Some(&LoweredGlobalDecl::Var(handle)) => handle, Some(_) => { return Err(Box::new(Error::ExpectedGlobalVariable { name_span: var_span, })) } None => return Err(Box::new(Error::UnknownIdent(var_span, var_name))), }) } else { None }; let incoming_ray_payload = if let Some((var_name, var_span)) = entry.ray_incoming_payload { Some(match ctx.globals.get(var_name) { Some(&LoweredGlobalDecl::Var(handle)) => handle, Some(_) => { return Err(Box::new(Error::ExpectedGlobalVariable { name_span: var_span, })) } None => return Err(Box::new(Error::UnknownIdent(var_span, var_name))), }) } else { None }; ctx.module.entry_points.push(ir::EntryPoint { name: f.name.name.to_string(), stage: entry.stage, early_depth_test: entry.early_depth_test, workgroup_size, workgroup_size_overrides, function, mesh_info, task_payload, incoming_ray_payload, }); Ok(LoweredGlobalDecl::EntryPoint( ctx.module.entry_points.len() - 1, )) } else { let handle = ctx.module.functions.append(function, span); Ok(LoweredGlobalDecl::Function { handle, must_use: f.result.as_ref().is_some_and(|res| res.must_use), }) } } fn workgroup_size_override( &mut self, size_expr: Handle>, ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result<'source, Handle> { let span = ctx.ast_expressions.get_span(size_expr); let expr = self.expression(size_expr, ctx)?; match resolve_inner!(ctx, expr).scalar_kind().ok_or(0) { Ok(ir::ScalarKind::Sint) | Ok(ir::ScalarKind::Uint) => Ok(expr), _ => Err(Box::new(Error::ExpectedConstExprConcreteIntegerScalar( span, ))), } } fn block( &mut self, b: &ast::Block<'source>, is_inside_loop: bool, ctx: &mut StatementContext<'source, '_, '_>, ) -> Result<'source, ir::Block> { let mut block = ir::Block::default(); for stmt in b.stmts.iter() { self.statement(stmt, &mut block, is_inside_loop, ctx)?; } Ok(block) } fn statement( &mut self, stmt: &ast::Statement<'source>, block: &mut ir::Block, is_inside_loop: bool, ctx: &mut StatementContext<'source, '_, '_>, ) -> Result<'source, ()> { let out = match stmt.kind { ast::StatementKind::Block(ref block) => { let block = self.block(block, is_inside_loop, ctx)?; ir::Statement::Block(block) } ast::StatementKind::LocalDecl(ref decl) => match *decl { ast::LocalDecl::Let(ref l) => { let mut emitter = proc::Emitter::default(); emitter.start(&ctx.function.expressions); let explicit_ty = l .ty .as_ref() .map(|ty| self.resolve_ast_type(ty, &mut ctx.as_const(block, &mut emitter))) .transpose()?; let mut ectx = ctx.as_expression(block, &mut emitter); let (ty, initializer) = self.type_and_init( l.name, Some(l.init), explicit_ty, AbstractRule::Concretize, &mut ectx, )?; // We have this special check here for `let` declarations because the // validator doesn't check them (they are comingled with other things in // `named_expressions`; see ). // The check could go in `type_and_init`, but then we'd have to // distinguish whether override-sized is allowed. The error ought to use // the type's span, but `module.types.get_span(ty)` is `Span::UNDEFINED` // (see ). if ctx.module.types[ty] .inner .is_dynamically_sized(&ctx.module.types) { return Err(Box::new(Error::TypeNotConstructible(l.name.span))); } // We passed `Some()` to `type_and_init`, so we // will get a lowered initializer expression back. let initializer = initializer.expect("type_and_init did not return an initializer"); // The WGSL spec says that any expression that refers to a // `let`-bound variable is not a const expression. This // affects when errors must be reported, so we can't even // treat suitable `let` bindings as constant as an // optimization. ctx.local_expression_kind_tracker .force_non_const(initializer); block.extend(emitter.finish(&ctx.function.expressions)); ctx.local_table .insert(l.handle, Declared::Runtime(Typed::Plain(initializer))); ctx.named_expressions .insert(initializer, (l.name.name.to_string(), l.name.span)); return Ok(()); } ast::LocalDecl::Var(ref v) => { let mut emitter = proc::Emitter::default(); emitter.start(&ctx.function.expressions); let explicit_ty = v.ty.as_ref() .map(|ast| { self.resolve_ast_type(ast, &mut ctx.as_const(block, &mut emitter)) }) .transpose()?; let mut ectx = ctx.as_expression(block, &mut emitter); let (ty, initializer) = self.type_and_init( v.name, v.init, explicit_ty, AbstractRule::Concretize, &mut ectx, )?; let (const_initializer, initializer) = { match initializer { Some(init) => { // It's not correct to hoist the initializer up // to the top of the function if: // - the initialization is inside a loop, and should // take place on every iteration, or // - the initialization is not a constant // expression, so its value depends on the // state at the point of initialization. if is_inside_loop || !ctx.local_expression_kind_tracker.is_const_or_override(init) { (None, Some(init)) } else { (Some(init), None) } } None => (None, None), } }; let var = ctx.function.local_variables.append( ir::LocalVariable { name: Some(v.name.name.to_string()), ty, init: const_initializer, }, stmt.span, ); let handle = ctx .as_expression(block, &mut emitter) .interrupt_emitter(ir::Expression::LocalVariable(var), Span::UNDEFINED)?; block.extend(emitter.finish(&ctx.function.expressions)); ctx.local_table .insert(v.handle, Declared::Runtime(Typed::Reference(handle))); match initializer { Some(initializer) => ir::Statement::Store { pointer: handle, value: initializer, }, None => return Ok(()), } } ast::LocalDecl::Const(ref c) => { let mut emitter = proc::Emitter::default(); emitter.start(&ctx.function.expressions); let ectx = &mut ctx.as_const(block, &mut emitter); let explicit_ty = c.ty.as_ref() .map(|ast| self.resolve_ast_type(ast, &mut ectx.as_const())) .transpose()?; let (_ty, init) = self.type_and_init( c.name, Some(c.init), explicit_ty, AbstractRule::Allow, &mut ectx.as_const(), )?; let init = init.expect("Local const must have init"); block.extend(emitter.finish(&ctx.function.expressions)); ctx.local_table .insert(c.handle, Declared::Const(Typed::Plain(init))); return Ok(()); } }, ast::StatementKind::If { condition, ref accept, ref reject, } => { let mut emitter = proc::Emitter::default(); emitter.start(&ctx.function.expressions); let condition = self.expression(condition, &mut ctx.as_expression(block, &mut emitter))?; block.extend(emitter.finish(&ctx.function.expressions)); let accept = self.block(accept, is_inside_loop, ctx)?; let reject = self.block(reject, is_inside_loop, ctx)?; ir::Statement::If { condition, accept, reject, } } ast::StatementKind::Switch { selector, ref cases, } => { let mut emitter = proc::Emitter::default(); emitter.start(&ctx.function.expressions); let mut ectx = ctx.as_expression(block, &mut emitter); // Determine the scalar type of the selector and case expressions, find the // consensus type for automatic conversion, then convert them. let (mut exprs, spans) = core::iter::once(selector) .chain(cases.iter().filter_map(|case| match case.value { ast::SwitchValue::Expr(expr) => Some(expr), ast::SwitchValue::Default => None, })) .enumerate() .map(|(i, expr)| { let span = ectx.ast_expressions.get_span(expr); let expr = self.expression_for_abstract(expr, &mut ectx)?; let ty = resolve_inner!(ectx, expr); match *ty { ir::TypeInner::Scalar( ir::Scalar::I32 | ir::Scalar::U32 | ir::Scalar::ABSTRACT_INT, ) => Ok((expr, span)), _ => match i { 0 => Err(Box::new(Error::InvalidSwitchSelector { span })), _ => Err(Box::new(Error::InvalidSwitchCase { span })), }, } }) .collect::, Vec<_>)>>()?; let mut consensus = ectx.automatic_conversion_consensus(None, &exprs) .map_err(|span_idx| Error::SwitchCaseTypeMismatch { span: spans[span_idx], })?; // Concretize to I32 if the selector and all cases were abstract if consensus == ir::Scalar::ABSTRACT_INT { consensus = ir::Scalar::I32; } for expr in &mut exprs { ectx.convert_to_leaf_scalar(expr, consensus)?; } block.extend(emitter.finish(&ctx.function.expressions)); let mut exprs = exprs.into_iter(); let selector = exprs .next() .expect("First element should be selector expression"); let cases = cases .iter() .map(|case| { Ok(ir::SwitchCase { value: match case.value { ast::SwitchValue::Expr(expr) => { let span = ctx.ast_expressions.get_span(expr); let expr = exprs.next().expect( "Should yield expression for each SwitchValue::Expr case", ); match ctx .module .to_ctx() .get_const_val_from(expr, &ctx.function.expressions) { Ok(ir::Literal::I32(value)) => ir::SwitchValue::I32(value), Ok(ir::Literal::U32(value)) => ir::SwitchValue::U32(value), _ => { return Err(Box::new(Error::InvalidSwitchCase { span, })); } } } ast::SwitchValue::Default => ir::SwitchValue::Default, }, body: self.block(&case.body, is_inside_loop, ctx)?, fall_through: case.fall_through, }) }) .collect::>()?; ir::Statement::Switch { selector, cases } } ast::StatementKind::Loop { ref body, ref continuing, break_if, } => { let body = self.block(body, true, ctx)?; let mut continuing = self.block(continuing, true, ctx)?; let mut emitter = proc::Emitter::default(); emitter.start(&ctx.function.expressions); let break_if = break_if .map(|expr| { self.expression(expr, &mut ctx.as_expression(&mut continuing, &mut emitter)) }) .transpose()?; continuing.extend(emitter.finish(&ctx.function.expressions)); ir::Statement::Loop { body, continuing, break_if, } } ast::StatementKind::Break => ir::Statement::Break, ast::StatementKind::Continue => ir::Statement::Continue, ast::StatementKind::Return { value: ast_value } => { let mut emitter = proc::Emitter::default(); emitter.start(&ctx.function.expressions); let value; if let Some(ast_expr) = ast_value { let result_ty = ctx.function.result.as_ref().map(|r| r.ty); let mut ectx = ctx.as_expression(block, &mut emitter); let expr = self.expression_for_abstract(ast_expr, &mut ectx)?; if let Some(result_ty) = result_ty { let mut ectx = ctx.as_expression(block, &mut emitter); let resolution = proc::TypeResolution::Handle(result_ty); let converted = ectx.try_automatic_conversions(expr, &resolution, Span::default())?; value = Some(converted); } else { value = Some(expr); } } else { value = None; } block.extend(emitter.finish(&ctx.function.expressions)); ir::Statement::Return { value } } ast::StatementKind::Kill => ir::Statement::Kill, ast::StatementKind::Call(ref call_phrase) => { let mut emitter = proc::Emitter::default(); emitter.start(&ctx.function.expressions); let _ = self.call( call_phrase, stmt.span, &mut ctx.as_expression(block, &mut emitter), true, )?; block.extend(emitter.finish(&ctx.function.expressions)); return Ok(()); } ast::StatementKind::Assign { target: ast_target, op, value, } => { let mut emitter = proc::Emitter::default(); emitter.start(&ctx.function.expressions); let target_span = ctx.ast_expressions.get_span(ast_target); let mut ectx = ctx.as_expression(block, &mut emitter); let target = self.expression_for_reference(ast_target, &mut ectx)?; let target_handle = match target { Typed::Reference(handle) => handle, Typed::Plain(handle) => { let ty = ctx.invalid_assignment_type(handle); return Err(Box::new(Error::InvalidAssignment { span: target_span, ty, })); } }; // Usually the value needs to be converted to match the type of // the memory view you're assigning it to. The bit shift // operators are exceptions, in that the right operand is always // a `u32` or `vecN`. let target_scalar = match op { Some(ir::BinaryOperator::ShiftLeft | ir::BinaryOperator::ShiftRight) => { Some(ir::Scalar::U32) } _ => resolve_inner!(ectx, target_handle) .pointer_automatically_convertible_scalar(&ectx.module.types), }; // Need to emit the LHS _before_ the RHS so that it is evaluated first. let op_assign = if let Some(op) = op { Some((op, ectx.apply_load_rule(target)?)) } else { None }; let value = self.expression_for_abstract(value, &mut ectx)?; let mut value = match target_scalar { Some(target_scalar) => ectx.try_automatic_conversion_for_leaf_scalar( value, target_scalar, target_span, )?, None => value, }; let value = match op_assign { Some((op, mut left)) => { ectx.binary_op_splat(op, &mut left, &mut value)?; ectx.append_expression( ir::Expression::Binary { op, left, right: value, }, stmt.span, )? } None => value, }; block.extend(emitter.finish(&ctx.function.expressions)); ir::Statement::Store { pointer: target_handle, value, } } ast::StatementKind::Increment(value) | ast::StatementKind::Decrement(value) => { let mut emitter = proc::Emitter::default(); emitter.start(&ctx.function.expressions); let op = match stmt.kind { ast::StatementKind::Increment(_) => ir::BinaryOperator::Add, ast::StatementKind::Decrement(_) => ir::BinaryOperator::Subtract, _ => unreachable!(), }; let value_span = ctx.ast_expressions.get_span(value); let target = self .expression_for_reference(value, &mut ctx.as_expression(block, &mut emitter))?; let target_handle = target.ref_or(Error::BadIncrDecrReferenceType(value_span))?; let mut ectx = ctx.as_expression(block, &mut emitter); let scalar = match *resolve_inner!(ectx, target_handle) { ir::TypeInner::ValuePointer { size: None, scalar, .. } => scalar, ir::TypeInner::Pointer { base, .. } => match ectx.module.types[base].inner { ir::TypeInner::Scalar(scalar) => scalar, _ => return Err(Box::new(Error::BadIncrDecrReferenceType(value_span))), }, _ => return Err(Box::new(Error::BadIncrDecrReferenceType(value_span))), }; let literal = match scalar.kind { ir::ScalarKind::Sint | ir::ScalarKind::Uint => ir::Literal::one(scalar) .ok_or(Error::BadIncrDecrReferenceType(value_span))?, _ => return Err(Box::new(Error::BadIncrDecrReferenceType(value_span))), }; let right = ectx.interrupt_emitter(ir::Expression::Literal(literal), Span::UNDEFINED)?; let rctx = ectx.runtime_expression_ctx(stmt.span)?; let left = rctx.function.expressions.append( ir::Expression::Load { pointer: target_handle, }, value_span, ); let value = rctx .function .expressions .append(ir::Expression::Binary { op, left, right }, stmt.span); rctx.local_expression_kind_tracker .insert(left, proc::ExpressionKind::Runtime); rctx.local_expression_kind_tracker .insert(value, proc::ExpressionKind::Runtime); block.extend(emitter.finish(&ctx.function.expressions)); ir::Statement::Store { pointer: target_handle, value, } } ast::StatementKind::ConstAssert(condition) => { let mut emitter = proc::Emitter::default(); emitter.start(&ctx.function.expressions); let condition = self.expression(condition, &mut ctx.as_const(block, &mut emitter))?; let span = ctx.function.expressions.get_span(condition); match ctx .module .to_ctx() .get_const_val_from(condition, &ctx.function.expressions) { Ok(true) => Ok(()), Ok(false) => Err(Error::ConstAssertFailed(span)), Err(proc::ConstValueError::NonConst | proc::ConstValueError::Negative) => { unreachable!() } Err(proc::ConstValueError::InvalidType) => Err(Error::NotBool(span)), }?; block.extend(emitter.finish(&ctx.function.expressions)); return Ok(()); } ast::StatementKind::Phony(expr) => { // Remembered the RHS of the phony assignment as a named expression. This // is important (1) to preserve the RHS for validation, (2) to track any // referenced globals. let mut emitter = proc::Emitter::default(); emitter.start(&ctx.function.expressions); let value = self.expression(expr, &mut ctx.as_expression(block, &mut emitter))?; block.extend(emitter.finish(&ctx.function.expressions)); ctx.named_expressions .insert(value, ("phony".to_string(), stmt.span)); return Ok(()); } }; block.push(out, stmt.span); Ok(()) } /// Lower `expr` and apply the Load Rule if possible. /// /// For the time being, this concretizes abstract values, to support /// consumers that haven't been adapted to consume them yet. Consumers /// prepared for abstract values can call [`expression_for_abstract`]. /// /// [`expression_for_abstract`]: Lowerer::expression_for_abstract fn expression( &mut self, expr: Handle>, ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result<'source, Handle> { let expr = self.expression_for_abstract(expr, ctx)?; ctx.concretize(expr) } fn expression_for_abstract( &mut self, expr: Handle>, ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result<'source, Handle> { let expr = self.expression_for_reference(expr, ctx)?; ctx.apply_load_rule(expr) } fn expression_with_leaf_scalar( &mut self, expr: Handle>, scalar: ir::Scalar, ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result<'source, Handle> { let unconverted = self.expression_for_abstract(expr, ctx)?; ctx.try_automatic_conversion_for_leaf_scalar(unconverted, scalar, Span::default()) } fn expression_for_reference( &mut self, expr: Handle>, ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result<'source, Typed>> { let span = ctx.ast_expressions.get_span(expr); let expr = &ctx.ast_expressions[expr]; let expr: Typed = match *expr { ast::Expression::Literal(literal) => { let literal = match literal { ast::Literal::Number(Number::F16(f)) => ir::Literal::F16(f), ast::Literal::Number(Number::F32(f)) => ir::Literal::F32(f), ast::Literal::Number(Number::I32(i)) => ir::Literal::I32(i), ast::Literal::Number(Number::U32(u)) => ir::Literal::U32(u), ast::Literal::Number(Number::I64(i)) => ir::Literal::I64(i), ast::Literal::Number(Number::U64(u)) => ir::Literal::U64(u), ast::Literal::Number(Number::F64(f)) => ir::Literal::F64(f), ast::Literal::Number(Number::AbstractInt(i)) => ir::Literal::AbstractInt(i), ast::Literal::Number(Number::AbstractFloat(f)) => ir::Literal::AbstractFloat(f), ast::Literal::Bool(b) => ir::Literal::Bool(b), }; let handle = ctx.interrupt_emitter(ir::Expression::Literal(literal), span)?; return Ok(Typed::Plain(handle)); } ast::Expression::Ident(ast::TemplateElaboratedIdent { ref template_list, .. }) if !template_list.is_empty() => { return Err(Box::new(Error::UnexpectedTemplate(span))) } ast::Expression::Ident(ast::TemplateElaboratedIdent { ident: ast::IdentExpr::Local(local), .. }) => { return ctx.local(&local, span); } ast::Expression::Ident(ast::TemplateElaboratedIdent { ident: ast::IdentExpr::Unresolved(name), .. }) => { let global = ctx .globals .get(name) .ok_or(Error::UnknownIdent(span, name))?; let expr = match *global { LoweredGlobalDecl::Var(handle) => { let expr = ir::Expression::GlobalVariable(handle); let v = &ctx.module.global_variables[handle]; match v.space { ir::AddressSpace::Handle => Typed::Plain(expr), _ => Typed::Reference(expr), } } LoweredGlobalDecl::Const(handle) => { Typed::Plain(ir::Expression::Constant(handle)) } LoweredGlobalDecl::Override(handle) => { Typed::Plain(ir::Expression::Override(handle)) } LoweredGlobalDecl::Function { .. } | LoweredGlobalDecl::Type(_) | LoweredGlobalDecl::EntryPoint(_) => { return Err(Box::new(Error::Unexpected(span, ExpectedToken::Variable))); } }; return expr.try_map(|handle| ctx.interrupt_emitter(handle, span)); } ast::Expression::Unary { op, expr } => { let expr = self.expression_for_abstract(expr, ctx)?; Typed::Plain(ir::Expression::Unary { op, expr }) } ast::Expression::AddrOf(expr) => { // The `&` operator simply converts a reference to a pointer. And since a // reference is required, the Load Rule is not applied. match self.expression_for_reference(expr, ctx)? { Typed::Reference(handle) => { let expr = &ctx.runtime_expression_ctx(span)?.function.expressions[handle]; if let &ir::Expression::Access { base, .. } | &ir::Expression::AccessIndex { base, .. } = expr { if let Some(ty) = resolve_inner!(ctx, base).pointer_base_type() { if matches!( *ty.inner_with(&ctx.module.types), ir::TypeInner::Vector { .. }, ) { return Err(Box::new(Error::InvalidAddrOfOperand( ctx.get_expression_span(handle), ))); } } } // No code is generated. We just declare the reference a pointer now. return Ok(Typed::Plain(handle)); } Typed::Plain(_) => { return Err(Box::new(Error::NotReference( "the operand of the `&` operator", span, ))); } } } ast::Expression::Deref(expr) => { // The pointer we dereference must be loaded. let pointer = self.expression(expr, ctx)?; if resolve_inner!(ctx, pointer).pointer_space().is_none() { return Err(Box::new(Error::NotPointer(span))); } // No code is generated. We just declare the pointer a reference now. return Ok(Typed::Reference(pointer)); } ast::Expression::Binary { op, left, right } => { self.binary(op, left, right, span, ctx)? } ast::Expression::Call(ref call_phrase) => { let handle = self .call(call_phrase, span, ctx, false)? .ok_or(Error::FunctionReturnsVoid(span))?; return Ok(Typed::Plain(handle)); } ast::Expression::Index { base, index } => { let mut lowered_base = self.expression_for_reference(base, ctx)?; let index = self.expression(index, ctx)?; // // Declare pointer as reference if let Typed::Plain(handle) = lowered_base { if resolve_inner!(ctx, handle).pointer_space().is_some() { lowered_base = Typed::Reference(handle); } } lowered_base.try_map(|base| match ctx.get_const_val(index).ok() { Some(index) => Ok::<_, Box>(ir::Expression::AccessIndex { base, index }), None => { // When an abstract array value e is indexed by an expression // that is not a const-expression, then the array is concretized // before the index is applied. // https://www.w3.org/TR/WGSL/#array-access-expr // Also applies to vectors and matrices. let base = ctx.concretize(base)?; Ok(ir::Expression::Access { base, index }) } })? } ast::Expression::Member { base, ref field } => { let mut lowered_base = self.expression_for_reference(base, ctx)?; // // Declare pointer as reference if let Typed::Plain(handle) = lowered_base { if resolve_inner!(ctx, handle).pointer_space().is_some() { lowered_base = Typed::Reference(handle); } } let temp_ty; let composite_type: &ir::TypeInner = match lowered_base { Typed::Reference(handle) => { temp_ty = resolve_inner!(ctx, handle) .pointer_base_type() .expect("In Typed::Reference(handle), handle must be a Naga pointer"); temp_ty.inner_with(&ctx.module.types) } Typed::Plain(handle) => { resolve_inner!(ctx, handle) } }; let access = match *composite_type { ir::TypeInner::Struct { ref members, .. } => { let index = members .iter() .position(|m| m.name.as_deref() == Some(field.name)) .ok_or(Error::BadAccessor(field.span))? as u32; lowered_base.map(|base| ir::Expression::AccessIndex { base, index }) } ir::TypeInner::Vector { size: vec_size, .. } => { match Components::new(field.name, field.span)? { Components::Swizzle { size, pattern } => { for &component in pattern[..size as usize].iter() { if component as u8 >= vec_size as u8 { return Err(Box::new(Error::BadAccessor(field.span))); } } Typed::Plain(ir::Expression::Swizzle { size, vector: ctx.apply_load_rule(lowered_base)?, pattern, }) } Components::Single(index) => { if index >= vec_size as u32 { return Err(Box::new(Error::BadAccessor(field.span))); } lowered_base.map(|base| ir::Expression::AccessIndex { base, index }) } } } _ => return Err(Box::new(Error::BadAccessor(field.span))), }; access } }; expr.try_map(|handle| ctx.append_expression(handle, span)) } /// Generate IR for the short-circuiting operators `&&` and `||`. /// /// `binary` has already lowered the LHS expression and resolved its type. fn logical( &mut self, op: crate::BinaryOperator, left: Handle, right: Handle>, span: Span, ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result<'source, Typed> { debug_assert!( op == crate::BinaryOperator::LogicalAnd || op == crate::BinaryOperator::LogicalOr ); if ctx.is_runtime() { // To simulate short-circuiting behavior, we want to generate IR // like the following for `&&`. For `||`, the condition is `!_lhs` // and the else value is `true`. // // var _e0: bool; // if _lhs { // _e0 = _rhs; // } else { // _e0 = false; // } let (condition, else_val) = if op == crate::BinaryOperator::LogicalAnd { let condition = left; let else_val = ctx.append_expression( crate::Expression::Literal(crate::Literal::Bool(false)), span, )?; (condition, else_val) } else { let condition = ctx.append_expression( crate::Expression::Unary { op: crate::UnaryOperator::LogicalNot, expr: left, }, span, )?; let else_val = ctx.append_expression( crate::Expression::Literal(crate::Literal::Bool(true)), span, )?; (condition, else_val) }; let bool_ty = ctx.ensure_type_exists(crate::TypeInner::Scalar(crate::Scalar::BOOL)); let rctx = ctx.runtime_expression_ctx(span)?; let result_var = rctx.function.local_variables.append( crate::LocalVariable { name: None, ty: bool_ty, init: None, }, span, ); let pointer = ctx.append_expression(crate::Expression::LocalVariable(result_var), span)?; let (right, mut accept) = ctx.with_nested_runtime_expression_ctx(span, |ctx| { let right = self.expression_for_abstract(right, ctx)?; ctx.grow_types(right)?; Ok(right) })?; accept.push( crate::Statement::Store { pointer, value: right, }, span, ); let mut reject = crate::Block::with_capacity(1); reject.push( crate::Statement::Store { pointer, value: else_val, }, span, ); let rctx = ctx.runtime_expression_ctx(span)?; rctx.block.push( crate::Statement::If { condition, accept, reject, }, span, ); Ok(Typed::Reference(crate::Expression::LocalVariable( result_var, ))) } else { let left_val: Option = ctx.get_const_val(left).ok(); if left_val.is_some_and(|left_val| { op == crate::BinaryOperator::LogicalAnd && !left_val || op == crate::BinaryOperator::LogicalOr && left_val }) { // Short-circuit behavior: don't evaluate the RHS. // TODO(https://github.com/gfx-rs/wgpu/issues/8440): We shouldn't ignore the // RHS completely, it should still be type-checked. Preserving it for type // checking is a bit tricky, because we're trying to produce an expression // for a const context, but the RHS is allowed to have things that aren't // const. Ok(Typed::Plain(ctx.get(left).clone())) } else { // Evaluate the RHS and construct the entire binary expression as we // normally would. This case applies to well-formed constant logical // expressions that don't short-circuit (handled by the constant evaluator // shortly), to override expressions (handled when overrides are processed) // and to non-well-formed expressions (rejected by type checking). let right = self.expression_for_abstract(right, ctx)?; ctx.grow_types(right)?; Ok(Typed::Plain(crate::Expression::Binary { op, left, right })) } } } fn type_expression( &mut self, expr: Handle>, ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result<'source, Handle> { let span = ctx.ast_expressions.get_span(expr); let expr = &ctx.ast_expressions[expr]; let ident = match *expr { ast::Expression::Ident(ref ident) => ident, _ => return Err(Box::new(Error::UnexpectedExprForTypeExpression(span))), }; self.type_specifier(ident, ctx, None) } fn type_specifier( &mut self, ident: &ast::TemplateElaboratedIdent<'source>, ctx: &mut ExpressionContext<'source, '_, '_>, alias_name: Option, ) -> Result<'source, Handle> { let &ast::TemplateElaboratedIdent { ref ident, ident_span, ref template_list, .. } = ident; let ident = match *ident { ast::IdentExpr::Unresolved(ident) => ident, ast::IdentExpr::Local(_) => { // Since WGSL only supports module-scope type definitions and // aliases, a local identifier can't possibly refer to a type. return Err(Box::new(Error::UnexpectedExprForTypeExpression(ident_span))); } }; let mut tl = TemplateListIter::new(ident_span, template_list); if let Some(global) = ctx.globals.get(ident) { let &LoweredGlobalDecl::Type(handle) = global else { return Err(Box::new(Error::UnexpectedExprForTypeExpression(ident_span))); }; // Type generators can only be predeclared, so since `ident` refers // to a module-scope declaration, the template parameter list should // be empty. tl.finish(ctx)?; return Ok(handle); } // If `ident` doesn't resolve to a module-scope declaration, then it // must resolve to a predeclared type or type generator. let ty = conv::map_predeclared_type(&ctx.enable_extensions, ident_span, ident)? .ok_or_else(|| Box::new(Error::UnknownIdent(ident_span, ident)))?; let ty = self.finalize_type(ctx, ty, &mut tl, alias_name)?; tl.finish(ctx)?; Ok(ty) } /// Construct an [`ir::Type`] from a [`conv::PredeclaredType`] and a list of /// template parameters. /// /// If we're processing a type alias, then `alias_name` is the name we /// should use in the new `ir::Type`. /// /// For example, when parsing `vec3`, the caller would pass: /// /// - for `ty`, [`TypeGenerator::Vector`], and /// /// - for `tl`, an iterator producing a single [`Expression::Ident`] representing `f32`. /// /// From those arguments this function will return a handle for the /// [`ir::Type`] representing `vec3`. /// /// [`TypeGenerator::Vector`]: conv::TypeGenerator::Vector /// [`Expression::Ident`]: crate::front::wgsl::parse::ast::Expression::Ident fn finalize_type( &mut self, ctx: &mut ExpressionContext<'source, '_, '_>, ty: conv::PredeclaredType, tl: &mut TemplateListIter<'_, 'source>, alias_name: Option, ) -> Result<'source, Handle> { let ty = match ty { conv::PredeclaredType::TypeInner(ty_inner) => { if let ir::TypeInner::Image { class: ir::ImageClass::External, .. } = ty_inner { // Other than the WGSL backend, every backend that supports // external textures does so by lowering them to a set of // ordinary textures and some parameters saying how to // sample from them. We don't know which backend will // consume the `Module` we're building, but in case it's not // WGSL, populate `SpecialTypes::external_texture_params` // and `SpecialTypes::external_texture_transfer_function` // with the types the backend will use for the parameter // buffer. // // Neither of these are the type we are lowering here: // that's an ordinary `TypeInner::Image`. But the fact we // are lowering a `texture_external` implies the backends // may need these additional types too. ctx.module.generate_external_texture_types(); } ctx.as_global().ensure_type_exists(alias_name, ty_inner) } conv::PredeclaredType::RayDesc => ctx.module.generate_ray_desc_type(), conv::PredeclaredType::RayIntersection => ctx.module.generate_ray_intersection_type(), conv::PredeclaredType::TypeGenerator(type_generator) => { let ty_inner = match type_generator { conv::TypeGenerator::Vector { size } => { let (scalar, _) = tl.scalar_ty(self, ctx)?; ir::TypeInner::Vector { size, scalar } } conv::TypeGenerator::Matrix { columns, rows } => { let (scalar, span) = tl.scalar_ty(self, ctx)?; if scalar.kind != ir::ScalarKind::Float { return Err(Box::new(Error::BadMatrixScalarKind(span, scalar))); } ir::TypeInner::Matrix { columns, rows, scalar, } } conv::TypeGenerator::Array => { let base = tl.ty(self, ctx)?; let size = tl.maybe_array_size(self, ctx)?; // Determine the size of the base type, if needed. ctx.layouter.update(ctx.module.to_ctx()).map_err(|err| { let LayoutErrorInner::TooLarge = err.inner else { unreachable!("unexpected layout error: {err:?}"); }; // Lots of type definitions don't get spans, so this error // message may not be very useful. Box::new(Error::TypeTooLarge { span: ctx.module.types.get_span(err.ty), }) })?; let stride = ctx.layouter[base].to_stride(); ir::TypeInner::Array { base, size, stride } } conv::TypeGenerator::Atomic => { let (scalar, _) = tl.scalar_ty(self, ctx)?; ir::TypeInner::Atomic(scalar) } conv::TypeGenerator::Pointer => { let mut space = tl.address_space(ctx)?; let base = tl.ty(self, ctx)?; tl.maybe_access_mode(&mut space, ctx)?; ir::TypeInner::Pointer { base, space } } conv::TypeGenerator::SampledTexture { dim, arrayed, multi, } => { let (scalar, span) = tl.scalar_ty(self, ctx)?; let ir::Scalar { kind, width } = scalar; if width != 4 { return Err(Box::new(Error::BadTextureSampleType { span, scalar })); } ir::TypeInner::Image { dim, arrayed, class: ir::ImageClass::Sampled { kind, multi }, } } conv::TypeGenerator::StorageTexture { dim, arrayed } => { let format = tl.storage_format(ctx)?; let access = tl.access_mode(ctx)?; ir::TypeInner::Image { dim, arrayed, class: ir::ImageClass::Storage { format, access }, } } conv::TypeGenerator::BindingArray => { let base = tl.ty(self, ctx)?; let size = tl.maybe_array_size(self, ctx)?; ir::TypeInner::BindingArray { base, size } } conv::TypeGenerator::AccelerationStructure => { let vertex_return = tl.maybe_vertex_return(ctx)?; ir::TypeInner::AccelerationStructure { vertex_return } } conv::TypeGenerator::RayQuery => { let vertex_return = tl.maybe_vertex_return(ctx)?; ir::TypeInner::RayQuery { vertex_return } } conv::TypeGenerator::CooperativeMatrix { columns, rows } => { let (ty, span) = tl.ty_with_span(self, ctx)?; let ir::TypeInner::Scalar(scalar) = ctx.module.types[ty].inner else { return Err(Box::new(Error::UnsupportedCooperativeScalar(span))); }; let role = tl.cooperative_role(ctx)?; ir::TypeInner::CooperativeMatrix { columns, rows, scalar, role, } } }; ctx.as_global().ensure_type_exists(alias_name, ty_inner) } }; Ok(ty) } fn binary( &mut self, op: ir::BinaryOperator, left: Handle>, right: Handle>, span: Span, ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result<'source, Typed> { if op == ir::BinaryOperator::LogicalAnd || op == ir::BinaryOperator::LogicalOr { let left = self.expression_for_abstract(left, ctx)?; ctx.grow_types(left)?; if !matches!( resolve_inner!(ctx, left), &ir::TypeInner::Scalar(ir::Scalar::BOOL) ) { // Pass it through as-is, will fail validation let right = self.expression_for_abstract(right, ctx)?; ctx.grow_types(right)?; Ok(Typed::Plain(crate::Expression::Binary { op, left, right })) } else { self.logical(op, left, right, span, ctx) } } else { // Load both operands. let mut left = self.expression_for_abstract(left, ctx)?; let mut right = self.expression_for_abstract(right, ctx)?; // Convert `scalar op vector` to `vector op vector` by introducing // `Splat` expressions. ctx.binary_op_splat(op, &mut left, &mut right)?; // Apply automatic conversions. match op { ir::BinaryOperator::ShiftLeft | ir::BinaryOperator::ShiftRight => { // Shift operators require the right operand to be `u32` or // `vecN`. We can let the validator sort out vector length // issues, but the right operand must be, or convert to, a u32 leaf // scalar. right = ctx.try_automatic_conversion_for_leaf_scalar(right, ir::Scalar::U32, span)?; // Additionally, we must concretize the left operand if the right operand // is not a const-expression. // See https://www.w3.org/TR/WGSL/#overload-resolution-section. // // 2. Eliminate any candidate where one of its subexpressions resolves to // an abstract type after feasible automatic conversions, but another of // the candidate’s subexpressions is not a const-expression. // // We only have to explicitly do so for shifts as their operands may be // of different types - for other binary ops this is achieved by finding // the conversion consensus for both operands. if !ctx.is_const(right) { left = ctx.concretize(left)?; } } // All other operators follow the same pattern: reconcile the // scalar leaf types. If there's no reconciliation possible, // leave the expressions as they are: validation will report the // problem. _ => { ctx.grow_types(left)?; ctx.grow_types(right)?; if let Ok(consensus_scalar) = ctx.automatic_conversion_consensus(None, [left, right].iter()) { ctx.convert_to_leaf_scalar(&mut left, consensus_scalar)?; ctx.convert_to_leaf_scalar(&mut right, consensus_scalar)?; } } } Ok(Typed::Plain(ir::Expression::Binary { op, left, right })) } } /// Generate Naga IR for a call to a WGSL builtin function. #[allow(clippy::too_many_arguments)] fn call_builtin<'phrase>( &mut self, function_name: &'source str, function_span: Span, arguments: &[Handle>], template_params: &mut TemplateListIter<'phrase, 'source>, call_span: Span, ctx: &mut ExpressionContext<'source, '_, '_>, is_statement: bool, ) -> Result<'source, Option<(Handle, MustUse)>> { let (expr, must_use) = if let Some(fun) = conv::map_relational_fun(function_name) { let mut args = ctx.prepare_args(arguments, 1, function_span); let argument = self.expression(args.next()?, ctx)?; args.finish()?; // Check for no-op all(bool) and any(bool): let argument_unmodified = matches!( fun, ir::RelationalFunction::All | ir::RelationalFunction::Any ) && { matches!( resolve_inner!(ctx, argument), &ir::TypeInner::Scalar(ir::Scalar { kind: ir::ScalarKind::Bool, .. }) ) }; if argument_unmodified { return Ok(Some((argument, MustUse::Yes))); } else { (ir::Expression::Relational { fun, argument }, MustUse::Yes) } } else if let Some((axis, ctrl)) = conv::map_derivative(function_name) { let mut args = ctx.prepare_args(arguments, 1, function_span); let expr = self.expression(args.next()?, ctx)?; args.finish()?; ( ir::Expression::Derivative { axis, ctrl, expr }, MustUse::Yes, ) } else if let Some(fun) = conv::map_standard_fun(function_name) { ( self.math_function_helper(function_span, fun, arguments, ctx)?, MustUse::Yes, ) } else if let Some(fun) = Texture::map(function_name) { ( self.texture_sample_helper(fun, arguments, function_span, ctx)?, MustUse::Yes, ) } else if let Some((op, cop)) = conv::map_subgroup_operation(function_name) { return Ok(Some(( self.subgroup_operation_helper(function_span, op, cop, arguments, ctx)?, MustUse::Yes, ))); } else if let Some(mode) = SubgroupGather::map(function_name) { return Ok(Some(( self.subgroup_gather_helper(function_span, mode, arguments, ctx)?, MustUse::Yes, ))); } else if let Some(fun) = ir::AtomicFunction::map(function_name) { return Ok(self .atomic_helper(function_span, fun, arguments, is_statement, ctx)? .map(|result| (result, MustUse::No))); } else { match function_name { "bitcast" => { let ty = template_params.ty(self, ctx)?; let mut args = ctx.prepare_args(arguments, 1, function_span); let expr = self.expression(args.next()?, ctx)?; args.finish()?; let element_scalar = match ctx.module.types[ty].inner { ir::TypeInner::Scalar(scalar) => scalar, ir::TypeInner::Vector { scalar, .. } => scalar, _ => { let ty_resolution = resolve!(ctx, expr); return Err(Box::new(Error::BadTypeCast { from_type: ctx.type_resolution_to_string(ty_resolution), span: function_span, to_type: ctx.type_to_string(ty), })); } }; ( ir::Expression::As { expr, kind: element_scalar.kind, convert: None, }, MustUse::Yes, ) } "coopLoad" | "coopLoadT" => { let row_major = function_name.ends_with("T"); let (matrix_ty, matrix_span) = template_params.ty_with_span(self, ctx)?; let mut args = ctx.prepare_args(arguments, 1, call_span); let pointer = self.expression(args.next()?, ctx)?; let (columns, rows, role) = match ctx.module.types[matrix_ty].inner { ir::TypeInner::CooperativeMatrix { columns, rows, role, .. } => (columns, rows, role), _ => return Err(Box::new(Error::InvalidCooperativeLoadType(matrix_span))), }; let stride = if args.total_args > 1 { self.expression(args.next()?, ctx)? } else { // Infer the stride from the matrix type let stride = if row_major { columns as u32 } else { rows as u32 }; ctx.append_expression( ir::Expression::Literal(ir::Literal::U32(stride)), Span::UNDEFINED, )? }; args.finish()?; ( crate::Expression::CooperativeLoad { columns, rows, role, data: crate::CooperativeData { pointer, stride, row_major, }, }, MustUse::Yes, ) } "select" => { let mut args = ctx.prepare_args(arguments, 3, function_span); let reject_orig = args.next()?; let accept_orig = args.next()?; let mut values = [ self.expression_for_abstract(reject_orig, ctx)?, self.expression_for_abstract(accept_orig, ctx)?, ]; let condition = self.expression(args.next()?, ctx)?; args.finish()?; let diagnostic_details = |ctx: &ExpressionContext<'_, '_, '_>, ty_res: &proc::TypeResolution, orig_expr| { ( ctx.ast_expressions.get_span(orig_expr), format!("`{}`", ctx.as_diagnostic_display(ty_res)), ) }; for (&value, orig_value) in values.iter().zip([reject_orig, accept_orig]) { let value_ty_res = resolve!(ctx, value); if value_ty_res .inner_with(&ctx.module.types) .vector_size_and_scalar() .is_none() { let (arg_span, arg_type) = diagnostic_details(ctx, value_ty_res, orig_value); return Err(Box::new(Error::SelectUnexpectedArgumentType { arg_span, arg_type, })); } } let mut consensus_scalar = ctx .automatic_conversion_consensus(None, &values) .map_err(|_idx| { let [reject, accept] = values; let [(reject_span, reject_type), (accept_span, accept_type)] = [(reject_orig, reject), (accept_orig, accept)].map( |(orig_expr, expr)| { let ty_res = &ctx.typifier()[expr]; diagnostic_details(ctx, ty_res, orig_expr) }, ); Error::SelectRejectAndAcceptHaveNoCommonType { reject_span, reject_type, accept_span, accept_type, } })?; if !ctx.is_const(condition) { consensus_scalar = consensus_scalar.concretize(); } ctx.convert_slice_to_common_leaf_scalar(&mut values, consensus_scalar)?; let [reject, accept] = values; ( ir::Expression::Select { reject, accept, condition, }, MustUse::Yes, ) } "arrayLength" => { let mut args = ctx.prepare_args(arguments, 1, function_span); let expr = self.expression(args.next()?, ctx)?; args.finish()?; (ir::Expression::ArrayLength(expr), MustUse::Yes) } "atomicLoad" => { let mut args = ctx.prepare_args(arguments, 1, function_span); let (pointer, _scalar) = self.atomic_pointer(args.next()?, ctx)?; args.finish()?; (ir::Expression::Load { pointer }, MustUse::No) } "atomicStore" => { let mut args = ctx.prepare_args(arguments, 2, function_span); let (pointer, scalar) = self.atomic_pointer(args.next()?, ctx)?; let value = self.expression_with_leaf_scalar(args.next()?, scalar, ctx)?; args.finish()?; let rctx = ctx.runtime_expression_ctx(function_span)?; rctx.block .extend(rctx.emitter.finish(&rctx.function.expressions)); rctx.emitter.start(&rctx.function.expressions); rctx.block .push(ir::Statement::Store { pointer, value }, function_span); return Ok(None); } "atomicCompareExchangeWeak" => { let mut args = ctx.prepare_args(arguments, 3, function_span); let (pointer, scalar) = self.atomic_pointer(args.next()?, ctx)?; let compare = self.expression_with_leaf_scalar(args.next()?, scalar, ctx)?; let value = args.next()?; let value_span = ctx.ast_expressions.get_span(value); let value = self.expression_with_leaf_scalar(value, scalar, ctx)?; args.finish()?; let expression = match *resolve_inner!(ctx, value) { ir::TypeInner::Scalar(scalar) => ir::Expression::AtomicResult { ty: ctx.module.generate_predeclared_type( ir::PredeclaredType::AtomicCompareExchangeWeakResult(scalar), ), comparison: true, }, _ => return Err(Box::new(Error::InvalidAtomicOperandType(value_span))), }; let result = ctx.interrupt_emitter(expression, function_span)?; let rctx = ctx.runtime_expression_ctx(function_span)?; rctx.block.push( ir::Statement::Atomic { pointer, fun: ir::AtomicFunction::Exchange { compare: Some(compare), }, value, result: Some(result), }, function_span, ); return Ok(Some((result, MustUse::No))); } "textureAtomicMin" | "textureAtomicMax" | "textureAtomicAdd" | "textureAtomicAnd" | "textureAtomicOr" | "textureAtomicXor" => { let mut args = ctx.prepare_args(arguments, 3, function_span); let image = args.next()?; let image_span = ctx.ast_expressions.get_span(image); let image = self.expression(image, ctx)?; let coordinate = self.expression(args.next()?, ctx)?; let (_, arrayed) = ctx.image_data(image, image_span)?; let array_index = arrayed .then(|| { args.min_args += 1; self.expression(args.next()?, ctx) }) .transpose()?; let value = self.expression(args.next()?, ctx)?; args.finish()?; let rctx = ctx.runtime_expression_ctx(function_span)?; rctx.block .extend(rctx.emitter.finish(&rctx.function.expressions)); rctx.emitter.start(&rctx.function.expressions); let stmt = ir::Statement::ImageAtomic { image, coordinate, array_index, fun: match function_name { "textureAtomicMin" => ir::AtomicFunction::Min, "textureAtomicMax" => ir::AtomicFunction::Max, "textureAtomicAdd" => ir::AtomicFunction::Add, "textureAtomicAnd" => ir::AtomicFunction::And, "textureAtomicOr" => ir::AtomicFunction::InclusiveOr, "textureAtomicXor" => ir::AtomicFunction::ExclusiveOr, _ => unreachable!(), }, value, }; rctx.block.push(stmt, function_span); return Ok(None); } "storageBarrier" => { ctx.prepare_args(arguments, 0, function_span).finish()?; let rctx = ctx.runtime_expression_ctx(function_span)?; rctx.block.push( ir::Statement::ControlBarrier(ir::Barrier::STORAGE), function_span, ); return Ok(None); } "workgroupBarrier" => { ctx.prepare_args(arguments, 0, function_span).finish()?; let rctx = ctx.runtime_expression_ctx(function_span)?; rctx.block.push( ir::Statement::ControlBarrier(ir::Barrier::WORK_GROUP), function_span, ); return Ok(None); } "subgroupBarrier" => { ctx.prepare_args(arguments, 0, function_span).finish()?; let rctx = ctx.runtime_expression_ctx(function_span)?; rctx.block.push( ir::Statement::ControlBarrier(ir::Barrier::SUB_GROUP), function_span, ); return Ok(None); } "textureBarrier" => { ctx.prepare_args(arguments, 0, function_span).finish()?; let rctx = ctx.runtime_expression_ctx(function_span)?; rctx.block.push( ir::Statement::ControlBarrier(ir::Barrier::TEXTURE), function_span, ); return Ok(None); } "workgroupUniformLoad" => { let mut args = ctx.prepare_args(arguments, 1, function_span); let expr = args.next()?; args.finish()?; let pointer = self.expression(expr, ctx)?; let result_ty = match *resolve_inner!(ctx, pointer) { ir::TypeInner::Pointer { base, space: ir::AddressSpace::WorkGroup, } => match ctx.module.types[base].inner { // Match `Expression::Load` semantics: // loading through a pointer to `atomic` produces a `T`. ir::TypeInner::Atomic(scalar) => ctx.module.types.insert( ir::Type { name: None, inner: ir::TypeInner::Scalar(scalar), }, function_span, ), _ => base, }, ir::TypeInner::ValuePointer { size, scalar, space: ir::AddressSpace::WorkGroup, } => ctx.module.types.insert( ir::Type { name: None, inner: match size { Some(size) => ir::TypeInner::Vector { size, scalar }, None => ir::TypeInner::Scalar(scalar), }, }, function_span, ), _ => { let span = ctx.ast_expressions.get_span(expr); return Err(Box::new(Error::InvalidWorkGroupUniformLoad(span))); } }; let result = ctx.interrupt_emitter( ir::Expression::WorkGroupUniformLoadResult { ty: result_ty }, function_span, )?; let rctx = ctx.runtime_expression_ctx(function_span)?; rctx.block.push( ir::Statement::WorkGroupUniformLoad { pointer, result }, function_span, ); return Ok(Some((result, MustUse::Yes))); } "textureStore" => { let mut args = ctx.prepare_args(arguments, 3, function_span); let image = args.next()?; let image_span = ctx.ast_expressions.get_span(image); let image = self.expression(image, ctx)?; let coordinate = self.expression(args.next()?, ctx)?; let (class, arrayed) = ctx.image_data(image, image_span)?; let array_index = arrayed .then(|| { args.min_args += 1; self.expression(args.next()?, ctx) }) .transpose()?; let scalar = if let ir::ImageClass::Storage { format, .. } = class { format.into() } else { return Err(Box::new(Error::NotStorageTexture(image_span))); }; let value = self.expression_with_leaf_scalar(args.next()?, scalar, ctx)?; args.finish()?; let rctx = ctx.runtime_expression_ctx(function_span)?; rctx.block .extend(rctx.emitter.finish(&rctx.function.expressions)); rctx.emitter.start(&rctx.function.expressions); let stmt = ir::Statement::ImageStore { image, coordinate, array_index, value, }; rctx.block.push(stmt, function_span); return Ok(None); } "textureLoad" => { let mut args = ctx.prepare_args(arguments, 2, function_span); let image = args.next()?; let image_span = ctx.ast_expressions.get_span(image); let image = self.expression(image, ctx)?; let coordinate = self.expression(args.next()?, ctx)?; let (class, arrayed) = ctx.image_data(image, image_span)?; let array_index = arrayed .then(|| { args.min_args += 1; self.expression(args.next()?, ctx) }) .transpose()?; let level = class .is_mipmapped() .then(|| { args.min_args += 1; self.expression(args.next()?, ctx) }) .transpose()?; let sample = class .is_multisampled() .then(|| self.expression(args.next()?, ctx)) .transpose()?; args.finish()?; ( ir::Expression::ImageLoad { image, coordinate, array_index, level, sample, }, MustUse::Yes, ) } "textureDimensions" => { let mut args = ctx.prepare_args(arguments, 1, function_span); let image = self.expression(args.next()?, ctx)?; let level = args .next() .map(|arg| self.expression(arg, ctx)) .ok() .transpose()?; args.finish()?; ( ir::Expression::ImageQuery { image, query: ir::ImageQuery::Size { level }, }, MustUse::Yes, ) } "textureNumLevels" => { let mut args = ctx.prepare_args(arguments, 1, function_span); let image = self.expression(args.next()?, ctx)?; args.finish()?; ( ir::Expression::ImageQuery { image, query: ir::ImageQuery::NumLevels, }, MustUse::Yes, ) } "textureNumLayers" => { let mut args = ctx.prepare_args(arguments, 1, function_span); let image = self.expression(args.next()?, ctx)?; args.finish()?; ( ir::Expression::ImageQuery { image, query: ir::ImageQuery::NumLayers, }, MustUse::Yes, ) } "textureNumSamples" => { let mut args = ctx.prepare_args(arguments, 1, function_span); let image = self.expression(args.next()?, ctx)?; args.finish()?; ( ir::Expression::ImageQuery { image, query: ir::ImageQuery::NumSamples, }, MustUse::Yes, ) } "rayQueryInitialize" => { let mut args = ctx.prepare_args(arguments, 3, function_span); let query = self.ray_query_pointer(args.next()?, ctx)?; let acceleration_structure = self.expression(args.next()?, ctx)?; let descriptor = self.expression(args.next()?, ctx)?; args.finish()?; let _ = ctx.module.generate_ray_desc_type(); let fun = ir::RayQueryFunction::Initialize { acceleration_structure, descriptor, }; let rctx = ctx.runtime_expression_ctx(function_span)?; rctx.block .extend(rctx.emitter.finish(&rctx.function.expressions)); rctx.emitter.start(&rctx.function.expressions); rctx.block .push(ir::Statement::RayQuery { query, fun }, function_span); return Ok(None); } "getCommittedHitVertexPositions" => { let mut args = ctx.prepare_args(arguments, 1, function_span); let query = self.ray_query_pointer(args.next()?, ctx)?; args.finish()?; let _ = ctx.module.generate_vertex_return_type(); ( ir::Expression::RayQueryVertexPositions { query, committed: true, }, MustUse::No, ) } "getCandidateHitVertexPositions" => { let mut args = ctx.prepare_args(arguments, 1, function_span); let query = self.ray_query_pointer(args.next()?, ctx)?; args.finish()?; let _ = ctx.module.generate_vertex_return_type(); ( ir::Expression::RayQueryVertexPositions { query, committed: false, }, MustUse::No, ) } "rayQueryProceed" => { let mut args = ctx.prepare_args(arguments, 1, function_span); let query = self.ray_query_pointer(args.next()?, ctx)?; args.finish()?; let result = ctx .interrupt_emitter(ir::Expression::RayQueryProceedResult, function_span)?; let fun = ir::RayQueryFunction::Proceed { result }; let rctx = ctx.runtime_expression_ctx(function_span)?; rctx.block .push(ir::Statement::RayQuery { query, fun }, function_span); return Ok(Some((result, MustUse::No))); } "rayQueryGenerateIntersection" => { let mut args = ctx.prepare_args(arguments, 2, function_span); let query = self.ray_query_pointer(args.next()?, ctx)?; let hit_t = self.expression(args.next()?, ctx)?; args.finish()?; let fun = ir::RayQueryFunction::GenerateIntersection { hit_t }; let rctx = ctx.runtime_expression_ctx(function_span)?; rctx.block .push(ir::Statement::RayQuery { query, fun }, function_span); return Ok(None); } "rayQueryConfirmIntersection" => { let mut args = ctx.prepare_args(arguments, 1, function_span); let query = self.ray_query_pointer(args.next()?, ctx)?; args.finish()?; let fun = ir::RayQueryFunction::ConfirmIntersection; let rctx = ctx.runtime_expression_ctx(function_span)?; rctx.block .push(ir::Statement::RayQuery { query, fun }, function_span); return Ok(None); } "rayQueryTerminate" => { let mut args = ctx.prepare_args(arguments, 1, function_span); let query = self.ray_query_pointer(args.next()?, ctx)?; args.finish()?; let fun = ir::RayQueryFunction::Terminate; let rctx = ctx.runtime_expression_ctx(function_span)?; rctx.block .push(ir::Statement::RayQuery { query, fun }, function_span); return Ok(None); } "rayQueryGetCommittedIntersection" => { let mut args = ctx.prepare_args(arguments, 1, function_span); let query = self.ray_query_pointer(args.next()?, ctx)?; args.finish()?; let _ = ctx.module.generate_ray_intersection_type(); ( ir::Expression::RayQueryGetIntersection { query, committed: true, }, MustUse::No, ) } "rayQueryGetCandidateIntersection" => { let mut args = ctx.prepare_args(arguments, 1, function_span); let query = self.ray_query_pointer(args.next()?, ctx)?; args.finish()?; let _ = ctx.module.generate_ray_intersection_type(); ( ir::Expression::RayQueryGetIntersection { query, committed: false, }, MustUse::No, ) } "subgroupBallot" => { let mut args = ctx.prepare_args(arguments, 0, function_span); let predicate = if arguments.len() == 1 { Some(self.expression(args.next()?, ctx)?) } else { None }; args.finish()?; let result = ctx.interrupt_emitter(ir::Expression::SubgroupBallotResult, function_span)?; let rctx = ctx.runtime_expression_ctx(function_span)?; rctx.block.push( ir::Statement::SubgroupBallot { result, predicate }, function_span, ); return Ok(Some((result, MustUse::Yes))); } "quadSwapX" => { let mut args = ctx.prepare_args(arguments, 1, function_span); let argument = self.expression(args.next()?, ctx)?; args.finish()?; let ty = ctx.register_type(argument)?; let result = ctx.interrupt_emitter( crate::Expression::SubgroupOperationResult { ty }, function_span, )?; let rctx = ctx.runtime_expression_ctx(function_span)?; rctx.block.push( crate::Statement::SubgroupGather { mode: crate::GatherMode::QuadSwap(crate::Direction::X), argument, result, }, function_span, ); return Ok(Some((result, MustUse::Yes))); } "quadSwapY" => { let mut args = ctx.prepare_args(arguments, 1, function_span); let argument = self.expression(args.next()?, ctx)?; args.finish()?; let ty = ctx.register_type(argument)?; let result = ctx.interrupt_emitter( crate::Expression::SubgroupOperationResult { ty }, function_span, )?; let rctx = ctx.runtime_expression_ctx(function_span)?; rctx.block.push( crate::Statement::SubgroupGather { mode: crate::GatherMode::QuadSwap(crate::Direction::Y), argument, result, }, function_span, ); return Ok(Some((result, MustUse::Yes))); } "quadSwapDiagonal" => { let mut args = ctx.prepare_args(arguments, 1, function_span); let argument = self.expression(args.next()?, ctx)?; args.finish()?; let ty = ctx.register_type(argument)?; let result = ctx.interrupt_emitter( crate::Expression::SubgroupOperationResult { ty }, function_span, )?; let rctx = ctx.runtime_expression_ctx(function_span)?; rctx.block.push( crate::Statement::SubgroupGather { mode: crate::GatherMode::QuadSwap(crate::Direction::Diagonal), argument, result, }, function_span, ); return Ok(Some((result, MustUse::Yes))); } "coopStore" | "coopStoreT" => { let row_major = function_name.ends_with("T"); let mut args = ctx.prepare_args(arguments, 2, function_span); let target = self.expression(args.next()?, ctx)?; let pointer = self.expression(args.next()?, ctx)?; let stride = if args.total_args > 2 { self.expression(args.next()?, ctx)? } else { // Infer the stride from the matrix type let stride = match *resolve_inner!(ctx, target) { ir::TypeInner::CooperativeMatrix { columns, rows, .. } => { if row_major { columns as u32 } else { rows as u32 } } _ => 0, }; ctx.append_expression( ir::Expression::Literal(ir::Literal::U32(stride)), Span::UNDEFINED, )? }; args.finish()?; let rctx = ctx.runtime_expression_ctx(function_span)?; rctx.block.push( crate::Statement::CooperativeStore { target, data: crate::CooperativeData { pointer, stride, row_major, }, }, function_span, ); return Ok(None); } "coopMultiplyAdd" => { let mut args = ctx.prepare_args(arguments, 3, function_span); let a = self.expression(args.next()?, ctx)?; let b = self.expression(args.next()?, ctx)?; let c = self.expression(args.next()?, ctx)?; args.finish()?; ( ir::Expression::CooperativeMultiplyAdd { a, b, c }, MustUse::Yes, ) } "traceRay" => { let mut args = ctx.prepare_args(arguments, 3, function_span); let acceleration_structure = self.expression(args.next()?, ctx)?; let descriptor = self.expression(args.next()?, ctx)?; let payload = self.expression(args.next()?, ctx)?; args.finish()?; let _ = ctx.module.generate_ray_desc_type(); let fun = ir::RayPipelineFunction::TraceRay { acceleration_structure, descriptor, payload, }; let rctx = ctx.runtime_expression_ctx(function_span)?; rctx.block .extend(rctx.emitter.finish(&rctx.function.expressions)); rctx.emitter.start(&rctx.function.expressions); rctx.block .push(ir::Statement::RayPipelineFunction(fun), function_span); return Ok(None); } _ => return Err(Box::new(Error::UnknownIdent(function_span, function_name))), } }; let expr = ctx.append_expression(expr, function_span)?; Ok(Some((expr, must_use))) } /// Generate Naga IR for call expressions and statements, and type /// constructor expressions. /// /// The "function" being called is simply an `Ident` that we know refers to /// some module-scope definition. /// /// - If it is the name of a type, then the expression is a type constructor /// expression: either constructing a value from components, a conversion /// expression, or a zero value expression. /// /// - If it is the name of a function, then we're generating a [`Call`] /// statement. We may be in the midst of generating code for an /// expression, in which case we must generate an `Emit` statement to /// force evaluation of the IR expressions we've generated so far, add the /// `Call` statement to the current block, and then resume generating /// expressions. /// /// [`Call`]: ir::Statement::Call fn call( &mut self, call_phrase: &ast::CallPhrase<'source>, span: Span, ctx: &mut ExpressionContext<'source, '_, '_>, is_statement: bool, ) -> Result<'source, Option>> { let function_name = match call_phrase.function.ident { ast::IdentExpr::Unresolved(name) => name, ast::IdentExpr::Local(_) => { return Err(Box::new(Error::CalledLocalDecl( call_phrase.function.ident_span, ))) } }; let mut function_span = call_phrase.function.ident_span; function_span.subsume(call_phrase.function.template_list_span); let arguments = call_phrase.arguments.as_slice(); let mut tl = TemplateListIter::new(function_span, &call_phrase.function.template_list); let result = match ctx.globals.get(function_name) { Some(&LoweredGlobalDecl::Type(ty)) => { // user-declared types can't make use of template lists tl.finish(ctx)?; let handle = self.construct(span, Constructor::Type(ty), function_span, arguments, ctx)?; Some((handle, MustUse::Yes)) } Some( &LoweredGlobalDecl::Const(_) | &LoweredGlobalDecl::Override(_) | &LoweredGlobalDecl::Var(_), ) => { return Err(Box::new(Error::Unexpected( function_span, ExpectedToken::Function, ))) } Some(&LoweredGlobalDecl::EntryPoint(_)) => { return Err(Box::new(Error::CalledEntryPoint(function_span))); } Some(&LoweredGlobalDecl::Function { handle: function, must_use, }) => { // user-declared functions can't make use of template lists tl.finish(ctx)?; let arguments = arguments .iter() .enumerate() .map(|(i, &arg)| { // Try to convert abstract values to the known argument types let Some(&ir::FunctionArgument { ty: parameter_ty, .. }) = ctx.module.functions[function].arguments.get(i) else { // Wrong number of arguments... just concretize the type here // and let the validator report the error. return self.expression(arg, ctx); }; let expr = self.expression_for_abstract(arg, ctx)?; ctx.try_automatic_conversions( expr, &proc::TypeResolution::Handle(parameter_ty), ctx.ast_expressions.get_span(arg), ) }) .collect::>>()?; let has_result = ctx.module.functions[function].result.is_some(); let rctx = ctx.runtime_expression_ctx(span)?; // we need to always do this before a fn call since all arguments need to be emitted before the fn call rctx.block .extend(rctx.emitter.finish(&rctx.function.expressions)); let result = has_result.then(|| { let result = rctx .function .expressions .append(ir::Expression::CallResult(function), span); rctx.local_expression_kind_tracker .insert(result, proc::ExpressionKind::Runtime); (result, must_use.into()) }); rctx.emitter.start(&rctx.function.expressions); rctx.block.push( ir::Statement::Call { function, arguments, result: result.map(|(expr, _)| expr), }, span, ); result } None => { // If the name refers to a predeclared type, this is a construction expression. let ty = conv::map_predeclared_type( &ctx.enable_extensions, function_span, function_name, )?; if let Some(ty) = ty { let empty_template_list = call_phrase.function.template_list.is_empty(); let constructor_ty = match ty { conv::PredeclaredType::TypeGenerator(conv::TypeGenerator::Vector { size, }) if empty_template_list => Constructor::PartialVector { size }, conv::PredeclaredType::TypeGenerator(conv::TypeGenerator::Matrix { columns, rows, }) if empty_template_list => Constructor::PartialMatrix { columns, rows }, conv::PredeclaredType::TypeGenerator(conv::TypeGenerator::Array) if empty_template_list => { Constructor::PartialArray } conv::PredeclaredType::TypeGenerator( conv::TypeGenerator::CooperativeMatrix { .. }, ) if empty_template_list => { return Err(Box::new(Error::UnderspecifiedCooperativeMatrix)); } _ => Constructor::Type(self.finalize_type(ctx, ty, &mut tl, None)?), }; tl.finish(ctx)?; let handle = self.construct(span, constructor_ty, function_span, arguments, ctx)?; Some((handle, MustUse::Yes)) } else { // Otherwise, it must be a call to a builtin function. let result = self.call_builtin( function_name, function_span, arguments, &mut tl, span, ctx, is_statement, )?; tl.finish(ctx)?; result } } }; let result_used = !is_statement; if matches!(result, Some((_, MustUse::Yes))) && !result_used { return Err(Box::new(Error::FunctionMustUseUnused(function_span))); } Ok(result.map(|(expr, _)| expr)) } /// Generate a Naga IR [`Math`] expression. /// /// Generate Naga IR for a call to the [`MathFunction`] `fun`, whose /// unlowered arguments are `ast_arguments`. /// /// The `span` argument should give the span of the function name in the /// call expression. /// /// [`Math`]: ir::Expression::Math /// [`MathFunction`]: ir::MathFunction fn math_function_helper( &mut self, span: Span, fun: ir::MathFunction, ast_arguments: &[Handle>], ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result<'source, ir::Expression> { let mut lowered_arguments = Vec::with_capacity(ast_arguments.len()); for &arg in ast_arguments { let lowered = self.expression_for_abstract(arg, ctx)?; ctx.grow_types(lowered)?; lowered_arguments.push(lowered); } let fun_overloads = fun.overloads(); let rule = self.resolve_overloads(span, fun, fun_overloads, &lowered_arguments, ctx)?; self.apply_automatic_conversions_for_call(&rule, &mut lowered_arguments, ctx)?; // If this function returns a predeclared type, register it // in `Module::special_types`. The typifier will expect to // be able to find it there. if let proc::Conclusion::Predeclared(predeclared) = rule.conclusion { ctx.module.generate_predeclared_type(predeclared); } Ok(ir::Expression::Math { fun, arg: lowered_arguments[0], arg1: lowered_arguments.get(1).cloned(), arg2: lowered_arguments.get(2).cloned(), arg3: lowered_arguments.get(3).cloned(), }) } /// Choose the right overload for a function call. /// /// Return a [`Rule`] representing the most preferred overload in /// `overloads` to apply to `arguments`, or return an error explaining why /// the call is not valid. /// /// Use `fun` to identify the function being called in error messages; /// `span` should be the span of the function name in the call expression. /// /// [`Rule`]: proc::Rule fn resolve_overloads( &self, span: Span, fun: F, overloads: O, arguments: &[Handle], ctx: &ExpressionContext<'source, '_, '_>, ) -> Result<'source, proc::Rule> where O: proc::OverloadSet, F: TryToWgsl + core::fmt::Debug + Copy, { let mut remaining_overloads = overloads.clone(); let min_arguments = remaining_overloads.min_arguments(); let max_arguments = remaining_overloads.max_arguments(); if arguments.len() < min_arguments { return Err(Box::new(Error::WrongArgumentCount { span, expected: min_arguments as u32..max_arguments as u32, found: arguments.len() as u32, })); } if arguments.len() > max_arguments { return Err(Box::new(Error::TooManyArguments { function: fun.to_wgsl_for_diagnostics(), call_span: span, arg_span: ctx.get_expression_span(arguments[max_arguments]), max_arguments: max_arguments as _, })); } log::debug!( "Initial overloads: {:#?}", remaining_overloads.for_debug(&ctx.module.types) ); for (arg_index, &arg) in arguments.iter().enumerate() { let arg_type_resolution = &ctx.typifier()[arg]; let arg_inner = arg_type_resolution.inner_with(&ctx.module.types); log::debug!( "Supplying argument {arg_index} of type {:?}", arg_type_resolution.for_debug(&ctx.module.types) ); let next_remaining_overloads = remaining_overloads.arg(arg_index, arg_inner, &ctx.module.types); // If any argument is not a constant expression, then no overloads // that accept abstract values should be considered. // (`OverloadSet::concrete_only` is supposed to help impose this // restriction.) However, no `MathFunction` accepts a mix of // abstract and concrete arguments, so we don't need to worry // about that here. log::debug!( "Remaining overloads: {:#?}", next_remaining_overloads.for_debug(&ctx.module.types) ); // If the set of remaining overloads is empty, then this argument's type // was unacceptable. Diagnose the problem and produce an error message. if next_remaining_overloads.is_empty() { let function = fun.to_wgsl_for_diagnostics(); let call_span = span; let arg_span = ctx.get_expression_span(arg); let arg_ty = ctx.as_diagnostic_display(arg_type_resolution).to_string(); // Is this type *ever* permitted for the arg_index'th argument? // For example, `bool` is never permitted for `max`. let only_this_argument = overloads.arg(arg_index, arg_inner, &ctx.module.types); if only_this_argument.is_empty() { // No overload of `fun` accepts this type as the // arg_index'th argument. Determine the set of types that // would ever be allowed there. let allowed: Vec = overloads .allowed_args(arg_index, &ctx.module.to_ctx()) .iter() .map(|ty| ctx.type_resolution_to_string(ty)) .collect(); if allowed.is_empty() { // No overload of `fun` accepts any argument at this // index, so it's a simple case of excess arguments. // However, since each `MathFunction`'s overloads all // have the same arity, we should have detected this // earlier. unreachable!("expected all overloads to have the same arity"); } // Some overloads of `fun` do accept this many arguments, // but none accept one of this type. return Err(Box::new(Error::WrongArgumentType { function, call_span, arg_span, arg_index: arg_index as u32, arg_ty, allowed, })); } // This argument's type is accepted by some overloads---just // not those overloads that remain, given the prior arguments. // For example, `max` accepts `f32` as its second argument - // but not if the first was `i32`. // Build a list of the types that would have been accepted here, // given the prior arguments. let allowed: Vec = remaining_overloads .allowed_args(arg_index, &ctx.module.to_ctx()) .iter() .map(|ty| ctx.type_resolution_to_string(ty)) .collect(); // Re-run the argument list to determine which prior argument // made this one unacceptable. let mut remaining_overloads = overloads; for (prior_index, &prior_expr) in arguments.iter().enumerate() { let prior_type_resolution = &ctx.typifier()[prior_expr]; let prior_ty = prior_type_resolution.inner_with(&ctx.module.types); remaining_overloads = remaining_overloads.arg(prior_index, prior_ty, &ctx.module.types); if remaining_overloads .arg(arg_index, arg_inner, &ctx.module.types) .is_empty() { // This is the argument that killed our dreams. let inconsistent_span = ctx.get_expression_span(arguments[prior_index]); let inconsistent_ty = ctx.as_diagnostic_display(prior_type_resolution).to_string(); if allowed.is_empty() { // Some overloads did accept `ty` at `arg_index`, but // given the arguments up through `prior_expr`, we see // no types acceptable at `arg_index`. This means that some // overloads expect fewer arguments than others. However, // each `MathFunction`'s overloads have the same arity, so this // should be impossible. unreachable!("expected all overloads to have the same arity"); } // Report `arg`'s type as inconsistent with `prior_expr`'s return Err(Box::new(Error::InconsistentArgumentType { function, call_span, arg_span, arg_index: arg_index as u32, arg_ty, inconsistent_span, inconsistent_index: prior_index as u32, inconsistent_ty, allowed, })); } } unreachable!("Failed to eliminate argument type when re-tried"); } remaining_overloads = next_remaining_overloads; } // Select the most preferred type rule for this call, // given the argument types supplied above. Ok(remaining_overloads.most_preferred()) } /// Apply automatic type conversions for a function call. /// /// Apply whatever automatic conversions are needed to pass `arguments` to /// the function overload described by `rule`. Update `arguments` to refer /// to the converted arguments. fn apply_automatic_conversions_for_call( &self, rule: &proc::Rule, arguments: &mut [Handle], ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result<'source, ()> { for (i, argument) in arguments.iter_mut().enumerate() { let goal_inner = rule.arguments[i].inner_with(&ctx.module.types); let converted = match goal_inner.scalar_for_conversions(&ctx.module.types) { Some(goal_scalar) => { let arg_span = ctx.get_expression_span(*argument); ctx.try_automatic_conversion_for_leaf_scalar(*argument, goal_scalar, arg_span)? } // No conversion is necessary. None => *argument, }; *argument = converted; } Ok(()) } fn atomic_pointer( &mut self, expr: Handle>, ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result<'source, (Handle, ir::Scalar)> { let span = ctx.ast_expressions.get_span(expr); let pointer = self.expression(expr, ctx)?; match *resolve_inner!(ctx, pointer) { ir::TypeInner::Pointer { base, .. } => match ctx.module.types[base].inner { ir::TypeInner::Atomic(scalar) => Ok((pointer, scalar)), ref other => { log::error!("Pointer type to {other:?} passed to atomic op"); Err(Box::new(Error::InvalidAtomicPointer(span))) } }, ref other => { log::error!("Type {other:?} passed to atomic op"); Err(Box::new(Error::InvalidAtomicPointer(span))) } } } fn atomic_helper( &mut self, span: Span, fun: ir::AtomicFunction, args: &[Handle>], is_statement: bool, ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result<'source, Option>> { let mut args = ctx.prepare_args(args, 2, span); let (pointer, scalar) = self.atomic_pointer(args.next()?, ctx)?; let value = self.expression_with_leaf_scalar(args.next()?, scalar, ctx)?; let value_inner = resolve_inner!(ctx, value); args.finish()?; // If we don't use the return value of a 64-bit `min` or `max` // operation, generate a no-result form of the `Atomic` statement, so // that we can pass validation with only `SHADER_INT64_ATOMIC_MIN_MAX` // whenever possible. let is_64_bit_min_max = matches!(fun, ir::AtomicFunction::Min | ir::AtomicFunction::Max) && matches!( *value_inner, ir::TypeInner::Scalar(ir::Scalar { width: 8, .. }) ); let result = if is_64_bit_min_max && is_statement { let rctx = ctx.runtime_expression_ctx(span)?; rctx.block .extend(rctx.emitter.finish(&rctx.function.expressions)); rctx.emitter.start(&rctx.function.expressions); None } else { let ty = ctx.register_type(value)?; Some(ctx.interrupt_emitter( ir::Expression::AtomicResult { ty, comparison: false, }, span, )?) }; let rctx = ctx.runtime_expression_ctx(span)?; rctx.block.push( ir::Statement::Atomic { pointer, fun, value, result, }, span, ); Ok(result) } fn texture_sample_helper( &mut self, fun: Texture, args: &[Handle>], span: Span, ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result<'source, ir::Expression> { let mut args = ctx.prepare_args(args, fun.min_argument_count(), span); fn get_image_and_span<'source>( lowerer: &mut Lowerer<'source, '_>, args: &mut ArgumentContext<'_, 'source>, ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result<'source, (Handle, Span)> { let image = args.next()?; let image_span = ctx.ast_expressions.get_span(image); let image = lowerer.expression_for_abstract(image, ctx)?; Ok((image, image_span)) } let image; let image_span; let gather; match fun { Texture::Gather => { let image_or_component = args.next()?; let image_or_component_span = ctx.ast_expressions.get_span(image_or_component); // Gathers from depth textures don't take an initial `component` argument. let lowered_image_or_component = self.expression(image_or_component, ctx)?; match *resolve_inner!(ctx, lowered_image_or_component) { ir::TypeInner::Image { class: ir::ImageClass::Depth { .. }, .. } => { image = lowered_image_or_component; image_span = image_or_component_span; gather = Some(ir::SwizzleComponent::X); } _ => { (image, image_span) = get_image_and_span(self, &mut args, ctx)?; gather = Some(ctx.gather_component( lowered_image_or_component, image_or_component_span, span, )?); } } } Texture::GatherCompare => { (image, image_span) = get_image_and_span(self, &mut args, ctx)?; gather = Some(ir::SwizzleComponent::X); } _ => { (image, image_span) = get_image_and_span(self, &mut args, ctx)?; gather = None; } }; let sampler = self.expression_for_abstract(args.next()?, ctx)?; let coordinate = self.expression_with_leaf_scalar(args.next()?, ir::Scalar::F32, ctx)?; let clamp_to_edge = matches!(fun, Texture::SampleBaseClampToEdge); let (class, arrayed) = ctx.image_data(image, image_span)?; let array_index = arrayed .then(|| self.expression(args.next()?, ctx)) .transpose()?; let level; let depth_ref; match fun { Texture::Gather => { level = ir::SampleLevel::Zero; depth_ref = None; } Texture::GatherCompare => { let reference = self.expression_with_leaf_scalar(args.next()?, ir::Scalar::F32, ctx)?; level = ir::SampleLevel::Zero; depth_ref = Some(reference); } Texture::Sample => { level = ir::SampleLevel::Auto; depth_ref = None; } Texture::SampleBias => { let bias = self.expression_with_leaf_scalar(args.next()?, ir::Scalar::F32, ctx)?; level = ir::SampleLevel::Bias(bias); depth_ref = None; } Texture::SampleCompare => { let reference = self.expression_with_leaf_scalar(args.next()?, ir::Scalar::F32, ctx)?; level = ir::SampleLevel::Auto; depth_ref = Some(reference); } Texture::SampleCompareLevel => { let reference = self.expression_with_leaf_scalar(args.next()?, ir::Scalar::F32, ctx)?; level = ir::SampleLevel::Zero; depth_ref = Some(reference); } Texture::SampleGrad => { let x = self.expression_with_leaf_scalar(args.next()?, ir::Scalar::F32, ctx)?; let y = self.expression_with_leaf_scalar(args.next()?, ir::Scalar::F32, ctx)?; level = ir::SampleLevel::Gradient { x, y }; depth_ref = None; } Texture::SampleLevel => { let exact = match class { // When applied to depth textures, `textureSampleLevel`'s // `level` argument is an `i32` or `u32`. ir::ImageClass::Depth { .. } => self.expression(args.next()?, ctx)?, // When applied to other sampled types, its `level` argument // is an `f32`. ir::ImageClass::Sampled { .. } => { self.expression_with_leaf_scalar(args.next()?, ir::Scalar::F32, ctx)? } // Sampling `External` textures with a specified level isn't // allowed, and sampling `Storage` textures isn't allowed at // all. Let the validator report the error. ir::ImageClass::Storage { .. } | ir::ImageClass::External => { self.expression(args.next()?, ctx)? } }; level = ir::SampleLevel::Exact(exact); depth_ref = None; } Texture::SampleBaseClampToEdge => { level = crate::SampleLevel::Zero; depth_ref = None; } }; let offset = args .next() .map(|arg| self.expression_with_leaf_scalar(arg, ir::Scalar::I32, &mut ctx.as_const())) .ok() .transpose()?; args.finish()?; Ok(ir::Expression::ImageSample { image, sampler, gather, coordinate, array_index, offset, level, depth_ref, clamp_to_edge, }) } fn subgroup_operation_helper( &mut self, span: Span, op: ir::SubgroupOperation, collective_op: ir::CollectiveOperation, arguments: &[Handle>], ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result<'source, Handle> { let mut args = ctx.prepare_args(arguments, 1, span); let argument = self.expression(args.next()?, ctx)?; args.finish()?; let ty = ctx.register_type(argument)?; let result = ctx.interrupt_emitter(ir::Expression::SubgroupOperationResult { ty }, span)?; let rctx = ctx.runtime_expression_ctx(span)?; rctx.block.push( ir::Statement::SubgroupCollectiveOperation { op, collective_op, argument, result, }, span, ); Ok(result) } fn subgroup_gather_helper( &mut self, span: Span, mode: SubgroupGather, arguments: &[Handle>], ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result<'source, Handle> { let mut args = ctx.prepare_args(arguments, 2, span); let argument = self.expression(args.next()?, ctx)?; use SubgroupGather as Sg; let mode = if let Sg::BroadcastFirst = mode { ir::GatherMode::BroadcastFirst } else { let index = self.expression(args.next()?, ctx)?; match mode { Sg::BroadcastFirst => unreachable!(), Sg::Broadcast => ir::GatherMode::Broadcast(index), Sg::Shuffle => ir::GatherMode::Shuffle(index), Sg::ShuffleDown => ir::GatherMode::ShuffleDown(index), Sg::ShuffleUp => ir::GatherMode::ShuffleUp(index), Sg::ShuffleXor => ir::GatherMode::ShuffleXor(index), Sg::QuadBroadcast => ir::GatherMode::QuadBroadcast(index), } }; args.finish()?; let ty = ctx.register_type(argument)?; let result = ctx.interrupt_emitter(ir::Expression::SubgroupOperationResult { ty }, span)?; let rctx = ctx.runtime_expression_ctx(span)?; rctx.block.push( ir::Statement::SubgroupGather { mode, argument, result, }, span, ); Ok(result) } fn r#struct( &mut self, s: &ast::Struct<'source>, span: Span, ctx: &mut GlobalContext<'source, '_, '_>, ) -> Result<'source, Handle> { let mut offset = 0; let mut struct_alignment = proc::Alignment::ONE; let mut members = Vec::with_capacity(s.members.len()); let mut doc_comments: Vec>> = Vec::new(); for member in s.members.iter() { let ty = self.resolve_ast_type(&member.ty, &mut ctx.as_const())?; ctx.layouter.update(ctx.module.to_ctx()).map_err(|err| { let LayoutErrorInner::TooLarge = err.inner else { unreachable!("unexpected layout error: {err:?}"); }; // Since anonymous types of struct members don't get a span, // associate the error with the member. The layouter could have // failed on any type that was pending layout, but if it wasn't // the current struct member, it wasn't a struct member at all, // because we resolve struct members one-by-one. if ty == err.ty { Box::new(Error::StructMemberTooLarge { member_name_span: member.name.span, }) } else { // Lots of type definitions don't get spans, so this error // message may not be very useful. Box::new(Error::TypeTooLarge { span: ctx.module.types.get_span(err.ty), }) } })?; let member_min_size = ctx.layouter[ty].size; let member_min_alignment = ctx.layouter[ty].alignment; let member_size = if let Some(size_expr) = member.size { let (size, span) = self.const_u32(size_expr, &mut ctx.as_const())?; if size < member_min_size { return Err(Box::new(Error::SizeAttributeTooLow(span, member_min_size))); } else { size } } else { member_min_size }; let member_alignment = if let Some(align_expr) = member.align { let (align, span) = self.const_u32(align_expr, &mut ctx.as_const())?; if let Some(alignment) = proc::Alignment::new(align) { if alignment < member_min_alignment { return Err(Box::new(Error::AlignAttributeTooLow( span, member_min_alignment, ))); } else { alignment } } else { return Err(Box::new(Error::NonPowerOfTwoAlignAttribute(span))); } } else { member_min_alignment }; let binding = self.binding(&member.binding, ty, ctx)?; offset = member_alignment.round_up(offset); struct_alignment = struct_alignment.max(member_alignment); if !member.doc_comments.is_empty() { doc_comments.push(Some( member.doc_comments.iter().map(|s| s.to_string()).collect(), )); } members.push(ir::StructMember { name: Some(member.name.name.to_owned()), ty, binding, offset, }); offset += member_size; if offset > crate::valid::MAX_TYPE_SIZE { return Err(Box::new(Error::TypeTooLarge { span })); } } let size = struct_alignment.round_up(offset); let inner = ir::TypeInner::Struct { members, span: size, }; let handle = ctx.module.types.insert( ir::Type { name: Some(s.name.name.to_string()), inner, }, span, ); for (i, c) in doc_comments.drain(..).enumerate() { if let Some(comment) = c { ctx.module .get_or_insert_default_doc_comments() .struct_members .insert((handle, i), comment); } } Ok(handle) } fn const_u32( &mut self, expr: Handle>, ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result<'source, (u32, Span)> { let span = ctx.ast_expressions.get_span(expr); let expr = self.expression(expr, ctx)?; let value = ctx .module .to_ctx() .get_const_val(expr) .map_err(|err| match err { proc::ConstValueError::NonConst | proc::ConstValueError::InvalidType => { Error::ExpectedConstExprConcreteIntegerScalar(span) } proc::ConstValueError::Negative => Error::ExpectedNonNegative(span), })?; Ok((value, span)) } fn array_size( &mut self, expr: Handle>, ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result<'source, ir::ArraySize> { let span = ctx.ast_expressions.get_span(expr); let const_ctx = &mut ctx.as_const(); let const_expr = self.expression(expr, const_ctx); match const_expr { Ok(value) => { let len = const_ctx.get_const_val(value).map_err(|err| { Box::new(match err { proc::ConstValueError::NonConst | proc::ConstValueError::InvalidType => { Error::ExpectedConstExprConcreteIntegerScalar(span) } proc::ConstValueError::Negative => Error::ExpectedPositiveArrayLength(span), }) })?; let size = NonZeroU32::new(len).ok_or(Error::ExpectedPositiveArrayLength(span))?; Ok(ir::ArraySize::Constant(size)) } Err(err) => { // If the error is simply that `expr` was an override expression, then we // can represent that as an array length. let Error::ConstantEvaluatorError(ref ty, _) = *err else { return Err(err); }; let proc::ConstantEvaluatorError::OverrideExpr = **ty else { return Err(err); }; Ok(ir::ArraySize::Pending(self.array_size_override( expr, &mut ctx.as_global().as_override(), span, )?)) } } } fn array_size_override( &mut self, size_expr: Handle>, ctx: &mut ExpressionContext<'source, '_, '_>, span: Span, ) -> Result<'source, Handle> { let expr = self.expression(size_expr, ctx)?; match resolve_inner!(ctx, expr).scalar_kind().ok_or(0) { Ok(ir::ScalarKind::Sint) | Ok(ir::ScalarKind::Uint) => Ok({ if let ir::Expression::Override(handle) = ctx.module.global_expressions[expr] { handle } else { let ty = ctx.register_type(expr)?; ctx.module.overrides.append( ir::Override { name: None, id: None, ty, init: Some(expr), }, span, ) } }), _ => Err(Box::new(Error::ExpectedConstExprConcreteIntegerScalar( span, ))), } } /// Build the Naga equivalent of a named AST type. /// /// Return a Naga `Handle` representing the front-end type /// `handle`, which should be named `name`, if given. /// /// If `handle` refers to a type cached in [`SpecialTypes`], /// `name` may be ignored. /// /// [`SpecialTypes`]: ir::SpecialTypes fn resolve_named_ast_type( &mut self, ident: &ast::TemplateElaboratedIdent<'source>, name: String, ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result<'source, Handle> { self.type_specifier(ident, ctx, Some(name)) } /// Return a Naga `Handle` representing the front-end type `handle`. fn resolve_ast_type( &mut self, ident: &ast::TemplateElaboratedIdent<'source>, ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result<'source, Handle> { self.type_specifier(ident, ctx, None) } fn binding( &mut self, binding: &Option>, ty: Handle, ctx: &mut GlobalContext<'source, '_, '_>, ) -> Result<'source, Option> { Ok(match *binding { Some(ast::Binding::BuiltIn(b)) => Some(ir::Binding::BuiltIn(b)), Some(ast::Binding::Location { location, interpolation, sampling, blend_src, per_primitive, }) => { let blend_src = if let Some(blend_src) = blend_src { Some(self.const_u32(blend_src, &mut ctx.as_const())?.0) } else { None }; let mut binding = ir::Binding::Location { location: self.const_u32(location, &mut ctx.as_const())?.0, interpolation, sampling, blend_src, per_primitive, }; binding.apply_default_interpolation(&ctx.module.types[ty].inner); Some(binding) } None => None, }) } fn ray_query_pointer( &mut self, expr: Handle>, ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result<'source, Handle> { let span = ctx.ast_expressions.get_span(expr); let pointer = self.expression(expr, ctx)?; match *resolve_inner!(ctx, pointer) { ir::TypeInner::Pointer { base, .. } => match ctx.module.types[base].inner { ir::TypeInner::RayQuery { .. } => Ok(pointer), ref other => { log::error!("Pointer type to {other:?} passed to ray query op"); Err(Box::new(Error::InvalidRayQueryPointer(span))) } }, ref other => { log::error!("Type {other:?} passed to ray query op"); Err(Box::new(Error::InvalidRayQueryPointer(span))) } } } } impl ir::AtomicFunction { pub fn map(word: &str) -> Option { Some(match word { "atomicAdd" => ir::AtomicFunction::Add, "atomicSub" => ir::AtomicFunction::Subtract, "atomicAnd" => ir::AtomicFunction::And, "atomicOr" => ir::AtomicFunction::InclusiveOr, "atomicXor" => ir::AtomicFunction::ExclusiveOr, "atomicMin" => ir::AtomicFunction::Min, "atomicMax" => ir::AtomicFunction::Max, "atomicExchange" => ir::AtomicFunction::Exchange { compare: None }, _ => return None, }) } } naga-29.0.3/src/front/wgsl/lower/template_list.rs000064400000000000000000000152251046102023000200710ustar 00000000000000use alloc::{boxed::Box, vec::Vec}; use crate::{ front::wgsl::{ error::Error, lower::{ExpressionContext, Lowerer, Result}, parse::{ast, conv}, }, ir, Handle, Span, }; /// Iterator over a template list. /// /// All functions will attempt to consume an element in the list. /// /// Function variants prefixed with "maybe" will not return an error if there /// are no more elements left in the list. pub struct TemplateListIter<'iter, 'source> { ident_span: Span, template_list: core::slice::Iter<'iter, Handle>>, } impl<'iter, 'source> TemplateListIter<'iter, 'source> { pub fn new(ident_span: Span, template_list: &'iter [Handle>]) -> Self { Self { ident_span, template_list: template_list.iter(), } } pub fn finish(self, ctx: &ExpressionContext<'source, '_, '_>) -> Result<'source, ()> { let unused_args: Vec = self .template_list .map(|expr| ctx.ast_expressions.get_span(*expr)) .collect(); if unused_args.is_empty() { Ok(()) } else { Err(Box::new(Error::UnusedArgsForTemplate(unused_args))) } } fn expect_next( &mut self, description: &'static str, ) -> Result<'source, Handle>> { if let Some(expr) = self.template_list.next() { Ok(*expr) } else { Err(Box::new(Error::MissingTemplateArg { span: self.ident_span, description, })) } } pub fn ty( &mut self, lowerer: &mut Lowerer<'source, '_>, ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result<'source, Handle> { let expr = self.expect_next("`T`, a type")?; lowerer.type_expression(expr, ctx) } /// Lower the next template list element as a type, and return its span. /// /// This returns the span of the template list element. This is generally /// different from the span of the returned `Handle`, as the /// latter may refer to the type's definition, not its use in the template list. pub fn ty_with_span( &mut self, lowerer: &mut Lowerer<'source, '_>, ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result<'source, (Handle, Span)> { let expr = self.expect_next("`T`, a type")?; let span = ctx.ast_expressions.get_span(expr); let ty = lowerer.type_expression(expr, ctx)?; Ok((ty, span)) } pub fn scalar_ty( &mut self, lowerer: &mut Lowerer<'source, '_>, ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result<'source, (ir::Scalar, Span)> { let expr = self.expect_next("`T`, a scalar type")?; let ty = lowerer.type_expression(expr, ctx)?; let span = ctx.ast_expressions.get_span(expr); match ctx.module.types[ty].inner { ir::TypeInner::Scalar(scalar) => Ok((scalar, span)), _ => Err(Box::new(Error::UnknownScalarType(span))), } } pub fn maybe_array_size( &mut self, lowerer: &mut Lowerer<'source, '_>, ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result<'source, ir::ArraySize> { if let Some(expr) = self.template_list.next() { lowerer.array_size(*expr, ctx) } else { Ok(ir::ArraySize::Dynamic) } } pub fn address_space( &mut self, ctx: &ExpressionContext<'source, '_, '_>, ) -> Result<'source, ir::AddressSpace> { let expr = self.expect_next("`AS`, an address space")?; let (enumerant, span) = ctx.enumerant(expr)?; conv::map_address_space(enumerant, span, &ctx.enable_extensions) } pub fn maybe_address_space( &mut self, ctx: &ExpressionContext<'source, '_, '_>, ) -> Result<'source, Option> { let Some(expr) = self.template_list.next() else { return Ok(None); }; let (enumerant, span) = ctx.enumerant(*expr)?; Ok(Some(conv::map_address_space( enumerant, span, &ctx.enable_extensions, )?)) } pub fn access_mode( &mut self, ctx: &ExpressionContext<'source, '_, '_>, ) -> Result<'source, ir::StorageAccess> { let expr = self.expect_next("`Access`, an access mode")?; let (enumerant, span) = ctx.enumerant(expr)?; conv::map_access_mode(enumerant, span) } /// Update `space` with an access mode from `self`, if appropriate. /// /// If `space` is [`Storage`], and there is another template parameter in /// `self`, parse it as a storage mode and mutate `space` accordingly. If /// there are no more template parameters, treat that like `read`. /// /// If `space` is some other address space, don't consume any template /// parameters. /// /// [`Storage`]: ir::AddressSpace::Storage pub fn maybe_access_mode( &mut self, space: &mut ir::AddressSpace, ctx: &ExpressionContext<'source, '_, '_>, ) -> Result<'source, ()> { if let &mut ir::AddressSpace::Storage { ref mut access } = space { if let Some(expr) = self.template_list.next() { let (enumerant, span) = ctx.enumerant(*expr)?; let access_mode = conv::map_access_mode(enumerant, span)?; *access = access_mode; } else { // defaulting to `read` *access = ir::StorageAccess::LOAD } } Ok(()) } pub fn storage_format( &mut self, ctx: &ExpressionContext<'source, '_, '_>, ) -> Result<'source, ir::StorageFormat> { let expr = self.expect_next("`Format`, a texel format")?; let (enumerant, span) = ctx.enumerant(expr)?; conv::map_storage_format(enumerant, span) } pub fn maybe_vertex_return( &mut self, ctx: &ExpressionContext<'source, '_, '_>, ) -> Result<'source, bool> { let Some(expr) = self.template_list.next() else { return Ok(false); }; let (enumerant, span) = ctx.enumerant(*expr)?; conv::map_ray_flag(&ctx.enable_extensions, enumerant, span)?; Ok(true) } pub fn cooperative_role( &mut self, ctx: &ExpressionContext<'source, '_, '_>, ) -> Result<'source, crate::CooperativeRole> { let role_expr = self.expect_next("`Role`, a cooperative matrix role")?; let (enumerant, span) = ctx.enumerant(role_expr)?; let role = conv::map_cooperative_role(enumerant, span)?; Ok(role) } } naga-29.0.3/src/front/wgsl/mod.rs000064400000000000000000000060261046102023000146510ustar 00000000000000/*! Frontend for [WGSL][wgsl] (WebGPU Shading Language). [wgsl]: https://gpuweb.github.io/gpuweb/wgsl.html */ mod error; mod index; mod lower; mod parse; #[cfg(test)] mod tests; pub use parse::directive::enable_extension::{ EnableExtension, ImplementedEnableExtension, UnimplementedEnableExtension, }; pub use crate::front::wgsl::error::ParseError; pub use crate::front::wgsl::parse::directive::language_extension::{ ImplementedLanguageExtension, LanguageExtension, UnimplementedLanguageExtension, }; pub use crate::front::wgsl::parse::Options; use alloc::boxed::Box; use thiserror::Error; use crate::front::wgsl::error::Error; use crate::front::wgsl::lower::Lowerer; use crate::front::wgsl::parse::Parser; use crate::Scalar; #[cfg(test)] use std::println; pub(crate) type Result<'a, T> = core::result::Result>>; pub struct Frontend { parser: Parser, options: Options, } impl Frontend { pub const fn new() -> Self { Self { parser: Parser::new(), options: Options::new(), } } pub const fn new_with_options(options: Options) -> Self { Self { parser: Parser::new(), options, } } pub const fn set_options(&mut self, options: Options) { self.options = options; } pub fn parse(&mut self, source: &str) -> core::result::Result { self.inner(source).map_err(|x| x.as_parse_error(source)) } fn inner<'a>(&mut self, source: &'a str) -> Result<'a, crate::Module> { let tu = self.parser.parse(source, &self.options)?; let index = index::Index::generate(&tu)?; let module = Lowerer::new(&index).lower(tu)?; Ok(module) } } ///

pub fn parse_str(source: &str) -> core::result::Result { Frontend::new().parse(source) } #[cfg(test)] #[track_caller] pub fn assert_parse_err(input: &str, snapshot: &str) { let output = parse_str(input) .expect_err("expected parser error") .emit_to_string(input); if output != snapshot { for diff in diff::lines(snapshot, &output) { match diff { diff::Result::Left(l) => println!("-{l}"), diff::Result::Both(l, _) => println!(" {l}"), diff::Result::Right(r) => println!("+{r}"), } } panic!("Error snapshot failed"); } } naga-29.0.3/src/front/wgsl/parse/ast.rs000064400000000000000000000325731046102023000160010ustar 00000000000000use alloc::vec::Vec; use core::hash::Hash; use crate::diagnostic_filter::DiagnosticFilterNode; use crate::front::wgsl::parse::directive::enable_extension::EnableExtensions; use crate::front::wgsl::parse::number::Number; use crate::{Arena, FastIndexSet, Handle, Span}; #[derive(Debug, Default)] pub struct TranslationUnit<'a> { pub enable_extensions: EnableExtensions, pub decls: Arena>, /// The common expressions arena for the entire translation unit. /// /// All functions, global initializers, array lengths, etc. store their /// expressions here. We apportion these out to individual Naga /// [`Function`]s' expression arenas at lowering time. Keeping them all in a /// single arena simplifies handling of things like array lengths (which are /// effectively global and thus don't clearly belong to any function) and /// initializers (which can appear in both function-local and module-scope /// contexts). /// /// [`Function`]: crate::Function pub expressions: Arena>, /// Arena for all diagnostic filter rules parsed in this module, including those in functions. /// /// See [`DiagnosticFilterNode`] for details on how the tree is represented and used in /// validation. pub diagnostic_filters: Arena, /// The leaf of all `diagnostic(…)` directives in this module. /// /// See [`DiagnosticFilterNode`] for details on how the tree is represented and used in /// validation. pub diagnostic_filter_leaf: Option>, /// Doc comments appearing first in the file. /// This serves as documentation for the whole TranslationUnit. pub doc_comments: Vec<&'a str>, } #[derive(Debug, Clone, Copy)] pub struct Ident<'a> { pub name: &'a str, pub span: Span, } /// An identifier that [resolves] to some declaration. /// /// This does not cover context-dependent names: attributes, built-in values, /// and so on. We map those to their Naga IR equivalents as soon as they're /// parsed, so they never need to appear as identifiers in the AST. /// /// [resolves]: https://gpuweb.github.io/gpuweb/wgsl/#resolves #[derive(Debug)] pub enum IdentExpr<'a> { /// An identifier referring to a module-scope declaration or predeclared /// object. /// /// We need to collect the entire module before we can resolve this, to /// distinguish between predeclared objects and module-scope declarations /// that appear after their uses. /// /// Whenever you create one of these values, you almost certainly want to /// insert the `&str` into [`ExpressionContext::unresolved`][ECu], to ensure /// that [indexing] knows that the name's declaration must be lowered before /// the one containing this use. Using [`Parser::ident_expr`][ie] to build /// `IdentExpr` will take care of that for you. /// /// [ECu]: super::ExpressionContext::unresolved /// [ie]: super::Parser::ident_expr /// [indexing]: crate::front::wgsl::index::Index::generate Unresolved(&'a str), /// An identifier that has been resolved to a non-module-scope declaration. Local(Handle), } /// An identifier with optional template parameters. /// /// Following the WGSL specification (see the [`template_list`] non-terminal), /// `TemplateElaboratedIdent` represents all template parameters as expressions: /// even parameters to type generators, like the `f32` in `vec3`, are [Type /// Expressions]. /// /// # Examples /// /// - A use of a global variable `colors` would be an [`Expression::Ident(v)`][EI], /// where `v` is an `TemplateElaboratedIdent` whose `ident` is /// [`IdentExpr::Unresolved("colors")`][IEU]. Lowering will resolve this to a /// reference to the global variable. /// /// - The type `f32` in a variable declaration is represented as a /// `TemplateElaboratedIdent` whose `ident` is /// [`IdentExpr::Unresolved("f32")`][IEU]. Lowering will resolve this to /// WGSL's predeclared `f32` type. /// /// - The type `vec3` can be represented as a `TemplateElaboratedIdent` /// whose `ident` is [`IdentExpr::Unresolved("vec3")`][IEU], and whose /// `template_list` has one element: an [`ExpressionIdent(v)`][EI] where `v` is a /// nested `TemplateElaboratedIdent` representing `f32` as described above. /// /// - The type `array, 4>` has `"array"` as its `ident`, and then /// a two-element `template_list`: /// /// - `template_list[0]` is an [`Expression::Ident(v)`][EI] where `v` is a nested /// `TemplateElaboratedIdent` representing `vec3` as described above. /// /// - `template_list[1]` is an [`Expression`] representing `4`. /// /// After [indexing] the module to ensure that declarations appear before uses, /// lowering can see which declaration a given `TemplateElaboratedIdent`s /// `ident` refers to. The declaration then determines how to interpret the /// `template_list`. /// /// [`template_list`]: https://gpuweb.github.io/gpuweb/wgsl/#syntax-template_list /// [Type Expressions]: https://gpuweb.github.io/gpuweb/wgsl/#type-expr /// [IEU]: IdentExpr::Unresolved /// [EI]: Expression::Ident /// [indexing]: crate::front::wgsl::index::Index::generate #[derive(Debug)] pub struct TemplateElaboratedIdent<'a> { pub ident: IdentExpr<'a>, pub ident_span: Span, /// If non-empty, the template parameters following the identifier. pub template_list: Vec>>, pub template_list_span: Span, } /// A function call or value constructor expression. /// /// We can't tell whether an expression like `IDENTIFIER(EXPR, ...)` is a /// construction expression or a function call until we know `IDENTIFIER`'s /// definition, so we represent everything of that form as one of these /// expressions until lowering. At that point, [`Lowerer::call`] has /// everything's definition in hand, and can decide whether to emit a Naga /// [`Constant`], [`As`], [`Splat`], or [`Compose`] expression. /// /// [`Lowerer::call`]: Lowerer::call /// [`Constant`]: crate::Expression::Constant /// [`As`]: crate::Expression::As /// [`Splat`]: crate::Expression::Splat /// [`Compose`]: crate::Expression::Compose #[derive(Debug)] pub struct CallPhrase<'a> { pub function: TemplateElaboratedIdent<'a>, pub arguments: Vec>>, } /// A reference to a module-scope definition or predeclared object. /// /// Each [`GlobalDecl`] holds a set of these values, to be resolved to /// specific definitions later. To support de-duplication, `Eq` and /// `Hash` on a `Dependency` value consider only the name, not the /// source location at which the reference occurs. #[derive(Debug)] pub struct Dependency<'a> { /// The name referred to. pub ident: &'a str, /// The location at which the reference to that name occurs. pub usage: Span, } impl Hash for Dependency<'_> { fn hash(&self, state: &mut H) { self.ident.hash(state); } } impl PartialEq for Dependency<'_> { fn eq(&self, other: &Self) -> bool { self.ident == other.ident } } impl Eq for Dependency<'_> {} /// A module-scope declaration. #[derive(Debug)] pub struct GlobalDecl<'a> { pub kind: GlobalDeclKind<'a>, /// Names of all module-scope or predeclared objects this /// declaration uses. pub dependencies: FastIndexSet>, } #[derive(Debug)] pub enum GlobalDeclKind<'a> { Fn(Function<'a>), Var(GlobalVariable<'a>), Const(Const<'a>), Override(Override<'a>), Struct(Struct<'a>), Type(TypeAlias<'a>), ConstAssert(Handle>), } #[derive(Debug)] pub struct FunctionArgument<'a> { pub name: Ident<'a>, pub ty: TemplateElaboratedIdent<'a>, pub binding: Option>, pub handle: Handle, } #[derive(Debug)] pub struct FunctionResult<'a> { pub ty: TemplateElaboratedIdent<'a>, pub binding: Option>, pub must_use: bool, } #[derive(Debug)] pub struct EntryPoint<'a> { pub stage: crate::ShaderStage, pub early_depth_test: Option, pub workgroup_size: Option<[Option>>; 3]>, pub mesh_output_variable: Option<(&'a str, Span)>, pub task_payload: Option<(&'a str, Span)>, pub ray_incoming_payload: Option<(&'a str, Span)>, } #[cfg(doc)] use crate::front::wgsl::lower::{LocalExpressionContext, StatementContext}; #[derive(Debug)] pub struct Function<'a> { pub entry_point: Option>, pub name: Ident<'a>, pub arguments: Vec>, pub result: Option>, pub body: Block<'a>, pub diagnostic_filter_leaf: Option>, pub doc_comments: Vec<&'a str>, } #[derive(Debug)] pub enum Binding<'a> { BuiltIn(crate::BuiltIn), Location { location: Handle>, interpolation: Option, sampling: Option, blend_src: Option>>, per_primitive: bool, }, } #[derive(Debug)] pub struct ResourceBinding<'a> { pub group: Handle>, pub binding: Handle>, } #[derive(Debug)] pub struct GlobalVariable<'a> { pub name: Ident<'a>, /// The template list parameters for the `var`, giving the variable's /// address space and access mode, if present. pub template_list: Vec>>, /// The `@group` and `@binding` attributes, if present. pub binding: Option>, pub ty: Option>, pub init: Option>>, pub doc_comments: Vec<&'a str>, /// Memory decorations for this variable (`@coherent`, `@volatile`). pub memory_decorations: crate::MemoryDecorations, } #[derive(Debug)] pub struct StructMember<'a> { pub name: Ident<'a>, pub ty: TemplateElaboratedIdent<'a>, pub binding: Option>, pub align: Option>>, pub size: Option>>, pub doc_comments: Vec<&'a str>, } #[derive(Debug)] pub struct Struct<'a> { pub name: Ident<'a>, pub members: Vec>, pub doc_comments: Vec<&'a str>, } #[derive(Debug)] pub struct TypeAlias<'a> { pub name: Ident<'a>, pub ty: TemplateElaboratedIdent<'a>, } #[derive(Debug)] pub struct Const<'a> { pub name: Ident<'a>, pub ty: Option>, pub init: Handle>, pub doc_comments: Vec<&'a str>, } #[derive(Debug)] pub struct Override<'a> { pub name: Ident<'a>, pub id: Option>>, pub ty: Option>, pub init: Option>>, } #[derive(Debug, Default)] pub struct Block<'a> { pub stmts: Vec>, } #[derive(Debug)] pub struct Statement<'a> { pub kind: StatementKind<'a>, pub span: Span, } #[derive(Debug)] pub enum StatementKind<'a> { LocalDecl(LocalDecl<'a>), Block(Block<'a>), If { condition: Handle>, accept: Block<'a>, reject: Block<'a>, }, Switch { selector: Handle>, cases: Vec>, }, Loop { body: Block<'a>, continuing: Block<'a>, break_if: Option>>, }, Break, Continue, Return { value: Option>>, }, Kill, Call(CallPhrase<'a>), Assign { target: Handle>, op: Option, value: Handle>, }, Increment(Handle>), Decrement(Handle>), Phony(Handle>), ConstAssert(Handle>), } #[derive(Debug)] pub enum SwitchValue<'a> { Expr(Handle>), Default, } #[derive(Debug)] pub struct SwitchCase<'a> { pub value: SwitchValue<'a>, pub body: Block<'a>, pub fall_through: bool, } #[derive(Debug, Copy, Clone)] pub enum Literal { Bool(bool), Number(Number), } #[cfg(doc)] use crate::front::wgsl::lower::Lowerer; #[derive(Debug)] pub enum Expression<'a> { Literal(Literal), Ident(TemplateElaboratedIdent<'a>), Unary { op: crate::UnaryOperator, expr: Handle>, }, AddrOf(Handle>), Deref(Handle>), Binary { op: crate::BinaryOperator, left: Handle>, right: Handle>, }, Call(CallPhrase<'a>), Index { base: Handle>, index: Handle>, }, Member { base: Handle>, field: Ident<'a>, }, } #[derive(Debug)] pub struct LocalVariable<'a> { pub name: Ident<'a>, pub ty: Option>, pub init: Option>>, pub handle: Handle, } #[derive(Debug)] pub struct Let<'a> { pub name: Ident<'a>, pub ty: Option>, pub init: Handle>, pub handle: Handle, } #[derive(Debug)] pub struct LocalConst<'a> { pub name: Ident<'a>, pub ty: Option>, pub init: Handle>, pub handle: Handle, } #[derive(Debug)] pub enum LocalDecl<'a> { Var(LocalVariable<'a>), Let(Let<'a>), Const(LocalConst<'a>), } #[derive(Debug)] /// A placeholder for a local variable declaration. /// /// See [`super::ExpressionContext::locals`] for more information. pub struct Local; naga-29.0.3/src/front/wgsl/parse/conv.rs000064400000000000000000000617231046102023000161560ustar 00000000000000use crate::front::wgsl::parse::directive::enable_extension::{ EnableExtensions, ImplementedEnableExtension, }; use crate::front::wgsl::{Error, Result, Scalar}; use crate::{ImageClass, ImageDimension, Span, TypeInner, VectorSize}; use alloc::boxed::Box; pub fn map_address_space<'a>( word: &str, span: Span, enable_extensions: &EnableExtensions, ) -> Result<'a, crate::AddressSpace> { match word { "private" => Ok(crate::AddressSpace::Private), "workgroup" => Ok(crate::AddressSpace::WorkGroup), "uniform" => Ok(crate::AddressSpace::Uniform), "storage" => Ok(crate::AddressSpace::Storage { access: crate::StorageAccess::default(), }), "immediate" => Ok(crate::AddressSpace::Immediate), "function" => Ok(crate::AddressSpace::Function), "task_payload" => { enable_extensions.require(ImplementedEnableExtension::WgpuMeshShader, span)?; Ok(crate::AddressSpace::TaskPayload) } "ray_payload" => { if enable_extensions.contains(ImplementedEnableExtension::WgpuRayTracingPipeline) { Ok(crate::AddressSpace::RayPayload) } else { Err(Box::new(Error::EnableExtensionNotEnabled { span, kind: ImplementedEnableExtension::WgpuRayTracingPipeline.into(), })) } } "incoming_ray_payload" => { if enable_extensions.contains(ImplementedEnableExtension::WgpuRayTracingPipeline) { Ok(crate::AddressSpace::IncomingRayPayload) } else { Err(Box::new(Error::EnableExtensionNotEnabled { span, kind: ImplementedEnableExtension::WgpuRayTracingPipeline.into(), })) } } _ => Err(Box::new(Error::UnknownAddressSpace(span))), } } pub fn map_access_mode(word: &str, span: Span) -> Result<'_, crate::StorageAccess> { match word { "read" => Ok(crate::StorageAccess::LOAD), "write" => Ok(crate::StorageAccess::STORE), "read_write" => Ok(crate::StorageAccess::LOAD | crate::StorageAccess::STORE), "atomic" => Ok(crate::StorageAccess::ATOMIC | crate::StorageAccess::LOAD | crate::StorageAccess::STORE), _ => Err(Box::new(Error::UnknownAccess(span))), } } pub fn map_ray_flag( enable_extensions: &EnableExtensions, word: &str, span: Span, ) -> Result<'static, ()> { match word { "vertex_return" => { if !enable_extensions.contains(ImplementedEnableExtension::WgpuRayQueryVertexReturn) { return Err(Box::new(Error::EnableExtensionNotEnabled { span, kind: ImplementedEnableExtension::WgpuRayQueryVertexReturn.into(), })); } Ok(()) } _ => Err(Box::new(Error::UnknownRayFlag(span))), } } pub fn map_cooperative_role(word: &str, span: Span) -> Result<'_, crate::CooperativeRole> { match word { "A" => Ok(crate::CooperativeRole::A), "B" => Ok(crate::CooperativeRole::B), "C" => Ok(crate::CooperativeRole::C), _ => Err(Box::new(Error::UnknownAccess(span))), } } pub fn map_built_in( enable_extensions: &EnableExtensions, word: &str, span: Span, ) -> Result<'static, crate::BuiltIn> { let built_in = match word { "position" => crate::BuiltIn::Position { invariant: false }, // vertex "vertex_index" => crate::BuiltIn::VertexIndex, "instance_index" => crate::BuiltIn::InstanceIndex, "view_index" => crate::BuiltIn::ViewIndex, "clip_distances" => crate::BuiltIn::ClipDistance, // fragment "front_facing" => crate::BuiltIn::FrontFacing, "frag_depth" => crate::BuiltIn::FragDepth, "primitive_index" => crate::BuiltIn::PrimitiveIndex, "draw_index" => crate::BuiltIn::DrawIndex, "barycentric" => crate::BuiltIn::Barycentric { perspective: true }, "barycentric_no_perspective" => crate::BuiltIn::Barycentric { perspective: false }, "sample_index" => crate::BuiltIn::SampleIndex, "sample_mask" => crate::BuiltIn::SampleMask, // compute "global_invocation_id" => crate::BuiltIn::GlobalInvocationId, "local_invocation_id" => crate::BuiltIn::LocalInvocationId, "local_invocation_index" => crate::BuiltIn::LocalInvocationIndex, "workgroup_id" => crate::BuiltIn::WorkGroupId, "num_workgroups" => crate::BuiltIn::NumWorkGroups, // subgroup "num_subgroups" => crate::BuiltIn::NumSubgroups, "subgroup_id" => crate::BuiltIn::SubgroupId, "subgroup_size" => crate::BuiltIn::SubgroupSize, "subgroup_invocation_id" => crate::BuiltIn::SubgroupInvocationId, // mesh "cull_primitive" => crate::BuiltIn::CullPrimitive, "point_index" => crate::BuiltIn::PointIndex, "line_indices" => crate::BuiltIn::LineIndices, "triangle_indices" => crate::BuiltIn::TriangleIndices, "mesh_task_size" => crate::BuiltIn::MeshTaskSize, // mesh global variable "vertex_count" => crate::BuiltIn::VertexCount, "vertices" => crate::BuiltIn::Vertices, "primitive_count" => crate::BuiltIn::PrimitiveCount, "primitives" => crate::BuiltIn::Primitives, // ray tracing pipeline "ray_invocation_id" => crate::BuiltIn::RayInvocationId, "num_ray_invocations" => crate::BuiltIn::NumRayInvocations, "instance_custom_data" => crate::BuiltIn::InstanceCustomData, "geometry_index" => crate::BuiltIn::GeometryIndex, "world_ray_origin" => crate::BuiltIn::WorldRayOrigin, "world_ray_direction" => crate::BuiltIn::WorldRayDirection, "object_ray_origin" => crate::BuiltIn::ObjectRayOrigin, "object_ray_direction" => crate::BuiltIn::ObjectRayDirection, "ray_t_min" => crate::BuiltIn::RayTmin, "ray_t_current_max" => crate::BuiltIn::RayTCurrentMax, "object_to_world" => crate::BuiltIn::ObjectToWorld, "world_to_object" => crate::BuiltIn::WorldToObject, "hit_kind" => crate::BuiltIn::HitKind, _ => return Err(Box::new(Error::UnknownBuiltin(span))), }; match built_in { crate::BuiltIn::ClipDistance => { enable_extensions.require(ImplementedEnableExtension::ClipDistances, span)? } crate::BuiltIn::PrimitiveIndex => { enable_extensions.require(ImplementedEnableExtension::PrimitiveIndex, span)? } crate::BuiltIn::DrawIndex => { enable_extensions.require(ImplementedEnableExtension::DrawIndex, span)? } crate::BuiltIn::CullPrimitive | crate::BuiltIn::PointIndex | crate::BuiltIn::LineIndices | crate::BuiltIn::TriangleIndices | crate::BuiltIn::VertexCount | crate::BuiltIn::Vertices | crate::BuiltIn::PrimitiveCount | crate::BuiltIn::Primitives => { enable_extensions.require(ImplementedEnableExtension::WgpuMeshShader, span)? } _ => {} } Ok(built_in) } pub fn map_interpolation(word: &str, span: Span) -> Result<'_, crate::Interpolation> { match word { "linear" => Ok(crate::Interpolation::Linear), "flat" => Ok(crate::Interpolation::Flat), "perspective" => Ok(crate::Interpolation::Perspective), "per_vertex" => Ok(crate::Interpolation::PerVertex), _ => Err(Box::new(Error::UnknownAttribute(span))), } } pub fn map_sampling(word: &str, span: Span) -> Result<'_, crate::Sampling> { match word { "center" => Ok(crate::Sampling::Center), "centroid" => Ok(crate::Sampling::Centroid), "sample" => Ok(crate::Sampling::Sample), "first" => Ok(crate::Sampling::First), "either" => Ok(crate::Sampling::Either), _ => Err(Box::new(Error::UnknownAttribute(span))), } } pub fn map_storage_format(word: &str, span: Span) -> Result<'_, crate::StorageFormat> { use crate::StorageFormat as Sf; Ok(match word { "r8unorm" => Sf::R8Unorm, "r8snorm" => Sf::R8Snorm, "r8uint" => Sf::R8Uint, "r8sint" => Sf::R8Sint, "r16unorm" => Sf::R16Unorm, "r16snorm" => Sf::R16Snorm, "r16uint" => Sf::R16Uint, "r16sint" => Sf::R16Sint, "r16float" => Sf::R16Float, "rg8unorm" => Sf::Rg8Unorm, "rg8snorm" => Sf::Rg8Snorm, "rg8uint" => Sf::Rg8Uint, "rg8sint" => Sf::Rg8Sint, "r32uint" => Sf::R32Uint, "r32sint" => Sf::R32Sint, "r32float" => Sf::R32Float, "rg16unorm" => Sf::Rg16Unorm, "rg16snorm" => Sf::Rg16Snorm, "rg16uint" => Sf::Rg16Uint, "rg16sint" => Sf::Rg16Sint, "rg16float" => Sf::Rg16Float, "rgba8unorm" => Sf::Rgba8Unorm, "rgba8snorm" => Sf::Rgba8Snorm, "rgba8uint" => Sf::Rgba8Uint, "rgba8sint" => Sf::Rgba8Sint, "rgb10a2uint" => Sf::Rgb10a2Uint, "rgb10a2unorm" => Sf::Rgb10a2Unorm, "rg11b10ufloat" => Sf::Rg11b10Ufloat, "r64uint" => Sf::R64Uint, "rg32uint" => Sf::Rg32Uint, "rg32sint" => Sf::Rg32Sint, "rg32float" => Sf::Rg32Float, "rgba16unorm" => Sf::Rgba16Unorm, "rgba16snorm" => Sf::Rgba16Snorm, "rgba16uint" => Sf::Rgba16Uint, "rgba16sint" => Sf::Rgba16Sint, "rgba16float" => Sf::Rgba16Float, "rgba32uint" => Sf::Rgba32Uint, "rgba32sint" => Sf::Rgba32Sint, "rgba32float" => Sf::Rgba32Float, "bgra8unorm" => Sf::Bgra8Unorm, _ => return Err(Box::new(Error::UnknownStorageFormat(span))), }) } pub fn map_derivative(word: &str) -> Option<(crate::DerivativeAxis, crate::DerivativeControl)> { use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl}; match word { "dpdxCoarse" => Some((Axis::X, Ctrl::Coarse)), "dpdyCoarse" => Some((Axis::Y, Ctrl::Coarse)), "fwidthCoarse" => Some((Axis::Width, Ctrl::Coarse)), "dpdxFine" => Some((Axis::X, Ctrl::Fine)), "dpdyFine" => Some((Axis::Y, Ctrl::Fine)), "fwidthFine" => Some((Axis::Width, Ctrl::Fine)), "dpdx" => Some((Axis::X, Ctrl::None)), "dpdy" => Some((Axis::Y, Ctrl::None)), "fwidth" => Some((Axis::Width, Ctrl::None)), _ => None, } } pub fn map_relational_fun(word: &str) -> Option { match word { "any" => Some(crate::RelationalFunction::Any), "all" => Some(crate::RelationalFunction::All), _ => None, } } pub fn map_standard_fun(word: &str) -> Option { use crate::MathFunction as Mf; Some(match word { // comparison "abs" => Mf::Abs, "min" => Mf::Min, "max" => Mf::Max, "clamp" => Mf::Clamp, "saturate" => Mf::Saturate, // trigonometry "cos" => Mf::Cos, "cosh" => Mf::Cosh, "sin" => Mf::Sin, "sinh" => Mf::Sinh, "tan" => Mf::Tan, "tanh" => Mf::Tanh, "acos" => Mf::Acos, "acosh" => Mf::Acosh, "asin" => Mf::Asin, "asinh" => Mf::Asinh, "atan" => Mf::Atan, "atanh" => Mf::Atanh, "atan2" => Mf::Atan2, "radians" => Mf::Radians, "degrees" => Mf::Degrees, // decomposition "ceil" => Mf::Ceil, "floor" => Mf::Floor, "round" => Mf::Round, "fract" => Mf::Fract, "trunc" => Mf::Trunc, "modf" => Mf::Modf, "frexp" => Mf::Frexp, "ldexp" => Mf::Ldexp, // exponent "exp" => Mf::Exp, "exp2" => Mf::Exp2, "log" => Mf::Log, "log2" => Mf::Log2, "pow" => Mf::Pow, // geometry "dot" => Mf::Dot, "dot4I8Packed" => Mf::Dot4I8Packed, "dot4U8Packed" => Mf::Dot4U8Packed, "cross" => Mf::Cross, "distance" => Mf::Distance, "length" => Mf::Length, "normalize" => Mf::Normalize, "faceForward" => Mf::FaceForward, "reflect" => Mf::Reflect, "refract" => Mf::Refract, // computational "sign" => Mf::Sign, "fma" => Mf::Fma, "mix" => Mf::Mix, "step" => Mf::Step, "smoothstep" => Mf::SmoothStep, "sqrt" => Mf::Sqrt, "inverseSqrt" => Mf::InverseSqrt, "transpose" => Mf::Transpose, "determinant" => Mf::Determinant, "quantizeToF16" => Mf::QuantizeToF16, // bits "countTrailingZeros" => Mf::CountTrailingZeros, "countLeadingZeros" => Mf::CountLeadingZeros, "countOneBits" => Mf::CountOneBits, "reverseBits" => Mf::ReverseBits, "extractBits" => Mf::ExtractBits, "insertBits" => Mf::InsertBits, "firstTrailingBit" => Mf::FirstTrailingBit, "firstLeadingBit" => Mf::FirstLeadingBit, // data packing "pack4x8snorm" => Mf::Pack4x8snorm, "pack4x8unorm" => Mf::Pack4x8unorm, "pack2x16snorm" => Mf::Pack2x16snorm, "pack2x16unorm" => Mf::Pack2x16unorm, "pack2x16float" => Mf::Pack2x16float, "pack4xI8" => Mf::Pack4xI8, "pack4xU8" => Mf::Pack4xU8, "pack4xI8Clamp" => Mf::Pack4xI8Clamp, "pack4xU8Clamp" => Mf::Pack4xU8Clamp, // data unpacking "unpack4x8snorm" => Mf::Unpack4x8snorm, "unpack4x8unorm" => Mf::Unpack4x8unorm, "unpack2x16snorm" => Mf::Unpack2x16snorm, "unpack2x16unorm" => Mf::Unpack2x16unorm, "unpack2x16float" => Mf::Unpack2x16float, "unpack4xI8" => Mf::Unpack4xI8, "unpack4xU8" => Mf::Unpack4xU8, _ => return None, }) } pub fn map_conservative_depth(word: &str, span: Span) -> Result<'_, crate::ConservativeDepth> { use crate::ConservativeDepth as Cd; match word { "greater_equal" => Ok(Cd::GreaterEqual), "less_equal" => Ok(Cd::LessEqual), "unchanged" => Ok(Cd::Unchanged), _ => Err(Box::new(Error::UnknownConservativeDepth(span))), } } pub fn map_subgroup_operation( word: &str, ) -> Option<(crate::SubgroupOperation, crate::CollectiveOperation)> { use crate::CollectiveOperation as co; use crate::SubgroupOperation as sg; Some(match word { "subgroupAll" => (sg::All, co::Reduce), "subgroupAny" => (sg::Any, co::Reduce), "subgroupAdd" => (sg::Add, co::Reduce), "subgroupMul" => (sg::Mul, co::Reduce), "subgroupMin" => (sg::Min, co::Reduce), "subgroupMax" => (sg::Max, co::Reduce), "subgroupAnd" => (sg::And, co::Reduce), "subgroupOr" => (sg::Or, co::Reduce), "subgroupXor" => (sg::Xor, co::Reduce), "subgroupExclusiveAdd" => (sg::Add, co::ExclusiveScan), "subgroupExclusiveMul" => (sg::Mul, co::ExclusiveScan), "subgroupInclusiveAdd" => (sg::Add, co::InclusiveScan), "subgroupInclusiveMul" => (sg::Mul, co::InclusiveScan), _ => return None, }) } pub enum TypeGenerator { Vector { size: VectorSize, }, Matrix { columns: VectorSize, rows: VectorSize, }, Array, Atomic, Pointer, SampledTexture { dim: ImageDimension, arrayed: bool, multi: bool, }, StorageTexture { dim: ImageDimension, arrayed: bool, }, BindingArray, AccelerationStructure, RayQuery, CooperativeMatrix { columns: crate::CooperativeSize, rows: crate::CooperativeSize, }, } pub enum PredeclaredType { TypeInner(TypeInner), RayDesc, RayIntersection, TypeGenerator(TypeGenerator), } impl From for PredeclaredType { fn from(value: TypeInner) -> Self { Self::TypeInner(value) } } impl From for PredeclaredType { fn from(value: TypeGenerator) -> Self { Self::TypeGenerator(value) } } pub fn map_predeclared_type( enable_extensions: &EnableExtensions, span: Span, word: &str, ) -> Result<'static, Option> { use Scalar as Sc; use TypeInner as Ti; use VectorSize as Vs; #[rustfmt::skip] let ty = match word { // predeclared types // scalars "bool" => Ti::Scalar(Sc::BOOL).into(), "i32" => Ti::Scalar(Sc::I32).into(), "u32" => Ti::Scalar(Sc::U32).into(), "f32" => Ti::Scalar(Sc::F32).into(), "f16" => Ti::Scalar(Sc::F16).into(), "i64" => Ti::Scalar(Sc::I64).into(), "u64" => Ti::Scalar(Sc::U64).into(), "f64" => Ti::Scalar(Sc::F64).into(), // vector aliases "vec2i" => Ti::Vector { size: Vs::Bi, scalar: Sc::I32 }.into(), "vec3i" => Ti::Vector { size: Vs::Tri, scalar: Sc::I32 }.into(), "vec4i" => Ti::Vector { size: Vs::Quad, scalar: Sc::I32 }.into(), "vec2u" => Ti::Vector { size: Vs::Bi, scalar: Sc::U32 }.into(), "vec3u" => Ti::Vector { size: Vs::Tri, scalar: Sc::U32 }.into(), "vec4u" => Ti::Vector { size: Vs::Quad, scalar: Sc::U32 }.into(), "vec2f" => Ti::Vector { size: Vs::Bi, scalar: Sc::F32 }.into(), "vec3f" => Ti::Vector { size: Vs::Tri, scalar: Sc::F32 }.into(), "vec4f" => Ti::Vector { size: Vs::Quad, scalar: Sc::F32 }.into(), "vec2h" => Ti::Vector { size: Vs::Bi, scalar: Sc::F16 }.into(), "vec3h" => Ti::Vector { size: Vs::Tri, scalar: Sc::F16 }.into(), "vec4h" => Ti::Vector { size: Vs::Quad, scalar: Sc::F16 }.into(), // matrix aliases "mat2x2f" => Ti::Matrix { columns: Vs::Bi, rows: Vs::Bi, scalar: Sc::F32 }.into(), "mat2x3f" => Ti::Matrix { columns: Vs::Bi, rows: Vs::Tri, scalar: Sc::F32 }.into(), "mat2x4f" => Ti::Matrix { columns: Vs::Bi, rows: Vs::Quad, scalar: Sc::F32 }.into(), "mat3x2f" => Ti::Matrix { columns: Vs::Tri, rows: Vs::Bi, scalar: Sc::F32 }.into(), "mat3x3f" => Ti::Matrix { columns: Vs::Tri, rows: Vs::Tri, scalar: Sc::F32 }.into(), "mat3x4f" => Ti::Matrix { columns: Vs::Tri, rows: Vs::Quad, scalar: Sc::F32 }.into(), "mat4x2f" => Ti::Matrix { columns: Vs::Quad, rows: Vs::Bi, scalar: Sc::F32 }.into(), "mat4x3f" => Ti::Matrix { columns: Vs::Quad, rows: Vs::Tri, scalar: Sc::F32 }.into(), "mat4x4f" => Ti::Matrix { columns: Vs::Quad, rows: Vs::Quad, scalar: Sc::F32 }.into(), "mat2x2h" => Ti::Matrix { columns: Vs::Bi, rows: Vs::Bi, scalar: Sc::F16 }.into(), "mat2x3h" => Ti::Matrix { columns: Vs::Bi, rows: Vs::Tri, scalar: Sc::F16 }.into(), "mat2x4h" => Ti::Matrix { columns: Vs::Bi, rows: Vs::Quad, scalar: Sc::F16 }.into(), "mat3x2h" => Ti::Matrix { columns: Vs::Tri, rows: Vs::Bi, scalar: Sc::F16 }.into(), "mat3x3h" => Ti::Matrix { columns: Vs::Tri, rows: Vs::Tri, scalar: Sc::F16 }.into(), "mat3x4h" => Ti::Matrix { columns: Vs::Tri, rows: Vs::Quad, scalar: Sc::F16 }.into(), "mat4x2h" => Ti::Matrix { columns: Vs::Quad, rows: Vs::Bi, scalar: Sc::F16 }.into(), "mat4x3h" => Ti::Matrix { columns: Vs::Quad, rows: Vs::Tri, scalar: Sc::F16 }.into(), "mat4x4h" => Ti::Matrix { columns: Vs::Quad, rows: Vs::Quad, scalar: Sc::F16 }.into(), // samplers "sampler" => Ti::Sampler { comparison: false }.into(), "sampler_comparison" => Ti::Sampler { comparison: true }.into(), // depth textures "texture_depth_2d" => Ti::Image { dim: ImageDimension::D2, arrayed: false, class: ImageClass::Depth { multi: false } }.into(), "texture_depth_2d_array" => Ti::Image { dim: ImageDimension::D2, arrayed: true, class: ImageClass::Depth { multi: false } }.into(), "texture_depth_cube" => Ti::Image { dim: ImageDimension::Cube, arrayed: false, class: ImageClass::Depth { multi: false } }.into(), "texture_depth_cube_array" => Ti::Image { dim: ImageDimension::Cube, arrayed: true, class: ImageClass::Depth { multi: false } }.into(), "texture_depth_multisampled_2d" => Ti::Image { dim: ImageDimension::D2, arrayed: false, class: ImageClass::Depth { multi: true } }.into(), // external texture "texture_external" => Ti::Image { dim: ImageDimension::D2, arrayed: false, class: ImageClass::External }.into(), // ray desc "RayDesc" => PredeclaredType::RayDesc, // ray intersection "RayIntersection" => PredeclaredType::RayIntersection, // predeclared type generators // vector "vec2" => TypeGenerator::Vector { size: Vs::Bi }.into(), "vec3" => TypeGenerator::Vector { size: Vs::Tri }.into(), "vec4" => TypeGenerator::Vector { size: Vs::Quad }.into(), // matrix "mat2x2" => TypeGenerator::Matrix { columns: Vs::Bi, rows: Vs::Bi }.into(), "mat2x3" => TypeGenerator::Matrix { columns: Vs::Bi, rows: Vs::Tri }.into(), "mat2x4" => TypeGenerator::Matrix { columns: Vs::Bi, rows: Vs::Quad }.into(), "mat3x2" => TypeGenerator::Matrix { columns: Vs::Tri, rows: Vs::Bi }.into(), "mat3x3" => TypeGenerator::Matrix { columns: Vs::Tri, rows: Vs::Tri }.into(), "mat3x4" => TypeGenerator::Matrix { columns: Vs::Tri, rows: Vs::Quad }.into(), "mat4x2" => TypeGenerator::Matrix { columns: Vs::Quad, rows: Vs::Bi }.into(), "mat4x3" => TypeGenerator::Matrix { columns: Vs::Quad, rows: Vs::Tri }.into(), "mat4x4" => TypeGenerator::Matrix { columns: Vs::Quad, rows: Vs::Quad }.into(), // array "array" => TypeGenerator::Array.into(), // atomic "atomic" => TypeGenerator::Atomic.into(), // pointer "ptr" => TypeGenerator::Pointer.into(), // sampled textures "texture_1d" => TypeGenerator::SampledTexture { dim: ImageDimension::D1, arrayed: false, multi: false }.into(), "texture_2d" => TypeGenerator::SampledTexture { dim: ImageDimension::D2, arrayed: false, multi: false }.into(), "texture_2d_array" => TypeGenerator::SampledTexture { dim: ImageDimension::D2, arrayed: true, multi: false }.into(), "texture_3d" => TypeGenerator::SampledTexture { dim: ImageDimension::D3, arrayed: false, multi: false }.into(), "texture_cube" => TypeGenerator::SampledTexture { dim: ImageDimension::Cube, arrayed: false, multi: false }.into(), "texture_cube_array" => TypeGenerator::SampledTexture { dim: ImageDimension::Cube, arrayed: true, multi: false }.into(), "texture_multisampled_2d" => TypeGenerator::SampledTexture { dim: ImageDimension::D2, arrayed: false, multi: true }.into(), // storage textures "texture_storage_1d" => TypeGenerator::StorageTexture { dim: ImageDimension::D1, arrayed: false }.into(), "texture_storage_2d" => TypeGenerator::StorageTexture { dim: ImageDimension::D2, arrayed: false }.into(), "texture_storage_2d_array" => TypeGenerator::StorageTexture { dim: ImageDimension::D2, arrayed: true }.into(), "texture_storage_3d" => TypeGenerator::StorageTexture { dim: ImageDimension::D3, arrayed: false }.into(), // binding array "binding_array" => TypeGenerator::BindingArray.into(), // acceleration structure "acceleration_structure" => TypeGenerator::AccelerationStructure.into(), // ray query "ray_query" => TypeGenerator::RayQuery.into(), // cooperative matrix "coop_mat8x8" => TypeGenerator::CooperativeMatrix { columns: crate::CooperativeSize::Eight, rows: crate::CooperativeSize::Eight, }.into(), "coop_mat16x16" => TypeGenerator::CooperativeMatrix { columns: crate::CooperativeSize::Sixteen, rows: crate::CooperativeSize::Sixteen, }.into(), _ => return Ok(None), }; // Check for the enable extension required to use this type, if any. // Slice should be at least len one otherwise extension_needed should be None. let extensions_needed: Option<&[_]> = match ty { PredeclaredType::TypeInner(ref ty) if ty.scalar() == Some(Sc::F16) => { Some(&[ImplementedEnableExtension::F16]) } PredeclaredType::RayDesc | PredeclaredType::RayIntersection | PredeclaredType::TypeGenerator(TypeGenerator::AccelerationStructure) | PredeclaredType::TypeGenerator(TypeGenerator::RayQuery) => Some(&[ ImplementedEnableExtension::WgpuRayQuery, ImplementedEnableExtension::WgpuRayTracingPipeline, ]), PredeclaredType::TypeGenerator(TypeGenerator::CooperativeMatrix { .. }) => { Some(&[ImplementedEnableExtension::WgpuCooperativeMatrix]) } _ => None, }; if let Some(extensions_needed) = extensions_needed { let mut any_extension_enabled = false; for extension_needed in extensions_needed { if enable_extensions.contains(*extension_needed) { any_extension_enabled = true; } } if !any_extension_enabled { return Err(Box::new(Error::EnableExtensionNotEnabled { span, kind: extensions_needed[0].into(), })); } } Ok(Some(ty)) } naga-29.0.3/src/front/wgsl/parse/directive/enable_extension.rs000064400000000000000000000272501046102023000225060ustar 00000000000000//! `enable …;` extensions in WGSL. //! //! The focal point of this module is the [`EnableExtension`] API. use crate::front::wgsl::{Error, Result}; use crate::Span; use alloc::boxed::Box; /// Tracks the status of every enable-extension known to Naga. #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub(crate) struct EnableExtensions { wgpu_mesh_shader: bool, wgpu_ray_query: bool, wgpu_ray_query_vertex_return: bool, wgpu_ray_tracing_pipelines: bool, dual_source_blending: bool, /// Whether `enable f16;` was written earlier in the shader module. f16: bool, clip_distances: bool, wgpu_cooperative_matrix: bool, draw_index: bool, primitive_index: bool, } impl EnableExtensions { pub(crate) const fn empty() -> Self { Self { wgpu_mesh_shader: false, wgpu_ray_query: false, wgpu_ray_query_vertex_return: false, wgpu_ray_tracing_pipelines: false, f16: false, dual_source_blending: false, clip_distances: false, wgpu_cooperative_matrix: false, draw_index: false, primitive_index: false, } } /// Add an enable-extension to the set requested by a module. pub(crate) const fn add(&mut self, ext: ImplementedEnableExtension) { let field = match ext { ImplementedEnableExtension::WgpuMeshShader => &mut self.wgpu_mesh_shader, ImplementedEnableExtension::WgpuRayQuery => &mut self.wgpu_ray_query, ImplementedEnableExtension::WgpuRayQueryVertexReturn => { &mut self.wgpu_ray_query_vertex_return } ImplementedEnableExtension::WgpuRayTracingPipeline => { &mut self.wgpu_ray_tracing_pipelines } ImplementedEnableExtension::DualSourceBlending => &mut self.dual_source_blending, ImplementedEnableExtension::F16 => &mut self.f16, ImplementedEnableExtension::ClipDistances => &mut self.clip_distances, ImplementedEnableExtension::WgpuCooperativeMatrix => &mut self.wgpu_cooperative_matrix, ImplementedEnableExtension::DrawIndex => &mut self.draw_index, ImplementedEnableExtension::PrimitiveIndex => &mut self.primitive_index, }; *field = true; } /// Query whether an enable-extension tracked here has been requested. pub(crate) const fn contains(&self, ext: ImplementedEnableExtension) -> bool { match ext { ImplementedEnableExtension::WgpuMeshShader => self.wgpu_mesh_shader, ImplementedEnableExtension::WgpuRayQuery => self.wgpu_ray_query, ImplementedEnableExtension::WgpuRayQueryVertexReturn => { self.wgpu_ray_query_vertex_return } ImplementedEnableExtension::WgpuRayTracingPipeline => self.wgpu_ray_tracing_pipelines, ImplementedEnableExtension::DualSourceBlending => self.dual_source_blending, ImplementedEnableExtension::F16 => self.f16, ImplementedEnableExtension::ClipDistances => self.clip_distances, ImplementedEnableExtension::WgpuCooperativeMatrix => self.wgpu_cooperative_matrix, ImplementedEnableExtension::DrawIndex => self.draw_index, ImplementedEnableExtension::PrimitiveIndex => self.primitive_index, } } pub(crate) fn require( &self, ext: ImplementedEnableExtension, span: Span, ) -> Result<'static, ()> { if !self.contains(ext) { Err(Box::new(Error::EnableExtensionNotEnabled { span, kind: ext.into(), })) } else { Ok(()) } } } impl Default for EnableExtensions { fn default() -> Self { Self::empty() } } /// An enable-extension not guaranteed to be present in all environments. /// /// WGSL spec.: #[derive(Clone, Copy, Debug, Hash, Eq, PartialEq)] pub enum EnableExtension { Implemented(ImplementedEnableExtension), Unimplemented(UnimplementedEnableExtension), } impl From for EnableExtension { fn from(value: ImplementedEnableExtension) -> Self { Self::Implemented(value) } } impl EnableExtension { const F16: &'static str = "f16"; const CLIP_DISTANCES: &'static str = "clip_distances"; const DUAL_SOURCE_BLENDING: &'static str = "dual_source_blending"; const MESH_SHADER: &'static str = "wgpu_mesh_shader"; const RAY_QUERY: &'static str = "wgpu_ray_query"; const RAY_QUERY_VERTEX_RETURN: &'static str = "wgpu_ray_query_vertex_return"; const RAY_TRACING_PIPELINE: &'static str = "wgpu_ray_tracing_pipeline"; const COOPERATIVE_MATRIX: &'static str = "wgpu_cooperative_matrix"; const SUBGROUPS: &'static str = "subgroups"; const PRIMITIVE_INDEX: &'static str = "primitive_index"; const DRAW_INDEX: &'static str = "draw_index"; /// Convert from a sentinel word in WGSL into its associated [`EnableExtension`], if possible. pub(crate) fn from_ident(word: &str, span: Span) -> Result<'_, Self> { Ok(match word { Self::F16 => Self::Implemented(ImplementedEnableExtension::F16), Self::CLIP_DISTANCES => Self::Implemented(ImplementedEnableExtension::ClipDistances), Self::DUAL_SOURCE_BLENDING => { Self::Implemented(ImplementedEnableExtension::DualSourceBlending) } Self::MESH_SHADER => Self::Implemented(ImplementedEnableExtension::WgpuMeshShader), Self::RAY_QUERY => Self::Implemented(ImplementedEnableExtension::WgpuRayQuery), Self::RAY_QUERY_VERTEX_RETURN => { Self::Implemented(ImplementedEnableExtension::WgpuRayQueryVertexReturn) } Self::RAY_TRACING_PIPELINE => { Self::Implemented(ImplementedEnableExtension::WgpuRayTracingPipeline) } Self::COOPERATIVE_MATRIX => { Self::Implemented(ImplementedEnableExtension::WgpuCooperativeMatrix) } Self::SUBGROUPS => Self::Unimplemented(UnimplementedEnableExtension::Subgroups), Self::DRAW_INDEX => Self::Implemented(ImplementedEnableExtension::DrawIndex), Self::PRIMITIVE_INDEX => Self::Implemented(ImplementedEnableExtension::PrimitiveIndex), _ => return Err(Box::new(Error::UnknownEnableExtension(span, word))), }) } /// Maps this [`EnableExtension`] into the sentinel word associated with it in WGSL. pub const fn to_ident(self) -> &'static str { match self { Self::Implemented(kind) => match kind { ImplementedEnableExtension::WgpuMeshShader => Self::MESH_SHADER, ImplementedEnableExtension::WgpuRayQuery => Self::RAY_QUERY, ImplementedEnableExtension::WgpuRayQueryVertexReturn => { Self::RAY_QUERY_VERTEX_RETURN } ImplementedEnableExtension::WgpuCooperativeMatrix => Self::COOPERATIVE_MATRIX, ImplementedEnableExtension::DualSourceBlending => Self::DUAL_SOURCE_BLENDING, ImplementedEnableExtension::F16 => Self::F16, ImplementedEnableExtension::ClipDistances => Self::CLIP_DISTANCES, ImplementedEnableExtension::DrawIndex => Self::DRAW_INDEX, ImplementedEnableExtension::PrimitiveIndex => Self::PRIMITIVE_INDEX, ImplementedEnableExtension::WgpuRayTracingPipeline => Self::RAY_TRACING_PIPELINE, }, Self::Unimplemented(kind) => match kind { UnimplementedEnableExtension::Subgroups => Self::SUBGROUPS, }, } } } /// A variant of [`EnableExtension::Implemented`]. #[derive(Clone, Copy, Debug, Hash, Eq, PartialEq)] #[cfg_attr(test, derive(strum::VariantArray))] pub enum ImplementedEnableExtension { /// Enables `f16`/`half` primitive support in all shader languages. /// /// In the WGSL standard, this corresponds to [`enable f16;`]. /// /// [`enable f16;`]: https://www.w3.org/TR/WGSL/#extension-f16 F16, /// Enables the `blend_src` attribute in WGSL. /// /// In the WGSL standard, this corresponds to [`enable dual_source_blending;`]. /// /// [`enable dual_source_blending;`]: https://www.w3.org/TR/WGSL/#extension-dual_source_blending DualSourceBlending, /// Enables the `clip_distances` variable in WGSL. /// /// In the WGSL standard, this corresponds to [`enable clip_distances;`]. /// /// [`enable clip_distances;`]: https://www.w3.org/TR/WGSL/#extension-clip_distances ClipDistances, /// Enables the `wgpu_mesh_shader` extension, native only WgpuMeshShader, /// Enables the `wgpu_ray_query` extension, native only. WgpuRayQuery, /// Enables the `wgpu_ray_query_vertex_return` extension, native only. WgpuRayQueryVertexReturn, /// Enables the `wgpu_ray_tracing_pipeline` extension, native only. WgpuRayTracingPipeline, /// Enables the `wgpu_cooperative_matrix` extension, native only. WgpuCooperativeMatrix, /// Enables the `draw_index` builtin. Not currently part of the WGSL spec but probably will be at some point. DrawIndex, /// Enables the `@builtin(primitive_index)` attribute in WGSL. /// /// In the WGSL standard, this corresponds to [`enable primitive-index;`]. /// /// [`enable primitive-index;`]: https://www.w3.org/TR/WGSL/#extension-primitive_index PrimitiveIndex, } impl ImplementedEnableExtension { /// A slice of all variants of [`ImplementedEnableExtension`]. pub const VARIANTS: &'static [Self] = &[ Self::F16, Self::DualSourceBlending, Self::ClipDistances, Self::WgpuMeshShader, Self::WgpuRayQuery, Self::WgpuRayQueryVertexReturn, Self::WgpuRayTracingPipeline, Self::WgpuCooperativeMatrix, Self::DrawIndex, Self::PrimitiveIndex, ]; /// Returns slice of all variants of [`ImplementedEnableExtension`]. pub const fn all() -> &'static [Self] { Self::VARIANTS } /// Returns the capability required for this enable extension. pub const fn capability(self) -> crate::valid::Capabilities { use crate::valid::Capabilities as C; match self { Self::F16 => C::SHADER_FLOAT16, Self::DualSourceBlending => C::DUAL_SOURCE_BLENDING, Self::ClipDistances => C::CLIP_DISTANCE, Self::WgpuMeshShader => C::MESH_SHADER, Self::WgpuRayQuery => C::RAY_QUERY, Self::WgpuRayQueryVertexReturn => C::RAY_HIT_VERTEX_POSITION, Self::WgpuCooperativeMatrix => C::COOPERATIVE_MATRIX, Self::WgpuRayTracingPipeline => C::RAY_TRACING_PIPELINE, Self::DrawIndex => C::DRAW_INDEX, Self::PrimitiveIndex => C::PRIMITIVE_INDEX, } } } #[test] /// Asserts that the manual implementation of VARIANTS is the same as the derived strum version would be /// while still allowing strum to be a dev-only dependency fn test_manual_variants_array_is_correct() { assert_eq!( ::VARIANTS, ImplementedEnableExtension::VARIANTS ); } /// A variant of [`EnableExtension::Unimplemented`]. #[derive(Clone, Copy, Debug, Hash, Eq, PartialEq)] pub enum UnimplementedEnableExtension { /// Enables subgroup built-ins in all languages. /// /// In the WGSL standard, this corresponds to [`enable subgroups;`]. /// /// [`enable subgroups;`]: https://www.w3.org/TR/WGSL/#extension-subgroups Subgroups, } impl UnimplementedEnableExtension { pub(crate) const fn tracking_issue_num(self) -> u16 { match self { Self::Subgroups => 5555, } } } naga-29.0.3/src/front/wgsl/parse/directive/language_extension.rs000064400000000000000000000103751046102023000230430ustar 00000000000000//! `requires …;` extensions in WGSL. //! //! The focal point of this module is the [`LanguageExtension`] API. /// A language extension recognized by Naga, but not guaranteed to be present in all environments. /// /// WGSL spec.: #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] pub enum LanguageExtension { Implemented(ImplementedLanguageExtension), Unimplemented(UnimplementedLanguageExtension), } impl LanguageExtension { const READONLY_AND_READWRITE_STORAGE_TEXTURES: &'static str = "readonly_and_readwrite_storage_textures"; const PACKED4X8_INTEGER_DOT_PRODUCT: &'static str = "packed_4x8_integer_dot_product"; const UNRESTRICTED_POINTER_PARAMETERS: &'static str = "unrestricted_pointer_parameters"; const POINTER_COMPOSITE_ACCESS: &'static str = "pointer_composite_access"; /// Convert from a sentinel word in WGSL into its associated [`LanguageExtension`], if possible. pub fn from_ident(s: &str) -> Option { Some(match s { Self::READONLY_AND_READWRITE_STORAGE_TEXTURES => { Self::Implemented(ImplementedLanguageExtension::ReadOnlyAndReadWriteStorageTextures) } Self::PACKED4X8_INTEGER_DOT_PRODUCT => { Self::Implemented(ImplementedLanguageExtension::Packed4x8IntegerDotProduct) } Self::UNRESTRICTED_POINTER_PARAMETERS => { Self::Unimplemented(UnimplementedLanguageExtension::UnrestrictedPointerParameters) } Self::POINTER_COMPOSITE_ACCESS => { Self::Implemented(ImplementedLanguageExtension::PointerCompositeAccess) } _ => return None, }) } /// Maps this [`LanguageExtension`] into the sentinel word associated with it in WGSL. pub const fn to_ident(self) -> &'static str { match self { Self::Implemented(kind) => kind.to_ident(), Self::Unimplemented(kind) => match kind { UnimplementedLanguageExtension::UnrestrictedPointerParameters => { Self::UNRESTRICTED_POINTER_PARAMETERS } }, } } } /// A variant of [`LanguageExtension::Implemented`]. #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] #[cfg_attr(test, derive(strum::VariantArray))] pub enum ImplementedLanguageExtension { ReadOnlyAndReadWriteStorageTextures, Packed4x8IntegerDotProduct, PointerCompositeAccess, } impl ImplementedLanguageExtension { /// A slice of all variants of [`ImplementedLanguageExtension`]. pub const VARIANTS: &'static [Self] = &[ Self::ReadOnlyAndReadWriteStorageTextures, Self::Packed4x8IntegerDotProduct, Self::PointerCompositeAccess, ]; /// Returns slice of all variants of [`ImplementedLanguageExtension`]. pub const fn all() -> &'static [Self] { Self::VARIANTS } /// Maps this [`ImplementedLanguageExtension`] into the sentinel word associated with it in WGSL. pub const fn to_ident(self) -> &'static str { match self { ImplementedLanguageExtension::ReadOnlyAndReadWriteStorageTextures => { LanguageExtension::READONLY_AND_READWRITE_STORAGE_TEXTURES } ImplementedLanguageExtension::Packed4x8IntegerDotProduct => { LanguageExtension::PACKED4X8_INTEGER_DOT_PRODUCT } ImplementedLanguageExtension::PointerCompositeAccess => { LanguageExtension::POINTER_COMPOSITE_ACCESS } } } } #[test] /// Asserts that the manual implementation of VARIANTS is the same as the derived strum version would be /// while still allowing strum to be a dev-only dependency fn test_manual_variants_array_is_correct() { assert_eq!( ::VARIANTS, ImplementedLanguageExtension::VARIANTS ); } /// A variant of [`LanguageExtension::Unimplemented`]. #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] pub enum UnimplementedLanguageExtension { UnrestrictedPointerParameters, } impl UnimplementedLanguageExtension { pub(crate) const fn tracking_issue_num(self) -> u16 { match self { Self::UnrestrictedPointerParameters => 5158, } } } naga-29.0.3/src/front/wgsl/parse/directive.rs000064400000000000000000000067711046102023000171710ustar 00000000000000//! WGSL directives. The focal point of this API is [`DirectiveKind`]. //! //! See also . pub mod enable_extension; pub(crate) mod language_extension; use alloc::boxed::Box; /// A parsed sentinel word indicating the type of directive to be parsed next. #[derive(Clone, Copy, Debug, Hash, Eq, PartialEq)] #[cfg_attr(test, derive(strum::EnumIter))] pub(crate) enum DirectiveKind { /// A [`crate::diagnostic_filter`]. Diagnostic, /// An [`enable_extension`]. Enable, /// A [`language_extension`]. Requires, } impl DirectiveKind { const DIAGNOSTIC: &'static str = "diagnostic"; const ENABLE: &'static str = "enable"; const REQUIRES: &'static str = "requires"; /// Convert from a sentinel word in WGSL into its associated [`DirectiveKind`], if possible. pub fn from_ident(s: &str) -> Option { Some(match s { Self::DIAGNOSTIC => Self::Diagnostic, Self::ENABLE => Self::Enable, Self::REQUIRES => Self::Requires, _ => return None, }) } } impl crate::diagnostic_filter::Severity { #[cfg(feature = "wgsl-in")] pub(crate) fn report_wgsl_parse_diag<'a>( self, err: Box>, source: &str, ) -> crate::front::wgsl::Result<'a, ()> { self.report_diag(err, |e, level| { let e = e.as_parse_error(source); log::log!(level, "{}", e.emit_to_string(source)); }) } } #[cfg(test)] mod test { use alloc::format; use strum::IntoEnumIterator; use super::DirectiveKind; use crate::front::wgsl::assert_parse_err; #[test] fn directive_after_global_decl() { for unsupported_shader in DirectiveKind::iter() { let directive; let expected_msg; match unsupported_shader { DirectiveKind::Diagnostic => { directive = "diagnostic(off,derivative_uniformity)"; expected_msg = "\ error: expected global declaration, but found a global directive ┌─ wgsl:2:1 │ 2 │ diagnostic(off,derivative_uniformity); │ ^^^^^^^^^^ written after first global declaration │ = note: global directives are only allowed before global declarations; maybe hoist this closer to the top of the shader module? "; } DirectiveKind::Enable => { directive = "enable f16"; expected_msg = "\ error: expected global declaration, but found a global directive ┌─ wgsl:2:1 │ 2 │ enable f16; │ ^^^^^^ written after first global declaration │ = note: global directives are only allowed before global declarations; maybe hoist this closer to the top of the shader module? "; } DirectiveKind::Requires => { directive = "requires readonly_and_readwrite_storage_textures"; expected_msg = "\ error: expected global declaration, but found a global directive ┌─ wgsl:2:1 │ 2 │ requires readonly_and_readwrite_storage_textures; │ ^^^^^^^^ written after first global declaration │ = note: global directives are only allowed before global declarations; maybe hoist this closer to the top of the shader module? "; } } let shader = format!( "\ @group(0) @binding(0) var thing: i32; {directive}; " ); assert_parse_err(&shader, expected_msg); } } } naga-29.0.3/src/front/wgsl/parse/lexer.rs000064400000000000000000001240311046102023000163200ustar 00000000000000use super::{number::consume_number, Error, ExpectedToken, Result}; use crate::front::wgsl::error::NumberError; use crate::front::wgsl::parse::directive::enable_extension::{ EnableExtensions, ImplementedEnableExtension, }; use crate::front::wgsl::parse::Number; use crate::Span; use alloc::{boxed::Box, vec::Vec}; pub type TokenSpan<'a> = (Token<'a>, Span); #[derive(Copy, Clone, Debug, PartialEq)] pub enum Token<'a> { /// A separator character: `:;,`, and `.` when not part of a numeric /// literal. Separator(char), /// A parenthesis-like character: `()[]{}`, and also `<>`. /// /// Note that `<>` representing template argument brackets are distinguished /// using WGSL's [template list discovery algorithm][tlda], and are returned /// as [`Token::TemplateArgsStart`] and [`Token::TemplateArgsEnd`]. That is, /// we use `Paren` for `<>` when they are *not* parens. /// /// [tlda]: https://gpuweb.github.io/gpuweb/wgsl/#template-list-discovery Paren(char), /// The attribute introduction character `@`. Attribute, /// A numeric literal, either integral or floating-point, including any /// type suffix. Number(core::result::Result), /// An identifier, possibly a reserved word. Word(&'a str), /// A miscellaneous single-character operator, like an arithmetic unary or /// binary operator. This includes `=`, for assignment and initialization. Operation(char), /// Certain multi-character logical operators: `!=`, `==`, `&&`, /// `||`, `<=` and `>=`. The value gives the operator's first /// character. /// /// For `<` and `>` operators, see [`Token::Paren`]. LogicalOperation(char), /// A shift operator: `>>` or `<<`. ShiftOperation(char), /// A compound assignment operator like `+=`. /// /// When the given character is `<` or `>`, those represent the left shift /// and right shift assignment operators, `<<=` and `>>=`. AssignmentOperation(char), /// The `++` operator. IncrementOperation, /// The `--` operator. DecrementOperation, /// The `->` token. Arrow, /// A `<` representing the start of a template argument list, according to /// WGSL's [template list discovery algorithm][tlda]. /// /// [tlda]: https://gpuweb.github.io/gpuweb/wgsl/#template-list-discovery TemplateArgsStart, /// A `>` representing the end of a template argument list, according to /// WGSL's [template list discovery algorithm][tlda]. /// /// [tlda]: https://gpuweb.github.io/gpuweb/wgsl/#template-list-discovery TemplateArgsEnd, /// A character that does not represent a legal WGSL token. Unknown(char), /// Comment or whitespace. Trivia, /// A doc comment, beginning with `///` or `/**`. DocComment(&'a str), /// A module-level doc comment, beginning with `//!` or `/*!`. ModuleDocComment(&'a str), /// The end of the input. End, } fn consume_any(input: &str, what: impl Fn(char) -> bool) -> (&str, &str) { let pos = input.find(|c| !what(c)).unwrap_or(input.len()); input.split_at(pos) } struct UnclosedCandidate { index: usize, depth: usize, } /// Produce at least one token, distinguishing [template lists] from other uses /// of `<` and `>`. /// /// Consume one or more tokens from `input` and store them in `tokens`, updating /// `input` to refer to the remaining text. Apply WGSL's [template list /// discovery algorithm] to decide what sort of tokens `<` and `>` characters in /// the input actually represent. /// /// Store the tokens in `tokens` in the *reverse* of the order they appear in /// the text, such that the caller can pop from the end of the vector to see the /// tokens in textual order. /// /// The `tokens` vector must be empty on entry. The idea is for the caller to /// use it as a buffer of unconsumed tokens, and call this function to refill it /// when it's empty. /// /// The `source` argument must be the whole original source code, used to /// compute spans. /// /// If `ignore_doc_comments` is true, then doc comments are returned as /// [`Token::Trivia`], like ordinary comments. /// /// [template lists]: https://gpuweb.github.io/gpuweb/wgsl/#template-lists-sec /// [template list discovery algorithm]: https://gpuweb.github.io/gpuweb/wgsl/#template-list-discovery fn discover_template_lists<'a>( tokens: &mut Vec<(TokenSpan<'a>, &'a str)>, source: &'a str, mut input: &'a str, ignore_doc_comments: bool, ) { assert!(tokens.is_empty()); let mut looking_for_template_start = false; let mut pending: Vec = Vec::new(); // Current nesting depth of `()` and `[]` brackets. (`{}` brackets // exit all template list processing.) let mut depth = 0; fn pop_until(pending: &mut Vec, depth: usize) { while pending .last() .map(|candidate| candidate.depth >= depth) .unwrap_or(false) { pending.pop(); } } loop { // Decide whether `consume_token` should treat a `>` character as // `TemplateArgsEnd`, without considering the characters that follow. // // This condition matches the one that determines whether the spec's // template list discovery algorithm looks past a `>` character for a // `=`. By passing this flag to `consume_token`, we ensure it follows // that behavior. let waiting_for_template_end = pending .last() .is_some_and(|candidate| candidate.depth == depth); // Ask `consume_token` for the next token and add it to `tokens`, along // with its span. // // This means that `<` enters the buffer as `Token::Paren('<')`, the // ordinary comparison operator. We'll change that to // `Token::TemplateArgsStart` later if appropriate. let (token, rest) = consume_token(input, waiting_for_template_end, ignore_doc_comments); let span = Span::from(source.len() - input.len()..source.len() - rest.len()); tokens.push(((token, span), rest)); input = rest; // Since `consume_token` treats `<<=`, `<<` and `<=` as operators, not // `Token::Paren`, that takes care of the WGSL algorithm's post-'<' lookahead // for us. match token { Token::Word(_) => { looking_for_template_start = true; continue; } Token::Trivia | Token::DocComment(_) | Token::ModuleDocComment(_) if looking_for_template_start => { continue; } Token::Paren('<') if looking_for_template_start => { pending.push(UnclosedCandidate { index: tokens.len() - 1, depth, }); } Token::TemplateArgsEnd => { // The `consume_token` function only returns `TemplateArgsEnd` // if `waiting_for_template_end` is true, so we know `pending` // has a top entry at the appropriate depth. // // Find the matching `<` token and change its type to // `TemplateArgsStart`. let candidate = pending.pop().unwrap(); let &mut ((ref mut token, _), _) = tokens.get_mut(candidate.index).unwrap(); *token = Token::TemplateArgsStart; } Token::Paren('(' | '[') => { depth += 1; } Token::Paren(')' | ']') => { pop_until(&mut pending, depth); depth = depth.saturating_sub(1); } Token::Operation('=') | Token::Separator(':' | ';') | Token::Paren('{') => { pending.clear(); depth = 0; } Token::LogicalOperation('&') | Token::LogicalOperation('|') => { pop_until(&mut pending, depth); } Token::End => break, _ => {} } looking_for_template_start = false; // The WGSL spec's template list discovery algorithm processes the // entire source at once, but Naga would rather limit its lookahead to // the actual text that could possibly be a template parameter list. // This is usually less than a line. if pending.is_empty() { break; } } tokens.reverse(); } /// Return the token at the start of `input`. /// /// The `waiting_for_template_end` flag enables some special handling to help out /// `discover_template_lists`: /// /// - If `waiting_for_template_end` is `true`, then return text starting with /// '>` as [`Token::TemplateArgsEnd`] and consume only the `>` character, /// regardless of what characters follow it. This is required by the [template /// list discovery algorithm][tlda] when the `>` would end a template argument list. /// /// - If `waiting_for_template_end` is false, recognize multi-character tokens /// beginning with `>` as usual. /// /// If `ignore_doc_comments` is true, then doc comments are returned as /// [`Token::Trivia`], like ordinary comments. /// /// [tlda]: https://gpuweb.github.io/gpuweb/wgsl/#template-list-discovery fn consume_token( input: &str, waiting_for_template_end: bool, ignore_doc_comments: bool, ) -> (Token<'_>, &str) { let mut chars = input.chars(); let cur = match chars.next() { Some(c) => c, None => return (Token::End, ""), }; match cur { ':' | ';' | ',' => (Token::Separator(cur), chars.as_str()), '.' => { let og_chars = chars.as_str(); match chars.next() { Some('0'..='9') => consume_number(input), _ => (Token::Separator(cur), og_chars), } } '@' => (Token::Attribute, chars.as_str()), '(' | ')' | '{' | '}' | '[' | ']' => (Token::Paren(cur), chars.as_str()), '<' | '>' => { let og_chars = chars.as_str(); if cur == '>' && waiting_for_template_end { return (Token::TemplateArgsEnd, og_chars); } match chars.next() { Some('=') => (Token::LogicalOperation(cur), chars.as_str()), Some(c) if c == cur => { let og_chars = chars.as_str(); match chars.next() { Some('=') => (Token::AssignmentOperation(cur), chars.as_str()), _ => (Token::ShiftOperation(cur), og_chars), } } _ => (Token::Paren(cur), og_chars), } } '0'..='9' => consume_number(input), '/' => { let og_chars = chars.as_str(); match chars.next() { Some('/') => { let mut input_chars = input.char_indices(); let doc_comment_end = input_chars .find_map(|(index, c)| is_comment_end(c).then_some(index)) .unwrap_or(input.len()); let token = match chars.next() { Some('/') if !ignore_doc_comments => { Token::DocComment(&input[..doc_comment_end]) } Some('!') if !ignore_doc_comments => { Token::ModuleDocComment(&input[..doc_comment_end]) } _ => Token::Trivia, }; (token, input_chars.as_str()) } Some('*') => { let next_c = chars.next(); enum CommentType { Doc, ModuleDoc, Normal, } let comment_type = match next_c { Some('*') if !ignore_doc_comments => CommentType::Doc, Some('!') if !ignore_doc_comments => CommentType::ModuleDoc, _ => CommentType::Normal, }; let mut depth = 1; let mut prev = next_c; for c in &mut chars { match (prev, c) { (Some('*'), '/') => { prev = None; depth -= 1; if depth == 0 { let rest = chars.as_str(); let token = match comment_type { CommentType::Doc => { let doc_comment_end = input.len() - rest.len(); Token::DocComment(&input[..doc_comment_end]) } CommentType::ModuleDoc => { let doc_comment_end = input.len() - rest.len(); Token::ModuleDocComment(&input[..doc_comment_end]) } CommentType::Normal => Token::Trivia, }; return (token, rest); } } (Some('/'), '*') => { prev = None; depth += 1; } _ => { prev = Some(c); } } } (Token::End, "") } Some('=') => (Token::AssignmentOperation(cur), chars.as_str()), _ => (Token::Operation(cur), og_chars), } } '-' => { let og_chars = chars.as_str(); match chars.next() { Some('>') => (Token::Arrow, chars.as_str()), Some('-') => (Token::DecrementOperation, chars.as_str()), Some('=') => (Token::AssignmentOperation(cur), chars.as_str()), _ => (Token::Operation(cur), og_chars), } } '+' => { let og_chars = chars.as_str(); match chars.next() { Some('+') => (Token::IncrementOperation, chars.as_str()), Some('=') => (Token::AssignmentOperation(cur), chars.as_str()), _ => (Token::Operation(cur), og_chars), } } '*' | '%' | '^' => { let og_chars = chars.as_str(); match chars.next() { Some('=') => (Token::AssignmentOperation(cur), chars.as_str()), _ => (Token::Operation(cur), og_chars), } } '~' => (Token::Operation(cur), chars.as_str()), '=' | '!' => { let og_chars = chars.as_str(); match chars.next() { Some('=') => (Token::LogicalOperation(cur), chars.as_str()), _ => (Token::Operation(cur), og_chars), } } '&' | '|' => { let og_chars = chars.as_str(); match chars.next() { Some(c) if c == cur => (Token::LogicalOperation(cur), chars.as_str()), Some('=') => (Token::AssignmentOperation(cur), chars.as_str()), _ => (Token::Operation(cur), og_chars), } } _ if is_blankspace(cur) => { let (_, rest) = consume_any(input, is_blankspace); (Token::Trivia, rest) } _ if is_word_start(cur) => { let (word, rest) = consume_any(input, is_word_part); (Token::Word(word), rest) } _ => (Token::Unknown(cur), chars.as_str()), } } /// Returns whether or not a char is a comment end /// (Unicode Pattern_White_Space excluding U+0020, U+0009, U+200E and U+200F) /// const fn is_comment_end(c: char) -> bool { match c { '\u{000a}'..='\u{000d}' | '\u{0085}' | '\u{2028}' | '\u{2029}' => true, _ => false, } } /// Returns whether or not a char is a blankspace (Unicode Pattern_White_Space) const fn is_blankspace(c: char) -> bool { match c { '\u{0020}' | '\u{0009}'..='\u{000d}' | '\u{0085}' | '\u{200e}' | '\u{200f}' | '\u{2028}' | '\u{2029}' => true, _ => false, } } /// Returns whether or not a char is a word start (Unicode XID_Start + '_') fn is_word_start(c: char) -> bool { c == '_' || unicode_ident::is_xid_start(c) } /// Returns whether or not a char is a word part (Unicode XID_Continue) fn is_word_part(c: char) -> bool { unicode_ident::is_xid_continue(c) } pub(in crate::front::wgsl) struct Lexer<'a> { /// The remaining unconsumed input. input: &'a str, /// The full original source code. /// /// We compare `input` against this to compute the lexer's current offset in /// the source. pub(in crate::front::wgsl) source: &'a str, /// The byte offset of the end of the most recently returned non-trivia /// token. /// /// This is consulted by the `span_from` function, for finding the /// end of the span for larger structures like expressions or /// statements. last_end_offset: usize, /// A stack of unconsumed tokens to which template list discovery has been /// applied. /// /// This is a stack: the next token is at the *end* of the vector, not the /// start. So tokens appear here in the reverse of the order they appear in /// the source. /// /// This doesn't contain the whole source, only those tokens produced by /// [`discover_template_lists`]'s look-ahead, or that have been produced by /// other look-ahead functions like `peek` and `next_if`. When this is empty, /// we call [`discover_template_lists`] to get more. tokens: Vec<(TokenSpan<'a>, &'a str)>, /// Whether or not to ignore doc comments. /// If `true`, doc comments are treated as [`Token::Trivia`]. ignore_doc_comments: bool, /// The set of [enable-extensions] present in the module, determined in a pre-pass. /// /// [enable-extensions]: https://gpuweb.github.io/gpuweb/wgsl/#enable-extensions-sec pub(in crate::front::wgsl) enable_extensions: EnableExtensions, } impl<'a> Lexer<'a> { pub(in crate::front::wgsl) const fn new(input: &'a str, ignore_doc_comments: bool) -> Self { Lexer { input, source: input, last_end_offset: 0, tokens: Vec::new(), enable_extensions: EnableExtensions::empty(), ignore_doc_comments, } } /// Check that `extension` is enabled in `self`. pub(in crate::front::wgsl) fn require_enable_extension( &self, extension: ImplementedEnableExtension, span: Span, ) -> Result<'static, ()> { self.enable_extensions.require(extension, span) } /// Calls the function with a lexer and returns the result of the function as well as the span for everything the function parsed /// /// # Examples /// ```ignore /// let lexer = Lexer::new("5"); /// let (value, span) = lexer.capture_span(Lexer::next_uint_literal); /// assert_eq!(value, 5); /// ``` #[inline] pub fn capture_span( &mut self, inner: impl FnOnce(&mut Self) -> core::result::Result, ) -> core::result::Result<(T, Span), E> { let start = self.current_byte_offset(); let res = inner(self)?; let end = self.current_byte_offset(); Ok((res, Span::from(start..end))) } pub(in crate::front::wgsl) fn start_byte_offset(&mut self) -> usize { loop { // Eat all trivia because `next` doesn't eat trailing trivia. let (token, rest) = consume_token(self.input, false, true); if let Token::Trivia = token { self.input = rest; } else { return self.current_byte_offset(); } } } /// Collect all module doc comments until a non doc token is found. pub(in crate::front::wgsl) fn accumulate_module_doc_comments(&mut self) -> Vec<&'a str> { let mut doc_comments = Vec::new(); loop { // ignore blankspace self.input = consume_any(self.input, is_blankspace).1; let (token, rest) = consume_token(self.input, false, self.ignore_doc_comments); if let Token::ModuleDocComment(doc_comment) = token { self.input = rest; doc_comments.push(doc_comment); } else { return doc_comments; } } } /// Collect all doc comments until a non doc token is found. pub(in crate::front::wgsl) fn accumulate_doc_comments(&mut self) -> Vec<&'a str> { let mut doc_comments = Vec::new(); loop { // ignore blankspace self.input = consume_any(self.input, is_blankspace).1; let (token, rest) = consume_token(self.input, false, self.ignore_doc_comments); if let Token::DocComment(doc_comment) = token { self.input = rest; doc_comments.push(doc_comment); } else { return doc_comments; } } } const fn current_byte_offset(&self) -> usize { self.source.len() - self.input.len() } pub(in crate::front::wgsl) fn span_from(&self, offset: usize) -> Span { Span::from(offset..self.last_end_offset) } pub(in crate::front::wgsl) fn span_with_start(&self, span: Span) -> Span { span.until(&Span::from(0..self.last_end_offset)) } /// Return the next non-whitespace token from `self`. /// /// Assume we are a parse state where bit shift operators may /// occur, but not angle brackets. #[must_use] pub(in crate::front::wgsl) fn next(&mut self) -> TokenSpan<'a> { self.next_impl(true) } #[cfg(test)] pub fn next_with_unignored_doc_comments(&mut self) -> TokenSpan<'a> { self.next_impl(false) } /// Return the next non-whitespace token from `self`, with a span. fn next_impl(&mut self, ignore_doc_comments: bool) -> TokenSpan<'a> { loop { if self.tokens.is_empty() { discover_template_lists( &mut self.tokens, self.source, self.input, ignore_doc_comments || self.ignore_doc_comments, ); } assert!(!self.tokens.is_empty()); let (token, rest) = self.tokens.pop().unwrap(); self.input = rest; self.last_end_offset = self.current_byte_offset(); match token.0 { Token::Trivia => {} _ => return token, } } } #[must_use] pub(in crate::front::wgsl) fn peek(&mut self) -> TokenSpan<'a> { let input = self.input; let last_end_offset = self.last_end_offset; let token = self.next(); self.tokens.push((token, self.input)); self.input = input; self.last_end_offset = last_end_offset; token } /// If the next token matches it's consumed and true is returned pub(in crate::front::wgsl) fn next_if(&mut self, what: Token<'_>) -> bool { let input = self.input; let last_end_offset = self.last_end_offset; let token = self.next(); if token.0 == what { true } else { self.tokens.push((token, self.input)); self.input = input; self.last_end_offset = last_end_offset; false } } pub(in crate::front::wgsl) fn expect_span(&mut self, expected: Token<'a>) -> Result<'a, Span> { let next = self.next(); if next.0 == expected { Ok(next.1) } else { Err(Box::new(Error::Unexpected( next.1, ExpectedToken::Token(expected), ))) } } pub(in crate::front::wgsl) fn expect(&mut self, expected: Token<'a>) -> Result<'a, ()> { self.expect_span(expected)?; Ok(()) } pub(in crate::front::wgsl) fn next_ident_with_span(&mut self) -> Result<'a, (&'a str, Span)> { match self.next() { (Token::Word("_"), span) => Err(Box::new(Error::InvalidIdentifierUnderscore(span))), (Token::Word(word), span) => { if word.starts_with("__") { Err(Box::new(Error::ReservedIdentifierPrefix(span))) } else { Ok((word, span)) } } (_, span) => Err(Box::new(Error::Unexpected(span, ExpectedToken::Identifier))), } } pub(in crate::front::wgsl) fn next_ident(&mut self) -> Result<'a, super::ast::Ident<'a>> { self.next_ident_with_span() .and_then(|(word, span)| Self::word_as_ident(word, span)) .map(|(name, span)| super::ast::Ident { name, span }) } fn word_as_ident(word: &'a str, span: Span) -> Result<'a, (&'a str, Span)> { if crate::keywords::wgsl::RESERVED.contains(&word) { Err(Box::new(Error::ReservedKeyword(span))) } else { Ok((word, span)) } } pub(in crate::front::wgsl) fn open_arguments(&mut self) -> Result<'a, ()> { self.expect(Token::Paren('(')) } pub(in crate::front::wgsl) fn next_argument(&mut self) -> Result<'a, bool> { let paren = Token::Paren(')'); if self.next_if(Token::Separator(',')) { Ok(!self.next_if(paren)) } else { self.expect(paren).map(|()| false) } } } #[cfg(test)] #[track_caller] fn sub_test(source: &str, expected_tokens: &[Token]) { sub_test_with(true, source, expected_tokens); } #[cfg(test)] #[track_caller] fn sub_test_with_and_without_doc_comments(source: &str, expected_tokens: &[Token]) { sub_test_with(false, source, expected_tokens); sub_test_with( true, source, expected_tokens .iter() .filter(|v| !matches!(**v, Token::DocComment(_) | Token::ModuleDocComment(_))) .cloned() .collect::>() .as_slice(), ); } #[cfg(test)] #[track_caller] fn sub_test_with(ignore_doc_comments: bool, source: &str, expected_tokens: &[Token]) { let mut lex = Lexer::new(source, ignore_doc_comments); for &token in expected_tokens { assert_eq!(lex.next_with_unignored_doc_comments().0, token); } assert_eq!(lex.next().0, Token::End); } #[test] fn test_numbers() { use half::f16; // WGSL spec examples // // decimal integer sub_test( "0x123 0X123u 1u 123 0 0i 0x3f", &[ Token::Number(Ok(Number::AbstractInt(291))), Token::Number(Ok(Number::U32(291))), Token::Number(Ok(Number::U32(1))), Token::Number(Ok(Number::AbstractInt(123))), Token::Number(Ok(Number::AbstractInt(0))), Token::Number(Ok(Number::I32(0))), Token::Number(Ok(Number::AbstractInt(63))), ], ); // decimal floating point sub_test( "0.e+4f 01. .01 12.34 .0f 0h 1e-3 0xa.fp+2 0x1P+4f 0X.3 0x3p+2h 0X1.fp-4 0x3.2p+2h", &[ Token::Number(Ok(Number::F32(0.))), Token::Number(Ok(Number::AbstractFloat(1.))), Token::Number(Ok(Number::AbstractFloat(0.01))), Token::Number(Ok(Number::AbstractFloat(12.34))), Token::Number(Ok(Number::F32(0.))), Token::Number(Ok(Number::F16(f16::from_f32(0.)))), Token::Number(Ok(Number::AbstractFloat(0.001))), Token::Number(Ok(Number::AbstractFloat(43.75))), Token::Number(Ok(Number::F32(16.))), Token::Number(Ok(Number::AbstractFloat(0.1875))), // https://github.com/gfx-rs/wgpu/issues/7046 Token::Number(Err(NumberError::NotRepresentable)), // Should be 0.75 Token::Number(Ok(Number::AbstractFloat(0.12109375))), // https://github.com/gfx-rs/wgpu/issues/7046 Token::Number(Err(NumberError::NotRepresentable)), // Should be 12.5 ], ); // MIN / MAX // // min / max decimal integer sub_test( "0i 2147483647i 2147483648i", &[ Token::Number(Ok(Number::I32(0))), Token::Number(Ok(Number::I32(i32::MAX))), Token::Number(Err(NumberError::NotRepresentable)), ], ); // min / max decimal unsigned integer sub_test( "0u 4294967295u 4294967296u", &[ Token::Number(Ok(Number::U32(u32::MIN))), Token::Number(Ok(Number::U32(u32::MAX))), Token::Number(Err(NumberError::NotRepresentable)), ], ); // min / max hexadecimal signed integer sub_test( "0x0i 0x7FFFFFFFi 0x80000000i", &[ Token::Number(Ok(Number::I32(0))), Token::Number(Ok(Number::I32(i32::MAX))), Token::Number(Err(NumberError::NotRepresentable)), ], ); // min / max hexadecimal unsigned integer sub_test( "0x0u 0xFFFFFFFFu 0x100000000u", &[ Token::Number(Ok(Number::U32(u32::MIN))), Token::Number(Ok(Number::U32(u32::MAX))), Token::Number(Err(NumberError::NotRepresentable)), ], ); // min/max decimal abstract int sub_test( "0 9223372036854775807 9223372036854775808", &[ Token::Number(Ok(Number::AbstractInt(0))), Token::Number(Ok(Number::AbstractInt(i64::MAX))), Token::Number(Err(NumberError::NotRepresentable)), ], ); // min/max hexadecimal abstract int sub_test( "0 0x7fffffffffffffff 0x8000000000000000", &[ Token::Number(Ok(Number::AbstractInt(0))), Token::Number(Ok(Number::AbstractInt(i64::MAX))), Token::Number(Err(NumberError::NotRepresentable)), ], ); /// ≈ 2^-126 * 2^−23 (= 2^−149) const SMALLEST_POSITIVE_SUBNORMAL_F32: f32 = 1e-45; /// ≈ 2^-126 * (1 − 2^−23) const LARGEST_SUBNORMAL_F32: f32 = 1.1754942e-38; /// ≈ 2^-126 const SMALLEST_POSITIVE_NORMAL_F32: f32 = f32::MIN_POSITIVE; /// ≈ 1 − 2^−24 const LARGEST_F32_LESS_THAN_ONE: f32 = 0.99999994; /// ≈ 1 + 2^−23 const SMALLEST_F32_LARGER_THAN_ONE: f32 = 1.0000001; /// ≈ 2^127 * (2 − 2^−23) const LARGEST_NORMAL_F32: f32 = f32::MAX; // decimal floating point sub_test( "1e-45f 1.1754942e-38f 1.17549435e-38f 0.99999994f 1.0000001f 3.40282347e+38f", &[ Token::Number(Ok(Number::F32(SMALLEST_POSITIVE_SUBNORMAL_F32))), Token::Number(Ok(Number::F32(LARGEST_SUBNORMAL_F32))), Token::Number(Ok(Number::F32(SMALLEST_POSITIVE_NORMAL_F32))), Token::Number(Ok(Number::F32(LARGEST_F32_LESS_THAN_ONE))), Token::Number(Ok(Number::F32(SMALLEST_F32_LARGER_THAN_ONE))), Token::Number(Ok(Number::F32(LARGEST_NORMAL_F32))), ], ); sub_test( "3.40282367e+38f", &[ Token::Number(Err(NumberError::NotRepresentable)), // ≈ 2^128 ], ); // hexadecimal floating point sub_test( "0x1p-149f 0x7FFFFFp-149f 0x1p-126f 0xFFFFFFp-24f 0x800001p-23f 0xFFFFFFp+104f", &[ Token::Number(Ok(Number::F32(SMALLEST_POSITIVE_SUBNORMAL_F32))), Token::Number(Ok(Number::F32(LARGEST_SUBNORMAL_F32))), Token::Number(Ok(Number::F32(SMALLEST_POSITIVE_NORMAL_F32))), Token::Number(Ok(Number::F32(LARGEST_F32_LESS_THAN_ONE))), Token::Number(Ok(Number::F32(SMALLEST_F32_LARGER_THAN_ONE))), Token::Number(Ok(Number::F32(LARGEST_NORMAL_F32))), ], ); sub_test( "0x1p128f 0x1.000001p0f", &[ Token::Number(Err(NumberError::NotRepresentable)), // = 2^128 Token::Number(Err(NumberError::NotRepresentable)), ], ); } #[test] fn double_floats() { sub_test( "0x1.2p4lf 0x1p8lf 0.0625lf 625e-4lf 10lf 10l", &[ Token::Number(Ok(Number::F64(18.0))), Token::Number(Ok(Number::F64(256.0))), Token::Number(Ok(Number::F64(0.0625))), Token::Number(Ok(Number::F64(0.0625))), Token::Number(Ok(Number::F64(10.0))), Token::Number(Ok(Number::AbstractInt(10))), Token::Word("l"), ], ) } #[test] fn test_tokens() { sub_test("id123_OK", &[Token::Word("id123_OK")]); sub_test( "92No", &[ Token::Number(Ok(Number::AbstractInt(92))), Token::Word("No"), ], ); sub_test( "2u3o", &[ Token::Number(Ok(Number::U32(2))), Token::Number(Ok(Number::AbstractInt(3))), Token::Word("o"), ], ); sub_test( "2.4f44po", &[ Token::Number(Ok(Number::F32(2.4))), Token::Number(Ok(Number::AbstractInt(44))), Token::Word("po"), ], ); sub_test( "Δέλτα réflexion Кызыл 𐰓𐰏𐰇 朝焼け سلام 검정 שָׁלוֹם गुलाबी փիրուզ", &[ Token::Word("Δέλτα"), Token::Word("réflexion"), Token::Word("Кызыл"), Token::Word("𐰓𐰏𐰇"), Token::Word("朝焼け"), Token::Word("سلام"), Token::Word("검정"), Token::Word("שָׁלוֹם"), Token::Word("गुलाबी"), Token::Word("փիրուզ"), ], ); sub_test("æNoø", &[Token::Word("æNoø")]); sub_test("No¾", &[Token::Word("No"), Token::Unknown('¾')]); sub_test("No好", &[Token::Word("No好")]); sub_test("_No", &[Token::Word("_No")]); sub_test_with_and_without_doc_comments( "*/*/***/*//=/*****//", &[ Token::Operation('*'), Token::AssignmentOperation('/'), Token::DocComment("/*****/"), Token::Operation('/'), ], ); // Type suffixes are only allowed on hex float literals // if you provided an exponent. sub_test( "0x1.2f 0x1.2f 0x1.2h 0x1.2H 0x1.2lf", &[ // The 'f' suffixes are taken as a hex digit: // the fractional part is 0x2f / 256. Token::Number(Ok(Number::AbstractFloat(1.0 + 0x2f as f64 / 256.0))), Token::Number(Ok(Number::AbstractFloat(1.0 + 0x2f as f64 / 256.0))), Token::Number(Ok(Number::AbstractFloat(1.125))), Token::Word("h"), Token::Number(Ok(Number::AbstractFloat(1.125))), Token::Word("H"), Token::Number(Ok(Number::AbstractFloat(1.125))), Token::Word("lf"), ], ) } #[test] fn test_variable_decl() { sub_test( "@group(0 ) var< uniform> texture: texture_multisampled_2d ;", &[ Token::Attribute, Token::Word("group"), Token::Paren('('), Token::Number(Ok(Number::AbstractInt(0))), Token::Paren(')'), Token::Word("var"), Token::TemplateArgsStart, Token::Word("uniform"), Token::TemplateArgsEnd, Token::Word("texture"), Token::Separator(':'), Token::Word("texture_multisampled_2d"), Token::TemplateArgsStart, Token::Word("f32"), Token::TemplateArgsEnd, Token::Separator(';'), ], ); sub_test( "var buffer: array;", &[ Token::Word("var"), Token::TemplateArgsStart, Token::Word("storage"), Token::Separator(','), Token::Word("read_write"), Token::TemplateArgsEnd, Token::Word("buffer"), Token::Separator(':'), Token::Word("array"), Token::TemplateArgsStart, Token::Word("u32"), Token::TemplateArgsEnd, Token::Separator(';'), ], ); } #[test] fn test_template_list() { sub_test( "AD", &[ Token::Word("A"), Token::Paren('<'), Token::Word("B"), Token::LogicalOperation('|'), Token::Word("C"), Token::Paren('>'), Token::Word("D"), ], ); sub_test( "A(B(E))", &[ Token::Word("A"), Token::Paren('('), Token::Word("B"), Token::TemplateArgsStart, Token::Word("C"), Token::Separator(','), Token::Word("D"), Token::TemplateArgsEnd, Token::Paren('('), Token::Word("E"), Token::Paren(')'), Token::Paren(')'), ], ); sub_test( "arrayB)>", &[ Token::Word("array"), Token::TemplateArgsStart, Token::Word("i32"), Token::Separator(','), Token::Word("select"), Token::Paren('('), Token::Number(Ok(Number::AbstractInt(2))), Token::Separator(','), Token::Number(Ok(Number::AbstractInt(3))), Token::Separator(','), Token::Word("A"), Token::Paren('>'), Token::Word("B"), Token::Paren(')'), Token::TemplateArgsEnd, ], ); sub_test( "A[BD", &[ Token::Word("A"), Token::Paren('['), Token::Word("B"), Token::Paren('<'), Token::Word("C"), Token::Paren(']'), Token::Paren('>'), Token::Word("D"), ], ); sub_test( "A", &[ Token::Word("A"), Token::TemplateArgsStart, Token::Word("B"), Token::ShiftOperation('<'), Token::Word("C"), Token::TemplateArgsEnd, ], ); sub_test( "A<(B>=C)>", &[ Token::Word("A"), Token::TemplateArgsStart, Token::Paren('('), Token::Word("B"), Token::LogicalOperation('>'), Token::Word("C"), Token::Paren(')'), Token::TemplateArgsEnd, ], ); sub_test( "A=C>", &[ Token::Word("A"), Token::TemplateArgsStart, Token::Word("B"), Token::TemplateArgsEnd, Token::Operation('='), Token::Word("C"), Token::Paren('>'), ], ); } #[test] fn test_comments() { sub_test("// Single comment", &[]); sub_test( "/* multi line comment */", &[], ); sub_test( "/* multi line comment */ // and another", &[], ); } #[test] fn test_doc_comments() { sub_test_with_and_without_doc_comments( "/// Single comment", &[Token::DocComment("/// Single comment")], ); sub_test_with_and_without_doc_comments( "/** multi line comment */", &[Token::DocComment( "/** multi line comment */", )], ); sub_test_with_and_without_doc_comments( "/** multi line comment */ /// and another", &[ Token::DocComment( "/** multi line comment */", ), Token::DocComment("/// and another"), ], ); } #[test] fn test_doc_comment_nested() { sub_test_with_and_without_doc_comments( "/** a comment with nested one /** nested comment */ */ const a : i32 = 2;", &[ Token::DocComment( "/** a comment with nested one /** nested comment */ */", ), Token::Word("const"), Token::Word("a"), Token::Separator(':'), Token::Word("i32"), Token::Operation('='), Token::Number(Ok(Number::AbstractInt(2))), Token::Separator(';'), ], ); } #[test] fn test_doc_comment_long_character() { sub_test_with_and_without_doc_comments( "/// π/2 /// D(𝐡) = ─────────────────────────────────────────────────── /// παₜα_b((𝐡 ⋅ 𝐭)² / αₜ²) + (𝐡 ⋅ 𝐛)² / α_b² +` const a : i32 = 2;", &[ Token::DocComment("/// π/2"), Token::DocComment("/// D(𝐡) = ───────────────────────────────────────────────────"), Token::DocComment("/// παₜα_b((𝐡 ⋅ 𝐭)² / αₜ²) + (𝐡 ⋅ 𝐛)² / α_b² +`"), Token::Word("const"), Token::Word("a"), Token::Separator(':'), Token::Word("i32"), Token::Operation('='), Token::Number(Ok(Number::AbstractInt(2))), Token::Separator(';'), ], ); } #[test] fn test_doc_comments_module() { sub_test_with_and_without_doc_comments( "//! Comment Module //! Another one. /*! Different module comment */ /// Trying to break module comment // Trying to break module comment again //! After a regular comment is ok. /*! Different module comment again */ //! After a break is supported. const //! After anything else is not.", &[ Token::ModuleDocComment("//! Comment Module"), Token::ModuleDocComment("//! Another one."), Token::ModuleDocComment("/*! Different module comment */"), Token::DocComment("/// Trying to break module comment"), Token::ModuleDocComment("//! After a regular comment is ok."), Token::ModuleDocComment("/*! Different module comment again */"), Token::ModuleDocComment("//! After a break is supported."), Token::Word("const"), Token::ModuleDocComment("//! After anything else is not."), ], ); } naga-29.0.3/src/front/wgsl/parse/mod.rs000064400000000000000000002712501046102023000157660ustar 00000000000000use alloc::{boxed::Box, vec::Vec}; use directive::enable_extension::ImplementedEnableExtension; use crate::diagnostic_filter::{ self, DiagnosticFilter, DiagnosticFilterMap, DiagnosticFilterNode, FilterableTriggeringRule, ShouldConflictOnFullDuplicate, StandardFilterableTriggeringRule, }; use crate::front::wgsl::error::{DiagnosticAttributeNotSupportedPosition, Error, ExpectedToken}; use crate::front::wgsl::parse::directive::enable_extension::{EnableExtension, EnableExtensions}; use crate::front::wgsl::parse::directive::language_extension::LanguageExtension; use crate::front::wgsl::parse::directive::DirectiveKind; use crate::front::wgsl::parse::lexer::{Lexer, Token, TokenSpan}; use crate::front::wgsl::parse::number::Number; use crate::front::wgsl::Result; use crate::front::SymbolTable; use crate::{Arena, FastHashSet, FastIndexSet, Handle, ShaderStage, Span}; pub mod ast; pub mod conv; pub mod directive; pub mod lexer; pub mod number; /// State for constructing an AST expression. /// /// Not to be confused with [`lower::ExpressionContext`], which is for producing /// Naga IR from the AST we produce here. /// /// [`lower::ExpressionContext`]: super::lower::ExpressionContext struct ExpressionContext<'input, 'temp, 'out> { /// The [`TranslationUnit::expressions`] arena to which we should contribute /// expressions. /// /// [`TranslationUnit::expressions`]: ast::TranslationUnit::expressions expressions: &'out mut Arena>, /// A map from identifiers in scope to the locals/arguments they represent. /// /// The handles refer to the [`locals`] arena; see that field's /// documentation for details. /// /// [`locals`]: ExpressionContext::locals local_table: &'temp mut SymbolTable<&'input str, Handle>, /// Local variable and function argument arena for the function we're building. /// /// Note that the [`ast::Local`] here is actually a zero-sized type. This /// `Arena`'s only role is to assign a unique `Handle` to each local /// identifier, and track its definition's span for use in diagnostics. All /// the detailed information about locals - names, types, etc. - is kept in /// the [`LocalDecl`] statements we parsed from their declarations. For /// arguments, that information is kept in [`arguments`]. /// /// In the AST, when an [`Ident`] expression refers to a local variable or /// argument, its [`IdentExpr`] holds the referent's `Handle` in this /// arena. /// /// During lowering, [`LocalDecl`] statements add entries to a per-function /// table that maps `Handle` values to their Naga representations, /// accessed via [`StatementContext::local_table`] and /// [`LocalExpressionContext::local_table`]. This table is then consulted when /// lowering subsequent [`Ident`] expressions. /// /// [`LocalDecl`]: ast::StatementKind::LocalDecl /// [`arguments`]: ast::Function::arguments /// [`Ident`]: ast::Expression::Ident /// [`IdentExpr`]: ast::IdentExpr /// [`StatementContext::local_table`]: super::lower::StatementContext::local_table /// [`LocalExpressionContext::local_table`]: super::lower::LocalExpressionContext::local_table locals: &'out mut Arena, /// Identifiers used by the current global declaration that have no local definition. /// /// This becomes the [`GlobalDecl`]'s [`dependencies`] set. /// /// Note that we don't know at parse time what kind of [`GlobalDecl`] the /// name refers to. We can't look up names until we've seen the entire /// translation unit. /// /// [`GlobalDecl`]: ast::GlobalDecl /// [`dependencies`]: ast::GlobalDecl::dependencies unresolved: &'out mut FastIndexSet>, } impl<'a> ExpressionContext<'a, '_, '_> { fn parse_binary_op( &mut self, lexer: &mut Lexer<'a>, classifier: impl Fn(Token<'a>) -> Option, mut parser: impl FnMut(&mut Lexer<'a>, &mut Self) -> Result<'a, Handle>>, ) -> Result<'a, Handle>> { let start = lexer.start_byte_offset(); let mut accumulator = parser(lexer, self)?; while let Some(op) = classifier(lexer.peek().0) { let _ = lexer.next(); let left = accumulator; let right = parser(lexer, self)?; accumulator = self.expressions.append( ast::Expression::Binary { op, left, right }, lexer.span_from(start), ); } Ok(accumulator) } fn declare_local(&mut self, name: ast::Ident<'a>) -> Result<'a, Handle> { let handle = self.locals.append(ast::Local, name.span); if let Some(old) = self.local_table.add(name.name, handle) { Err(Box::new(Error::Redefinition { previous: self.locals.get_span(old), current: name.span, })) } else { Ok(handle) } } } /// Which grammar rule we are in the midst of parsing. /// /// This is used for error checking. `Parser` maintains a stack of /// these and (occasionally) checks that it is being pushed and popped /// as expected. #[derive(Copy, Clone, Debug, PartialEq)] enum Rule { Attribute, VariableDecl, FunctionDecl, Block, Statement, PrimaryExpr, SingularExpr, UnaryExpr, GeneralExpr, Directive, GenericExpr, EnclosedExpr, LhsExpr, } struct ParsedAttribute { value: Option, } impl Default for ParsedAttribute { fn default() -> Self { Self { value: None } } } impl ParsedAttribute { fn set(&mut self, value: T, name_span: Span) -> Result<'static, ()> { if self.value.is_some() { return Err(Box::new(Error::RepeatedAttribute(name_span))); } self.value = Some(value); Ok(()) } } #[derive(Default)] struct BindingParser<'a> { location: ParsedAttribute>>, built_in: ParsedAttribute, interpolation: ParsedAttribute, sampling: ParsedAttribute, invariant: ParsedAttribute, blend_src: ParsedAttribute>>, per_primitive: ParsedAttribute<()>, } impl<'a> BindingParser<'a> { fn parse( &mut self, parser: &mut Parser, lexer: &mut Lexer<'a>, name: &'a str, name_span: Span, ctx: &mut ExpressionContext<'a, '_, '_>, ) -> Result<'a, ()> { match name { "location" => { lexer.expect(Token::Paren('('))?; self.location .set(parser.expression(lexer, ctx)?, name_span)?; lexer.next_if(Token::Separator(',')); lexer.expect(Token::Paren(')'))?; } "builtin" => { lexer.expect(Token::Paren('('))?; let (raw, span) = lexer.next_ident_with_span()?; self.built_in.set( conv::map_built_in(&lexer.enable_extensions, raw, span)?, name_span, )?; lexer.next_if(Token::Separator(',')); lexer.expect(Token::Paren(')'))?; } "interpolate" => { lexer.expect(Token::Paren('('))?; let (raw, span) = lexer.next_ident_with_span()?; self.interpolation .set(conv::map_interpolation(raw, span)?, name_span)?; if lexer.next_if(Token::Separator(',')) { let (raw, span) = lexer.next_ident_with_span()?; self.sampling .set(conv::map_sampling(raw, span)?, name_span)?; } lexer.next_if(Token::Separator(',')); lexer.expect(Token::Paren(')'))?; } "invariant" => { self.invariant.set(true, name_span)?; } "blend_src" => { lexer.require_enable_extension( ImplementedEnableExtension::DualSourceBlending, name_span, )?; lexer.expect(Token::Paren('('))?; self.blend_src .set(parser.expression(lexer, ctx)?, name_span)?; lexer.next_if(Token::Separator(',')); lexer.expect(Token::Paren(')'))?; } "per_primitive" => { lexer.require_enable_extension( ImplementedEnableExtension::WgpuMeshShader, name_span, )?; self.per_primitive.set((), name_span)?; } _ => return Err(Box::new(Error::UnknownAttribute(name_span))), } Ok(()) } fn finish(self, span: Span) -> Result<'a, Option>> { match ( self.location.value, self.built_in.value, self.interpolation.value, self.sampling.value, self.invariant.value.unwrap_or_default(), self.blend_src.value, self.per_primitive.value, ) { (None, None, None, None, false, None, None) => Ok(None), (Some(location), None, interpolation, sampling, false, blend_src, per_primitive) => { // Before handing over the completed `Module`, we call // `apply_default_interpolation` to ensure that the interpolation and // sampling have been explicitly specified on all vertex shader output and fragment // shader input user bindings, so leaving them potentially `None` here is fine. Ok(Some(ast::Binding::Location { location, interpolation, sampling, blend_src, per_primitive: per_primitive.is_some(), })) } (None, Some(crate::BuiltIn::Position { .. }), None, None, invariant, None, None) => { Ok(Some(ast::Binding::BuiltIn(crate::BuiltIn::Position { invariant, }))) } (None, Some(built_in), None, None, false, None, None) => { Ok(Some(ast::Binding::BuiltIn(built_in))) } (_, _, _, _, _, _, _) => Err(Box::new(Error::InconsistentBinding(span))), } } } /// Configuration for the whole parser run. pub struct Options { /// Controls whether the parser should parse doc comments. pub parse_doc_comments: bool, /// Capabilities to enable during parsing. pub capabilities: crate::valid::Capabilities, } impl Options { /// Creates a new default [`Options`]. pub const fn new() -> Self { Options { parse_doc_comments: false, capabilities: crate::valid::Capabilities::all(), } } } pub struct Parser { rules: Vec<(Rule, usize)>, recursion_depth: u32, } impl Parser { pub const fn new() -> Self { Parser { rules: Vec::new(), recursion_depth: 0, } } fn reset(&mut self) { self.rules.clear(); self.recursion_depth = 0; } fn push_rule_span(&mut self, rule: Rule, lexer: &mut Lexer<'_>) { self.rules.push((rule, lexer.start_byte_offset())); } fn pop_rule_span(&mut self, lexer: &Lexer<'_>) -> Span { let (_, initial) = self.rules.pop().unwrap(); lexer.span_from(initial) } fn peek_rule_span(&mut self, lexer: &Lexer<'_>) -> Span { let &(_, initial) = self.rules.last().unwrap(); lexer.span_from(initial) } fn race_rules(&self, rule0: Rule, rule1: Rule) -> Option { Some( self.rules .iter() .rev() .find(|&x| x.0 == rule0 || x.0 == rule1)? .0, ) } fn track_recursion<'a, F, R>(&mut self, f: F) -> Result<'a, R> where F: FnOnce(&mut Self) -> Result<'a, R>, { self.recursion_depth += 1; if self.recursion_depth >= 256 { return Err(Box::new(Error::Internal("Parser recursion limit exceeded"))); } let ret = f(self); self.recursion_depth -= 1; ret } fn switch_value<'a>( &mut self, lexer: &mut Lexer<'a>, ctx: &mut ExpressionContext<'a, '_, '_>, ) -> Result<'a, ast::SwitchValue<'a>> { if lexer.next_if(Token::Word("default")) { return Ok(ast::SwitchValue::Default); } let expr = self.expression(lexer, ctx)?; Ok(ast::SwitchValue::Expr(expr)) } /// Expects `name` to be consumed (not in lexer). fn arguments<'a>( &mut self, lexer: &mut Lexer<'a>, ctx: &mut ExpressionContext<'a, '_, '_>, ) -> Result<'a, Vec>>> { self.push_rule_span(Rule::EnclosedExpr, lexer); lexer.open_arguments()?; let mut arguments = Vec::new(); loop { if !arguments.is_empty() { if !lexer.next_argument()? { break; } } else if lexer.next_if(Token::Paren(')')) { break; } let arg = self.expression(lexer, ctx)?; arguments.push(arg); } self.pop_rule_span(lexer); Ok(arguments) } fn enclosed_expression<'a>( &mut self, lexer: &mut Lexer<'a>, ctx: &mut ExpressionContext<'a, '_, '_>, ) -> Result<'a, Handle>> { self.push_rule_span(Rule::EnclosedExpr, lexer); let expr = self.expression(lexer, ctx)?; self.pop_rule_span(lexer); Ok(expr) } fn ident_expr<'a>( &mut self, name: &'a str, name_span: Span, ctx: &mut ExpressionContext<'a, '_, '_>, ) -> ast::IdentExpr<'a> { match ctx.local_table.lookup(name) { Some(&local) => ast::IdentExpr::Local(local), None => { ctx.unresolved.insert(ast::Dependency { ident: name, usage: name_span, }); ast::IdentExpr::Unresolved(name) } } } fn primary_expression<'a>( &mut self, lexer: &mut Lexer<'a>, ctx: &mut ExpressionContext<'a, '_, '_>, token: TokenSpan<'a>, ) -> Result<'a, Handle>> { self.push_rule_span(Rule::PrimaryExpr, lexer); const fn literal_ray_flag<'b>(flag: crate::RayFlag) -> ast::Expression<'b> { ast::Expression::Literal(ast::Literal::Number(Number::U32(flag.bits()))) } const fn literal_ray_intersection<'b>( intersection: crate::RayQueryIntersection, ) -> ast::Expression<'b> { ast::Expression::Literal(ast::Literal::Number(Number::U32(intersection as u32))) } let expr = match token { (Token::Paren('('), _) => { let expr = self.enclosed_expression(lexer, ctx)?; lexer.expect(Token::Paren(')'))?; self.pop_rule_span(lexer); return Ok(expr); } (Token::Word("true"), _) => ast::Expression::Literal(ast::Literal::Bool(true)), (Token::Word("false"), _) => ast::Expression::Literal(ast::Literal::Bool(false)), (Token::Number(res), span) => { let num = res.map_err(|err| Error::BadNumber(span, err))?; if let Some(enable_extension) = num.requires_enable_extension() { lexer.require_enable_extension(enable_extension, span)?; } ast::Expression::Literal(ast::Literal::Number(num)) } (Token::Word("RAY_FLAG_NONE"), _) => literal_ray_flag(crate::RayFlag::empty()), (Token::Word("RAY_FLAG_FORCE_OPAQUE"), _) => { literal_ray_flag(crate::RayFlag::FORCE_OPAQUE) } (Token::Word("RAY_FLAG_FORCE_NO_OPAQUE"), _) => { literal_ray_flag(crate::RayFlag::FORCE_NO_OPAQUE) } (Token::Word("RAY_FLAG_TERMINATE_ON_FIRST_HIT"), _) => { literal_ray_flag(crate::RayFlag::TERMINATE_ON_FIRST_HIT) } (Token::Word("RAY_FLAG_SKIP_CLOSEST_HIT_SHADER"), _) => { literal_ray_flag(crate::RayFlag::SKIP_CLOSEST_HIT_SHADER) } (Token::Word("RAY_FLAG_CULL_BACK_FACING"), _) => { literal_ray_flag(crate::RayFlag::CULL_BACK_FACING) } (Token::Word("RAY_FLAG_CULL_FRONT_FACING"), _) => { literal_ray_flag(crate::RayFlag::CULL_FRONT_FACING) } (Token::Word("RAY_FLAG_CULL_OPAQUE"), _) => { literal_ray_flag(crate::RayFlag::CULL_OPAQUE) } (Token::Word("RAY_FLAG_CULL_NO_OPAQUE"), _) => { literal_ray_flag(crate::RayFlag::CULL_NO_OPAQUE) } (Token::Word("RAY_FLAG_SKIP_TRIANGLES"), _) => { literal_ray_flag(crate::RayFlag::SKIP_TRIANGLES) } (Token::Word("RAY_FLAG_SKIP_AABBS"), _) => literal_ray_flag(crate::RayFlag::SKIP_AABBS), (Token::Word("RAY_QUERY_INTERSECTION_NONE"), _) => { literal_ray_intersection(crate::RayQueryIntersection::None) } (Token::Word("RAY_QUERY_INTERSECTION_TRIANGLE"), _) => { literal_ray_intersection(crate::RayQueryIntersection::Triangle) } (Token::Word("RAY_QUERY_INTERSECTION_GENERATED"), _) => { literal_ray_intersection(crate::RayQueryIntersection::Generated) } (Token::Word("RAY_QUERY_INTERSECTION_AABB"), _) => { literal_ray_intersection(crate::RayQueryIntersection::Aabb) } (Token::Word(word), span) => { let ident = self.template_elaborated_ident(word, span, lexer, ctx)?; if let Token::Paren('(') = lexer.peek().0 { let arguments = self.arguments(lexer, ctx)?; ast::Expression::Call(ast::CallPhrase { function: ident, arguments, }) } else { ast::Expression::Ident(ident) } } other => { return Err(Box::new(Error::Unexpected( other.1, ExpectedToken::PrimaryExpression, ))) } }; self.pop_rule_span(lexer); let span = lexer.span_with_start(token.1); let expr = ctx.expressions.append(expr, span); Ok(expr) } fn component_or_swizzle_specifier<'a>( &mut self, expr_start: Span, lexer: &mut Lexer<'a>, ctx: &mut ExpressionContext<'a, '_, '_>, expr: Handle>, ) -> Result<'a, Handle>> { let mut expr = expr; loop { let expression = match lexer.peek().0 { Token::Separator('.') => { let _ = lexer.next(); let field = lexer.next_ident()?; ast::Expression::Member { base: expr, field } } Token::Paren('[') => { let _ = lexer.next(); let index = self.enclosed_expression(lexer, ctx)?; lexer.expect(Token::Paren(']'))?; ast::Expression::Index { base: expr, index } } _ => break, }; let span = lexer.span_with_start(expr_start); expr = ctx.expressions.append(expression, span); } Ok(expr) } /// Parse a `unary_expression`. fn unary_expression<'a>( &mut self, lexer: &mut Lexer<'a>, ctx: &mut ExpressionContext<'a, '_, '_>, ) -> Result<'a, Handle>> { self.push_rule_span(Rule::UnaryExpr, lexer); enum UnaryOp { Negate, LogicalNot, BitwiseNot, Deref, AddrOf, } let mut ops = Vec::new(); let mut expr; loop { match lexer.next() { (Token::Operation('-'), span) => { ops.push((UnaryOp::Negate, span)); } (Token::Operation('!'), span) => { ops.push((UnaryOp::LogicalNot, span)); } (Token::Operation('~'), span) => { ops.push((UnaryOp::BitwiseNot, span)); } (Token::Operation('*'), span) => { ops.push((UnaryOp::Deref, span)); } (Token::Operation('&'), span) => { ops.push((UnaryOp::AddrOf, span)); } token => { expr = self.singular_expression(lexer, ctx, token)?; break; } }; } for (op, span) in ops.into_iter().rev() { let e = match op { UnaryOp::Negate => ast::Expression::Unary { op: crate::UnaryOperator::Negate, expr, }, UnaryOp::LogicalNot => ast::Expression::Unary { op: crate::UnaryOperator::LogicalNot, expr, }, UnaryOp::BitwiseNot => ast::Expression::Unary { op: crate::UnaryOperator::BitwiseNot, expr, }, UnaryOp::Deref => ast::Expression::Deref(expr), UnaryOp::AddrOf => ast::Expression::AddrOf(expr), }; let span = lexer.span_with_start(span); expr = ctx.expressions.append(e, span); } self.pop_rule_span(lexer); Ok(expr) } /// Parse a `lhs_expression`. /// /// LHS expressions only support the `&` and `*` operators and /// the `[]` and `.` postfix selectors. fn lhs_expression<'a>( &mut self, lexer: &mut Lexer<'a>, ctx: &mut ExpressionContext<'a, '_, '_>, token: Option>, expected_token: ExpectedToken<'a>, ) -> Result<'a, Handle>> { self.track_recursion(|this| { this.push_rule_span(Rule::LhsExpr, lexer); let token = token.unwrap_or_else(|| lexer.next()); let expr = match token { (Token::Operation('*'), _) => { let expr = this.lhs_expression(lexer, ctx, None, ExpectedToken::LhsExpression)?; let expr = ast::Expression::Deref(expr); let span = this.peek_rule_span(lexer); ctx.expressions.append(expr, span) } (Token::Operation('&'), _) => { let expr = this.lhs_expression(lexer, ctx, None, ExpectedToken::LhsExpression)?; let expr = ast::Expression::AddrOf(expr); let span = this.peek_rule_span(lexer); ctx.expressions.append(expr, span) } (Token::Paren('('), span) => { let expr = this.lhs_expression(lexer, ctx, None, ExpectedToken::LhsExpression)?; lexer.expect(Token::Paren(')'))?; this.component_or_swizzle_specifier(span, lexer, ctx, expr)? } (Token::Word(word), span) => { let ident = this.ident_expr(word, span, ctx); let ident = ast::TemplateElaboratedIdent { ident, ident_span: span, template_list: Vec::new(), template_list_span: Span::UNDEFINED, }; let ident = ctx.expressions.append(ast::Expression::Ident(ident), span); this.component_or_swizzle_specifier(span, lexer, ctx, ident)? } (_, span) => { return Err(Box::new(Error::Unexpected(span, expected_token))); } }; this.pop_rule_span(lexer); Ok(expr) }) } /// Parse a `singular_expression`. fn singular_expression<'a>( &mut self, lexer: &mut Lexer<'a>, ctx: &mut ExpressionContext<'a, '_, '_>, token: TokenSpan<'a>, ) -> Result<'a, Handle>> { self.push_rule_span(Rule::SingularExpr, lexer); let primary_expr = self.primary_expression(lexer, ctx, token)?; let singular_expr = self.component_or_swizzle_specifier(token.1, lexer, ctx, primary_expr)?; self.pop_rule_span(lexer); Ok(singular_expr) } fn equality_expression<'a>( &mut self, lexer: &mut Lexer<'a>, context: &mut ExpressionContext<'a, '_, '_>, ) -> Result<'a, Handle>> { // equality_expression context.parse_binary_op( lexer, |token| match token { Token::LogicalOperation('=') => Some(crate::BinaryOperator::Equal), Token::LogicalOperation('!') => Some(crate::BinaryOperator::NotEqual), _ => None, }, // relational_expression |lexer, context| { let enclosing = self.race_rules(Rule::GenericExpr, Rule::EnclosedExpr); context.parse_binary_op( lexer, match enclosing { Some(Rule::GenericExpr) => |token| match token { Token::LogicalOperation('<') => Some(crate::BinaryOperator::LessEqual), _ => None, }, _ => |token| match token { Token::Paren('<') => Some(crate::BinaryOperator::Less), Token::Paren('>') => Some(crate::BinaryOperator::Greater), Token::LogicalOperation('<') => Some(crate::BinaryOperator::LessEqual), Token::LogicalOperation('>') => { Some(crate::BinaryOperator::GreaterEqual) } _ => None, }, }, // shift_expression |lexer, context| { context.parse_binary_op( lexer, match enclosing { Some(Rule::GenericExpr) => |token| match token { Token::ShiftOperation('<') => { Some(crate::BinaryOperator::ShiftLeft) } _ => None, }, _ => |token| match token { Token::ShiftOperation('<') => { Some(crate::BinaryOperator::ShiftLeft) } Token::ShiftOperation('>') => { Some(crate::BinaryOperator::ShiftRight) } _ => None, }, }, // additive_expression |lexer, context| { context.parse_binary_op( lexer, |token| match token { Token::Operation('+') => Some(crate::BinaryOperator::Add), Token::Operation('-') => { Some(crate::BinaryOperator::Subtract) } _ => None, }, // multiplicative_expression |lexer, context| { context.parse_binary_op( lexer, |token| match token { Token::Operation('*') => { Some(crate::BinaryOperator::Multiply) } Token::Operation('/') => { Some(crate::BinaryOperator::Divide) } Token::Operation('%') => { Some(crate::BinaryOperator::Modulo) } _ => None, }, |lexer, context| self.unary_expression(lexer, context), ) }, ) }, ) }, ) }, ) } fn expression<'a>( &mut self, lexer: &mut Lexer<'a>, context: &mut ExpressionContext<'a, '_, '_>, ) -> Result<'a, Handle>> { self.push_rule_span(Rule::GeneralExpr, lexer); // logical_or_expression let handle = context.parse_binary_op( lexer, |token| match token { Token::LogicalOperation('|') => Some(crate::BinaryOperator::LogicalOr), _ => None, }, // logical_and_expression |lexer, context| { context.parse_binary_op( lexer, |token| match token { Token::LogicalOperation('&') => Some(crate::BinaryOperator::LogicalAnd), _ => None, }, // inclusive_or_expression |lexer, context| { context.parse_binary_op( lexer, |token| match token { Token::Operation('|') => Some(crate::BinaryOperator::InclusiveOr), _ => None, }, // exclusive_or_expression |lexer, context| { context.parse_binary_op( lexer, |token| match token { Token::Operation('^') => { Some(crate::BinaryOperator::ExclusiveOr) } _ => None, }, // and_expression |lexer, context| { context.parse_binary_op( lexer, |token| match token { Token::Operation('&') => { Some(crate::BinaryOperator::And) } _ => None, }, |lexer, context| { self.equality_expression(lexer, context) }, ) }, ) }, ) }, ) }, )?; self.pop_rule_span(lexer); Ok(handle) } fn optionally_typed_ident<'a>( &mut self, lexer: &mut Lexer<'a>, ctx: &mut ExpressionContext<'a, '_, '_>, ) -> Result<'a, (ast::Ident<'a>, Option>)> { let name = lexer.next_ident()?; let ty = if lexer.next_if(Token::Separator(':')) { Some(self.type_specifier(lexer, ctx)?) } else { None }; Ok((name, ty)) } /// 'var' _disambiguate_template template_list? optionally_typed_ident fn variable_decl<'a>( &mut self, lexer: &mut Lexer<'a>, ctx: &mut ExpressionContext<'a, '_, '_>, ) -> Result<'a, ast::GlobalVariable<'a>> { self.push_rule_span(Rule::VariableDecl, lexer); let (template_list, _) = self.maybe_template_list(lexer, ctx)?; let (name, ty) = self.optionally_typed_ident(lexer, ctx)?; let init = if lexer.next_if(Token::Operation('=')) { let handle = self.expression(lexer, ctx)?; Some(handle) } else { None }; lexer.expect(Token::Separator(';'))?; self.pop_rule_span(lexer); Ok(ast::GlobalVariable { name, template_list, binding: None, ty, init, doc_comments: Vec::new(), memory_decorations: crate::MemoryDecorations::empty(), }) } fn struct_body<'a>( &mut self, lexer: &mut Lexer<'a>, ctx: &mut ExpressionContext<'a, '_, '_>, ) -> Result<'a, Vec>> { let mut members = Vec::new(); let mut member_names = FastHashSet::default(); lexer.expect(Token::Paren('{'))?; let mut ready = true; while !lexer.next_if(Token::Paren('}')) { if !ready { return Err(Box::new(Error::Unexpected( lexer.next().1, ExpectedToken::Token(Token::Separator(',')), ))); } let doc_comments = lexer.accumulate_doc_comments(); let (mut size, mut align) = (ParsedAttribute::default(), ParsedAttribute::default()); self.push_rule_span(Rule::Attribute, lexer); let mut bind_parser = BindingParser::default(); while lexer.next_if(Token::Attribute) { match lexer.next_ident_with_span()? { ("size", name_span) => { lexer.expect(Token::Paren('('))?; let expr = self.expression(lexer, ctx)?; lexer.next_if(Token::Separator(',')); lexer.expect(Token::Paren(')'))?; size.set(expr, name_span)?; } ("align", name_span) => { lexer.expect(Token::Paren('('))?; let expr = self.expression(lexer, ctx)?; lexer.next_if(Token::Separator(',')); lexer.expect(Token::Paren(')'))?; align.set(expr, name_span)?; } (word, word_span) => bind_parser.parse(self, lexer, word, word_span, ctx)?, } } let bind_span = self.pop_rule_span(lexer); let binding = bind_parser.finish(bind_span)?; let name = lexer.next_ident()?; lexer.expect(Token::Separator(':'))?; let ty = self.type_specifier(lexer, ctx)?; ready = lexer.next_if(Token::Separator(',')); members.push(ast::StructMember { name, ty, binding, size: size.value, align: align.value, doc_comments, }); if !member_names.insert(name.name) { return Err(Box::new(Error::Redefinition { previous: members .iter() .find(|x| x.name.name == name.name) .map(|x| x.name.span) .unwrap(), current: name.span, })); } } Ok(members) } fn maybe_template_list<'a>( &mut self, lexer: &mut Lexer<'a>, ctx: &mut ExpressionContext<'a, '_, '_>, ) -> Result<'a, (Vec>>, Span)> { let start = lexer.start_byte_offset(); if lexer.next_if(Token::TemplateArgsStart) { let mut args = Vec::new(); args.push(self.expression(lexer, ctx)?); while lexer.next_if(Token::Separator(',')) && lexer.peek().0 != Token::TemplateArgsEnd { args.push(self.expression(lexer, ctx)?); } lexer.expect(Token::TemplateArgsEnd)?; let span = lexer.span_from(start); Ok((args, span)) } else { Ok((Vec::new(), Span::UNDEFINED)) } } fn template_elaborated_ident<'a>( &mut self, word: &'a str, span: Span, lexer: &mut Lexer<'a>, ctx: &mut ExpressionContext<'a, '_, '_>, ) -> Result<'a, ast::TemplateElaboratedIdent<'a>> { let ident = self.ident_expr(word, span, ctx); let (template_list, template_list_span) = self.maybe_template_list(lexer, ctx)?; Ok(ast::TemplateElaboratedIdent { ident, ident_span: span, template_list, template_list_span, }) } fn type_specifier<'a>( &mut self, lexer: &mut Lexer<'a>, ctx: &mut ExpressionContext<'a, '_, '_>, ) -> Result<'a, ast::TemplateElaboratedIdent<'a>> { let (name, span) = lexer.next_ident_with_span()?; self.template_elaborated_ident(name, span, lexer, ctx) } /// Parses assignment, increment and decrement statements /// /// This does not consume or require a final `;` token. In the update /// expression of a C-style `for` loop header, there is no terminating `;`. fn variable_updating_statement<'a>( &mut self, lexer: &mut Lexer<'a>, ctx: &mut ExpressionContext<'a, '_, '_>, block: &mut ast::Block<'a>, token: TokenSpan<'a>, expected_token: ExpectedToken<'a>, ) -> Result<'a, ()> { match token { (Token::Word("_"), span) => { lexer.expect(Token::Operation('='))?; let expr = self.expression(lexer, ctx)?; let span = lexer.span_with_start(span); block.stmts.push(ast::Statement { kind: ast::StatementKind::Phony(expr), span, }); return Ok(()); } _ => {} } let target = self.lhs_expression(lexer, ctx, Some(token), expected_token)?; let (op, value) = match lexer.next() { (Token::Operation('='), _) => { let value = self.expression(lexer, ctx)?; (None, value) } (Token::AssignmentOperation(c), _) => { use crate::BinaryOperator as Bo; let op = match c { '<' => Bo::ShiftLeft, '>' => Bo::ShiftRight, '+' => Bo::Add, '-' => Bo::Subtract, '*' => Bo::Multiply, '/' => Bo::Divide, '%' => Bo::Modulo, '&' => Bo::And, '|' => Bo::InclusiveOr, '^' => Bo::ExclusiveOr, // Note: `consume_token` shouldn't produce any other assignment ops _ => unreachable!(), }; let value = self.expression(lexer, ctx)?; (Some(op), value) } op_token @ (Token::IncrementOperation | Token::DecrementOperation, _) => { let op = match op_token.0 { Token::IncrementOperation => ast::StatementKind::Increment, Token::DecrementOperation => ast::StatementKind::Decrement, _ => unreachable!(), }; let span = lexer.span_with_start(token.1); block.stmts.push(ast::Statement { kind: op(target), span, }); return Ok(()); } (_, span) => return Err(Box::new(Error::Unexpected(span, ExpectedToken::Assignment))), }; let span = lexer.span_with_start(token.1); block.stmts.push(ast::Statement { kind: ast::StatementKind::Assign { target, op, value }, span, }); Ok(()) } /// Parse a function call statement. /// /// This assumes that `token` has been consumed from the lexer. /// /// This does not consume or require a final `;` token. In the update /// expression of a C-style `for` loop header, there is no terminating `;`. fn maybe_func_call_statement<'a>( &mut self, lexer: &mut Lexer<'a>, context: &mut ExpressionContext<'a, '_, '_>, block: &mut ast::Block<'a>, token: TokenSpan<'a>, ) -> Result<'a, bool> { let (name, name_span) = match token { (Token::Word(name), span) => (name, span), _ => return Ok(false), }; let ident = self.template_elaborated_ident(name, name_span, lexer, context)?; if ident.template_list.is_empty() && !matches!(lexer.peek(), (Token::Paren('('), _)) { return Ok(false); } self.push_rule_span(Rule::SingularExpr, lexer); let arguments = self.arguments(lexer, context)?; let span = lexer.span_with_start(name_span); block.stmts.push(ast::Statement { kind: ast::StatementKind::Call(ast::CallPhrase { function: ident, arguments, }), span, }); self.pop_rule_span(lexer); Ok(true) } /// Parses func_call_statement and variable_updating_statement /// /// This does not consume or require a final `;` token. In the update /// expression of a C-style `for` loop header, there is no terminating `;`. fn func_call_or_variable_updating_statement<'a>( &mut self, lexer: &mut Lexer<'a>, context: &mut ExpressionContext<'a, '_, '_>, block: &mut ast::Block<'a>, token: TokenSpan<'a>, expected_token: ExpectedToken<'a>, ) -> Result<'a, ()> { if !self.maybe_func_call_statement(lexer, context, block, token)? { self.variable_updating_statement(lexer, context, block, token, expected_token)?; } Ok(()) } /// Parses variable_or_value_statement, func_call_statement and variable_updating_statement. /// /// This is equivalent to the `for_init` production in the WGSL spec, /// but it's also used for parsing these forms when they appear within a block, /// hence the longer name. /// /// This does not consume the following `;` token. fn variable_or_value_or_func_call_or_variable_updating_statement<'a>( &mut self, lexer: &mut Lexer<'a>, ctx: &mut ExpressionContext<'a, '_, '_>, block: &mut ast::Block<'a>, token: TokenSpan<'a>, expected_token: ExpectedToken<'a>, ) -> Result<'a, ()> { let local_decl = match token { (Token::Word("let"), _) => { let (name, given_ty) = self.optionally_typed_ident(lexer, ctx)?; lexer.expect(Token::Operation('='))?; let expr_id = self.expression(lexer, ctx)?; let handle = ctx.declare_local(name)?; ast::LocalDecl::Let(ast::Let { name, ty: given_ty, init: expr_id, handle, }) } (Token::Word("const"), _) => { let (name, given_ty) = self.optionally_typed_ident(lexer, ctx)?; lexer.expect(Token::Operation('='))?; let expr_id = self.expression(lexer, ctx)?; let handle = ctx.declare_local(name)?; ast::LocalDecl::Const(ast::LocalConst { name, ty: given_ty, init: expr_id, handle, }) } (Token::Word("var"), _) => { if lexer.next_if(Token::TemplateArgsStart) { let (class_str, span) = lexer.next_ident_with_span()?; if class_str != "function" { return Err(Box::new(Error::InvalidLocalVariableAddressSpace(span))); } lexer.expect(Token::TemplateArgsEnd)?; } let (name, ty) = self.optionally_typed_ident(lexer, ctx)?; let init = if lexer.next_if(Token::Operation('=')) { let init = self.expression(lexer, ctx)?; Some(init) } else { None }; let handle = ctx.declare_local(name)?; ast::LocalDecl::Var(ast::LocalVariable { name, ty, init, handle, }) } token => { return self.func_call_or_variable_updating_statement( lexer, ctx, block, token, expected_token, ); } }; let span = lexer.span_with_start(token.1); block.stmts.push(ast::Statement { kind: ast::StatementKind::LocalDecl(local_decl), span, }); Ok(()) } fn statement<'a>( &mut self, lexer: &mut Lexer<'a>, ctx: &mut ExpressionContext<'a, '_, '_>, block: &mut ast::Block<'a>, brace_nesting_level: u8, ) -> Result<'a, ()> { self.track_recursion(|this| { this.push_rule_span(Rule::Statement, lexer); // We peek here instead of eagerly getting the next token since // `Parser::block` expects its first token to be `{`. // // Most callers have a single path leading to the start of the block; // `statement` is the only exception where there are multiple choices. match lexer.peek() { (token, _) if is_start_of_compound_statement(token) => { let (inner, span) = this.block(lexer, ctx, brace_nesting_level)?; block.stmts.push(ast::Statement { kind: ast::StatementKind::Block(inner), span, }); this.pop_rule_span(lexer); return Ok(()); } _ => {} } let kind = match lexer.next() { (Token::Separator(';'), _) => { this.pop_rule_span(lexer); return Ok(()); } (Token::Word("return"), _) => { let value = if lexer.peek().0 != Token::Separator(';') { let handle = this.expression(lexer, ctx)?; Some(handle) } else { None }; lexer.expect(Token::Separator(';'))?; ast::StatementKind::Return { value } } (Token::Word("if"), _) => { let condition = this.expression(lexer, ctx)?; let accept = this.block(lexer, ctx, brace_nesting_level)?.0; let mut elsif_stack = Vec::new(); let mut elseif_span_start = lexer.start_byte_offset(); let mut reject = loop { if !lexer.next_if(Token::Word("else")) { break ast::Block::default(); } if !lexer.next_if(Token::Word("if")) { // ... else { ... } break this.block(lexer, ctx, brace_nesting_level)?.0; } // ... else if (...) { ... } let other_condition = this.expression(lexer, ctx)?; let other_block = this.block(lexer, ctx, brace_nesting_level)?; elsif_stack.push((elseif_span_start, other_condition, other_block)); elseif_span_start = lexer.start_byte_offset(); }; // reverse-fold the else-if blocks //Note: we may consider uplifting this to the IR for (other_span_start, other_cond, other_block) in elsif_stack.into_iter().rev() { let sub_stmt = ast::StatementKind::If { condition: other_cond, accept: other_block.0, reject, }; reject = ast::Block::default(); let span = lexer.span_from(other_span_start); reject.stmts.push(ast::Statement { kind: sub_stmt, span, }) } ast::StatementKind::If { condition, accept, reject, } } (Token::Word("switch"), _) => { let selector = this.expression(lexer, ctx)?; let brace_span = lexer.expect_span(Token::Paren('{'))?; let brace_nesting_level = Self::increase_brace_nesting(brace_nesting_level, brace_span)?; let mut cases = Vec::new(); loop { // cases + default match lexer.next() { (Token::Word("case"), _) => { // parse a list of values let value = loop { let value = this.switch_value(lexer, ctx)?; if lexer.next_if(Token::Separator(',')) { // list of values ends with ':' or a compound statement let next_token = lexer.peek().0; if next_token == Token::Separator(':') || is_start_of_compound_statement(next_token) { break value; } } else { break value; } cases.push(ast::SwitchCase { value, body: ast::Block::default(), fall_through: true, }); }; lexer.next_if(Token::Separator(':')); let body = this.block(lexer, ctx, brace_nesting_level)?.0; cases.push(ast::SwitchCase { value, body, fall_through: false, }); } (Token::Word("default"), _) => { lexer.next_if(Token::Separator(':')); let body = this.block(lexer, ctx, brace_nesting_level)?.0; cases.push(ast::SwitchCase { value: ast::SwitchValue::Default, body, fall_through: false, }); } (Token::Paren('}'), _) => break, (_, span) => { return Err(Box::new(Error::Unexpected( span, ExpectedToken::SwitchItem, ))) } } } ast::StatementKind::Switch { selector, cases } } (Token::Word("loop"), _) => this.r#loop(lexer, ctx, brace_nesting_level)?, (Token::Word("while"), _) => { let mut body = ast::Block::default(); let (condition, span) = lexer.capture_span(|lexer| this.expression(lexer, ctx))?; let mut reject = ast::Block::default(); reject.stmts.push(ast::Statement { kind: ast::StatementKind::Break, span, }); body.stmts.push(ast::Statement { kind: ast::StatementKind::If { condition, accept: ast::Block::default(), reject, }, span, }); let (block, span) = this.block(lexer, ctx, brace_nesting_level)?; body.stmts.push(ast::Statement { kind: ast::StatementKind::Block(block), span, }); ast::StatementKind::Loop { body, continuing: ast::Block::default(), break_if: None, } } (Token::Word("for"), _) => { lexer.expect(Token::Paren('('))?; ctx.local_table.push_scope(); if !lexer.next_if(Token::Separator(';')) { let token = lexer.next(); this.variable_or_value_or_func_call_or_variable_updating_statement( lexer, ctx, block, token, ExpectedToken::ForInit, )?; lexer.expect(Token::Separator(';'))?; }; let mut body = ast::Block::default(); if !lexer.next_if(Token::Separator(';')) { let (condition, span) = lexer.capture_span(|lexer| -> Result<'_, _> { let condition = this.expression(lexer, ctx)?; lexer.expect(Token::Separator(';'))?; Ok(condition) })?; let mut reject = ast::Block::default(); reject.stmts.push(ast::Statement { kind: ast::StatementKind::Break, span, }); body.stmts.push(ast::Statement { kind: ast::StatementKind::If { condition, accept: ast::Block::default(), reject, }, span, }); }; let mut continuing = ast::Block::default(); if !lexer.next_if(Token::Paren(')')) { let token = lexer.next(); this.func_call_or_variable_updating_statement( lexer, ctx, &mut continuing, token, ExpectedToken::ForUpdate, )?; lexer.expect(Token::Paren(')'))?; } let (block, span) = this.block(lexer, ctx, brace_nesting_level)?; body.stmts.push(ast::Statement { kind: ast::StatementKind::Block(block), span, }); ctx.local_table.pop_scope(); ast::StatementKind::Loop { body, continuing, break_if: None, } } (Token::Word("break"), span) => { // Check if the next token is an `if`, this indicates // that the user tried to type out a `break if` which // is illegal in this position. let (peeked_token, peeked_span) = lexer.peek(); if let Token::Word("if") = peeked_token { let span = span.until(&peeked_span); return Err(Box::new(Error::InvalidBreakIf(span))); } lexer.expect(Token::Separator(';'))?; ast::StatementKind::Break } (Token::Word("continue"), _) => { lexer.expect(Token::Separator(';'))?; ast::StatementKind::Continue } (Token::Word("discard"), _) => { lexer.expect(Token::Separator(';'))?; ast::StatementKind::Kill } // https://www.w3.org/TR/WGSL/#const-assert-statement (Token::Word("const_assert"), _) => { // parentheses are optional let paren = lexer.next_if(Token::Paren('(')); let condition = this.expression(lexer, ctx)?; if paren { lexer.expect(Token::Paren(')'))?; } lexer.expect(Token::Separator(';'))?; ast::StatementKind::ConstAssert(condition) } token => { this.variable_or_value_or_func_call_or_variable_updating_statement( lexer, ctx, block, token, ExpectedToken::Statement, )?; lexer.expect(Token::Separator(';'))?; this.pop_rule_span(lexer); return Ok(()); } }; let span = this.pop_rule_span(lexer); block.stmts.push(ast::Statement { kind, span }); Ok(()) }) } fn r#loop<'a>( &mut self, lexer: &mut Lexer<'a>, ctx: &mut ExpressionContext<'a, '_, '_>, brace_nesting_level: u8, ) -> Result<'a, ast::StatementKind<'a>> { let mut body = ast::Block::default(); let mut continuing = ast::Block::default(); let mut break_if = None; let brace_span = lexer.expect_span(Token::Paren('{'))?; let brace_nesting_level = Self::increase_brace_nesting(brace_nesting_level, brace_span)?; ctx.local_table.push_scope(); loop { if lexer.next_if(Token::Word("continuing")) { // Branch for the `continuing` block, this must be // the last thing in the loop body // Expect a opening brace to start the continuing block let brace_span = lexer.expect_span(Token::Paren('{'))?; let brace_nesting_level = Self::increase_brace_nesting(brace_nesting_level, brace_span)?; loop { if lexer.next_if(Token::Word("break")) { // Branch for the `break if` statement, this statement // has the form `break if ;` and must be the last // statement in a continuing block // The break must be followed by an `if` to form // the break if lexer.expect(Token::Word("if"))?; let condition = self.expression(lexer, ctx)?; // Set the condition of the break if to the newly parsed // expression break_if = Some(condition); // Expect a semicolon to close the statement lexer.expect(Token::Separator(';'))?; // Expect a closing brace to close the continuing block, // since the break if must be the last statement lexer.expect(Token::Paren('}'))?; // Stop parsing the continuing block break; } else if lexer.next_if(Token::Paren('}')) { // If we encounter a closing brace it means we have reached // the end of the continuing block and should stop processing break; } else { // Otherwise try to parse a statement self.statement(lexer, ctx, &mut continuing, brace_nesting_level)?; } } // Since the continuing block must be the last part of the loop body, // we expect to see a closing brace to end the loop body lexer.expect(Token::Paren('}'))?; break; } if lexer.next_if(Token::Paren('}')) { // If we encounter a closing brace it means we have reached // the end of the loop body and should stop processing break; } // Otherwise try to parse a statement self.statement(lexer, ctx, &mut body, brace_nesting_level)?; } ctx.local_table.pop_scope(); Ok(ast::StatementKind::Loop { body, continuing, break_if, }) } /// compound_statement fn block<'a>( &mut self, lexer: &mut Lexer<'a>, ctx: &mut ExpressionContext<'a, '_, '_>, brace_nesting_level: u8, ) -> Result<'a, (ast::Block<'a>, Span)> { self.push_rule_span(Rule::Block, lexer); ctx.local_table.push_scope(); let mut diagnostic_filters = DiagnosticFilterMap::new(); self.push_rule_span(Rule::Attribute, lexer); while lexer.next_if(Token::Attribute) { let (name, name_span) = lexer.next_ident_with_span()?; if let Some(DirectiveKind::Diagnostic) = DirectiveKind::from_ident(name) { let filter = self.diagnostic_filter(lexer)?; let span = self.peek_rule_span(lexer); diagnostic_filters .add(filter, span, ShouldConflictOnFullDuplicate::Yes) .map_err(|e| Box::new(e.into()))?; } else { return Err(Box::new(Error::Unexpected( name_span, ExpectedToken::DiagnosticAttribute, ))); } } self.pop_rule_span(lexer); if !diagnostic_filters.is_empty() { return Err(Box::new( Error::DiagnosticAttributeNotYetImplementedAtParseSite { site_name_plural: "compound statements", spans: diagnostic_filters.spans().collect(), }, )); } let brace_span = lexer.expect_span(Token::Paren('{'))?; let brace_nesting_level = Self::increase_brace_nesting(brace_nesting_level, brace_span)?; let mut block = ast::Block::default(); while !lexer.next_if(Token::Paren('}')) { self.statement(lexer, ctx, &mut block, brace_nesting_level)?; } ctx.local_table.pop_scope(); let span = self.pop_rule_span(lexer); Ok((block, span)) } fn varying_binding<'a>( &mut self, lexer: &mut Lexer<'a>, ctx: &mut ExpressionContext<'a, '_, '_>, ) -> Result<'a, Option>> { let mut bind_parser = BindingParser::default(); self.push_rule_span(Rule::Attribute, lexer); while lexer.next_if(Token::Attribute) { let (word, span) = lexer.next_ident_with_span()?; bind_parser.parse(self, lexer, word, span, ctx)?; } let span = self.pop_rule_span(lexer); bind_parser.finish(span) } fn function_decl<'a>( &mut self, lexer: &mut Lexer<'a>, diagnostic_filter_leaf: Option>, must_use: Option, out: &mut ast::TranslationUnit<'a>, dependencies: &mut FastIndexSet>, ) -> Result<'a, ast::Function<'a>> { self.push_rule_span(Rule::FunctionDecl, lexer); // read function name let fun_name = lexer.next_ident()?; let mut locals = Arena::new(); let mut ctx = ExpressionContext { expressions: &mut out.expressions, local_table: &mut SymbolTable::default(), locals: &mut locals, unresolved: dependencies, }; // start a scope that contains arguments as well as the function body ctx.local_table.push_scope(); // Reduce lookup scope to parse the parameter list and return type // avoiding identifier lookup to match newly declared param names. ctx.local_table.reduce_lookup_scope(); // read parameter list let mut arguments = Vec::new(); lexer.expect(Token::Paren('('))?; let mut ready = true; while !lexer.next_if(Token::Paren(')')) { if !ready { return Err(Box::new(Error::Unexpected( lexer.next().1, ExpectedToken::Token(Token::Separator(',')), ))); } let binding = self.varying_binding(lexer, &mut ctx)?; let param_name = lexer.next_ident()?; lexer.expect(Token::Separator(':'))?; let param_type = self.type_specifier(lexer, &mut ctx)?; let handle = ctx.declare_local(param_name)?; arguments.push(ast::FunctionArgument { name: param_name, ty: param_type, binding, handle, }); ready = lexer.next_if(Token::Separator(',')); } // read return type let result = if lexer.next_if(Token::Arrow) { let binding = self.varying_binding(lexer, &mut ctx)?; let ty = self.type_specifier(lexer, &mut ctx)?; let must_use = must_use.is_some(); Some(ast::FunctionResult { ty, binding, must_use, }) } else if let Some(must_use) = must_use { return Err(Box::new(Error::FunctionMustUseReturnsVoid( must_use, self.peek_rule_span(lexer), ))); } else { None }; ctx.local_table.reset_lookup_scope(); // do not use `self.block` here, since we must not push a new scope lexer.expect(Token::Paren('{'))?; let brace_nesting_level = 1; let mut body = ast::Block::default(); while !lexer.next_if(Token::Paren('}')) { self.statement(lexer, &mut ctx, &mut body, brace_nesting_level)?; } ctx.local_table.pop_scope(); let fun = ast::Function { entry_point: None, name: fun_name, arguments, result, body, diagnostic_filter_leaf, doc_comments: Vec::new(), }; // done self.pop_rule_span(lexer); Ok(fun) } fn directive_ident_list<'a>( &self, lexer: &mut Lexer<'a>, handler: impl FnMut(&'a str, Span) -> Result<'a, ()>, ) -> Result<'a, ()> { let mut handler = handler; 'next_arg: loop { let (ident, span) = lexer.next_ident_with_span()?; handler(ident, span)?; let expected_token = match lexer.peek().0 { Token::Separator(',') => { let _ = lexer.next(); if matches!(lexer.peek().0, Token::Word(..)) { continue 'next_arg; } ExpectedToken::AfterIdentListComma } _ => ExpectedToken::AfterIdentListArg, }; if !matches!(lexer.next().0, Token::Separator(';')) { return Err(Box::new(Error::Unexpected(span, expected_token))); } break Ok(()); } } fn global_decl<'a>( &mut self, lexer: &mut Lexer<'a>, out: &mut ast::TranslationUnit<'a>, ) -> Result<'a, ()> { let doc_comments = lexer.accumulate_doc_comments(); // read attributes let mut binding = None; let mut stage = ParsedAttribute::default(); // Span in case we need to report an error for a shader stage missing something (e.g. its workgroup size). // Doesn't need to be set in the vertex and fragment stages because they don't have errors like that. let mut shader_stage_error_span = Span::new(0, 0); let mut workgroup_size = ParsedAttribute::default(); let mut early_depth_test = ParsedAttribute::default(); let (mut bind_index, mut bind_group) = (ParsedAttribute::default(), ParsedAttribute::default()); let mut id = ParsedAttribute::default(); // the payload variable for a mesh shader let mut payload = ParsedAttribute::default(); // the incoming payload from a traceRay call let mut incoming_payload = ParsedAttribute::default(); let mut mesh_output = ParsedAttribute::default(); let mut must_use: ParsedAttribute = ParsedAttribute::default(); let mut memory_decorations = crate::MemoryDecorations::empty(); let mut dependencies = FastIndexSet::default(); let mut ctx = ExpressionContext { expressions: &mut out.expressions, local_table: &mut SymbolTable::default(), locals: &mut Arena::new(), unresolved: &mut dependencies, }; let mut diagnostic_filters = DiagnosticFilterMap::new(); let ensure_no_diag_attrs = |on_what, filters: DiagnosticFilterMap| -> Result<()> { if filters.is_empty() { Ok(()) } else { Err(Box::new(Error::DiagnosticAttributeNotSupported { on_what, spans: filters.spans().collect(), })) } }; self.push_rule_span(Rule::Attribute, lexer); while lexer.next_if(Token::Attribute) { let (name, name_span) = lexer.next_ident_with_span()?; if let Some(DirectiveKind::Diagnostic) = DirectiveKind::from_ident(name) { let filter = self.diagnostic_filter(lexer)?; let span = self.peek_rule_span(lexer); diagnostic_filters .add(filter, span, ShouldConflictOnFullDuplicate::Yes) .map_err(|e| Box::new(e.into()))?; continue; } match name { "binding" => { lexer.expect(Token::Paren('('))?; bind_index.set(self.expression(lexer, &mut ctx)?, name_span)?; lexer.next_if(Token::Separator(',')); lexer.expect(Token::Paren(')'))?; } "group" => { lexer.expect(Token::Paren('('))?; bind_group.set(self.expression(lexer, &mut ctx)?, name_span)?; lexer.next_if(Token::Separator(',')); lexer.expect(Token::Paren(')'))?; } "id" => { lexer.expect(Token::Paren('('))?; id.set(self.expression(lexer, &mut ctx)?, name_span)?; lexer.next_if(Token::Separator(',')); lexer.expect(Token::Paren(')'))?; } "vertex" => { stage.set(ShaderStage::Vertex, name_span)?; } "fragment" => { stage.set(ShaderStage::Fragment, name_span)?; } "compute" => { stage.set(ShaderStage::Compute, name_span)?; shader_stage_error_span = name_span; } "task" => { lexer.require_enable_extension( ImplementedEnableExtension::WgpuMeshShader, name_span, )?; stage.set(ShaderStage::Task, name_span)?; shader_stage_error_span = name_span; } "mesh" => { lexer.require_enable_extension( ImplementedEnableExtension::WgpuMeshShader, name_span, )?; stage.set(ShaderStage::Mesh, name_span)?; shader_stage_error_span = name_span; lexer.expect(Token::Paren('('))?; mesh_output.set(lexer.next_ident_with_span()?, name_span)?; lexer.expect(Token::Paren(')'))?; } "ray_generation" => { lexer.require_enable_extension( ImplementedEnableExtension::WgpuRayTracingPipeline, name_span, )?; stage.set(ShaderStage::RayGeneration, name_span)?; shader_stage_error_span = name_span; } "any_hit" => { lexer.require_enable_extension( ImplementedEnableExtension::WgpuRayTracingPipeline, name_span, )?; stage.set(ShaderStage::AnyHit, name_span)?; shader_stage_error_span = name_span; } "closest_hit" => { lexer.require_enable_extension( ImplementedEnableExtension::WgpuRayTracingPipeline, name_span, )?; stage.set(ShaderStage::ClosestHit, name_span)?; shader_stage_error_span = name_span; } "miss" => { lexer.require_enable_extension( ImplementedEnableExtension::WgpuRayTracingPipeline, name_span, )?; stage.set(ShaderStage::Miss, name_span)?; shader_stage_error_span = name_span; } "payload" => { lexer.require_enable_extension( ImplementedEnableExtension::WgpuMeshShader, name_span, )?; lexer.expect(Token::Paren('('))?; payload.set(lexer.next_ident_with_span()?, name_span)?; lexer.expect(Token::Paren(')'))?; } "incoming_payload" => { lexer.require_enable_extension( ImplementedEnableExtension::WgpuRayTracingPipeline, name_span, )?; lexer.expect(Token::Paren('('))?; incoming_payload.set(lexer.next_ident_with_span()?, name_span)?; lexer.expect(Token::Paren(')'))?; } "workgroup_size" => { lexer.expect(Token::Paren('('))?; let mut new_workgroup_size = [None; 3]; for size in new_workgroup_size.iter_mut() { *size = Some(self.expression(lexer, &mut ctx)?); match lexer.next() { (Token::Paren(')'), _) => break, (Token::Separator(','), _) => { if lexer.next_if(Token::Paren(')')) { break; } } other => { return Err(Box::new(Error::Unexpected( other.1, ExpectedToken::WorkgroupSizeSeparator, ))) } } } workgroup_size.set(new_workgroup_size, name_span)?; } "early_depth_test" => { lexer.expect(Token::Paren('('))?; let (ident, ident_span) = lexer.next_ident_with_span()?; let value = if ident == "force" { crate::EarlyDepthTest::Force } else { crate::EarlyDepthTest::Allow { conservative: conv::map_conservative_depth(ident, ident_span)?, } }; lexer.expect(Token::Paren(')'))?; early_depth_test.set(value, name_span)?; } "must_use" => { must_use.set(name_span, name_span)?; } "coherent" => { memory_decorations |= crate::MemoryDecorations::COHERENT; } "volatile" => { memory_decorations |= crate::MemoryDecorations::VOLATILE; } _ => return Err(Box::new(Error::UnknownAttribute(name_span))), } } let attrib_span = self.pop_rule_span(lexer); match (bind_group.value, bind_index.value) { (Some(group), Some(index)) => { binding = Some(ast::ResourceBinding { group, binding: index, }); } (Some(_), None) => { return Err(Box::new(Error::MissingAttribute("binding", attrib_span))) } (None, Some(_)) => return Err(Box::new(Error::MissingAttribute("group", attrib_span))), (None, None) => {} } // read item let start = lexer.start_byte_offset(); let kind = match lexer.next() { (Token::Separator(';'), _) => { ensure_no_diag_attrs( DiagnosticAttributeNotSupportedPosition::SemicolonInModulePosition, diagnostic_filters, )?; None } (Token::Word(word), directive_span) if DirectiveKind::from_ident(word).is_some() => { return Err(Box::new(Error::DirectiveAfterFirstGlobalDecl { directive_span, })); } (Token::Word("struct"), _) => { ensure_no_diag_attrs("`struct`s".into(), diagnostic_filters)?; let name = lexer.next_ident()?; let members = self.struct_body(lexer, &mut ctx)?; Some(ast::GlobalDeclKind::Struct(ast::Struct { name, members, doc_comments, })) } (Token::Word("alias"), _) => { ensure_no_diag_attrs("`alias`es".into(), diagnostic_filters)?; let name = lexer.next_ident()?; lexer.expect(Token::Operation('='))?; let ty = self.type_specifier(lexer, &mut ctx)?; lexer.expect(Token::Separator(';'))?; Some(ast::GlobalDeclKind::Type(ast::TypeAlias { name, ty })) } (Token::Word("const"), _) => { ensure_no_diag_attrs("`const`s".into(), diagnostic_filters)?; let (name, ty) = self.optionally_typed_ident(lexer, &mut ctx)?; lexer.expect(Token::Operation('='))?; let init = self.expression(lexer, &mut ctx)?; lexer.expect(Token::Separator(';'))?; Some(ast::GlobalDeclKind::Const(ast::Const { name, ty, init, doc_comments, })) } (Token::Word("override"), _) => { ensure_no_diag_attrs("`override`s".into(), diagnostic_filters)?; let (name, ty) = self.optionally_typed_ident(lexer, &mut ctx)?; let init = if lexer.next_if(Token::Operation('=')) { Some(self.expression(lexer, &mut ctx)?) } else { None }; lexer.expect(Token::Separator(';'))?; Some(ast::GlobalDeclKind::Override(ast::Override { name, id: id.value, ty, init, })) } (Token::Word("var"), _) => { ensure_no_diag_attrs("`var`s".into(), diagnostic_filters)?; let mut var = self.variable_decl(lexer, &mut ctx)?; var.binding = binding.take(); var.doc_comments = doc_comments; var.memory_decorations = memory_decorations; Some(ast::GlobalDeclKind::Var(var)) } (Token::Word("fn"), _) => { let diagnostic_filter_leaf = Self::write_diagnostic_filters( &mut out.diagnostic_filters, diagnostic_filters, out.diagnostic_filter_leaf, ); let function = self.function_decl( lexer, diagnostic_filter_leaf, must_use.value, out, &mut dependencies, )?; Some(ast::GlobalDeclKind::Fn(ast::Function { entry_point: if let Some(stage) = stage.value { if stage.compute_like() && workgroup_size.value.is_none() { return Err(Box::new(Error::MissingWorkgroupSize( shader_stage_error_span, ))); } match stage { ShaderStage::AnyHit | ShaderStage::ClosestHit | ShaderStage::Miss => { if incoming_payload.value.is_none() { return Err(Box::new(Error::MissingIncomingPayload( shader_stage_error_span, ))); } } _ => {} } Some(ast::EntryPoint { stage, early_depth_test: early_depth_test.value, workgroup_size: workgroup_size.value, mesh_output_variable: mesh_output.value, task_payload: payload.value, ray_incoming_payload: incoming_payload.value, }) } else { None }, doc_comments, ..function })) } (Token::Word("const_assert"), _) => { ensure_no_diag_attrs("`const_assert`s".into(), diagnostic_filters)?; // parentheses are optional let paren = lexer.next_if(Token::Paren('(')); let condition = self.expression(lexer, &mut ctx)?; if paren { lexer.expect(Token::Paren(')'))?; } lexer.expect(Token::Separator(';'))?; Some(ast::GlobalDeclKind::ConstAssert(condition)) } (Token::End, _) => return Ok(()), other => { return Err(Box::new(Error::Unexpected( other.1, ExpectedToken::GlobalItem, ))) } }; if let Some(kind) = kind { out.decls.append( ast::GlobalDecl { kind, dependencies }, lexer.span_from(start), ); } if !self.rules.is_empty() { log::error!("Reached the end of global decl, but rule stack is not empty"); log::error!("Rules: {:?}", self.rules); return Err(Box::new(Error::Internal("rule stack is not empty"))); }; match binding { None => Ok(()), Some(_) => Err(Box::new(Error::Internal( "we had the attribute but no var?", ))), } } pub fn parse<'a>( &mut self, source: &'a str, options: &Options, ) -> Result<'a, ast::TranslationUnit<'a>> { self.reset(); let mut lexer = Lexer::new(source, !options.parse_doc_comments); let mut tu = ast::TranslationUnit::default(); let mut enable_extensions = EnableExtensions::empty(); let mut diagnostic_filters = DiagnosticFilterMap::new(); // Parse module doc comments. tu.doc_comments = lexer.accumulate_module_doc_comments(); // Parse directives. while let (Token::Word(word), _) = lexer.peek() { if let Some(kind) = DirectiveKind::from_ident(word) { self.push_rule_span(Rule::Directive, &mut lexer); let _ = lexer.next_ident_with_span().unwrap(); match kind { DirectiveKind::Diagnostic => { let diagnostic_filter = self.diagnostic_filter(&mut lexer)?; let span = self.peek_rule_span(&lexer); diagnostic_filters .add(diagnostic_filter, span, ShouldConflictOnFullDuplicate::No) .map_err(|e| Box::new(e.into()))?; lexer.expect(Token::Separator(';'))?; } DirectiveKind::Enable => { self.directive_ident_list(&mut lexer, |ident, span| { let kind = EnableExtension::from_ident(ident, span)?; let extension = match kind { EnableExtension::Implemented(kind) => kind, EnableExtension::Unimplemented(kind) => { return Err(Box::new(Error::EnableExtensionNotYetImplemented { kind, span, })) } }; // Check if the required capability is supported let required_capability = extension.capability(); if !options.capabilities.contains(required_capability) { return Err(Box::new(Error::EnableExtensionNotSupported { kind, span, })); } enable_extensions.add(extension); Ok(()) })?; } DirectiveKind::Requires => { self.directive_ident_list(&mut lexer, |ident, span| { match LanguageExtension::from_ident(ident) { Some(LanguageExtension::Implemented(_kind)) => { // NOTE: No further validation is needed for an extension, so // just throw parsed information away. If we ever want to apply // what we've parsed to diagnostics, maybe we'll want to refer // to enabled extensions later? Ok(()) } Some(LanguageExtension::Unimplemented(kind)) => { Err(Box::new(Error::LanguageExtensionNotYetImplemented { kind, span, })) } None => Err(Box::new(Error::UnknownLanguageExtension(span, ident))), } })?; } } self.pop_rule_span(&lexer); } else { break; } } lexer.enable_extensions = enable_extensions; tu.enable_extensions = enable_extensions; tu.diagnostic_filter_leaf = Self::write_diagnostic_filters(&mut tu.diagnostic_filters, diagnostic_filters, None); loop { match self.global_decl(&mut lexer, &mut tu) { Err(error) => return Err(error), Ok(()) => { if lexer.peek().0 == Token::End { break; } } } } Ok(tu) } fn increase_brace_nesting(brace_nesting_level: u8, brace_span: Span) -> Result<'static, u8> { // From [spec.](https://gpuweb.github.io/gpuweb/wgsl/#limits): // // > § 2.4. Limits // > // > … // > // > Maximum nesting depth of brace-enclosed statements in a function[:] 127 const BRACE_NESTING_MAXIMUM: u8 = 127; if brace_nesting_level + 1 > BRACE_NESTING_MAXIMUM { return Err(Box::new(Error::ExceededLimitForNestedBraces { span: brace_span, limit: BRACE_NESTING_MAXIMUM, })); } Ok(brace_nesting_level + 1) } fn diagnostic_filter<'a>(&self, lexer: &mut Lexer<'a>) -> Result<'a, DiagnosticFilter> { lexer.expect(Token::Paren('('))?; let (severity_control_name, severity_control_name_span) = lexer.next_ident_with_span()?; let new_severity = diagnostic_filter::Severity::from_wgsl_ident(severity_control_name) .ok_or(Error::DiagnosticInvalidSeverity { severity_control_name_span, })?; lexer.expect(Token::Separator(','))?; let (diagnostic_name_token, diagnostic_name_token_span) = lexer.next_ident_with_span()?; let triggering_rule = if lexer.next_if(Token::Separator('.')) { let (ident, _span) = lexer.next_ident_with_span()?; FilterableTriggeringRule::User(Box::new([diagnostic_name_token.into(), ident.into()])) } else { let diagnostic_rule_name = diagnostic_name_token; let diagnostic_rule_name_span = diagnostic_name_token_span; if let Some(triggering_rule) = StandardFilterableTriggeringRule::from_wgsl_ident(diagnostic_rule_name) { FilterableTriggeringRule::Standard(triggering_rule) } else { diagnostic_filter::Severity::Warning.report_wgsl_parse_diag( Box::new(Error::UnknownDiagnosticRuleName(diagnostic_rule_name_span)), lexer.source, )?; FilterableTriggeringRule::Unknown(diagnostic_rule_name.into()) } }; let filter = DiagnosticFilter { triggering_rule, new_severity, }; lexer.next_if(Token::Separator(',')); lexer.expect(Token::Paren(')'))?; Ok(filter) } pub(crate) fn write_diagnostic_filters( arena: &mut Arena, filters: DiagnosticFilterMap, parent: Option>, ) -> Option> { filters .into_iter() .fold(parent, |parent, (triggering_rule, (new_severity, span))| { Some(arena.append( DiagnosticFilterNode { inner: DiagnosticFilter { new_severity, triggering_rule, }, parent, }, span, )) }) } } const fn is_start_of_compound_statement<'a>(token: Token<'a>) -> bool { matches!(token, Token::Attribute | Token::Paren('{')) } naga-29.0.3/src/front/wgsl/parse/number.rs000064400000000000000000000374611046102023000165030ustar 00000000000000use alloc::format; use crate::front::wgsl::error::NumberError; use crate::front::wgsl::parse::directive::enable_extension::ImplementedEnableExtension; use crate::front::wgsl::parse::lexer::Token; use half::f16; /// When using this type assume no Abstract Int/Float for now #[derive(Copy, Clone, Debug, PartialEq)] pub enum Number { /// Abstract Int (-2^63 ≤ i < 2^63) AbstractInt(i64), /// Abstract Float (IEEE-754 binary64) AbstractFloat(f64), /// Concrete i32 I32(i32), /// Concrete u32 U32(u32), /// Concrete i64 I64(i64), /// Concrete u64 U64(u64), /// Concrete f16 F16(f16), /// Concrete f32 F32(f32), /// Concrete f64 F64(f64), } impl Number { pub(super) const fn requires_enable_extension(&self) -> Option { match *self { Number::F16(_) => Some(ImplementedEnableExtension::F16), _ => None, } } } pub(in crate::front::wgsl) fn consume_number(input: &str) -> (Token<'_>, &str) { let (result, rest) = parse(input); (Token::Number(result), rest) } enum Kind { Int(IntKind), Float(FloatKind), } enum IntKind { I32, U32, I64, U64, } #[derive(Debug)] enum FloatKind { F16, F32, F64, } // The following regexes (from the WGSL spec) will be matched: // int_literal: // | / 0 [iu]? / // | / [1-9][0-9]* [iu]? / // | / 0[xX][0-9a-fA-F]+ [iu]? / // decimal_float_literal: // | / 0 [fh] / // | / [1-9][0-9]* [fh] / // | / [0-9]* \.[0-9]+ ([eE][+-]?[0-9]+)? [fh]? / // | / [0-9]+ \.[0-9]* ([eE][+-]?[0-9]+)? [fh]? / // | / [0-9]+ [eE][+-]?[0-9]+ [fh]? / // hex_float_literal: // | / 0[xX][0-9a-fA-F]* \.[0-9a-fA-F]+ ([pP][+-]?[0-9]+ [fh]?)? / // | / 0[xX][0-9a-fA-F]+ \.[0-9a-fA-F]* ([pP][+-]?[0-9]+ [fh]?)? / // | / 0[xX][0-9a-fA-F]+ [pP][+-]?[0-9]+ [fh]? / // You could visualize the regex below via https://debuggex.com to get a rough idea what `parse` is doing // (?:0[xX](?:([0-9a-fA-F]+\.[0-9a-fA-F]*|[0-9a-fA-F]*\.[0-9a-fA-F]+)(?:([pP][+-]?[0-9]+)([fh]?))?|([0-9a-fA-F]+)([pP][+-]?[0-9]+)([fh]?)|([0-9a-fA-F]+)([iu]?))|((?:[0-9]+[eE][+-]?[0-9]+|(?:[0-9]+\.[0-9]*|[0-9]*\.[0-9]+)(?:[eE][+-]?[0-9]+)?))([fh]?)|((?:[0-9]|[1-9][0-9]+))([iufh]?)) // Leading signs are handled as unary operators. fn parse(input: &str) -> (Result, &str) { /// returns `true` and consumes `X` bytes from the given byte buffer /// if the given `X` nr of patterns are found at the start of the buffer macro_rules! consume { ($bytes:ident, $($pattern:pat),*) => { match $bytes { &[$($pattern),*, ref rest @ ..] => { $bytes = rest; true }, _ => false, } }; } /// consumes one byte from the given byte buffer /// if one of the given patterns are found at the start of the buffer /// returning the corresponding expr for the matched pattern macro_rules! consume_map { ($bytes:ident, [$( $($pattern:pat_param),* => $to:expr),* $(,)?]) => { match $bytes { $( &[ $($pattern),*, ref rest @ ..] => { $bytes = rest; Some($to) }, )* _ => None, } }; } /// consumes all consecutive bytes matched by the `0-9` pattern from the given byte buffer /// returning the number of consumed bytes macro_rules! consume_dec_digits { ($bytes:ident) => {{ let start_len = $bytes.len(); while let &[b'0'..=b'9', ref rest @ ..] = $bytes { $bytes = rest; } start_len - $bytes.len() }}; } /// consumes all consecutive bytes matched by the `0-9 | a-f | A-F` pattern from the given byte buffer /// returning the number of consumed bytes macro_rules! consume_hex_digits { ($bytes:ident) => {{ let start_len = $bytes.len(); while let &[b'0'..=b'9' | b'a'..=b'f' | b'A'..=b'F', ref rest @ ..] = $bytes { $bytes = rest; } start_len - $bytes.len() }}; } macro_rules! consume_float_suffix { ($bytes:ident) => { consume_map!($bytes, [ b'h' => FloatKind::F16, b'f' => FloatKind::F32, b'l', b'f' => FloatKind::F64, ]) }; } /// maps the given `&[u8]` (tail of the initial `input: &str`) to a `&str` macro_rules! rest_to_str { ($bytes:ident) => { &input[input.len() - $bytes.len()..] }; } struct ExtractSubStr<'a>(&'a str); impl<'a> ExtractSubStr<'a> { /// given an `input` and a `start` (tail of the `input`) /// creates a new [`ExtractSubStr`](`Self`) fn start(input: &'a str, start: &'a [u8]) -> Self { let start = input.len() - start.len(); Self(&input[start..]) } /// given an `end` (tail of the initial `input`) /// returns a substring of `input` fn end(&self, end: &'a [u8]) -> &'a str { let end = self.0.len() - end.len(); &self.0[..end] } } let mut bytes = input.as_bytes(); let general_extract = ExtractSubStr::start(input, bytes); if consume!(bytes, b'0', b'x' | b'X') { let digits_extract = ExtractSubStr::start(input, bytes); let consumed = consume_hex_digits!(bytes); if consume!(bytes, b'.') { let consumed_after_period = consume_hex_digits!(bytes); if consumed + consumed_after_period == 0 { return (Err(NumberError::Invalid), rest_to_str!(bytes)); } let significand = general_extract.end(bytes); if consume!(bytes, b'p' | b'P') { consume!(bytes, b'+' | b'-'); let consumed = consume_dec_digits!(bytes); if consumed == 0 { return (Err(NumberError::Invalid), rest_to_str!(bytes)); } let number = general_extract.end(bytes); let kind = consume_float_suffix!(bytes); (parse_hex_float(number, kind), rest_to_str!(bytes)) } else { ( parse_hex_float_missing_exponent(significand, None), rest_to_str!(bytes), ) } } else { if consumed == 0 { return (Err(NumberError::Invalid), rest_to_str!(bytes)); } let significand = general_extract.end(bytes); let digits = digits_extract.end(bytes); let exp_extract = ExtractSubStr::start(input, bytes); if consume!(bytes, b'p' | b'P') { consume!(bytes, b'+' | b'-'); let consumed = consume_dec_digits!(bytes); if consumed == 0 { return (Err(NumberError::Invalid), rest_to_str!(bytes)); } let exponent = exp_extract.end(bytes); let kind = consume_float_suffix!(bytes); ( parse_hex_float_missing_period(significand, exponent, kind), rest_to_str!(bytes), ) } else { let kind = consume_map!(bytes, [ b'i' => IntKind::I32, b'u' => IntKind::U32, b'l', b'i' => IntKind::I64, b'l', b'u' => IntKind::U64, ]); (parse_hex_int(digits, kind), rest_to_str!(bytes)) } } } else { let is_first_zero = bytes.first() == Some(&b'0'); let consumed = consume_dec_digits!(bytes); if consume!(bytes, b'.') { let consumed_after_period = consume_dec_digits!(bytes); if consumed + consumed_after_period == 0 { return (Err(NumberError::Invalid), rest_to_str!(bytes)); } if consume!(bytes, b'e' | b'E') { consume!(bytes, b'+' | b'-'); let consumed = consume_dec_digits!(bytes); if consumed == 0 { return (Err(NumberError::Invalid), rest_to_str!(bytes)); } } let number = general_extract.end(bytes); let kind = consume_float_suffix!(bytes); (parse_dec_float(number, kind), rest_to_str!(bytes)) } else { if consumed == 0 { return (Err(NumberError::Invalid), rest_to_str!(bytes)); } if consume!(bytes, b'e' | b'E') { consume!(bytes, b'+' | b'-'); let consumed = consume_dec_digits!(bytes); if consumed == 0 { return (Err(NumberError::Invalid), rest_to_str!(bytes)); } let number = general_extract.end(bytes); let kind = consume_float_suffix!(bytes); (parse_dec_float(number, kind), rest_to_str!(bytes)) } else { // make sure the multi-digit numbers don't start with zero if consumed > 1 && is_first_zero { return (Err(NumberError::Invalid), rest_to_str!(bytes)); } let digits = general_extract.end(bytes); let kind = consume_map!(bytes, [ b'i' => Kind::Int(IntKind::I32), b'u' => Kind::Int(IntKind::U32), b'l', b'i' => Kind::Int(IntKind::I64), b'l', b'u' => Kind::Int(IntKind::U64), b'h' => Kind::Float(FloatKind::F16), b'f' => Kind::Float(FloatKind::F32), b'l', b'f' => Kind::Float(FloatKind::F64), ]); (parse_dec(digits, kind), rest_to_str!(bytes)) } } } } fn parse_hex_float_missing_exponent( // format: 0[xX] ( [0-9a-fA-F]+\.[0-9a-fA-F]* | [0-9a-fA-F]*\.[0-9a-fA-F]+ ) significand: &str, kind: Option, ) -> Result { let hexf_input = format!("{}{}", significand, "p0"); parse_hex_float(&hexf_input, kind) } fn parse_hex_float_missing_period( // format: 0[xX] [0-9a-fA-F]+ significand: &str, // format: [pP][+-]?[0-9]+ exponent: &str, kind: Option, ) -> Result { let hexf_input = format!("{significand}.{exponent}"); parse_hex_float(&hexf_input, kind) } fn parse_hex_int( // format: [0-9a-fA-F]+ digits: &str, kind: Option, ) -> Result { parse_int(digits, kind, 16) } fn parse_dec( // format: ( [0-9] | [1-9][0-9]+ ) digits: &str, kind: Option, ) -> Result { match kind { None => parse_int(digits, None, 10), Some(Kind::Int(kind)) => parse_int(digits, Some(kind), 10), Some(Kind::Float(kind)) => parse_dec_float(digits, Some(kind)), } } // Float parsing notes // The following chapters of IEEE 754-2019 are relevant: // // 7.4 Overflow (largest finite number is exceeded by what would have been // the rounded floating-point result were the exponent range unbounded) // // 7.5 Underflow (tiny non-zero result is detected; // for decimal formats tininess is detected before rounding when a non-zero result // computed as though both the exponent range and the precision were unbounded // would lie strictly between 2^−126) // // 7.6 Inexact (rounded result differs from what would have been computed // were both exponent range and precision unbounded) // The WGSL spec requires us to error: // on overflow for decimal floating point literals // on overflow and inexact for hexadecimal floating point literals // (underflow is not mentioned) // hexf_parse errors on overflow, underflow, inexact // rust std lib float from str handles overflow, underflow, inexact transparently (rounds and will not error) // Therefore we only check for overflow manually for decimal floating point literals // input format: 0[xX] ( [0-9a-fA-F]+\.[0-9a-fA-F]* | [0-9a-fA-F]*\.[0-9a-fA-F]+ ) [pP][+-]?[0-9]+ fn parse_hex_float(input: &str, kind: Option) -> Result { match kind { None => match hexf_parse::parse_hexf64(input, false) { Ok(num) => Ok(Number::AbstractFloat(num)), // can only be ParseHexfErrorKind::Inexact but we can't check since it's private _ => Err(NumberError::NotRepresentable), }, // TODO: f16 is not supported by hexf_parse Some(FloatKind::F16) => Err(NumberError::NotRepresentable), Some(FloatKind::F32) => match hexf_parse::parse_hexf32(input, false) { Ok(num) => Ok(Number::F32(num)), // can only be ParseHexfErrorKind::Inexact but we can't check since it's private _ => Err(NumberError::NotRepresentable), }, Some(FloatKind::F64) => match hexf_parse::parse_hexf64(input, false) { Ok(num) => Ok(Number::F64(num)), // can only be ParseHexfErrorKind::Inexact but we can't check since it's private _ => Err(NumberError::NotRepresentable), }, } } // input format: ( [0-9]+\.[0-9]* | [0-9]*\.[0-9]+ ) ([eE][+-]?[0-9]+)? // | [0-9]+ [eE][+-]?[0-9]+ fn parse_dec_float(input: &str, kind: Option) -> Result { match kind { None => { let num = input.parse::().unwrap(); // will never fail num.is_finite() .then_some(Number::AbstractFloat(num)) .ok_or(NumberError::NotRepresentable) } Some(FloatKind::F32) => { let num = input.parse::().unwrap(); // will never fail num.is_finite() .then_some(Number::F32(num)) .ok_or(NumberError::NotRepresentable) } Some(FloatKind::F64) => { let num = input.parse::().unwrap(); // will never fail num.is_finite() .then_some(Number::F64(num)) .ok_or(NumberError::NotRepresentable) } Some(FloatKind::F16) => { let num = input.parse::().unwrap(); // will never fail num.is_finite() .then_some(Number::F16(num)) .ok_or(NumberError::NotRepresentable) } } } fn parse_int(input: &str, kind: Option, radix: u32) -> Result { fn map_err(e: core::num::ParseIntError) -> NumberError { match *e.kind() { core::num::IntErrorKind::PosOverflow | core::num::IntErrorKind::NegOverflow => { NumberError::NotRepresentable } _ => unreachable!(), } } match kind { None => match i64::from_str_radix(input, radix) { Ok(num) => Ok(Number::AbstractInt(num)), Err(e) => Err(map_err(e)), }, Some(IntKind::I32) => match i32::from_str_radix(input, radix) { Ok(num) => Ok(Number::I32(num)), Err(e) => Err(map_err(e)), }, Some(IntKind::U32) => match u32::from_str_radix(input, radix) { Ok(num) => Ok(Number::U32(num)), Err(e) => Err(map_err(e)), }, Some(IntKind::I64) => match i64::from_str_radix(input, radix) { Ok(num) => Ok(Number::I64(num)), Err(e) => Err(map_err(e)), }, Some(IntKind::U64) => match u64::from_str_radix(input, radix) { Ok(num) => Ok(Number::U64(num)), Err(e) => Err(map_err(e)), }, } } naga-29.0.3/src/front/wgsl/tests.rs000064400000000000000000000572011046102023000152350ustar 00000000000000use alloc::format; use super::parse_str; #[test] fn parse_comment() { parse_str( "// //// ///////////////////////////////////////////////////////// asda //////////////////// dad ////////// / ///////////////////////////////////////////////////////////////////////////////////////////////////// // ", ) .unwrap(); } #[test] fn parse_types() { parse_str("const a : i32 = 2;").unwrap(); parse_str("const a : u64 = 2lu;").unwrap(); assert!(parse_str("const a : x32 = 2;").is_err()); parse_str("var t: texture_2d;").unwrap(); parse_str("var t: texture_cube_array;").unwrap(); parse_str("var t: texture_multisampled_2d;").unwrap(); parse_str("var t: texture_storage_1d;").unwrap(); parse_str("var t: texture_storage_3d;").unwrap(); } #[test] fn parse_type_inference() { parse_str( " fn foo() { let a = 2u; let b: u32 = a; var x = 3.; var y = vec2(1, 2); }", ) .unwrap(); assert!(parse_str( " fn foo() { let c : i32 = 2.0; }", ) .is_err()); } #[test] fn parse_type_cast() { parse_str( " const a : i32 = 2; fn main() { var x: f32 = f32(a); x = f32(i32(a + 1) / 2); } ", ) .unwrap(); parse_str( " fn main() { let x: vec2 = vec2(1.0, 2.0); let y: vec2 = vec2(x); } ", ) .unwrap(); parse_str( " fn main() { let x: vec2 = vec2(0.0); } ", ) .unwrap(); assert!(parse_str( " fn main() { let x: vec2 = vec2(0.0, 0.0); } ", ) .is_err()); } #[test] fn parse_type_coercion() { parse_str( " fn foo(bar: f32) {} fn main() { foo(0); } ", ) .unwrap(); assert!(parse_str( " fn foo(bar: i32) {} fn main() { foo(0.0); } ", ) .is_err()); } #[test] fn parse_struct() { parse_str( " struct Foo { x: i32 } struct Bar { @size(16) x: vec2, @align(16) y: f32, @size(32) @align(128) z: vec3, }; struct Empty {} var s: Foo; ", ) .unwrap(); } #[test] fn parse_standard_fun() { parse_str( " fn main() { var x: i32 = min(max(1, 2), 3); } ", ) .unwrap(); } #[test] fn parse_statement() { parse_str( " fn main() { ; {} {;} } ", ) .unwrap(); parse_str( " fn foo() {} fn bar() { foo(); } ", ) .unwrap(); } #[test] fn parse_if() { parse_str( " fn main() { if true { discard; } else {} if 0 != 1 {} if false { return; } else if true { return; } else {} } ", ) .unwrap(); } #[test] fn parse_parentheses_if() { parse_str( " fn main() { if (true) { discard; } else {} if (0 != 1) {} if (false) { return; } else if (true) { return; } else {} } ", ) .unwrap(); } #[test] fn parse_loop() { parse_str( " fn main() { var i: i32 = 0; loop { if i == 1 { break; } continuing { i = 1; } } loop { if i == 0 { continue; } break; } } ", ) .unwrap(); parse_str( " fn main() { var found: bool = false; var i: i32 = 0; while !found { if i == 10 { found = true; } i = i + 1; } } ", ) .unwrap(); parse_str( " fn main() { while true { break; } } ", ) .unwrap(); parse_str( " fn main() { var a: i32 = 0; for(var i: i32 = 0; i < 4; i = i + 1) { a = a + 2; } } ", ) .unwrap(); parse_str( " fn main() { for(;;) { break; } } ", ) .unwrap(); } #[test] fn parse_switch() { parse_str( " fn main() { var pos: f32; switch (3) { case 0, 1: { pos = 0.0; } case 2: { pos = 1.0; } default: { pos = 3.0; } } } ", ) .unwrap(); } #[test] fn parse_switch_optional_colon_in_case() { parse_str( " fn main() { var pos: f32; switch (3) { case 0, 1 { pos = 0.0; } case 2 { pos = 1.0; } default { pos = 3.0; } } } ", ) .unwrap(); } #[test] fn parse_switch_default_in_case() { parse_str( " fn main() { var pos: f32; switch (3) { case 0, 1: { pos = 0.0; } case 2: {} case default, 3: { pos = 3.0; } } } ", ) .unwrap(); } #[test] fn parse_parentheses_switch() { parse_str( " fn main() { var pos: i32; switch pos + 1 { default: { pos = 3; } } } ", ) .unwrap(); } #[test] fn parse_texture_load() { parse_str( " var t: texture_3d; fn foo() { let r: vec4 = textureLoad(t, vec3(0u, 1u, 2u), 1); } ", ) .unwrap(); parse_str( " var t: texture_2d_array; fn foo() { let r: vec4 = textureLoad(t, vec2(10, 20), 2, 3); } ", ) .unwrap(); parse_str( " var t: texture_storage_1d; fn foo() { let r: vec4 = textureLoad(t, 10); } ", ) .unwrap(); } #[test] fn parse_texture_store() { parse_str( " var t: texture_storage_2d; fn foo() { textureStore(t, vec2(10, 20), vec4(0.0, 1.0, 2.0, 3.0)); } ", ) .unwrap(); } #[test] fn parse_texture_query() { parse_str( " var t: texture_multisampled_2d; fn foo() { let dim = textureDimensions(t); let samples = textureNumSamples(t); } ", ) .unwrap(); parse_str( " var t: texture_2d_array; fn foo() { let dim = textureDimensions(t); let levels = textureNumLevels(t); let layers = textureNumLayers(t); } ", ) .unwrap(); } #[test] fn parse_postfix() { parse_str( "fn foo() { let x: f32 = vec4(1.0, 2.0, 3.0, 4.0).xyz.rgbr.aaaa.wz.g; let y: f32 = fract(vec2(0.5, x)).x; }", ) .unwrap(); let err = parse_str( "fn foo() { let v = mat4x4().x; }", ) .unwrap_err(); assert_eq!(err.message(), "invalid field accessor `x`"); } #[test] fn parse_expressions() { parse_str("fn foo() { let x: f32 = select(0.0, 1.0, true); let y: vec2 = select(vec2(1.0, 1.0), vec2(x, x), vec2((x < 0.5), (x > 0.5))); let z: bool = !(0.0 == 1.0); }").unwrap(); } #[test] fn parse_assignment_statements() { parse_str( " struct Foo { x: i32 }; fn foo() { var x: u32 = 0u; x++; x--; x = 1u; x += 1u; var v: vec2 = vec2(1.0, 1.0); v[0] += 1.0; (v)[0] += 1.0; var s: Foo = Foo(0); s.x -= 1; (s.x) -= 1; (s).x -= 1; _ = 5u; }", ) .unwrap(); let error = parse_str( "fn foo() { x|x++; }", ) .unwrap_err(); assert_eq!( error.message(), "expected assignment or increment/decrement, found \"|\"", ); } #[test] fn parse_local_var_address_space() { parse_str( " fn foo() { var a = true; var b: i32 = 5; var c = 10; }", ) .unwrap(); let error = parse_str( "fn foo() { var x: i32 = 5; }", ) .unwrap_err(); assert_eq!( error.message(), "invalid address space for local variable: `private`", ); let error = parse_str( "fn foo() { var x: i32 = 5; }", ) .unwrap_err(); assert_eq!( error.message(), "invalid address space for local variable: `storage`", ); } #[test] fn binary_expression_mixed_scalar_and_vector_operands() { for (operand, expect_splat) in [ ('<', false), ('>', false), ('&', false), ('|', false), ('+', true), ('-', true), ('*', false), ('/', true), ('%', true), ] { let module = parse_str(&format!( " @fragment fn main(@location(0) some_vec: vec3) -> @location(0) vec4 {{ if (all(1.0 {operand} some_vec)) {{ return vec4(0.0); }} return vec4(1.0); }} " )) .unwrap(); let expressions = &&module.entry_points[0].function.expressions; let found_expressions = expressions .iter() .filter(|&(_, e)| { if let crate::Expression::Binary { left, .. } = *e { matches!( (expect_splat, &expressions[left]), (false, &crate::Expression::Literal(crate::Literal::F32(..))) | (true, &crate::Expression::Splat { .. }) ) } else { false } }) .count(); assert_eq!( found_expressions, 1, "expected `{operand}` expression {} splat", if expect_splat { "with" } else { "without" } ); } let module = parse_str( "@fragment fn main(mat: mat3x3) { let vec = vec3(1.0, 1.0, 1.0); let result = mat / vec; }", ) .unwrap(); let expressions = &&module.entry_points[0].function.expressions; let found_splat = expressions.iter().any(|(_, e)| { if let crate::Expression::Binary { left, .. } = *e { matches!(&expressions[left], &crate::Expression::Splat { .. }) } else { false } }); assert!(!found_splat, "'mat / vec' should not be splatted"); } #[test] fn parse_pointers() { parse_str( "fn foo(a: ptr) -> f32 { return *a; } fn bar() { var x: f32 = 1.0; let px = &x; let py = foo(px); }", ) .unwrap(); } #[test] fn parse_struct_instantiation() { parse_str( " struct Foo { a: f32, b: vec3, } @fragment fn fs_main() { var foo: Foo = Foo(0.0, vec3(0.0, 1.0, 42.0)); } ", ) .unwrap(); } #[test] fn parse_array_length() { parse_str( " struct Foo { data: array } // this is used as both input and output for convenience @group(0) @binding(0) var foo: Foo; @group(0) @binding(1) var bar: array; fn baz() { var x: u32 = arrayLength(foo.data); var y: u32 = arrayLength(bar); } ", ) .unwrap(); } #[test] fn parse_storage_buffers() { parse_str( " @group(0) @binding(0) var foo: array; ", ) .unwrap(); parse_str( " @group(0) @binding(0) var foo: array; ", ) .unwrap(); parse_str( " @group(0) @binding(0) var foo: array; ", ) .unwrap(); parse_str( " @group(0) @binding(0) var foo: array; ", ) .unwrap(); } #[test] fn parse_alias() { parse_str( " alias Vec4 = vec4; ", ) .unwrap(); } #[test] fn shadowing_predeclared_types() { parse_str( " fn test(f32: vec2f) -> vec2f { return f32; } ", ) .unwrap(); parse_str( " fn test(vec2: vec2f) -> vec2f { return vec2; } ", ) .unwrap(); parse_str( " alias vec2f = vec2u; fn test(v: vec2f) -> vec2u { return v; } ", ) .unwrap(); parse_str( " struct vec2f { inner: vec2 }; fn test(v: vec2f) -> vec2 { return v.inner; } ", ) .unwrap(); } #[test] fn parse_texture_load_store_expecting_four_args() { for (func, texture) in [ ( "textureStore", "texture_storage_2d_array", ), ("textureLoad", "texture_2d_array"), ] { let error = parse_str(&format!( " @group(0) @binding(0) var tex_los_res: {texture}; @compute @workgroup_size(1) fn main(@builtin(global_invocation_id) id: vec3) {{ var color = vec4(1, 1, 1, 1); {func}(tex_los_res, id, color); }} " )) .unwrap_err(); assert_eq!( error.message(), "wrong number of arguments: expected 4, found 3" ); } } #[test] fn parse_repeated_attributes() { use crate::{ front::wgsl::{error::Error, Frontend}, Span, }; let template_vs = "@vertex fn vs() -> __REPLACE__ vec4 { return vec4(0.0); }"; let template_struct = "struct A { __REPLACE__ data: vec3 }"; let template_resource = "__REPLACE__ var tex_los_res: texture_2d_array;"; let template_stage = "__REPLACE__ fn vs() -> vec4 { return vec4(0.0); }"; for (attribute, template) in [ ("align(16)", template_struct), ("binding(0)", template_resource), ("builtin(position)", template_vs), ("compute", template_stage), ("fragment", template_stage), ("group(0)", template_resource), ("interpolate(flat)", template_vs), ("invariant", template_vs), ("location(0)", template_vs), ("size(16)", template_struct), ("vertex", template_stage), ("early_depth_test(less_equal)", template_resource), ("workgroup_size(1)", template_stage), ] { let shader = template.replace("__REPLACE__", &format!("@{attribute} @{attribute}")); let name_length = attribute.rfind('(').unwrap_or(attribute.len()) as u32; let span_start = shader.rfind(attribute).unwrap() as u32; let span_end = span_start + name_length; let expected_span = Span::new(span_start, span_end); let result = Frontend::new().inner(&shader); assert!(matches!( *result.unwrap_err(), Error::RepeatedAttribute(span) if span == expected_span )); } } #[test] fn parse_missing_workgroup_size() { use crate::{ front::wgsl::{error::Error, Frontend}, Span, }; let shader = "@compute fn vs() -> vec4 { return vec4(0.0); }"; let result = Frontend::new().inner(shader); assert!(matches!( *result.unwrap_err(), Error::MissingWorkgroupSize(span) if span == Span::new(1, 8) )); } mod diagnostic_filter { use crate::front::wgsl::assert_parse_err; #[test] fn intended_global_directive() { let shader = "@diagnostic(off, my.lint);"; assert_parse_err( shader, "\ error: `@diagnostic(…)` attribute(s) on semicolons are not supported ┌─ wgsl:1:1 │ 1 │ @diagnostic(off, my.lint); │ ^^^^^^^^^^^^^^^^^^^^^^^^^ │ = note: `@diagnostic(…)` attributes are only permitted on `fn`s, some statements, and `switch`/`loop` bodies. = note: If you meant to declare a diagnostic filter that applies to the entire module, move this line to the top of the file and remove the `@` symbol. " ); } mod parse_sites_not_yet_supported { use crate::front::wgsl::assert_parse_err; #[test] fn user_rules() { let shader = " fn myfunc() { if (true) @diagnostic(off, my.lint) { // ^^^^^^^^^^^^^^^^^^^^^^^^^ not yet supported, should report an error } } "; assert_parse_err(shader, "\ error: `@diagnostic(…)` attribute(s) not yet implemented ┌─ wgsl:3:15 │ 3 │ if (true) @diagnostic(off, my.lint) { │ ^^^^^^^^^^^^^^^^^^^^^^^^^ can't use this on compound statements (yet) │ = note: Let Naga maintainers know that you ran into this at , so they can prioritize it! "); } #[test] fn unknown_rules() { let shader = " fn myfunc() { if (true) @diagnostic(off, wat_is_this) { // ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ should emit a warning } } "; assert_parse_err(shader, "\ error: `@diagnostic(…)` attribute(s) not yet implemented ┌─ wgsl:3:12 │ 3 │ if (true) @diagnostic(off, wat_is_this) { │ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ can't use this on compound statements (yet) │ = note: Let Naga maintainers know that you ran into this at , so they can prioritize it! "); } } mod directive_conflict { use crate::front::wgsl::assert_parse_err; #[test] fn user_rules() { let shader = " diagnostic(off, my.lint); diagnostic(warning, my.lint); "; assert_parse_err(shader, "\ error: found conflicting `diagnostic(…)` rule(s) ┌─ wgsl:2:1 │ 2 │ diagnostic(off, my.lint); │ ^^^^^^^^^^^^^^^^^^^^^^^^ first rule 3 │ diagnostic(warning, my.lint); │ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ second rule │ = note: Multiple `diagnostic(…)` rules with the same rule name conflict unless they are directives and the severity is the same. = note: You should delete the rule you don't want. "); } #[test] fn unknown_rules() { let shader = " diagnostic(off, wat_is_this); diagnostic(warning, wat_is_this); "; assert_parse_err(shader, "\ error: found conflicting `diagnostic(…)` rule(s) ┌─ wgsl:2:1 │ 2 │ diagnostic(off, wat_is_this); │ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ first rule 3 │ diagnostic(warning, wat_is_this); │ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ second rule │ = note: Multiple `diagnostic(…)` rules with the same rule name conflict unless they are directives and the severity is the same. = note: You should delete the rule you don't want. "); } } mod attribute_conflict { use crate::front::wgsl::assert_parse_err; #[test] fn user_rules() { let shader = " diagnostic(off, my.lint); diagnostic(warning, my.lint); "; assert_parse_err(shader, "\ error: found conflicting `diagnostic(…)` rule(s) ┌─ wgsl:2:1 │ 2 │ diagnostic(off, my.lint); │ ^^^^^^^^^^^^^^^^^^^^^^^^ first rule 3 │ diagnostic(warning, my.lint); │ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ second rule │ = note: Multiple `diagnostic(…)` rules with the same rule name conflict unless they are directives and the severity is the same. = note: You should delete the rule you don't want. "); } #[test] fn unknown_rules() { let shader = " diagnostic(off, wat_is_this); diagnostic(warning, wat_is_this); "; assert_parse_err(shader, "\ error: found conflicting `diagnostic(…)` rule(s) ┌─ wgsl:2:1 │ 2 │ diagnostic(off, wat_is_this); │ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ first rule 3 │ diagnostic(warning, wat_is_this); │ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ second rule │ = note: Multiple `diagnostic(…)` rules with the same rule name conflict unless they are directives and the severity is the same. = note: You should delete the rule you don't want. "); } } } mod template { use crate::front::wgsl::assert_parse_err; #[test] fn missing_template_end() { assert_parse_err( " fn storage() {} var s: u32; ", "\ error: identifier `storage` resolves to a declaration ┌─ wgsl:3:5 │ 3 │ var s: u32; │ ^^^^^^^ needs to resolve to a predeclared enumerant ", ); } #[test] fn unexpected_expr_as_enumerant() { assert_parse_err( " var<1 + 1> s: u32; ", "\ error: unexpected expression ┌─ wgsl:2:5 │ 2 │ var<1 + 1> s: u32; │ ^^^^^ needs to be an identifier resolving to a predeclared enumerant ", ); } #[test] fn unused_exprs_for_template() { assert_parse_err( " var s: u32; ", "\ error: unused expressions for template ┌─ wgsl:2:26 │ 2 │ var s: u32; │ ^^^^^^ ^^^^^^ unused │ │\x20\x20\x20\x20\x20\x20\x20\x20 │ unused ", ); } #[test] fn unused_template_list_for_fn() { assert_parse_err( " fn inner_test() {} fn test() { inner_test(); } ", "\ error: unused expressions for template ┌─ wgsl:4:16 │ 4 │ inner_test(); │ ^^^^^^^^^^^^^^^^^^^ unused ", ); } #[test] fn unused_template_list_for_struct() { assert_parse_err( " struct test_struct {} fn test() { _ = test_struct(); } ", "\ error: unused expressions for template ┌─ wgsl:4:21 │ 4 │ _ = test_struct(); │ ^^^^^^^^^^^^^^^^^^^ unused ", ); } #[test] fn unused_template_list_for_alias() { assert_parse_err( " alias test_alias = f32; fn test() { _ = test_alias(); } ", "\ error: unused expressions for template ┌─ wgsl:4:20 │ 4 │ _ = test_alias(); │ ^^^^^^^^^^^^^^^^^^^ unused ", ); } #[test] fn unexpected_template() { assert_parse_err( " fn vertex() -> vec4 { return vec4; } ", "\ error: unexpected template ┌─ wgsl:3:12 │ 3 │ return vec4; │ ^^^^^^^^^ expected identifier ", ); } #[test] fn expected_template_arg() { assert_parse_err( " fn test() { bitcast(8); } ", "\ error: `bitcast` needs a template argument specified: `T`, a type ┌─ wgsl:3:5 │ 3 │ bitcast(8); │ ^^^^^^^ is missing a template argument ", ); } } naga-29.0.3/src/ir/block.rs000064400000000000000000000066661046102023000135040ustar 00000000000000use alloc::vec::Vec; use core::ops::{Deref, DerefMut, RangeBounds}; use crate::{Span, Statement}; /// A code block is a vector of statements, with maybe a vector of spans. #[derive(Debug, Clone, Default)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "serialize", serde(transparent))] #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] pub struct Block { body: Vec, #[cfg_attr(feature = "serialize", serde(skip))] span_info: Vec, } impl Block { pub const fn new() -> Self { Self { body: Vec::new(), span_info: Vec::new(), } } pub fn from_vec(body: Vec) -> Self { let span_info = core::iter::repeat_n(Span::default(), body.len()).collect(); Self { body, span_info } } pub fn with_capacity(capacity: usize) -> Self { Self { body: Vec::with_capacity(capacity), span_info: Vec::with_capacity(capacity), } } #[allow(unused_variables)] pub fn push(&mut self, end: Statement, span: Span) { self.body.push(end); self.span_info.push(span); } pub fn extend(&mut self, item: Option<(Statement, Span)>) { if let Some((end, span)) = item { self.push(end, span) } } pub fn extend_block(&mut self, other: Self) { self.span_info.extend(other.span_info); self.body.extend(other.body); } pub fn append(&mut self, other: &mut Self) { self.span_info.append(&mut other.span_info); self.body.append(&mut other.body); } pub fn cull + Clone>(&mut self, range: R) { self.span_info.drain(range.clone()); self.body.drain(range); } pub fn splice + Clone>(&mut self, range: R, other: Self) { self.span_info.splice(range.clone(), other.span_info); self.body.splice(range, other.body); } pub fn span_into_iter(self) -> impl Iterator { let Block { body, span_info } = self; body.into_iter().zip(span_info) } pub fn span_iter(&self) -> impl Iterator { let span_iter = self.span_info.iter(); self.body.iter().zip(span_iter) } pub fn span_iter_mut(&mut self) -> impl Iterator)> { let span_iter = self.span_info.iter_mut().map(Some); self.body.iter_mut().zip(span_iter) } pub const fn is_empty(&self) -> bool { self.body.is_empty() } pub const fn len(&self) -> usize { self.body.len() } } impl Deref for Block { type Target = [Statement]; fn deref(&self) -> &[Statement] { &self.body } } impl DerefMut for Block { fn deref_mut(&mut self) -> &mut [Statement] { &mut self.body } } impl<'a> IntoIterator for &'a Block { type Item = &'a Statement; type IntoIter = core::slice::Iter<'a, Statement>; fn into_iter(self) -> core::slice::Iter<'a, Statement> { self.iter() } } #[cfg(feature = "deserialize")] impl<'de> serde::Deserialize<'de> for Block { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { Ok(Self::from_vec(Vec::deserialize(deserializer)?)) } } impl From> for Block { fn from(body: Vec) -> Self { Self::from_vec(body) } } naga-29.0.3/src/ir/mod.rs000064400000000000000000003257431046102023000131710ustar 00000000000000/*! The Intermediate Representation shared by all frontends and backends. The central structure of the IR, and the crate, is [`Module`]. A `Module` contains: - [`Function`]s, which have arguments, a return type, local variables, and a body, - [`EntryPoint`]s, which are specialized functions that can serve as the entry point for pipeline stages like vertex shading or fragment shading, - [`Constant`]s and [`GlobalVariable`]s used by `EntryPoint`s and `Function`s, and - [`Type`]s used by the above. The body of an `EntryPoint` or `Function` is represented using two types: - An [`Expression`] produces a value, but has no side effects or control flow. `Expressions` include variable references, unary and binary operators, and so on. - A [`Statement`] can have side effects and structured control flow. `Statement`s do not produce a value, other than by storing one in some designated place. `Statements` include blocks, conditionals, and loops, but also operations that have side effects, like stores and function calls. `Statement`s form a tree, with pointers into the DAG of `Expression`s. Restricting side effects to statements simplifies analysis and code generation. A Naga backend can generate code to evaluate an `Expression` however and whenever it pleases, as long as it is certain to observe the side effects of all previously executed `Statement`s. Many `Statement` variants use the [`Block`] type, which is `Vec`, with optional span info, representing a series of statements executed in order. The body of an `EntryPoint`s or `Function` is a `Block`, and `Statement` has a [`Block`][Statement::Block] variant. ## Function Calls Naga's representation of function calls is unusual. Most languages treat function calls as expressions, but because calls may have side effects, Naga represents them as a kind of statement, [`Statement::Call`]. If the function returns a value, a call statement designates a particular [`Expression::CallResult`] expression to represent its return value, for use by subsequent statements and expressions. ## `Expression` evaluation time It is essential to know when an [`Expression`] should be evaluated, because its value may depend on previous [`Statement`]s' effects. But whereas the order of execution for a tree of `Statement`s is apparent from its structure, it is not so clear for `Expressions`, since an expression may be referred to by any number of `Statement`s and other `Expression`s. Naga's rules for when `Expression`s are evaluated are as follows: - [`Literal`], [`Constant`], and [`ZeroValue`] expressions are considered to be implicitly evaluated before execution begins. - [`FunctionArgument`] and [`LocalVariable`] expressions are considered implicitly evaluated upon entry to the function to which they belong. Function arguments cannot be assigned to, and `LocalVariable` expressions produce a *pointer to* the variable's value (for use with [`Load`] and [`Store`]). Neither varies while the function executes, so it suffices to consider these expressions evaluated once on entry. - Similarly, [`GlobalVariable`] expressions are considered implicitly evaluated before execution begins, since their value does not change while code executes, for one of two reasons: - Most `GlobalVariable` expressions produce a pointer to the variable's value, for use with [`Load`] and [`Store`], as `LocalVariable` expressions do. Although the variable's value may change, its address does not. - A `GlobalVariable` expression referring to a global in the [`AddressSpace::Handle`] address space produces the value directly, not a pointer. Such global variables hold opaque types like shaders or images, and cannot be assigned to. - A [`CallResult`] expression that is the `result` of a [`Statement::Call`], representing the call's return value, is evaluated when the `Call` statement is executed. - Similarly, an [`AtomicResult`] expression that is the `result` of an [`Atomic`] statement, representing the result of the atomic operation, is evaluated when the `Atomic` statement is executed. - A [`RayQueryProceedResult`] expression, which is a boolean indicating if the ray query is finished, is evaluated when the [`RayQuery`] statement whose [`Proceed::result`] points to it is executed. - All other expressions are evaluated when the (unique) [`Statement::Emit`] statement that covers them is executed. Now, strictly speaking, not all `Expression` variants actually care when they're evaluated. For example, you can evaluate a [`BinaryOperator::Add`] expression any time you like, as long as you give it the right operands. It's really only a very small set of expressions that are affected by timing: - [`Load`], [`ImageSample`], and [`ImageLoad`] expressions are influenced by stores to the variables or images they access, and must execute at the proper time relative to them. - [`Derivative`] expressions are sensitive to control flow uniformity: they must not be moved out of an area of uniform control flow into a non-uniform area. - More generally, any expression that's used by more than one other expression or statement should probably be evaluated only once, and then stored in a variable to be cited at each point of use. Naga tries to help back ends handle all these cases correctly in a somewhat circuitous way. The [`ModuleInfo`] structure returned by [`Validator::validate`] provides a reference count for each expression in each function in the module. Naturally, any expression with a reference count of two or more deserves to be evaluated and stored in a temporary variable at the point that the `Emit` statement covering it is executed. But if we selectively lower the reference count threshold to _one_ for the sensitive expression types listed above, so that we _always_ generate a temporary variable and save their value, then the same code that manages multiply referenced expressions will take care of introducing temporaries for time-sensitive expressions as well. The `Expression::bake_ref_count` method (private to the back ends) is meant to help with this. ## `Expression` scope Each `Expression` has a *scope*, which is the region of the function within which it can be used by `Statement`s and other `Expression`s. It is a validation error to use an `Expression` outside its scope. An expression's scope is defined as follows: - The scope of a [`Constant`], [`GlobalVariable`], [`FunctionArgument`] or [`LocalVariable`] expression covers the entire `Function` in which it occurs. - The scope of an expression evaluated by an [`Emit`] statement covers the subsequent expressions in that `Emit`, the subsequent statements in the `Block` to which that `Emit` belongs (if any) and their sub-statements (if any). - The `result` expression of a [`Call`] or [`Atomic`] statement has a scope covering the subsequent statements in the `Block` in which the statement occurs (if any) and their sub-statements (if any). For example, this implies that an expression evaluated by some statement in a nested `Block` is not available in the `Block`'s parents. Such a value would need to be stored in a local variable to be carried upwards in the statement tree. ## Constant expressions A Naga *constant expression* is one of the following [`Expression`] variants, whose operands (if any) are also constant expressions: - [`Literal`] - [`Constant`], for [`Constant`]s - [`ZeroValue`], for fixed-size types - [`Compose`] - [`Access`] - [`AccessIndex`] - [`Splat`] - [`Swizzle`] - [`Unary`] - [`Binary`] - [`Select`] - [`Relational`] - [`Math`] - [`As`] A constant expression can be evaluated at module translation time. ## Override expressions A Naga *override expression* is the same as a [constant expression], except that it is also allowed to reference other [`Override`]s. An override expression can be evaluated at pipeline creation time. [`AtomicResult`]: Expression::AtomicResult [`RayQueryProceedResult`]: Expression::RayQueryProceedResult [`CallResult`]: Expression::CallResult [`Constant`]: Expression::Constant [`ZeroValue`]: Expression::ZeroValue [`Literal`]: Expression::Literal [`Derivative`]: Expression::Derivative [`FunctionArgument`]: Expression::FunctionArgument [`GlobalVariable`]: Expression::GlobalVariable [`ImageLoad`]: Expression::ImageLoad [`ImageSample`]: Expression::ImageSample [`Load`]: Expression::Load [`LocalVariable`]: Expression::LocalVariable [`Atomic`]: Statement::Atomic [`Call`]: Statement::Call [`Emit`]: Statement::Emit [`Store`]: Statement::Store [`RayQuery`]: Statement::RayQuery [`Proceed::result`]: RayQueryFunction::Proceed::result [`Validator::validate`]: crate::valid::Validator::validate [`ModuleInfo`]: crate::valid::ModuleInfo [`Literal`]: Expression::Literal [`ZeroValue`]: Expression::ZeroValue [`Compose`]: Expression::Compose [`Access`]: Expression::Access [`AccessIndex`]: Expression::AccessIndex [`Splat`]: Expression::Splat [`Swizzle`]: Expression::Swizzle [`Unary`]: Expression::Unary [`Binary`]: Expression::Binary [`Select`]: Expression::Select [`Relational`]: Expression::Relational [`Math`]: Expression::Math [`As`]: Expression::As [constant expression]: #constant-expressions */ mod block; use alloc::{boxed::Box, string::String, vec::Vec}; #[cfg(feature = "arbitrary")] use arbitrary::Arbitrary; use half::f16; #[cfg(feature = "deserialize")] use serde::Deserialize; #[cfg(feature = "serialize")] use serde::Serialize; use crate::arena::{Arena, Handle, Range, UniqueArena}; use crate::diagnostic_filter::DiagnosticFilterNode; use crate::{FastIndexMap, NamedExpressions}; pub use block::Block; /// Explicitly allows early depth/stencil tests. /// /// Normally, depth/stencil tests are performed after fragment shading. However, as an optimization, /// most drivers will move the depth/stencil tests before fragment shading if this does not /// have any observable consequences. This optimization is disabled under the following /// circumstances: /// - `discard` is called in the fragment shader. /// - The fragment shader writes to the depth buffer. /// - The fragment shader writes to any storage bindings. /// /// When `EarlyDepthTest` is set, it is allowed to perform an early depth/stencil test even if the /// above conditions are not met. When [`EarlyDepthTest::Force`] is used, depth/stencil tests /// **must** be performed before fragment shading. /// /// To force early depth/stencil tests in a shader: /// - GLSL: `layout(early_fragment_tests) in;` /// - HLSL: `Attribute earlydepthstencil` /// - SPIR-V: `ExecutionMode EarlyFragmentTests` /// - WGSL: `@early_depth_test(force)` /// /// This may also be enabled in a shader by specifying a [`ConservativeDepth`]. /// /// For more, see: /// - /// - /// - #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum EarlyDepthTest { /// Requires depth/stencil tests to be performed before fragment shading. /// /// This will disable depth/stencil tests after fragment shading, so discarding the fragment /// or overwriting the fragment depth will have no effect. Force, /// Allows an additional depth/stencil test to be performed before fragment shading. /// /// It is up to the driver to decide whether early tests are performed. Unlike `Force`, this /// does not disable depth/stencil tests after fragment shading. Allow { /// Specifies restrictions on how the depth value can be modified within the fragment /// shader. /// /// This may be taken into account when deciding whether to perform early tests. conservative: ConservativeDepth, }, } /// Enables adjusting depth without disabling early Z. /// /// To use in a shader: /// - GLSL: `layout (depth_) out float gl_FragDepth;` /// - `depth_any` option behaves as if the layout qualifier was not present. /// - HLSL: `SV_DepthGreaterEqual`/`SV_DepthLessEqual`/`SV_Depth` /// - SPIR-V: `ExecutionMode Depth` /// - WGSL: `@early_depth_test(greater_equal/less_equal/unchanged)` /// /// For more, see: /// - /// - /// - #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum ConservativeDepth { /// Shader may rewrite depth only with a value greater than calculated. GreaterEqual, /// Shader may rewrite depth smaller than one that would have been written without the modification. LessEqual, /// Shader may not rewrite depth value. Unchanged, } /// Stage of the programmable pipeline. #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum ShaderStage { /// A vertex shader, in a render pipeline. Vertex, /// A task shader, in a mesh render pipeline. Task, /// A mesh shader, in a mesh render pipeline. Mesh, /// A fragment shader, in a render pipeline. Fragment, /// Compute pipeline shader. Compute, /// A ray generation shader, in a ray tracing pipeline. RayGeneration, /// A miss shader, in a ray tracing pipeline. Miss, /// A any hit shader, in a ray tracing pipeline. AnyHit, /// A closest hit shader, in a ray tracing pipeline. ClosestHit, } /// Addressing space of variables. #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum AddressSpace { /// Function locals. Function, /// Private data, per invocation, mutable. Private, /// Workgroup shared data, mutable. WorkGroup, /// Uniform buffer data. Uniform, /// Storage buffer data, potentially mutable. Storage { access: StorageAccess }, /// Opaque handles, such as samplers and images. Handle, /// Immediate data. /// /// A [`Module`] may contain at most one [`GlobalVariable`] in /// this address space. Its contents are provided not by a buffer /// but by `SetImmediates` pass commands, allowing the CPU to /// establish different values for each draw/dispatch. /// /// `Immediate` variables may not contain `f16` values, even if /// the [`SHADER_FLOAT16`] capability is enabled. /// /// Backends generally place tight limits on the size of /// `Immediate` variables. /// /// [`SHADER_FLOAT16`]: crate::valid::Capabilities::SHADER_FLOAT16 Immediate, /// Task shader to mesh shader payload TaskPayload, /// Ray tracing payload, for inputting in TraceRays RayPayload, /// Ray tracing payload, for entrypoints invoked by a TraceRays call /// /// Each entrypoint may reference only one variable in this scope, as /// only one may be passed as a payload. IncomingRayPayload, } /// Built-in inputs and outputs. #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum BuiltIn { // This must be at the top so that it gets sorted to the top. PrimitiveIndex is considered a non SV // by FXC so it must appear before any other SVs. /// Read in fragment shaders, written in mesh shaders, read in any and closest hit shaders. PrimitiveIndex, /// Written in vertex/mesh shaders, read in fragment shaders Position { invariant: bool }, /// Read in task, mesh, vertex, and fragment shaders ViewIndex, /// Read in vertex shaders BaseInstance, /// Read in vertex shaders BaseVertex, /// Written in vertex & mesh shaders ClipDistance, /// Written in vertex & mesh shaders CullDistance, /// Read in vertex, any- and closest-hit shaders InstanceIndex, /// Written in vertex & mesh shaders PointSize, /// Read in vertex shaders VertexIndex, /// Read in vertex & task shaders, or mesh shaders in pipelines without task shaders DrawIndex, /// Written in fragment shaders FragDepth, /// Read in fragment shaders PointCoord, /// Read in fragment shaders FrontFacing, /// Read in fragment shaders Barycentric { perspective: bool }, /// Read in fragment shaders SampleIndex, /// Read or written in fragment shaders SampleMask, /// Read in compute, task, and mesh shaders GlobalInvocationId, /// Read in compute, task, and mesh shaders LocalInvocationId, /// Read in compute, task, and mesh shaders LocalInvocationIndex, /// Read in compute, task, and mesh shaders WorkGroupId, /// Read in compute, task, and mesh shaders WorkGroupSize, /// Read in compute, task, and mesh shaders NumWorkGroups, /// Read in compute, task, and mesh shaders NumSubgroups, /// Read in compute, task, and mesh shaders SubgroupId, /// Read in compute, fragment, task, and mesh shaders SubgroupSize, /// Read in compute, fragment, task, and mesh shaders SubgroupInvocationId, /// Written in task shaders MeshTaskSize, /// Written in mesh shaders CullPrimitive, /// Written in mesh shaders PointIndex, /// Written in mesh shaders LineIndices, /// Written in mesh shaders TriangleIndices, /// Written to a workgroup variable in mesh shaders VertexCount, /// Written to a workgroup variable in mesh shaders Vertices, /// Written to a workgroup variable in mesh shaders PrimitiveCount, /// Written to a workgroup variable in mesh shaders Primitives, /// Read in all ray tracing pipeline shaders, the id within the number of /// rays that this current ray is. RayInvocationId, /// Read in all ray tracing pipeline shaders, the number of rays created. NumRayInvocations, /// Read in closest hit and any hit shaders, the custom data in the tlas /// instance InstanceCustomData, /// Read in closest hit and any hit shaders, the index of the geometry in /// the blas. GeometryIndex, /// Read in closest hit, any hit, and miss shaders, the origin of the ray. WorldRayOrigin, /// Read in closest hit, any hit, and miss shaders, the direction of the /// ray. WorldRayDirection, /// Read in closest hit and any hit shaders, the direction of the ray in /// object space. ObjectRayOrigin, /// Read in closest hit and any hit shaders, the direction of the ray in /// object space. ObjectRayDirection, /// Read in closest hit, any hit, and miss shaders, the t min provided by /// in the ray desc. RayTmin, /// Read in closest hit, any hit, and miss shaders, the final bounds at which /// a hit is accepted (the closest committed hit if there is one otherwise, t /// max provided in the ray desc). RayTCurrentMax, /// Read in closest hit and any hit shaders, the matrix for converting from /// object space to world space ObjectToWorld, /// Read in closest hit and any hit shaders, the matrix for converting from /// world space to object space WorldToObject, /// Read in closest hit and any hit shaders, the type of hit as provided by /// the intersection function if any, otherwise this is 254 (0xFE) for a /// front facing triangle and 255 (0xFF) for a back facing triangle HitKind, } /// Number of bytes per scalar. pub type Bytes = u8; /// Number of components in a vector. #[repr(u8)] #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum VectorSize { /// 2D vector Bi = 2, /// 3D vector Tri = 3, /// 4D vector Quad = 4, } impl VectorSize { pub const MAX: usize = Self::Quad as usize; } impl From for u8 { fn from(size: VectorSize) -> u8 { size as u8 } } impl From for u32 { fn from(size: VectorSize) -> u32 { size as u32 } } /// Number of components in a cooperative vector. #[repr(u8)] #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum CooperativeSize { Eight = 8, Sixteen = 16, } /// Primitive type for a scalar. #[repr(u8)] #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum ScalarKind { /// Signed integer type. Sint, /// Unsigned integer type. Uint, /// Floating point type. Float, /// Boolean type. Bool, /// WGSL abstract integer type. /// /// These are forbidden by validation, and should never reach backends. AbstractInt, /// Abstract floating-point type. /// /// These are forbidden by validation, and should never reach backends. AbstractFloat, } /// Role of a cooperative variable in the equation "A * B + C" #[repr(u8)] #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum CooperativeRole { A, B, C, } /// Characteristics of a scalar type. #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub struct Scalar { /// How the value's bits are to be interpreted. pub kind: ScalarKind, /// This size of the value in bytes. pub width: Bytes, } /// Size of an array. #[repr(u8)] #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum ArraySize { /// The array size is constant. Constant(core::num::NonZeroU32), /// The array size is an override-expression. Pending(Handle), /// The array size can change at runtime. Dynamic, } /// The interpolation qualifier of a binding or struct field. #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum Interpolation { /// The value will be interpolated in a perspective-correct fashion. /// Also known as "smooth" in glsl. Perspective, /// Indicates that linear, non-perspective, correct /// interpolation must be used. /// Also known as "no_perspective" in glsl. Linear, /// Indicates that no interpolation will be performed. Flat, /// Indicates the fragment input binding holds an array of per-vertex values. /// This is typically used with barycentrics. PerVertex, } /// The sampling qualifiers of a binding or struct field. #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum Sampling { /// Interpolate the value at the center of the pixel. Center, /// Interpolate the value at a point that lies within all samples covered by /// the fragment within the current primitive. In multisampling, use a /// single value for all samples in the primitive. Centroid, /// Interpolate the value at each sample location. In multisampling, invoke /// the fragment shader once per sample. Sample, /// Use the value provided by the first vertex of the current primitive. First, /// Use the value provided by the first or last vertex of the current primitive. The exact /// choice is implementation-dependent. Either, } /// Member of a user-defined structure. // Clone is used only for error reporting and is not intended for end users #[derive(Clone, Debug, Eq, Hash, PartialEq)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub struct StructMember { pub name: Option, /// Type of the field. pub ty: Handle, /// For I/O structs, defines the binding. pub binding: Option, /// Offset from the beginning from the struct. pub offset: u32, } /// The number of dimensions an image has. #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum ImageDimension { /// 1D image D1, /// 2D image D2, /// 3D image D3, /// Cube map Cube, } bitflags::bitflags! { /// Flags describing an image. #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] #[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct StorageAccess: u32 { /// Storage can be used as a source for load ops. const LOAD = 0x1; /// Storage can be used as a target for store ops. const STORE = 0x2; /// Storage can be used as a target for atomic ops. const ATOMIC = 0x4; } } bitflags::bitflags! { /// Memory decorations for global variables. #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] #[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct MemoryDecorations: u8 { /// Reads and writes are automatically visible to other invocations /// without explicit barriers. const COHERENT = 0x1; /// The variable may be modified by something external to the shader, /// preventing certain compiler optimizations. const VOLATILE = 0x2; } } /// Image storage format. #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum StorageFormat { // 8-bit formats R8Unorm, R8Snorm, R8Uint, R8Sint, // 16-bit formats R16Uint, R16Sint, R16Float, Rg8Unorm, Rg8Snorm, Rg8Uint, Rg8Sint, // 32-bit formats R32Uint, R32Sint, R32Float, Rg16Uint, Rg16Sint, Rg16Float, Rgba8Unorm, Rgba8Snorm, Rgba8Uint, Rgba8Sint, Bgra8Unorm, // Packed 32-bit formats Rgb10a2Uint, Rgb10a2Unorm, Rg11b10Ufloat, // 64-bit formats R64Uint, Rg32Uint, Rg32Sint, Rg32Float, Rgba16Uint, Rgba16Sint, Rgba16Float, // 128-bit formats Rgba32Uint, Rgba32Sint, Rgba32Float, // Normalized 16-bit per channel formats R16Unorm, R16Snorm, Rg16Unorm, Rg16Snorm, Rgba16Unorm, Rgba16Snorm, } /// Sub-class of the image type. #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum ImageClass { /// Regular sampled image. Sampled { /// Kind of values to sample. kind: ScalarKind, /// Multi-sampled image. /// /// A multi-sampled image holds several samples per texel. Multi-sampled /// images cannot have mipmaps. multi: bool, }, /// Depth comparison image. Depth { /// Multi-sampled depth image. multi: bool, }, /// External texture. External, /// Storage image. Storage { format: StorageFormat, access: StorageAccess, }, } /// A data type declared in the module. #[derive(Clone, Debug, Eq, Hash, PartialEq)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub struct Type { /// The name of the type, if any. pub name: Option, /// Inner structure that depends on the kind of the type. pub inner: TypeInner, } /// Enum with additional information, depending on the kind of type. /// /// Comparison using `==` is not reliable in the case of [`Pointer`], /// [`ValuePointer`], or [`Struct`] variants. For these variants, /// use [`TypeInner::non_struct_equivalent`] or [`compare_types`]. /// /// [`compare_types`]: crate::proc::compare_types /// [`ValuePointer`]: TypeInner::ValuePointer /// [`Pointer`]: TypeInner::Pointer /// [`Struct`]: TypeInner::Struct #[derive(Clone, Debug, Eq, Hash, PartialEq)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum TypeInner { /// Number of integral or floating-point kind. Scalar(Scalar), /// Vector of numbers. Vector { size: VectorSize, scalar: Scalar }, /// Matrix of numbers. Matrix { columns: VectorSize, rows: VectorSize, scalar: Scalar, }, /// Matrix that is cooperatively processed by all the threads /// in an opaque mapping. CooperativeMatrix { columns: CooperativeSize, rows: CooperativeSize, scalar: Scalar, role: CooperativeRole, }, /// Atomic scalar. Atomic(Scalar), /// Pointer to another type. /// /// Pointers to scalars and vectors should be treated as equivalent to /// [`ValuePointer`] types. Use either [`TypeInner::non_struct_equivalent`] /// or [`compare_types`] to compare types in a way that treats pointers /// correctly. /// /// ## Pointers to non-`SIZED` types /// /// The `base` type of a pointer may be a non-[`SIZED`] type like a /// dynamically-sized [`Array`], or a [`Struct`] whose last member is a /// dynamically sized array. Such pointers occur as the types of /// [`GlobalVariable`] or [`AccessIndex`] expressions referring to /// dynamically-sized arrays. /// /// However, among pointers to non-`SIZED` types, only pointers to `Struct`s /// are [`DATA`]. Pointers to dynamically sized `Array`s cannot be passed as /// arguments, stored in variables, or held in arrays or structures. Their /// only use is as the types of `AccessIndex` expressions. /// /// [`SIZED`]: crate::valid::TypeFlags::SIZED /// [`DATA`]: crate::valid::TypeFlags::DATA /// [`Array`]: TypeInner::Array /// [`Struct`]: TypeInner::Struct /// [`ValuePointer`]: TypeInner::ValuePointer /// [`GlobalVariable`]: Expression::GlobalVariable /// [`AccessIndex`]: Expression::AccessIndex /// [`compare_types`]: crate::proc::compare_types Pointer { base: Handle, space: AddressSpace, }, /// Pointer to a scalar or vector. /// /// A `ValuePointer` type is equivalent to a `Pointer` whose `base` is a /// `Scalar` or `Vector` type. This is for use in [`TypeResolution::Value`] /// variants; see the documentation for [`TypeResolution`] for details. /// /// Use [`TypeInner::non_struct_equivalent`] or [`compare_types`] to compare /// types that could be pointers, to ensure that `Pointer` and /// `ValuePointer` types are recognized as equivalent. /// /// [`TypeResolution`]: crate::proc::TypeResolution /// [`TypeResolution::Value`]: crate::proc::TypeResolution::Value /// [`compare_types`]: crate::proc::compare_types ValuePointer { size: Option, scalar: Scalar, space: AddressSpace, }, /// Homogeneous list of elements. /// /// The `base` type must be a [`SIZED`], [`DATA`] type. /// /// ## Dynamically sized arrays /// /// An `Array` is [`SIZED`] unless its `size` is [`Dynamic`]. /// Dynamically-sized arrays may only appear in a few situations: /// /// - They may appear as the type of a [`GlobalVariable`], or as the last /// member of a [`Struct`]. /// /// - They may appear as the base type of a [`Pointer`]. An /// [`AccessIndex`] expression referring to a struct's final /// unsized array member would have such a pointer type. However, such /// pointer types may only appear as the types of such intermediate /// expressions. They are not [`DATA`], and cannot be stored in /// variables, held in arrays or structs, or passed as parameters. /// /// [`SIZED`]: crate::valid::TypeFlags::SIZED /// [`DATA`]: crate::valid::TypeFlags::DATA /// [`Dynamic`]: ArraySize::Dynamic /// [`Struct`]: TypeInner::Struct /// [`Pointer`]: TypeInner::Pointer /// [`AccessIndex`]: Expression::AccessIndex Array { base: Handle, size: ArraySize, stride: u32, }, /// User-defined structure. /// /// There must always be at least one member. /// /// A `Struct` type is [`DATA`], and the types of its members must be /// `DATA` as well. /// /// Member types must be [`SIZED`], except for the final member of a /// struct, which may be a dynamically sized [`Array`]. The /// `Struct` type itself is `SIZED` when all its members are `SIZED`. /// /// Two structure types with different names are not equivalent. Because /// this variant does not contain the name, it is not possible to use it /// to compare struct types. Use [`compare_types`] to compare two types /// that may be structs. /// /// [`DATA`]: crate::valid::TypeFlags::DATA /// [`SIZED`]: crate::∅TypeFlags::SIZED /// [`Array`]: TypeInner::Array /// [`compare_types`]: crate::proc::compare_types Struct { members: Vec, //TODO: should this be unaligned? span: u32, }, /// Possibly multidimensional array of texels. Image { dim: ImageDimension, arrayed: bool, //TODO: consider moving `multisampled: bool` out class: ImageClass, }, /// Can be used to sample values from images. Sampler { comparison: bool }, /// Opaque object representing an acceleration structure of geometry. AccelerationStructure { vertex_return: bool }, /// Locally used handle for ray queries. RayQuery { vertex_return: bool }, /// Array of bindings. /// /// A `BindingArray` represents an array where each element draws its value /// from a separate bound resource. The array's element type `base` may be /// [`Image`], [`Sampler`], or any type that would be permitted for a global /// in the [`Uniform`] or [`Storage`] address spaces. Only global variables /// may be binding arrays; on the host side, their values are provided by /// [`TextureViewArray`], [`SamplerArray`], or [`BufferArray`] /// bindings. /// /// Since each element comes from a distinct resource, a binding array of /// images could have images of varying sizes (but not varying dimensions; /// they must all have the same `Image` type). Or, a binding array of /// buffers could have elements that are dynamically sized arrays, each with /// a different length. /// /// Binding arrays are in the same address spaces as their underlying type. /// As such, referring to an array of images produces an [`Image`] value /// directly (as opposed to a pointer). The only operation permitted on /// `BindingArray` values is indexing, which works transparently: indexing /// a binding array of samplers yields a [`Sampler`], indexing a pointer to the /// binding array of storage buffers produces a pointer to the storage struct. /// /// Unlike textures and samplers, binding arrays are not [`ARGUMENT`], so /// they cannot be passed as arguments to functions. /// /// Naga's WGSL front end supports binding arrays with the type syntax /// `binding_array`. /// /// [`Image`]: TypeInner::Image /// [`Sampler`]: TypeInner::Sampler /// [`Uniform`]: AddressSpace::Uniform /// [`Storage`]: AddressSpace::Storage /// [`TextureViewArray`]: https://docs.rs/wgpu/latest/wgpu/enum.BindingResource.html#variant.TextureViewArray /// [`SamplerArray`]: https://docs.rs/wgpu/latest/wgpu/enum.BindingResource.html#variant.SamplerArray /// [`BufferArray`]: https://docs.rs/wgpu/latest/wgpu/enum.BindingResource.html#variant.BufferArray /// [`DATA`]: crate::valid::TypeFlags::DATA /// [`ARGUMENT`]: crate::valid::TypeFlags::ARGUMENT /// [naga#1864]: https://github.com/gfx-rs/naga/issues/1864 BindingArray { base: Handle, size: ArraySize }, } #[derive(Debug, Clone, Copy, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum Literal { /// May not be NaN or infinity. F64(f64), /// May not be NaN or infinity. F32(f32), /// May not be NaN or infinity. F16(f16), U32(u32), I32(i32), U64(u64), I64(i64), Bool(bool), AbstractInt(i64), AbstractFloat(f64), } /// Pipeline-overridable constant. #[derive(Clone, Debug, PartialEq)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub struct Override { pub name: Option, /// Pipeline Constant ID. pub id: Option, pub ty: Handle, /// The default value of the pipeline-overridable constant. /// /// This [`Handle`] refers to [`Module::global_expressions`], not /// any [`Function::expressions`] arena. pub init: Option>, } /// Constant value. #[derive(Clone, Debug, PartialEq)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub struct Constant { pub name: Option, pub ty: Handle, /// The value of the constant. /// /// This [`Handle`] refers to [`Module::global_expressions`], not /// any [`Function::expressions`] arena. pub init: Handle, } /// Describes how an input/output variable is to be bound. #[derive(Clone, Debug, Eq, PartialEq, Hash)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum Binding { /// Built-in shader variable. BuiltIn(BuiltIn), /// Indexed location. /// /// This is a value passed to a [`Fragment`] shader from a [`Vertex`] or /// [`Mesh`] shader. /// /// Values passed from the [`Vertex`] stage to the [`Fragment`] stage must /// have their `interpolation` defaulted (i.e. not `None`) by the front end /// as appropriate for that language. /// /// For other stages, we permit interpolations even though they're ignored. /// When a front end is parsing a struct type, it usually doesn't know what /// stages will be using it for IO, so it's easiest if it can apply the /// defaults to anything with a `Location` binding, just in case. /// /// For anything other than floating-point scalars and vectors, the /// interpolation must be `Flat`. /// /// [`Vertex`]: crate::ShaderStage::Vertex /// [`Mesh`]: crate::ShaderStage::Mesh /// [`Fragment`]: crate::ShaderStage::Fragment Location { location: u32, interpolation: Option, sampling: Option, /// Optional `blend_src` index used for dual source blending. /// See blend_src: Option, /// Whether the binding is a per-primitive binding for use with mesh shaders. /// /// This must be `true` if this binding is a mesh shader primitive output, or such /// an output's corresponding fragment shader input. It must be `false` otherwise. /// /// A stage's outputs must all have unique `location` numbers, regardless of /// whether they are per-primitive; a mesh shader's per-vertex and per-primitive /// outputs share the same location numbering space. /// /// Per-primitive values are not interpolated at all and are not dependent on the /// vertices or pixel location. For example, it may be used to store a /// non-interpolated normal vector. per_primitive: bool, }, } /// Pipeline binding information for global resources. #[derive(Copy, Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub struct ResourceBinding { /// The bind group index. pub group: u32, /// Binding number within the group. pub binding: u32, } /// Variable defined at module level. #[derive(Clone, Debug, PartialEq)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub struct GlobalVariable { /// Name of the variable, if any. pub name: Option, /// How this variable is to be stored. pub space: AddressSpace, /// For resources, defines the binding point. pub binding: Option, /// The type of this variable. pub ty: Handle, /// Initial value for this variable. /// /// This refers to an [`Expression`] in [`Module::global_expressions`]. pub init: Option>, /// Memory decorations for this variable. /// /// These are meaningful for storage address space variables in SPIR-V, /// where they map to SPIR-V memory decorations on the variable. /// /// In WGSL, these can be set with attributes like `@coherent` or `@volatile`. pub memory_decorations: MemoryDecorations, } /// Variable defined at function level. #[derive(Clone, Debug)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub struct LocalVariable { /// Name of the variable, if any. pub name: Option, /// The type of this variable. pub ty: Handle, /// Initial value for this variable. /// /// This handle refers to an expression in this `LocalVariable`'s function's /// [`expressions`] arena, but it is required to be an evaluated override /// expression. /// /// [`expressions`]: Function::expressions pub init: Option>, } /// Operation that can be applied on a single value. #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum UnaryOperator { Negate, LogicalNot, BitwiseNot, } /// Operation that can be applied on two values. /// /// ## Arithmetic type rules /// /// The arithmetic operations `Add`, `Subtract`, `Multiply`, `Divide`, and /// `Modulo` can all be applied to [`Scalar`] types other than [`Bool`], or /// [`Vector`]s thereof. Both operands must have the same type. /// /// `Add` and `Subtract` can also be applied to [`Matrix`] values. Both operands /// must have the same type. /// /// `Multiply` supports additional cases: /// /// - A [`Matrix`] or [`Vector`] can be multiplied by a scalar [`Float`], /// either on the left or the right. /// /// - A [`Matrix`] on the left can be multiplied by a [`Vector`] on the right /// if the matrix has as many columns as the vector has components /// (`matCxR * VecC`). /// /// - A [`Vector`] on the left can be multiplied by a [`Matrix`] on the right /// if the matrix has as many rows as the vector has components /// (`VecR * matCxR`). /// /// - Two matrices can be multiplied if the left operand has as many columns /// as the right operand has rows (`matNxR * matCxN`). /// /// In all the above `Multiply` cases, the byte widths of the underlying scalar /// types of both operands must be the same. /// /// Note that `Multiply` supports mixed vector and scalar operations directly, /// whereas the other arithmetic operations require an explicit [`Splat`] for /// mixed-type use. /// /// [`Scalar`]: TypeInner::Scalar /// [`Vector`]: TypeInner::Vector /// [`Matrix`]: TypeInner::Matrix /// [`Float`]: ScalarKind::Float /// [`Bool`]: ScalarKind::Bool /// [`Splat`]: Expression::Splat #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum BinaryOperator { Add, Subtract, Multiply, Divide, /// Equivalent of the WGSL's `%` operator or SPIR-V's `OpFRem` Modulo, Equal, NotEqual, Less, LessEqual, Greater, GreaterEqual, And, ExclusiveOr, InclusiveOr, LogicalAnd, LogicalOr, ShiftLeft, /// Right shift carries the sign of signed integers only. ShiftRight, } /// Function on an atomic value. /// /// Note: these do not include load/store, which use the existing /// [`Expression::Load`] and [`Statement::Store`]. /// /// All `Handle` values here refer to an expression in /// [`Function::expressions`]. #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum AtomicFunction { Add, Subtract, And, ExclusiveOr, InclusiveOr, Min, Max, Exchange { compare: Option> }, } /// Hint at which precision to compute a derivative. #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum DerivativeControl { Coarse, Fine, None, } /// Axis on which to compute a derivative. #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum DerivativeAxis { X, Y, Width, } /// Built-in shader function for testing relation between values. #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum RelationalFunction { All, Any, IsNan, IsInf, } /// Built-in shader function for math. #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum MathFunction { // comparison Abs, Min, Max, Clamp, Saturate, // trigonometry Cos, Cosh, Sin, Sinh, Tan, Tanh, Acos, Asin, Atan, Atan2, Asinh, Acosh, Atanh, Radians, Degrees, // decomposition Ceil, Floor, Round, Fract, Trunc, Modf, Frexp, Ldexp, // exponent Exp, Exp2, Log, Log2, Pow, // geometry Dot, Dot4I8Packed, Dot4U8Packed, Outer, Cross, Distance, Length, Normalize, FaceForward, Reflect, Refract, // computational Sign, Fma, Mix, Step, SmoothStep, Sqrt, InverseSqrt, Inverse, Transpose, Determinant, QuantizeToF16, // bits CountTrailingZeros, CountLeadingZeros, CountOneBits, ReverseBits, ExtractBits, InsertBits, FirstTrailingBit, FirstLeadingBit, // data packing Pack4x8snorm, Pack4x8unorm, Pack2x16snorm, Pack2x16unorm, Pack2x16float, Pack4xI8, Pack4xU8, Pack4xI8Clamp, Pack4xU8Clamp, // data unpacking Unpack4x8snorm, Unpack4x8unorm, Unpack2x16snorm, Unpack2x16unorm, Unpack2x16float, Unpack4xI8, Unpack4xU8, } /// Sampling modifier to control the level of detail. /// /// All `Handle` values here refer to an expression in /// [`Function::expressions`]. #[derive(Clone, Copy, Debug, PartialEq)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum SampleLevel { Auto, Zero, Exact(Handle), Bias(Handle), Gradient { x: Handle, y: Handle, }, } /// Type of an image query. /// /// All `Handle` values here refer to an expression in /// [`Function::expressions`]. #[derive(Clone, Copy, Debug, PartialEq)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum ImageQuery { /// Get the size at the specified level. /// /// The return value is a `u32` for 1D images, and a `vecN` /// for an image with dimensions N > 2. Size { /// If `None`, the base level is considered. level: Option>, }, /// Get the number of mipmap levels, a `u32`. NumLevels, /// Get the number of array layers, a `u32`. NumLayers, /// Get the number of samples, a `u32`. NumSamples, } /// Component selection for a vector swizzle. #[repr(u8)] #[derive(Clone, Copy, Debug, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum SwizzleComponent { X = 0, Y = 1, Z = 2, W = 3, } /// The specific behavior of a [`SubgroupGather`] statement. /// /// All `Handle` values here refer to an expression in /// [`Function::expressions`]. /// /// [`SubgroupGather`]: Statement::SubgroupGather #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum GatherMode { /// All gather from the active lane with the smallest index BroadcastFirst, /// All gather from the same lane at the index given by the expression Broadcast(Handle), /// Each gathers from a different lane at the index given by the expression Shuffle(Handle), /// Each gathers from their lane plus the shift given by the expression ShuffleDown(Handle), /// Each gathers from their lane minus the shift given by the expression ShuffleUp(Handle), /// Each gathers from their lane xored with the given by the expression ShuffleXor(Handle), /// All gather from the same quad lane at the index given by the expression QuadBroadcast(Handle), /// Each gathers from the opposite quad lane along the given direction QuadSwap(Direction), } #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum Direction { X = 0, Y = 1, Diagonal = 2, } #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum SubgroupOperation { All = 0, Any = 1, Add = 2, Mul = 3, Min = 4, Max = 5, And = 6, Or = 7, Xor = 8, } #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum CollectiveOperation { Reduce = 0, InclusiveScan = 1, ExclusiveScan = 2, } bitflags::bitflags! { /// Memory barrier flags. #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)] pub struct Barrier: u32 { /// Barrier affects all [`AddressSpace::Storage`] accesses. const STORAGE = 1 << 0; /// Barrier affects all [`AddressSpace::WorkGroup`] accesses. const WORK_GROUP = 1 << 1; /// Barrier synchronizes execution across all invocations within a subgroup that execute this instruction. const SUB_GROUP = 1 << 2; /// Barrier synchronizes texture memory accesses in a workgroup. const TEXTURE = 1 << 3; } } #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub struct CooperativeData { pub pointer: Handle, pub stride: Handle, pub row_major: bool, } /// An expression that can be evaluated to obtain a value. /// /// This is a Single Static Assignment (SSA) scheme similar to SPIR-V. /// /// When an `Expression` variant holds `Handle` fields, they refer /// to another expression in the same arena, unless explicitly noted otherwise. /// One `Arena` may only refer to a different arena indirectly, via /// [`Constant`] or [`Override`] expressions, which hold handles for their /// respective types. /// /// [`Constant`]: Expression::Constant /// [`Override`]: Expression::Override #[derive(Clone, Debug, PartialEq)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum Expression { /// Literal. Literal(Literal), /// Constant value. Constant(Handle), /// Pipeline-overridable constant. Override(Handle), /// Zero value of a type. ZeroValue(Handle), /// Composite expression. Compose { ty: Handle, components: Vec>, }, /// Array access with a computed index. /// /// ## Typing rules /// /// The `base` operand must be some composite type: [`Vector`], [`Matrix`], /// [`Array`], a [`Pointer`] to one of those, or a [`ValuePointer`] with a /// `size`. /// /// The `index` operand must be an integer, signed or unsigned. /// /// Indexing a [`Vector`] or [`Array`] produces a value of its element type. /// Indexing a [`Matrix`] produces a [`Vector`]. /// /// Indexing a [`Pointer`] to any of the above produces a pointer to the /// element/component type, in the same [`space`]. In the case of [`Array`], /// the result is an actual [`Pointer`], but for vectors and matrices, there /// may not be any type in the arena representing the component's type, so /// those produce [`ValuePointer`] types equivalent to the appropriate /// [`Pointer`]. /// /// ## Dynamic indexing restrictions /// /// To accommodate restrictions in some of the shader languages that Naga /// targets, it is not permitted to subscript a matrix with a dynamically /// computed index unless that matrix appears behind a pointer. In other /// words, if the inner type of `base` is [`Matrix`], then `index` must be a /// constant. But if the type of `base` is a [`Pointer`] to an matrix, then /// the index may be any expression of integer type. /// /// You can use the [`Expression::is_dynamic_index`] method to determine /// whether a given index expression requires matrix base operands to be /// behind a pointer. /// /// (It would be simpler to always require the use of `AccessIndex` when /// subscripting matrices that are not behind pointers, but to accommodate /// existing front ends, Naga also permits `Access`, with a restricted /// `index`.) /// /// [`Vector`]: TypeInner::Vector /// [`Matrix`]: TypeInner::Matrix /// [`Array`]: TypeInner::Array /// [`Pointer`]: TypeInner::Pointer /// [`space`]: TypeInner::Pointer::space /// [`ValuePointer`]: TypeInner::ValuePointer /// [`Float`]: ScalarKind::Float Access { base: Handle, index: Handle, }, /// Access the same types as [`Access`], plus [`Struct`] with a known index. /// /// [`Access`]: Expression::Access /// [`Struct`]: TypeInner::Struct AccessIndex { base: Handle, index: u32, }, /// Splat scalar into a vector. Splat { size: VectorSize, value: Handle, }, /// Vector swizzle. Swizzle { size: VectorSize, vector: Handle, pattern: [SwizzleComponent; 4], }, /// Reference a function parameter, by its index. /// /// A `FunctionArgument` expression evaluates to the argument's value. FunctionArgument(u32), /// Reference a global variable. /// /// If the given `GlobalVariable`'s [`space`] is [`AddressSpace::Handle`], /// then the variable stores some opaque type like a sampler or an image, /// and a `GlobalVariable` expression referring to it produces the /// variable's value directly. /// /// For any other address space, a `GlobalVariable` expression produces a /// pointer to the variable's value. You must use a [`Load`] expression to /// retrieve its value, or a [`Store`] statement to assign it a new value. /// /// [`space`]: GlobalVariable::space /// [`Load`]: Expression::Load /// [`Store`]: Statement::Store GlobalVariable(Handle), /// Reference a local variable. /// /// A `LocalVariable` expression evaluates to a pointer to the variable's value. /// You must use a [`Load`](Expression::Load) expression to retrieve its value, /// or a [`Store`](Statement::Store) statement to assign it a new value. LocalVariable(Handle), /// Load a value indirectly. /// /// For [`TypeInner::Atomic`] the result is a corresponding scalar. /// For other types behind the `pointer`, the result is `T`. Load { pointer: Handle }, /// Sample a point from a sampled or a depth image. ImageSample { image: Handle, sampler: Handle, /// If Some(), this operation is a gather operation /// on the selected component. gather: Option, coordinate: Handle, array_index: Option>, /// This must be a const-expression. offset: Option>, level: SampleLevel, depth_ref: Option>, /// Whether the sampling operation should clamp each component of /// `coordinate` to the range `[half_texel, 1 - half_texel]`, regardless /// of `sampler`. clamp_to_edge: bool, }, /// Load a texel from an image. /// /// For most images, this returns a four-element vector of the same /// [`ScalarKind`] as the image. If the format of the image does not have /// four components, default values are provided: the first three components /// (typically R, G, and B) default to zero, and the final component /// (typically alpha) defaults to one. /// /// However, if the image's [`class`] is [`Depth`], then this returns a /// [`Float`] scalar value. /// /// [`ScalarKind`]: ScalarKind /// [`class`]: TypeInner::Image::class /// [`Depth`]: ImageClass::Depth /// [`Float`]: ScalarKind::Float ImageLoad { /// The image to load a texel from. This must have type [`Image`]. (This /// will necessarily be a [`GlobalVariable`] or [`FunctionArgument`] /// expression, since no other expressions are allowed to have that /// type.) /// /// [`Image`]: TypeInner::Image /// [`GlobalVariable`]: Expression::GlobalVariable /// [`FunctionArgument`]: Expression::FunctionArgument image: Handle, /// The coordinate of the texel we wish to load. This must be a scalar /// for [`D1`] images, a [`Bi`] vector for [`D2`] images, and a [`Tri`] /// vector for [`D3`] images. (Array indices, sample indices, and /// explicit level-of-detail values are supplied separately.) Its /// component type must be [`Sint`]. /// /// [`D1`]: ImageDimension::D1 /// [`D2`]: ImageDimension::D2 /// [`D3`]: ImageDimension::D3 /// [`Bi`]: VectorSize::Bi /// [`Tri`]: VectorSize::Tri /// [`Sint`]: ScalarKind::Sint coordinate: Handle, /// The index into an arrayed image. If the [`arrayed`] flag in /// `image`'s type is `true`, then this must be `Some(expr)`, where /// `expr` is a [`Sint`] scalar. Otherwise, it must be `None`. /// /// [`arrayed`]: TypeInner::Image::arrayed /// [`Sint`]: ScalarKind::Sint array_index: Option>, /// A sample index, for multisampled [`Sampled`] and [`Depth`] images. /// /// [`Sampled`]: ImageClass::Sampled /// [`Depth`]: ImageClass::Depth sample: Option>, /// A level of detail, for mipmapped images. /// /// This must be present when accessing non-multisampled /// [`Sampled`] and [`Depth`] images, even if only the /// full-resolution level is present (in which case the only /// valid level is zero). /// /// [`Sampled`]: ImageClass::Sampled /// [`Depth`]: ImageClass::Depth level: Option>, }, /// Query information from an image. ImageQuery { image: Handle, query: ImageQuery, }, /// Apply an unary operator. Unary { op: UnaryOperator, expr: Handle, }, /// Apply a binary operator. Binary { op: BinaryOperator, left: Handle, right: Handle, }, /// Select between two values based on a condition. /// /// Note that, because expressions have no side effects, it is unobservable /// whether the non-selected branch is evaluated. Select { /// Boolean expression condition: Handle, accept: Handle, reject: Handle, }, /// Compute the derivative on an axis. Derivative { axis: DerivativeAxis, ctrl: DerivativeControl, expr: Handle, }, /// Call a relational function. Relational { fun: RelationalFunction, argument: Handle, }, /// Call a math function Math { fun: MathFunction, arg: Handle, arg1: Option>, arg2: Option>, arg3: Option>, }, /// Cast a simple type to another kind. As { /// Source expression, which can only be a scalar or a vector. expr: Handle, /// Target scalar kind. kind: ScalarKind, /// If provided, converts to the specified byte width. /// Otherwise, bitcast. convert: Option, }, /// Result of calling another function. CallResult(Handle), /// Result of an atomic operation. /// /// This expression must be referred to by the [`result`] field of exactly one /// [`Atomic`][stmt] statement somewhere in the same function. Let `T` be the /// scalar type contained by the [`Atomic`][type] value that the statement /// operates on. /// /// If `comparison` is `false`, then `ty` must be the scalar type `T`. /// /// If `comparison` is `true`, then `ty` must be a [`Struct`] with two members: /// /// - A member named `old_value`, whose type is `T`, and /// /// - A member named `exchanged`, of type [`BOOL`]. /// /// [`result`]: Statement::Atomic::result /// [stmt]: Statement::Atomic /// [type]: TypeInner::Atomic /// [`Struct`]: TypeInner::Struct /// [`BOOL`]: Scalar::BOOL AtomicResult { ty: Handle, comparison: bool }, /// Result of a [`WorkGroupUniformLoad`] statement. /// /// [`WorkGroupUniformLoad`]: Statement::WorkGroupUniformLoad WorkGroupUniformLoadResult { /// The type of the result ty: Handle, }, /// Get the length of an array. /// The expression must resolve to a pointer to an array with a dynamic size. /// /// This doesn't match the semantics of spirv's `OpArrayLength`, which must be passed /// a pointer to a structure containing a runtime array in its' last field. ArrayLength(Handle), /// Get the Positions of the triangle hit by the [`RayQuery`] /// /// [`RayQuery`]: Statement::RayQuery RayQueryVertexPositions { query: Handle, committed: bool, }, /// Result of a [`Proceed`] [`RayQuery`] statement. /// /// [`Proceed`]: RayQueryFunction::Proceed /// [`RayQuery`]: Statement::RayQuery RayQueryProceedResult, /// Return an intersection found by `query`. /// /// If `committed` is true, return the committed result available when RayQueryGetIntersection { query: Handle, committed: bool, }, /// Result of a [`SubgroupBallot`] statement. /// /// [`SubgroupBallot`]: Statement::SubgroupBallot SubgroupBallotResult, /// Result of a [`SubgroupCollectiveOperation`] or [`SubgroupGather`] statement. /// /// [`SubgroupCollectiveOperation`]: Statement::SubgroupCollectiveOperation /// [`SubgroupGather`]: Statement::SubgroupGather SubgroupOperationResult { ty: Handle }, /// Load a cooperative primitive from memory. CooperativeLoad { columns: CooperativeSize, rows: CooperativeSize, role: CooperativeRole, data: CooperativeData, }, /// Compute `a * b + c` CooperativeMultiplyAdd { a: Handle, b: Handle, c: Handle, }, } /// The value of the switch case. #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum SwitchValue { I32(i32), U32(u32), Default, } /// A case for a switch statement. // Clone is used only for error reporting and is not intended for end users #[derive(Clone, Debug)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub struct SwitchCase { /// Value, upon which the case is considered true. pub value: SwitchValue, /// Body of the case. pub body: Block, /// If true, the control flow continues to the next case in the list, /// or default. pub fall_through: bool, } /// An operation that a [`RayQuery` statement] applies to its [`query`] operand. /// /// [`RayQuery` statement]: Statement::RayQuery /// [`query`]: Statement::RayQuery::query #[derive(Clone, Debug)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum RayQueryFunction { /// Initialize the `RayQuery` object. Initialize { /// The acceleration structure within which this query should search for hits. /// /// The expression must be an [`AccelerationStructure`]. /// /// [`AccelerationStructure`]: TypeInner::AccelerationStructure acceleration_structure: Handle, #[allow(rustdoc::private_intra_doc_links)] /// A struct of detailed parameters for the ray query. /// /// This expression should have the struct type given in /// [`SpecialTypes::ray_desc`]. This is available in the WGSL /// front end as the `RayDesc` type. descriptor: Handle, }, /// Start or continue the query given by the statement's [`query`] operand. /// /// After executing this statement, the `result` expression is a /// [`Bool`] scalar indicating whether there are more intersection /// candidates to consider. /// /// [`query`]: Statement::RayQuery::query /// [`Bool`]: ScalarKind::Bool Proceed { result: Handle, }, /// Add a candidate generated intersection to be included /// in the determination of the closest hit for a ray query. GenerateIntersection { hit_t: Handle, }, /// Confirm a triangle intersection to be included in the determination of /// the closest hit for a ray query. ConfirmIntersection, Terminate, } //TODO: consider removing `Clone`. It's not valid to clone `Statement::Emit` anyway. /// Instructions which make up an executable block. /// /// `Handle` and `Range` values in `Statement` variants /// refer to expressions in [`Function::expressions`], unless otherwise noted. // Clone is used only for error reporting and is not intended for end users #[derive(Clone, Debug)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum Statement { /// Emit a range of expressions, visible to all statements that follow in this block. /// /// See the [module-level documentation][emit] for details. /// /// [emit]: index.html#expression-evaluation-time Emit(Range), /// A block containing more statements, to be executed sequentially. Block(Block), /// Conditionally executes one of two blocks, based on the value of the condition. /// /// Naga IR does not have "phi" instructions. If you need to use /// values computed in an `accept` or `reject` block after the `If`, /// store them in a [`LocalVariable`]. If { condition: Handle, //bool accept: Block, reject: Block, }, /// Conditionally executes one of multiple blocks, based on the value of the selector. /// /// Each case must have a distinct [`value`], exactly one of which must be /// [`Default`]. The `Default` may appear at any position, and covers all /// values not explicitly appearing in other cases. A `Default` appearing in /// the midst of the list of cases does not shadow the cases that follow. /// /// Some backend languages don't support fallthrough (HLSL due to FXC, /// WGSL), and may translate fallthrough cases in the IR by duplicating /// code. However, all backend languages do support cases selected by /// multiple values, like `case 1: case 2: case 3: { ... }`. This is /// represented in the IR as a series of fallthrough cases with empty /// bodies, except for the last. /// /// Naga IR does not have "phi" instructions. If you need to use /// values computed in a [`SwitchCase::body`] block after the `Switch`, /// store them in a [`LocalVariable`]. /// /// [`value`]: SwitchCase::value /// [`body`]: SwitchCase::body /// [`Default`]: SwitchValue::Default Switch { selector: Handle, cases: Vec, }, /// Executes a block repeatedly. /// /// Each iteration of the loop executes the `body` block, followed by the /// `continuing` block. /// /// Executing a [`Break`], [`Return`] or [`Kill`] statement exits the loop. /// /// A [`Continue`] statement in `body` jumps to the `continuing` block. The /// `continuing` block is meant to be used to represent structures like the /// third expression of a C-style `for` loop head, to which `continue` /// statements in the loop's body jump. /// /// The `continuing` block and its substatements must not contain `Return` /// or `Kill` statements, or any `Break` or `Continue` statements targeting /// this loop. (It may have `Break` and `Continue` statements targeting /// loops or switches nested within the `continuing` block.) Expressions /// emitted in `body` are in scope in `continuing`. /// /// If present, `break_if` is an expression which is evaluated after the /// continuing block. Expressions emitted in `body` or `continuing` are /// considered to be in scope. If the expression's value is true, control /// continues after the `Loop` statement, rather than branching back to the /// top of body as usual. The `break_if` expression corresponds to a "break /// if" statement in WGSL, or a loop whose back edge is an /// `OpBranchConditional` instruction in SPIR-V. /// /// Naga IR does not have "phi" instructions. If you need to use /// values computed in a `body` or `continuing` block after the /// `Loop`, store them in a [`LocalVariable`]. /// /// [`Break`]: Statement::Break /// [`Continue`]: Statement::Continue /// [`Kill`]: Statement::Kill /// [`Return`]: Statement::Return /// [`break if`]: Self::Loop::break_if Loop { body: Block, continuing: Block, break_if: Option>, }, /// Exits the innermost enclosing [`Loop`] or [`Switch`]. /// /// A `Break` statement may only appear within a [`Loop`] or [`Switch`] /// statement. It may not break out of a [`Loop`] from within the loop's /// `continuing` block. /// /// [`Loop`]: Statement::Loop /// [`Switch`]: Statement::Switch Break, /// Skips to the `continuing` block of the innermost enclosing [`Loop`]. /// /// A `Continue` statement may only appear within the `body` block of the /// innermost enclosing [`Loop`] statement. It must not appear within that /// loop's `continuing` block. /// /// [`Loop`]: Statement::Loop Continue, /// Returns from the function (possibly with a value). /// /// `Return` statements are forbidden within the `continuing` block of a /// [`Loop`] statement. /// /// [`Loop`]: Statement::Loop Return { value: Option> }, /// Aborts the current shader execution. /// /// `Kill` statements are forbidden within the `continuing` block of a /// [`Loop`] statement. /// /// [`Loop`]: Statement::Loop Kill, /// Synchronize invocations within the work group. /// The `Barrier` flags control which memory accesses should be synchronized. /// If empty, this becomes purely an execution barrier. ControlBarrier(Barrier), /// Synchronize invocations within the work group. /// The `Barrier` flags control which memory accesses should be synchronized. MemoryBarrier(Barrier), /// Stores a value at an address. /// /// For [`TypeInner::Atomic`] type behind the pointer, the value /// has to be a corresponding scalar. /// For other types behind the `pointer`, the value is `T`. /// /// This statement is a barrier for any operations on the /// `Expression::LocalVariable` or `Expression::GlobalVariable` /// that is the destination of an access chain, started /// from the `pointer`. Store { pointer: Handle, value: Handle, }, /// Stores a texel value to an image. /// /// The `image`, `coordinate`, and `array_index` fields have the same /// meanings as the corresponding operands of an [`ImageLoad`] expression; /// see that documentation for details. Storing into multisampled images or /// images with mipmaps is not supported, so there are no `level` or /// `sample` operands. /// /// This statement is a barrier for any operations on the corresponding /// [`Expression::GlobalVariable`] for this image. /// /// [`ImageLoad`]: Expression::ImageLoad ImageStore { image: Handle, coordinate: Handle, array_index: Option>, value: Handle, }, /// Atomic function. Atomic { /// Pointer to an atomic value. /// /// This must be a [`Pointer`] to an [`Atomic`] value. The atomic's /// scalar type may be [`I32`] or [`U32`]. /// /// If [`SHADER_INT64_ATOMIC_MIN_MAX`] or [`SHADER_INT64_ATOMIC_ALL_OPS`] are /// enabled, this may also be [`I64`] or [`U64`]. /// /// If [`SHADER_FLOAT32_ATOMIC`] is enabled, this may be [`F32`]. /// /// [`Pointer`]: TypeInner::Pointer /// [`Atomic`]: TypeInner::Atomic /// [`I32`]: Scalar::I32 /// [`U32`]: Scalar::U32 /// [`SHADER_INT64_ATOMIC_MIN_MAX`]: crate::valid::Capabilities::SHADER_INT64_ATOMIC_MIN_MAX /// [`SHADER_INT64_ATOMIC_ALL_OPS`]: crate::valid::Capabilities::SHADER_INT64_ATOMIC_ALL_OPS /// [`SHADER_FLOAT32_ATOMIC`]: crate::valid::Capabilities::SHADER_FLOAT32_ATOMIC /// [`I64`]: Scalar::I64 /// [`U64`]: Scalar::U64 /// [`F32`]: Scalar::F32 pointer: Handle, /// Function to run on the atomic value. /// /// If [`pointer`] refers to a 64-bit atomic value, then: /// /// - The [`SHADER_INT64_ATOMIC_ALL_OPS`] capability allows any [`AtomicFunction`] /// value here. /// /// - The [`SHADER_INT64_ATOMIC_MIN_MAX`] capability allows /// [`AtomicFunction::Min`] and [`AtomicFunction::Max`] /// in the [`Storage`] address space here. /// /// - If neither of those capabilities are present, then 64-bit scalar /// atomics are not allowed. /// /// If [`pointer`] refers to a 32-bit floating-point atomic value, then: /// /// - The [`SHADER_FLOAT32_ATOMIC`] capability allows [`AtomicFunction::Add`], /// [`AtomicFunction::Subtract`], and [`AtomicFunction::Exchange { compare: None }`] /// in the [`Storage`] address space here. /// /// [`AtomicFunction::Exchange { compare: None }`]: AtomicFunction::Exchange /// [`pointer`]: Statement::Atomic::pointer /// [`Storage`]: AddressSpace::Storage /// [`SHADER_INT64_ATOMIC_MIN_MAX`]: crate::valid::Capabilities::SHADER_INT64_ATOMIC_MIN_MAX /// [`SHADER_INT64_ATOMIC_ALL_OPS`]: crate::valid::Capabilities::SHADER_INT64_ATOMIC_ALL_OPS /// [`SHADER_FLOAT32_ATOMIC`]: crate::valid::Capabilities::SHADER_FLOAT32_ATOMIC fun: AtomicFunction, /// Value to use in the function. /// /// This must be a scalar of the same type as [`pointer`]'s atomic's scalar type. /// /// [`pointer`]: Statement::Atomic::pointer value: Handle, /// [`AtomicResult`] expression representing this function's result. /// /// If [`fun`] is [`Exchange { compare: None }`], this must be `Some`, /// as otherwise that operation would be equivalent to a simple [`Store`] /// to the atomic. /// /// Otherwise, this may be `None` if the return value of the operation is not needed. /// /// If `pointer` refers to a 64-bit atomic value, [`SHADER_INT64_ATOMIC_MIN_MAX`] /// is enabled, and [`SHADER_INT64_ATOMIC_ALL_OPS`] is not, this must be `None`. /// /// [`AtomicResult`]: crate::Expression::AtomicResult /// [`fun`]: Statement::Atomic::fun /// [`Store`]: Statement::Store /// [`Exchange { compare: None }`]: AtomicFunction::Exchange /// [`SHADER_INT64_ATOMIC_MIN_MAX`]: crate::valid::Capabilities::SHADER_INT64_ATOMIC_MIN_MAX /// [`SHADER_INT64_ATOMIC_ALL_OPS`]: crate::valid::Capabilities::SHADER_INT64_ATOMIC_ALL_OPS result: Option>, }, /// Performs an atomic operation on a texel value of an image. /// /// Doing atomics on images with mipmaps is not supported, so there is no /// `level` operand. ImageAtomic { /// The image to perform an atomic operation on. This must have type /// [`Image`]. (This will necessarily be a [`GlobalVariable`] or /// [`FunctionArgument`] expression, since no other expressions are /// allowed to have that type.) /// /// [`Image`]: TypeInner::Image /// [`GlobalVariable`]: Expression::GlobalVariable /// [`FunctionArgument`]: Expression::FunctionArgument image: Handle, /// The coordinate of the texel we wish to load. This must be a scalar /// for [`D1`] images, a [`Bi`] vector for [`D2`] images, and a [`Tri`] /// vector for [`D3`] images. (Array indices, sample indices, and /// explicit level-of-detail values are supplied separately.) Its /// component type must be [`Sint`]. /// /// [`D1`]: ImageDimension::D1 /// [`D2`]: ImageDimension::D2 /// [`D3`]: ImageDimension::D3 /// [`Bi`]: VectorSize::Bi /// [`Tri`]: VectorSize::Tri /// [`Sint`]: ScalarKind::Sint coordinate: Handle, /// The index into an arrayed image. If the [`arrayed`] flag in /// `image`'s type is `true`, then this must be `Some(expr)`, where /// `expr` is a [`Sint`] scalar. Otherwise, it must be `None`. /// /// [`arrayed`]: TypeInner::Image::arrayed /// [`Sint`]: ScalarKind::Sint array_index: Option>, /// The kind of atomic operation to perform on the texel. fun: AtomicFunction, /// The value with which to perform the atomic operation. value: Handle, }, /// Load uniformly from a uniform pointer in the workgroup address space. /// /// Corresponds to the [`workgroupUniformLoad`](https://www.w3.org/TR/WGSL/#workgroupUniformLoad-builtin) /// built-in function of wgsl, and has the same barrier semantics WorkGroupUniformLoad { /// This must be of type [`Pointer`] in the [`WorkGroup`] address space /// /// [`Pointer`]: TypeInner::Pointer /// [`WorkGroup`]: AddressSpace::WorkGroup pointer: Handle, /// The [`WorkGroupUniformLoadResult`] expression representing this load's result. /// /// [`WorkGroupUniformLoadResult`]: Expression::WorkGroupUniformLoadResult result: Handle, }, /// Calls a function. /// /// If the `result` is `Some`, the corresponding expression has to be /// `Expression::CallResult`, and this statement serves as a barrier for any /// operations on that expression. Call { function: Handle, arguments: Vec>, result: Option>, }, RayQuery { /// The [`RayQuery`] object this statement operates on. /// /// [`RayQuery`]: TypeInner::RayQuery query: Handle, /// The specific operation we're performing on `query`. fun: RayQueryFunction, }, /// A ray tracing pipeline shader intrinsic. RayPipelineFunction(RayPipelineFunction), /// Calculate a bitmask using a boolean from each active thread in the subgroup SubgroupBallot { /// The [`SubgroupBallotResult`] expression representing this load's result. /// /// [`SubgroupBallotResult`]: Expression::SubgroupBallotResult result: Handle, /// The value from this thread to store in the ballot predicate: Option>, }, /// Gather a value from another active thread in the subgroup SubgroupGather { /// Specifies which thread to gather from mode: GatherMode, /// The value to broadcast over argument: Handle, /// The [`SubgroupOperationResult`] expression representing this load's result. /// /// [`SubgroupOperationResult`]: Expression::SubgroupOperationResult result: Handle, }, /// Compute a collective operation across all active threads in the subgroup SubgroupCollectiveOperation { /// What operation to compute op: SubgroupOperation, /// How to combine the results collective_op: CollectiveOperation, /// The value to compute over argument: Handle, /// The [`SubgroupOperationResult`] expression representing this load's result. /// /// [`SubgroupOperationResult`]: Expression::SubgroupOperationResult result: Handle, }, /// Store a cooperative primitive into memory. CooperativeStore { target: Handle, data: CooperativeData, }, } /// A function argument. #[derive(Clone, Debug)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub struct FunctionArgument { /// Name of the argument, if any. pub name: Option, /// Type of the argument. pub ty: Handle, /// For entry points, an argument has to have a binding /// unless it's a structure. pub binding: Option, } /// A function result. #[derive(Clone, Debug)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub struct FunctionResult { /// Type of the result. pub ty: Handle, /// For entry points, the result has to have a binding /// unless it's a structure. pub binding: Option, } /// A function defined in the module. #[derive(Debug, Default, Clone)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub struct Function { /// Name of the function, if any. pub name: Option, /// Information about function argument. pub arguments: Vec, /// The result of this function, if any. pub result: Option, /// Local variables defined and used in the function. pub local_variables: Arena, /// Expressions used inside this function. /// /// Unless explicitly stated otherwise, if an [`Expression`] is in this /// arena, then its subexpressions are in this arena too. In other words, /// every `Handle` in this arena refers to an [`Expression`] in /// this arena too. /// /// The main ways this arena refers to [`Module::global_expressions`] are: /// /// - [`Constant`], [`Override`], and [`GlobalVariable`] expressions hold /// handles for their respective types, whose initializer expressions are /// in [`Module::global_expressions`]. /// /// - Various expressions hold [`Type`] handles, and [`Type`]s may refer to /// global expressions, for things like array lengths. /// /// An [`Expression`] must occur before all other [`Expression`]s that use /// its value. /// /// [`Constant`]: Expression::Constant /// [`Override`]: Expression::Override /// [`GlobalVariable`]: Expression::GlobalVariable pub expressions: Arena, /// Map of expressions that have associated variable names pub named_expressions: NamedExpressions, /// Block of instructions comprising the body of the function. pub body: Block, /// The leaf of all diagnostic filter rules tree (stored in [`Module::diagnostic_filters`]) /// parsed on this function. /// /// In WGSL, this corresponds to `@diagnostic(…)` attributes. /// /// See [`DiagnosticFilterNode`] for details on how the tree is represented and used in /// validation. pub diagnostic_filter_leaf: Option>, } /// The main function for a pipeline stage. /// /// An [`EntryPoint`] is a [`Function`] that serves as the main function for a /// graphics or compute pipeline stage. For example, an `EntryPoint` whose /// [`stage`] is [`ShaderStage::Vertex`] can serve as a graphics pipeline's /// vertex shader. /// /// Since an entry point is called directly by the graphics or compute pipeline, /// not by other WGSL functions, you must specify what the pipeline should pass /// as the entry point's arguments, and what values it will return. For example, /// a vertex shader needs a vertex's attributes as its arguments, but if it's /// used for instanced draw calls, it will also want to know the instance id. /// The vertex shader's return value will usually include an output vertex /// position, and possibly other attributes to be interpolated and passed along /// to a fragment shader. /// /// To specify this, the arguments and result of an `EntryPoint`'s [`function`] /// must each have a [`Binding`], or be structs whose members all have /// `Binding`s. This associates every value passed to or returned from the entry /// point with either a [`BuiltIn`] or a [`Location`]: /// /// - A [`BuiltIn`] has special semantics, usually specific to its pipeline /// stage. For example, the result of a vertex shader can include a /// [`BuiltIn::Position`] value, which determines the position of a vertex /// of a rendered primitive. Or, a compute shader might take an argument /// whose binding is [`BuiltIn::WorkGroupSize`], through which the compute /// pipeline would pass the number of invocations in your workgroup. /// /// - A [`Location`] indicates user-defined IO to be passed from one pipeline /// stage to the next. For example, a vertex shader might also produce a /// `uv` texture location as a user-defined IO value. /// /// In other words, the pipeline stage's input and output interface are /// determined by the bindings of the arguments and result of the `EntryPoint`'s /// [`function`]. /// /// [`Function`]: crate::Function /// [`Location`]: Binding::Location /// [`function`]: EntryPoint::function /// [`stage`]: EntryPoint::stage #[derive(Debug, Clone)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub struct EntryPoint { /// Name of this entry point, visible externally. /// /// Entry point names for a given `stage` must be distinct within a module. pub name: String, /// Shader stage. pub stage: ShaderStage, /// Early depth test for fragment stages. pub early_depth_test: Option, /// Workgroup size for compute stages pub workgroup_size: [u32; 3], /// Override expressions for workgroup size in the global_expressions arena pub workgroup_size_overrides: Option<[Option>; 3]>, /// The entrance function. pub function: Function, /// Information for [`Mesh`] shaders. /// /// [`Mesh`]: ShaderStage::Mesh pub mesh_info: Option, /// The unique global variable used as a task payload from task shader to mesh shader pub task_payload: Option>, /// The unique global variable used as an incoming ray payload going into any hit, closest hit and miss shaders. /// Unlike the outgoing ray payload, an incoming ray payload must be unique pub incoming_ray_payload: Option>, } /// Return types predeclared for the frexp, modf, and atomicCompareExchangeWeak built-in functions. /// /// These cannot be spelled in WGSL source. /// /// Stored in [`SpecialTypes::predeclared_types`] and created by [`Module::generate_predeclared_type`]. #[derive(Debug, PartialEq, Eq, Hash, Clone)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum PredeclaredType { AtomicCompareExchangeWeakResult(Scalar), ModfResult { size: Option, scalar: Scalar, }, FrexpResult { size: Option, scalar: Scalar, }, } /// Set of special types that can be optionally generated by the frontends. #[derive(Debug, Default, Clone)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub struct SpecialTypes { /// Type for `RayDesc`. /// /// Call [`Module::generate_ray_desc_type`] to populate this if /// needed and return the handle. pub ray_desc: Option>, /// Type for `RayIntersection`. /// /// Call [`Module::generate_ray_intersection_type`] to populate /// this if needed and return the handle. pub ray_intersection: Option>, /// Type for `RayVertexReturn`. /// /// Call [`Module::generate_vertex_return_type`] pub ray_vertex_return: Option>, /// Struct containing parameters required by some backends to emit code for /// [`ImageClass::External`] textures. /// /// See `wgpu_core::device::resource::ExternalTextureParams` for the /// documentation of each field. /// /// In WGSL, this type would be: /// /// ```ignore /// struct NagaExternalTextureParams { // align size offset /// yuv_conversion_matrix: mat4x4, // 16 64 0 /// gamut_conversion_matrix: mat3x3, // 16 48 64 /// src_tf: NagaExternalTextureTransferFn, // 4 16 112 /// dst_tf: NagaExternalTextureTransferFn, // 4 16 128 /// sample_transform: mat3x2, // 8 24 144 /// load_transform: mat3x2, // 8 24 168 /// size: vec2, // 8 8 192 /// num_planes: u32, // 4 4 200 /// } // whole struct: 16 208 /// ``` /// /// Call [`Module::generate_external_texture_types`] to populate this if /// needed. pub external_texture_params: Option>, /// Struct describing a gamma encoding transfer function. Member of /// `NagaExternalTextureParams`, describing how the backend should perform /// color space conversion when sampling from [`ImageClass::External`] /// textures. /// /// In WGSL, this type would be: /// /// ```ignore /// struct NagaExternalTextureTransferFn { // align size offset /// a: f32, // 4 4 0 /// b: f32, // 4 4 4 /// g: f32, // 4 4 8 /// k: f32, // 4 4 12 /// } // whole struct: 4 16 /// ``` /// /// Call [`Module::generate_external_texture_types`] to populate this if /// needed. pub external_texture_transfer_function: Option>, /// Types for predeclared wgsl types instantiated on demand. /// /// Call [`Module::generate_predeclared_type`] to populate this if /// needed and return the handle. pub predeclared_types: FastIndexMap>, } bitflags::bitflags! { /// Ray flags used when casting rays. /// Matching vulkan constants can be found in /// https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/KHR/ray_common/ray_flags_section.txt #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] #[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct RayFlag: u32 { /// Force all intersections to be treated as opaque. const FORCE_OPAQUE = 0x1; /// Force all intersections to be treated as non-opaque. const FORCE_NO_OPAQUE = 0x2; /// Stop traversal after the first hit. const TERMINATE_ON_FIRST_HIT = 0x4; /// Don't execute the closest hit shader. const SKIP_CLOSEST_HIT_SHADER = 0x8; /// Cull back facing geometry. const CULL_BACK_FACING = 0x10; /// Cull front facing geometry. const CULL_FRONT_FACING = 0x20; /// Cull opaque geometry. const CULL_OPAQUE = 0x40; /// Cull non-opaque geometry. const CULL_NO_OPAQUE = 0x80; /// Skip triangular geometry. const SKIP_TRIANGLES = 0x100; /// Skip axis-aligned bounding boxes. const SKIP_AABBS = 0x200; } } /// Type of a ray query intersection. /// Matching vulkan constants can be found in /// /// but the actual values are different for candidate intersections. #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] #[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] pub enum RayQueryIntersection { /// No intersection found. /// Matches `RayQueryCommittedIntersectionNoneKHR`. #[default] None = 0, /// Intersecting with triangles. /// Matches `RayQueryCommittedIntersectionTriangleKHR` and `RayQueryCandidateIntersectionTriangleKHR`. Triangle = 1, /// Intersecting with generated primitives. /// Matches `RayQueryCommittedIntersectionGeneratedKHR`. Generated = 2, /// Intersecting with Axis Aligned Bounding Boxes. /// Matches `RayQueryCandidateIntersectionAABBKHR`. Aabb = 3, } /// Doc comments preceding items. /// /// These can be used to generate automated documentation, /// IDE hover information or translate shaders with their context comments. #[derive(Debug, Default, Clone)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub struct DocComments { pub types: FastIndexMap, Vec>, // The key is: // - key.0: the handle to the Struct // - key.1: the index of the `StructMember`. pub struct_members: FastIndexMap<(Handle, usize), Vec>, pub entry_points: FastIndexMap>, pub functions: FastIndexMap, Vec>, pub constants: FastIndexMap, Vec>, pub global_variables: FastIndexMap, Vec>, // Top level comments, appearing before any space. pub module: Vec, } /// The output topology for a mesh shader. Note that mesh shaders don't allow things like triangle-strips. #[derive(Debug, Clone, Copy, PartialEq, Eq)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum MeshOutputTopology { /// Outputs individual vertices to be rendered as points. Points, /// Outputs groups of 2 vertices to be renderedas lines . Lines, /// Outputs groups of 3 vertices to be rendered as triangles. Triangles, } /// Information specific to mesh shader entry points. #[derive(Debug, Clone, PartialEq, Eq)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] #[allow(dead_code)] pub struct MeshStageInfo { /// The type of primitive outputted. pub topology: MeshOutputTopology, /// The maximum number of vertices a mesh shader may output. pub max_vertices: u32, /// If pipeline constants are used, the expressions that override `max_vertices` pub max_vertices_override: Option>, /// The maximum number of primitives a mesh shader may output. pub max_primitives: u32, /// If pipeline constants are used, the expressions that override `max_primitives` pub max_primitives_override: Option>, /// The type used by vertex outputs, i.e. what is passed to `setVertex`. pub vertex_output_type: Handle, /// The type used by primitive outputs, i.e. what is passed to `setPrimitive`. pub primitive_output_type: Handle, /// The global variable holding the outputted vertices, primitives, and counts pub output_variable: Handle, } /// Ray tracing pipeline intrinsics #[derive(Debug, Clone, Copy)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum RayPipelineFunction { /// Traces a ray through the given acceleration structure TraceRay { /// The acceleration structure within which this ray should search for hits. /// /// The expression must be an [`AccelerationStructure`]. /// /// [`AccelerationStructure`]: TypeInner::AccelerationStructure acceleration_structure: Handle, #[allow(rustdoc::private_intra_doc_links)] /// A struct of detailed parameters for the ray query. /// /// This expression should have the struct type given in /// [`SpecialTypes::ray_desc`]. This is available in the WGSL /// front end as the `RayDesc` type. descriptor: Handle, /// A pointer in the ray_payload or incoming_ray_payload address spaces payload: Handle, // Do we want miss index? What about sbt offset and sbt stride (could be hard to validate)? // https://github.com/gfx-rs/wgpu/issues/8894 }, } /// Shader module. /// /// A module is a set of constants, global variables and functions, as well as /// the types required to define them. /// /// Some functions are marked as entry points, to be used in a certain shader stage. /// /// To create a new module, use the `Default` implementation. /// Alternatively, you can load an existing shader using one of the [available front ends]. /// /// When finished, you can export modules using one of the [available backends]. /// /// ## Module arenas /// /// Most module contents are stored in [`Arena`]s. In a valid module, arena /// elements only refer to prior arena elements. That is, whenever an element in /// some `Arena` contains a `Handle` referring to another element the same /// arena, the handle's referent always precedes the element containing the /// handle. /// /// The elements of [`Module::types`] may refer to [`Expression`]s in /// [`Module::global_expressions`], and those expressions may in turn refer back /// to [`Type`]s in [`Module::types`]. In a valid module, there exists an order /// in which all types and global expressions can be visited such that: /// /// - types and expressions are visited in the order in which they appear in /// their arenas, and /// /// - every element refers only to previously visited elements. /// /// This implies that the graph of types and global expressions is acyclic. /// (However, it is a stronger condition: there are cycle-free arrangements of /// types and expressions for which an order like the one described above does /// not exist. Modules arranged in such a way are not valid.) /// /// [available front ends]: crate::front /// [available backends]: crate::back #[derive(Debug, Default, Clone)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub struct Module { /// Arena for the types defined in this module. /// /// See the [`Module`] docs for more details about this field. pub types: UniqueArena, /// Dictionary of special type handles. pub special_types: SpecialTypes, /// Arena for the constants defined in this module. pub constants: Arena, /// Arena for the pipeline-overridable constants defined in this module. pub overrides: Arena, /// Arena for the global variables defined in this module. pub global_variables: Arena, /// [Constant expressions] and [override expressions] used by this module. /// /// If an expression is in this arena, then its subexpressions are in this /// arena too. In other words, every `Handle` in this arena /// refers to an [`Expression`] in this arena too. /// /// See the [`Module`] docs for more details about this field. /// /// [Constant expressions]: index.html#constant-expressions /// [override expressions]: index.html#override-expressions pub global_expressions: Arena, /// Arena for the functions defined in this module. /// /// Each function must appear in this arena strictly before all its callers. /// Recursion is not supported. pub functions: Arena, /// Entry points. pub entry_points: Vec, /// Arena for all diagnostic filter rules parsed in this module, including those in functions /// and statements. /// /// This arena contains elements of a _tree_ of diagnostic filter rules. When nodes are built /// by a front-end, they refer to a parent scope pub diagnostic_filters: Arena, /// The leaf of all diagnostic filter rules tree parsed from directives in this module. /// /// In WGSL, this corresponds to `diagnostic(…);` directives. /// /// See [`DiagnosticFilterNode`] for details on how the tree is represented and used in /// validation. pub diagnostic_filter_leaf: Option>, /// Doc comments. pub doc_comments: Option>, } naga-29.0.3/src/keywords/mod.rs000064400000000000000000000002611046102023000144070ustar 00000000000000/*! Lists of reserved keywords for each shading language with a [frontend][crate::front] or [backend][crate::back]. */ #[cfg(any(feature = "wgsl-in", wgsl_out))] pub mod wgsl; naga-29.0.3/src/keywords/wgsl.rs000064400000000000000000000204641046102023000146130ustar 00000000000000/*! Keywords for [WGSL][wgsl] (WebGPU Shading Language). [wgsl]: https://gpuweb.github.io/gpuweb/wgsl.html */ use crate::proc::KeywordSet; use crate::racy_lock::RacyLock; // last sync: https://www.w3.org/TR/2025/CRD-WGSL-20250809/#keyword-summary pub const RESERVED: &[&str] = &[ // Keywords "alias", "break", "case", "const", "const_assert", "continue", "continuing", "default", "diagnostic", "discard", "else", "enable", "false", "fn", "for", "if", "let", "loop", "override", "requires", "return", "struct", "switch", "true", "var", "while", // Reserved "NULL", "Self", "abstract", "active", "alignas", "alignof", "as", "asm", "asm_fragment", "async", "attribute", "auto", "await", "become", "cast", "catch", "class", "co_await", "co_return", "co_yield", "coherent", "column_major", "common", "compile", "compile_fragment", "concept", "const_cast", "consteval", "constexpr", "constinit", "crate", "debugger", "decltype", "delete", "demote", "demote_to_helper", "do", "dynamic_cast", "enum", "explicit", "export", "extends", "extern", "external", "fallthrough", "filter", "final", "finally", "friend", "from", "fxgroup", "get", "goto", "groupshared", "highp", "impl", "implements", "import", "inline", "instanceof", "interface", "layout", "lowp", "macro", "macro_rules", "match", "mediump", "meta", "mod", "module", "move", "mut", "mutable", "namespace", "new", "nil", "noexcept", "noinline", "nointerpolation", "non_coherent", "noncoherent", "noperspective", "null", "nullptr", "of", "operator", "package", "packoffset", "partition", "pass", "patch", "pixelfragment", "precise", "precision", "premerge", "priv", "protected", "pub", "public", "readonly", "ref", "regardless", "register", "reinterpret_cast", "require", "resource", "restrict", "self", "set", "shared", "sizeof", "smooth", "snorm", "static", "static_assert", "static_cast", "std", "subroutine", "super", "target", "template", "this", "thread_local", "throw", "trait", "try", "type", "typedef", "typeid", "typename", "typeof", "union", "unless", "unorm", "unsafe", "unsized", "use", "using", "varying", "virtual", "volatile", "wgsl", "where", "with", "writeonly", "yield", ]; /// The above set of reserved keywords, turned into a cached HashSet. This saves /// significant time during [`Namer::reset`](crate::proc::Namer::reset). /// /// See for benchmarks. pub static RESERVED_SET: RacyLock = RacyLock::new(|| KeywordSet::from_iter(RESERVED)); /// Shadowable words that the WGSL backend should avoid using for declarations. /// /// Includes: /// - [6.9. Predeclared Types and Type-Generators] /// - [6.3.1. Predeclared enumerants] /// - [17. Built-in Functions] /// /// This set must be separate from the [`RESERVED`] set above since the /// [`Namer`](crate::proc::Namer) must ignore these identifiers if they appear /// as struct member names. This is because this set contains `fract` and `exp` /// which are also names used by return types of the `frexp` and `modf` built-in functions. /// /// [6.9. Predeclared Types and Type-Generators]: https://www.w3.org/TR/WGSL/#predeclared-types /// [6.3.1. Predeclared enumerants]: https://www.w3.org/TR/WGSL/#predeclared-enumerants /// [17. Built-in Functions]: https://www.w3.org/TR/WGSL/#builtin-functions pub const BUILTIN_IDENTIFIERS: &[&str] = &[ // types "bool", "i32", "u32", "f32", "f16", "array", "atomic", "vec2", "vec3", "vec4", "mat2x2", "mat2x3", "mat2x4", "mat3x2", "mat3x3", "mat3x4", "mat4x2", "mat4x3", "mat4x4", "ptr", "sampler", "sampler_comparison", "texture_1d", "texture_2d", "texture_2d_array", "texture_3d", "texture_cube", "texture_cube_array", "texture_multisampled_2d", "texture_depth_multisampled_2d", "texture_external", "texture_storage_1d", "texture_storage_2d", "texture_storage_2d_array", "texture_storage_3d", "texture_depth_2d", "texture_depth_2d_array", "texture_depth_cube", "texture_depth_cube_array", // enumerants "read", "write", "read_write", "function", "private", "workgroup", "uniform", "storage", "rgba8unorm", "rgba8snorm", "rgba8uint", "rgba8sint", "rgba16unorm", "rgba16snorm", "rgba16uint", "rgba16sint", "rgba16float", "rg8unorm", "rg8snorm", "rg8uint", "rg8sint", "rg16unorm", "rg16snorm", "rg16uint", "rg16sint", "rg16float", "r32uint", "r32sint", "r32float", "rg32uint", "rg32sint", "rg32float", "rgba32uint", "rgba32sint", "rgba32float", "bgra8unorm", "r8unorm", "r8snorm", "r8uint", "r8sint", "r16unorm", "r16snorm", "r16uint", "r16sint", "r16float", "rgb10a2unorm", "rgb10a2uint", "rg11b10ufloat", // functions "bitcast", "all", "any", "select", "arrayLength", "abs", "acos", "acosh", "asin", "asinh", "atan", "atanh", "atan2", "ceil", "clamp", "cos", "cosh", "countLeadingZeros", "countOneBits", "countTrailingZeros", "cross", "degrees", "determinant", "distance", "dot", "dot4U8Packed", "dot4I8Packed", "exp", "exp2", "extractBits", "faceForward", "firstLeadingBit", "firstTrailingBit", "floor", "fma", "fract", "frexp", "insertBits", "inverseSqrt", "ldexp", "length", "log", "log2", "max", "min", "mix", "modf", "normalize", "pow", "quantizeToF16", "radians", "reflect", "refract", "reverseBits", "round", "saturate", "sign", "sin", "sinh", "smoothstep", "sqrt", "step", "tan", "tanh", "transpose", "trunc", "dpdx", "dpdxCoarse", "dpdxFine", "dpdy", "dpdyCoarse", "dpdyFine", "fwidth", "fwidthCoarse", "fwidthFine", "textureDimensions", "textureGather", "textureGatherCompare", "textureLoad", "textureNumLayers", "textureNumLevels", "textureNumSamples", "textureSample", "textureSampleBias", "textureSampleCompare", "textureSampleCompareLevel", "textureSampleGrad", "textureSampleLevel", "textureSampleBaseClampToEdge", "textureStore", "atomicLoad", "atomicStore", "atomicAdd", "atomicSub", "atomicMax", "atomicMin", "atomicAnd", "atomicOr", "atomicXor", "atomicExchange", "atomicCompareExchangeWeak", "pack4x8snorm", "pack4x8unorm", "pack4xI8", "pack4xU8", "pack4xI8Clamp", "pack4xU8Clamp", "pack2x16snorm", "pack2x16unorm", "pack2x16float", "unpack4x8snorm", "unpack4x8unorm", "unpack4xI8", "unpack4xU8", "unpack2x16snorm", "unpack2x16unorm", "unpack2x16float", "storageBarrier", "textureBarrier", "workgroupBarrier", "workgroupUniformLoad", "subgroupAdd", "subgroupExclusiveAdd", "subgroupInclusiveAdd", "subgroupAll", "subgroupAnd", "subgroupAny", "subgroupBallot", "subgroupBroadcast", "subgroupBroadcastFirst", "subgroupElect", "subgroupMax", "subgroupMin", "subgroupMul", "subgroupExclusiveMul", "subgroupInclusiveMul", "subgroupOr", "subgroupShuffle", "subgroupShuffleDown", "subgroupShuffleUp", "subgroupShuffleXor", "subgroupXor", "quadBroadcast", "quadSwapDiagonal", "quadSwapX", "quadSwapY", // not in the WGSL spec "i64", "u64", "f64", "push_constant", "r64uint", ]; pub static BUILTIN_IDENTIFIER_SET: RacyLock = RacyLock::new(|| KeywordSet::from_iter(BUILTIN_IDENTIFIERS)); naga-29.0.3/src/lib.rs000064400000000000000000000107121046102023000125310ustar 00000000000000/*! Naga can be used to translate source code written in one shading language to another. # Example The following example translates WGSL to GLSL. It requires the features `"wgsl-in"` and `"glsl-out"` to be enabled. */ // If we don't have the required front- and backends, don't try to build this example. #![cfg_attr(all(feature = "wgsl-in", feature = "glsl-out"), doc = "```")] #![cfg_attr(not(all(feature = "wgsl-in", feature = "glsl-out")), doc = "```ignore")] /*! let wgsl_source = " @fragment fn main_fs() -> @location(0) vec4 { return vec4(1.0, 1.0, 1.0, 1.0); } "; // Parse the source into a Module. let module: naga::Module = naga::front::wgsl::parse_str(wgsl_source)?; // Validate the module. // Validation can be made less restrictive by changing the ValidationFlags. let module_info: naga::valid::ModuleInfo = naga::valid::Validator::new( naga::valid::ValidationFlags::all(), naga::valid::Capabilities::all(), ) .subgroup_stages(naga::valid::ShaderStages::all()) .subgroup_operations(naga::valid::SubgroupOperationSet::all()) .validate(&module)?; // Translate the module. use naga::back::glsl; let mut glsl_source = String::new(); glsl::Writer::new( &mut glsl_source, &module, &module_info, &glsl::Options::default(), &glsl::PipelineOptions { entry_point: "main_fs".into(), shader_stage: naga::ShaderStage::Fragment, multiview: None, }, naga::proc::BoundsCheckPolicies::default(), )?.write()?; assert_eq!(glsl_source, "\ #version 310 es precision highp float; precision highp int; layout(location = 0) out vec4 _fs2p_location0; void main() { _fs2p_location0 = vec4(1.0, 1.0, 1.0, 1.0); return; } "); # Ok::<(), Box>(()) ``` */ #![allow( clippy::new_without_default, clippy::unneeded_field_pattern, clippy::match_like_matches_macro, clippy::collapsible_if, clippy::derive_partial_eq_without_eq, clippy::needless_borrowed_reference, clippy::single_match, clippy::enum_variant_names )] #![warn( trivial_casts, trivial_numeric_casts, unused_extern_crates, unused_qualifications, clippy::pattern_type_mismatch, clippy::missing_const_for_fn, clippy::rest_pat_in_fully_bound_structs, clippy::match_wildcard_for_single_variants )] #![deny(clippy::exit)] #![cfg_attr( not(test), warn( clippy::dbg_macro, clippy::panic, clippy::print_stderr, clippy::print_stdout, clippy::todo ) )] #![no_std] #![forbid(unsafe_code)] #[cfg(std)] extern crate std; extern crate alloc; mod arena; pub mod back; pub mod common; pub mod compact; pub mod diagnostic_filter; pub mod error; pub mod front; pub mod ir; pub mod keywords; mod non_max_u32; pub mod proc; mod racy_lock; mod span; pub mod valid; use alloc::string::String; pub use crate::arena::{Arena, Handle, Range, UniqueArena}; pub use crate::span::{SourceLocation, Span, SpanContext, WithSpan}; // TODO: Eliminate this re-export and migrate uses of `crate::Foo` to `use crate::ir; ir::Foo`. pub use ir::*; /// Width of a boolean type, in bytes. pub const BOOL_WIDTH: Bytes = 1; /// Width of abstract types, in bytes. pub const ABSTRACT_WIDTH: Bytes = 8; /// Hash map that is faster but not resilient to DoS attacks. /// (Similar to rustc_hash::FxHashMap but using hashbrown::HashMap instead of alloc::collections::HashMap.) /// To construct a new instance: `FastHashMap::default()` pub type FastHashMap = hashbrown::HashMap>; /// Hash set that is faster but not resilient to DoS attacks. /// (Similar to rustc_hash::FxHashSet but using hashbrown::HashSet instead of alloc::collections::HashMap.) pub type FastHashSet = hashbrown::HashSet>; /// Insertion-order-preserving hash set (`IndexSet`), but with the same /// hasher as `FastHashSet` (faster but not resilient to DoS attacks). pub type FastIndexSet = indexmap::IndexSet>; /// Insertion-order-preserving hash map (`IndexMap`), but with the same /// hasher as `FastHashMap` (faster but not resilient to DoS attacks). pub type FastIndexMap = indexmap::IndexMap>; /// Map of expressions that have associated variable names pub(crate) type NamedExpressions = FastIndexMap, String>; naga-29.0.3/src/non_max_u32.rs000064400000000000000000000114651046102023000141210ustar 00000000000000//! [`NonMaxU32`], a 32-bit type that can represent any value except [`u32::MAX`]. //! //! Naga would like `Option>` to be a 32-bit value, which means we //! need to exclude some index value for use in representing [`None`]. We could //! have [`Handle`] store a [`NonZeroU32`], but zero is a very useful value for //! indexing. We could have a [`Handle`] store a value one greater than its index, //! but it turns out that it's not uncommon to want to work with [`Handle`]s' //! indices, so that bias of 1 becomes more visible than one would like. //! //! This module defines the type [`NonMaxU32`], for which `Option` is //! still a 32-bit value, but which is directly usable as a [`Handle`] index //! type. It still uses a bias of 1 under the hood, but that fact is isolated //! within the implementation. //! //! [`Handle`]: crate::arena::Handle //! [`NonZeroU32`]: core::num::NonZeroU32 #![allow(dead_code)] use core::num::NonZeroU32; /// An unsigned 32-bit value known not to be [`u32::MAX`]. /// /// A `NonMaxU32` value can represent any value in the range `0 .. u32::MAX - /// 1`, and an `Option` is still a 32-bit value. In other words, /// `NonMaxU32` is just like [`NonZeroU32`], except that a different value is /// missing from the full `u32` range. /// /// Since zero is a very useful value in indexing, `NonMaxU32` is more useful /// for representing indices than [`NonZeroU32`]. /// /// `NonMaxU32` values and `Option` values both occupy 32 bits. /// /// # Serialization and Deserialization /// /// When the appropriate Cargo features are enabled, `NonMaxU32` implements /// [`serde::Serialize`] and [`serde::Deserialize`] in the natural way, as the /// integer value it represents. For example, serializing /// `NonMaxU32::new(0).unwrap()` as JSON or RON yields the string `"0"`. This is /// the case despite `NonMaxU32`'s implementation, described below. /// /// # Implementation /// /// Although this should not be observable to its users, a `NonMaxU32` whose /// value is `n` is a newtype around a [`NonZeroU32`] whose value is `n + 1`. /// This way, the range of values that `NonMaxU32` can represent, /// `0..=u32::MAX - 1`, is mapped to the range `1..=u32::MAX`, which is the /// range that /// [`NonZeroU32`] can represent. (And conversely, since /// [`u32`] addition wraps around, the value unrepresentable in `NonMaxU32`, /// [`u32::MAX`], becomes the value unrepresentable in [`NonZeroU32`], `0`.) /// /// [`NonZeroU32`]: core::num::NonZeroU32 #[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] pub struct NonMaxU32(NonZeroU32); impl NonMaxU32 { /// Construct a [`NonMaxU32`] whose value is `n`, if possible. pub const fn new(n: u32) -> Option { // If `n` is `u32::MAX`, then `n.wrapping_add(1)` is `0`, // so `NonZeroU32::new` returns `None` in exactly the case // where we must return `None`. match NonZeroU32::new(n.wrapping_add(1)) { Some(non_zero) => Some(NonMaxU32(non_zero)), None => None, } } /// Return the value of `self` as a [`u32`]. pub const fn get(self) -> u32 { self.0.get() - 1 } pub fn checked_add(self, n: u32) -> Option { // Adding `n` to `self` produces `u32::MAX` if and only if // adding `n` to `self.0` produces `0`. So we can simply // call `NonZeroU32::checked_add` and let its check for zero // determine whether our add would have produced `u32::MAX`. Some(NonMaxU32(self.0.checked_add(n)?)) } } impl core::fmt::Debug for NonMaxU32 { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { self.get().fmt(f) } } impl core::fmt::Display for NonMaxU32 { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { self.get().fmt(f) } } #[cfg(feature = "serialize")] impl serde::Serialize for NonMaxU32 { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer, { serializer.serialize_u32(self.get()) } } #[cfg(feature = "deserialize")] impl<'de> serde::Deserialize<'de> for NonMaxU32 { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { // Defer to `u32`'s `Deserialize` implementation. let n = ::deserialize(deserializer)?; // Constrain the range of the value further. NonMaxU32::new(n).ok_or_else(|| { ::invalid_value( serde::de::Unexpected::Unsigned(n as u64), &"a value no less than 0 and no greater than 4294967294 (2^32 - 2)", ) }) } } #[test] fn size() { assert_eq!(size_of::>(), size_of::()); } naga-29.0.3/src/proc/constant_evaluator.rs000064400000000000000000005505411046102023000166520ustar 00000000000000// Code in this file intentionally uses `for` loops and `.push()` rather than // `ArrayVec::from_iter`, because the latter is monomorphized by all three of // the item type, the capacity, and the iterator type, which can easily bloat // the compiled executable (by ~260 KiB, when it was removed). use alloc::{ format, string::{String, ToString}, vec, vec::Vec, }; use core::iter; use arrayvec::ArrayVec; use half::f16; use num_traits::{real::Real, FromPrimitive, One, ToPrimitive, Zero}; use crate::{ arena::{Arena, Handle, HandleVec, UniqueArena}, ArraySize, BinaryOperator, Constant, Expression, Literal, Override, RelationalFunction, ScalarKind, Span, Type, TypeInner, UnaryOperator, }; #[cfg(feature = "wgsl-in")] use crate::common::wgsl::TryToWgsl; /// A macro that allows dollar signs (`$`) to be emitted by other macros. Useful for generating /// `macro_rules!` items that, in turn, emit their own `macro_rules!` items. /// /// Technique stolen directly from /// . macro_rules! with_dollar_sign { ($($body:tt)*) => { macro_rules! __with_dollar_sign { $($body)* } __with_dollar_sign!($); } } macro_rules! gen_component_wise_extractor { ( $ident:ident -> $target:ident, literals: [$( $literal:ident => $mapping:ident: $ty:ident ),+ $(,)?], scalar_kinds: [$( $scalar_kind:ident ),* $(,)?], ) => { /// A subset of [`Literal`]s intended to be used for implementing numeric built-ins. #[derive(Debug)] #[cfg_attr(test, derive(PartialEq))] enum $target { $( #[doc = concat!( "Maps to [`Literal::", stringify!($literal), "`]", )] $mapping([$ty; N]), )+ } impl From<$target<1>> for Expression { fn from(value: $target<1>) -> Self { match value { $( $target::$mapping([value]) => { Expression::Literal(Literal::$literal(value)) } )+ } } } #[doc = concat!( "Attempts to evaluate multiple `exprs` as a combined [`", stringify!($target), "`] to pass to `handler`. ", )] /// If `exprs` are vectors of the same length, `handler` is called for each corresponding /// component of each vector. /// /// `handler`'s output is registered as a new expression. If `exprs` are vectors of the /// same length, a new vector expression is registered, composed of each component emitted /// by `handler`. fn $ident( eval: &mut ConstantEvaluator<'_>, span: Span, exprs: [Handle; N], handler: fn($target) -> Result<$target, ConstantEvaluatorError>, ) -> Result, ConstantEvaluatorError> where $target: Into, { assert!(N > 0); let err = ConstantEvaluatorError::InvalidMathArg; let mut exprs = exprs.into_iter(); macro_rules! sanitize { ($expr:expr) => { eval.eval_zero_value_and_splat($expr, span) .map(|expr| &eval.expressions[expr]) }; } let new_expr: Result = match sanitize!(exprs.next().unwrap())? { $( &Expression::Literal(Literal::$literal(x)) => { let mut arr = ArrayVec::<_, N>::new(); arr.push(x); for expr in exprs { match sanitize!(expr)? { &Expression::Literal(Literal::$literal(val)) => arr.push(val), _ => return Err(err), } } let comps = $target::$mapping(arr.into_inner().unwrap()); Ok(handler(comps)?.into()) }, )+ &Expression::Compose { ty, ref components } => match &eval.types[ty].inner { &TypeInner::Vector { size, scalar } => match scalar.kind { $(ScalarKind::$scalar_kind)|* => { let first_ty = ty; let mut component_groups = ArrayVec::, N>::new(); { let mut inner = ArrayVec::new(); for item in crate::proc::flatten_compose( first_ty, components, eval.expressions, eval.types, ) { inner.push(item); } component_groups.push(inner); } for expr in exprs { match sanitize!(expr)? { &Expression::Compose { ty, ref components } if &eval.types[ty].inner == &eval.types[first_ty].inner => { let mut inner = ArrayVec::new(); for item in crate::proc::flatten_compose( ty, components, eval.expressions, eval.types, ) { inner.push(item); } component_groups.push(inner); } _ => return Err(err), } } let component_groups = component_groups.into_inner().unwrap(); let mut new_components = ArrayVec::<_, { crate::VectorSize::MAX }>::new(); for idx in 0..(size as u8).into() { let mut group_arr = ArrayVec::<_, N>::new(); for cs in component_groups.iter() { group_arr.push( cs.get(idx).cloned().ok_or_else(|| err.clone())?, ); } let group = group_arr.into_inner().unwrap(); new_components.push($ident( eval, span, group, handler, )?); } Ok(Expression::Compose { ty: first_ty, components: new_components.into_iter().collect(), }) } _ => return Err(err), }, _ => return Err(err), }, _ => return Err(err), }; eval.register_evaluated_expr(new_expr?, span) } with_dollar_sign! { ($d:tt) => { #[allow(unused)] #[doc = concat!( "A convenience macro for using the same RHS for each [`", stringify!($target), "`] variant in a call to [`", stringify!($ident), "`].", )] macro_rules! $ident { ( $eval:expr, $span:expr, [$d ($d expr:expr),+ $d (,)?], |$d ($d arg:ident),+| $d tt:tt ) => { $ident($eval, $span, [$d ($d expr),+], |args| match args { $( $target::$mapping([$d ($d arg),+]) => { let res = $d tt; Result::map(res, $target::$mapping) }, )+ }) }; } }; } }; } gen_component_wise_extractor! { component_wise_scalar -> Scalar, literals: [ AbstractFloat => AbstractFloat: f64, F32 => F32: f32, F16 => F16: f16, AbstractInt => AbstractInt: i64, U32 => U32: u32, I32 => I32: i32, U64 => U64: u64, I64 => I64: i64, ], scalar_kinds: [ Float, AbstractFloat, Sint, Uint, AbstractInt, ], } gen_component_wise_extractor! { component_wise_float -> Float, literals: [ AbstractFloat => Abstract: f64, F32 => F32: f32, F16 => F16: f16, ], scalar_kinds: [ Float, AbstractFloat, ], } gen_component_wise_extractor! { component_wise_concrete_int -> ConcreteInt, literals: [ U32 => U32: u32, I32 => I32: i32, ], scalar_kinds: [ Sint, Uint, ], } gen_component_wise_extractor! { component_wise_signed -> Signed, literals: [ AbstractFloat => AbstractFloat: f64, AbstractInt => AbstractInt: i64, F32 => F32: f32, F16 => F16: f16, I32 => I32: i32, ], scalar_kinds: [ Sint, AbstractInt, Float, AbstractFloat, ], } /// Vectors with a concrete element type. #[derive(Debug)] enum LiteralVector { F64(ArrayVec), F32(ArrayVec), F16(ArrayVec), U32(ArrayVec), I32(ArrayVec), U64(ArrayVec), I64(ArrayVec), Bool(ArrayVec), AbstractInt(ArrayVec), AbstractFloat(ArrayVec), } impl LiteralVector { #[allow(clippy::missing_const_for_fn, reason = "MSRV")] fn len(&self) -> usize { match *self { LiteralVector::F64(ref v) => v.len(), LiteralVector::F32(ref v) => v.len(), LiteralVector::F16(ref v) => v.len(), LiteralVector::U32(ref v) => v.len(), LiteralVector::I32(ref v) => v.len(), LiteralVector::U64(ref v) => v.len(), LiteralVector::I64(ref v) => v.len(), LiteralVector::Bool(ref v) => v.len(), LiteralVector::AbstractInt(ref v) => v.len(), LiteralVector::AbstractFloat(ref v) => v.len(), } } /// Creates [`LiteralVector`] of size 1 from single [`Literal`] fn from_literal(literal: Literal) -> Self { fn arrayvec_of(val: T) -> ArrayVec { let mut v = ArrayVec::new(); v.push(val); v } match literal { Literal::F64(e) => Self::F64(arrayvec_of(e)), Literal::F32(e) => Self::F32(arrayvec_of(e)), Literal::U32(e) => Self::U32(arrayvec_of(e)), Literal::I32(e) => Self::I32(arrayvec_of(e)), Literal::U64(e) => Self::U64(arrayvec_of(e)), Literal::I64(e) => Self::I64(arrayvec_of(e)), Literal::Bool(e) => Self::Bool(arrayvec_of(e)), Literal::AbstractInt(e) => Self::AbstractInt(arrayvec_of(e)), Literal::AbstractFloat(e) => Self::AbstractFloat(arrayvec_of(e)), Literal::F16(e) => Self::F16(arrayvec_of(e)), } } /// Creates [`LiteralVector`] from [`ArrayVec`] of [`Literal`]s. /// Returns error if components types do not match. /// # Panics /// Panics if vector is empty fn from_literal_vec( components: ArrayVec, ) -> Result { assert!(!components.is_empty()); // TODO: should a vector of i32 be constructible from abstract int? macro_rules! compose_literals { ($components:expr, $variant:ident, $self_variant:ident) => {{ let mut out = ArrayVec::new(); for l in &$components { match l { &Literal::$variant(v) => out.push(v), _ => return Err(ConstantEvaluatorError::InvalidMathArg), } } Self::$self_variant(out) }}; } Ok(match components[0] { Literal::I32(_) => compose_literals!(components, I32, I32), Literal::U32(_) => compose_literals!(components, U32, U32), Literal::I64(_) => compose_literals!(components, I64, I64), Literal::U64(_) => compose_literals!(components, U64, U64), Literal::F32(_) => compose_literals!(components, F32, F32), Literal::F64(_) => compose_literals!(components, F64, F64), Literal::Bool(_) => compose_literals!(components, Bool, Bool), Literal::AbstractInt(_) => compose_literals!(components, AbstractInt, AbstractInt), Literal::AbstractFloat(_) => { compose_literals!(components, AbstractFloat, AbstractFloat) } Literal::F16(_) => compose_literals!(components, F16, F16), }) } #[allow(dead_code)] /// Returns [`ArrayVec`] of [`Literal`]s fn to_literal_vec(&self) -> ArrayVec { macro_rules! decompose_literals { ($v:expr, $variant:ident) => {{ let mut out = ArrayVec::new(); for e in $v { out.push(Literal::$variant(*e)); } out }}; } match *self { LiteralVector::F64(ref v) => decompose_literals!(v, F64), LiteralVector::F32(ref v) => decompose_literals!(v, F32), LiteralVector::F16(ref v) => decompose_literals!(v, F16), LiteralVector::U32(ref v) => decompose_literals!(v, U32), LiteralVector::I32(ref v) => decompose_literals!(v, I32), LiteralVector::U64(ref v) => decompose_literals!(v, U64), LiteralVector::I64(ref v) => decompose_literals!(v, I64), LiteralVector::Bool(ref v) => decompose_literals!(v, Bool), LiteralVector::AbstractInt(ref v) => decompose_literals!(v, AbstractInt), LiteralVector::AbstractFloat(ref v) => decompose_literals!(v, AbstractFloat), } } #[allow(dead_code)] /// Puts self into eval's expressions arena and returns handle to it fn register_as_evaluated_expr( &self, eval: &mut ConstantEvaluator<'_>, span: Span, ) -> Result, ConstantEvaluatorError> { let lit_vec = self.to_literal_vec(); assert!(!lit_vec.is_empty()); let expr = if lit_vec.len() == 1 { Expression::Literal(lit_vec[0]) } else { Expression::Compose { ty: eval.types.insert( Type { name: None, inner: TypeInner::Vector { size: match lit_vec.len() { 2 => crate::VectorSize::Bi, 3 => crate::VectorSize::Tri, 4 => crate::VectorSize::Quad, _ => unreachable!(), }, scalar: lit_vec[0].scalar(), }, }, Span::UNDEFINED, ), components: lit_vec .iter() .map(|&l| eval.register_evaluated_expr(Expression::Literal(l), span)) .collect::>()?, } }; eval.register_evaluated_expr(expr, span) } } /// A macro for matching on [`LiteralVector`] variants. /// /// `Float` variant expands to `F16`, `F32`, `F64` and `AbstractFloat`. /// `Integer` variant expands to `I32`, `I64`, `U32`, `U64` and `AbstractInt`. /// /// For output both [`Literal`] (fold) and [`LiteralVector`] (map) are supported. /// /// Example usage: /// /// ```rust,ignore /// match_literal_vector!(match v => Literal { /// F16 => |v| {v.sum()}, /// Integer => |v| {v.sum()}, /// U32 => |v| -> I32 {v.sum()}, // optionally override return type /// }) /// ``` /// /// ```rust,ignore /// match_literal_vector!(match (e1, e2) => LiteralVector { /// F16 => |e1, e2| {e1+e2}, /// Integer => |e1, e2| {e1+e2}, /// U32 => |e1, e2| -> I32 {e1+e2}, // optionally override return type /// }) /// ``` macro_rules! match_literal_vector { (match $lit_vec:expr => $out:ident { $( $ty:ident => |$($var:ident),+| $(-> $ret:ident)? { $body:expr } ),+ $(,)? }) => { match_literal_vector!(@inner_start $lit_vec; $out; [$($ty),+]; [$({ $($var),+ ; $($ret)? ; $body }),+]) }; (@inner_start $lit_vec:expr; $out:ident; [$($ty:ident),+]; [$({ $($var:ident),+ ; $($ret:ident)? ; $body:expr }),+] ) => { match_literal_vector!(@inner $lit_vec; $out; [$($ty),+]; [] <> [$({ $($var),+ ; $($ret)? ; $body }),+] ) }; (@inner $lit_vec:expr; $out:ident; [$ty:ident $(, $ty1:ident)*]; [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <> [$({ $($var:ident),+ ; $($ret:ident)? ; $body:expr }),+] ) => { match_literal_vector!(@inner $ty; $lit_vec; $out; [$($ty1),*]; [$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <> [$({ $($var),+ ; $($ret)? ; $body }),+] ) }; (@inner Integer; $lit_vec:expr; $out:ident; [$($ty:ident),*]; [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <> [ { $($var:ident),+ ; $($ret:ident)? ; $body:expr } $(,{ $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })* ] ) => { match_literal_vector!(@inner $lit_vec; $out; [U32, I32, U64, I64, AbstractInt $(, $ty)*]; [$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <> [ { $($var),+ ; $($ret)? ; $body }, // U32 { $($var),+ ; $($ret)? ; $body }, // I32 { $($var),+ ; $($ret)? ; $body }, // U64 { $($var),+ ; $($ret)? ; $body }, // I64 { $($var),+ ; $($ret)? ; $body } // AbstractInt $(,{ $($var1),+ ; $($ret1)? ; $body1 })* ] ) }; (@inner Float; $lit_vec:expr; $out:ident; [$($ty:ident),*]; [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <> [ { $($var:ident),+ ; $($ret:ident)? ; $body:expr } $(,{ $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })* ] ) => { match_literal_vector!(@inner $lit_vec; $out; [F16, F32, F64, AbstractFloat $(, $ty)*]; [$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <> [ { $($var),+ ; $($ret)? ; $body }, // F16 { $($var),+ ; $($ret)? ; $body }, // F32 { $($var),+ ; $($ret)? ; $body }, // F64 { $($var),+ ; $($ret)? ; $body } // AbstractFloat $(,{ $($var1),+ ; $($ret1)? ; $body1 })* ] ) }; (@inner $ty:ident; $lit_vec:expr; $out:ident; [$ty1:ident $(,$ty2:ident)*]; [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <> [ { $($var:ident),+ ; $($ret:ident)? ; $body:expr } $(, { $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })* ] ) => { match_literal_vector!(@inner $ty1; $lit_vec; $out; [$($ty2),*]; [ $({$_ty ; $($_var),+ ; $($_ret)? ; $_body},)* { $ty; $($var),+ ; $($ret)? ; $body } ] <> [$({ $($var1),+ ; $($ret1)? ; $body1 }),*] ) }; (@inner $ty:ident; $lit_vec:expr; $out:ident; []; [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <> [{ $($var:ident),+ ; $($ret:ident)? ; $body:expr }] ) => { match_literal_vector!(@inner_finish $lit_vec; $out; [ $({ $_ty ; $($_var),+ ; $($_ret)? ; $_body },)* { $ty; $($var),+ ; $($ret)? ; $body } ] ) }; (@inner_finish $lit_vec:expr; $out:ident; [$({$ty:ident ; $($var:ident),+ ; $($ret:ident)? ; $body:expr}),+] ) => { match $lit_vec { $( #[allow(unused_parens)] ($(LiteralVector::$ty(ref $var)),+) => { Ok(match_literal_vector!(@expand_ret $out; $ty $(; $ret)? ; $body)) } )+ _ => Err(ConstantEvaluatorError::InvalidMathArg), } }; (@expand_ret $out:ident; $ty:ident; $body:expr) => { $out::$ty($body) }; (@expand_ret $out:ident; $_ty:ident; $ret:ident; $body:expr) => { $out::$ret($body) }; } #[derive(Debug)] enum Behavior<'a> { Wgsl(WgslRestrictions<'a>), Glsl(GlslRestrictions<'a>), } impl Behavior<'_> { /// Returns `true` if the inner WGSL/GLSL restrictions are runtime restrictions. const fn has_runtime_restrictions(&self) -> bool { matches!( self, &Behavior::Wgsl(WgslRestrictions::Runtime(_)) | &Behavior::Glsl(GlslRestrictions::Runtime(_)) ) } } /// A context for evaluating constant expressions. /// /// A `ConstantEvaluator` points at an expression arena to which it can append /// newly evaluated expressions: you pass [`try_eval_and_append`] whatever kind /// of Naga [`Expression`] you like, and if its value can be computed at compile /// time, `try_eval_and_append` appends an expression representing the computed /// value - a tree of [`Literal`], [`Compose`], [`ZeroValue`], and [`Swizzle`] /// expressions - to the arena. See the [`try_eval_and_append`] method for details. /// /// A `ConstantEvaluator` also holds whatever information we need to carry out /// that evaluation: types, other constants, and so on. /// /// [`try_eval_and_append`]: ConstantEvaluator::try_eval_and_append /// [`Compose`]: Expression::Compose /// [`ZeroValue`]: Expression::ZeroValue /// [`Literal`]: Expression::Literal /// [`Swizzle`]: Expression::Swizzle #[derive(Debug)] pub struct ConstantEvaluator<'a> { /// Which language's evaluation rules we should follow. behavior: Behavior<'a>, /// The module's type arena. /// /// Because expressions like [`Splat`] contain type handles, we need to be /// able to add new types to produce those expressions. /// /// [`Splat`]: Expression::Splat types: &'a mut UniqueArena, /// The module's constant arena. constants: &'a Arena, /// The module's override arena. overrides: &'a Arena, /// The arena to which we are contributing expressions. expressions: &'a mut Arena, /// Tracks the constness of expressions residing in [`Self::expressions`] expression_kind_tracker: &'a mut ExpressionKindTracker, layouter: &'a mut crate::proc::Layouter, } #[derive(Debug)] enum WgslRestrictions<'a> { /// - const-expressions will be evaluated and inserted in the arena Const(Option>), /// - const-expressions will be evaluated and inserted in the arena /// - override-expressions will be inserted in the arena Override, /// - const-expressions will be evaluated and inserted in the arena /// - override-expressions will be inserted in the arena /// - runtime-expressions will be inserted in the arena Runtime(FunctionLocalData<'a>), } #[derive(Debug)] enum GlslRestrictions<'a> { /// - const-expressions will be evaluated and inserted in the arena Const, /// - const-expressions will be evaluated and inserted in the arena /// - override-expressions will be inserted in the arena /// - runtime-expressions will be inserted in the arena Runtime(FunctionLocalData<'a>), } #[derive(Debug)] struct FunctionLocalData<'a> { /// Global constant expressions global_expressions: &'a Arena, emitter: &'a mut super::Emitter, block: &'a mut crate::Block, } #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] pub enum ExpressionKind { Const, Override, Runtime, } #[derive(Debug)] pub struct ExpressionKindTracker { inner: HandleVec, } impl ExpressionKindTracker { pub const fn new() -> Self { Self { inner: HandleVec::new(), } } /// Forces the the expression to not be const pub fn force_non_const(&mut self, value: Handle) { self.inner[value] = ExpressionKind::Runtime; } pub fn insert(&mut self, value: Handle, expr_type: ExpressionKind) { self.inner.insert(value, expr_type); } pub fn is_const(&self, h: Handle) -> bool { matches!(self.type_of(h), ExpressionKind::Const) } pub fn is_const_or_override(&self, h: Handle) -> bool { matches!( self.type_of(h), ExpressionKind::Const | ExpressionKind::Override ) } fn type_of(&self, value: Handle) -> ExpressionKind { self.inner[value] } pub fn from_arena(arena: &Arena) -> Self { let mut tracker = Self { inner: HandleVec::with_capacity(arena.len()), }; for (handle, expr) in arena.iter() { tracker .inner .insert(handle, tracker.type_of_with_expr(expr)); } tracker } fn type_of_with_expr(&self, expr: &Expression) -> ExpressionKind { match *expr { Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => { ExpressionKind::Const } Expression::Override(_) => ExpressionKind::Override, Expression::Compose { ref components, .. } => { let mut expr_type = ExpressionKind::Const; for component in components { expr_type = expr_type.max(self.type_of(*component)) } expr_type } Expression::Splat { value, .. } => self.type_of(value), Expression::AccessIndex { base, .. } => self.type_of(base), Expression::Access { base, index } => self.type_of(base).max(self.type_of(index)), Expression::Swizzle { vector, .. } => self.type_of(vector), Expression::Unary { expr, .. } => self.type_of(expr), Expression::Binary { left, right, .. } => self.type_of(left).max(self.type_of(right)), Expression::Math { arg, arg1, arg2, arg3, .. } => self .type_of(arg) .max( arg1.map(|arg| self.type_of(arg)) .unwrap_or(ExpressionKind::Const), ) .max( arg2.map(|arg| self.type_of(arg)) .unwrap_or(ExpressionKind::Const), ) .max( arg3.map(|arg| self.type_of(arg)) .unwrap_or(ExpressionKind::Const), ), Expression::As { expr, .. } => self.type_of(expr), Expression::Select { condition, accept, reject, } => self .type_of(condition) .max(self.type_of(accept)) .max(self.type_of(reject)), Expression::Relational { argument, .. } => self.type_of(argument), Expression::ArrayLength(expr) => self.type_of(expr), _ => ExpressionKind::Runtime, } } } #[derive(Clone, Debug, thiserror::Error)] #[cfg_attr(test, derive(PartialEq))] pub enum ConstantEvaluatorError { #[error("Constants cannot access function arguments")] FunctionArg, #[error("Constants cannot access global variables")] GlobalVariable, #[error("Constants cannot access local variables")] LocalVariable, #[error("Cannot get the array length of a non array type")] InvalidArrayLengthArg, #[error("Constants cannot get the array length of a dynamically sized array")] ArrayLengthDynamic, #[error("Cannot call arrayLength on array sized by override-expression")] ArrayLengthOverridden, #[error("Constants cannot call functions")] Call, #[error("Constants don't support workGroupUniformLoad")] WorkGroupUniformLoadResult, #[error("Constants don't support atomic functions")] Atomic, #[error("Constants don't support derivative functions")] Derivative, #[error("Constants don't support load expressions")] Load, #[error("Constants don't support image expressions")] ImageExpression, #[error("Constants don't support ray query expressions")] RayQueryExpression, #[error("Constants don't support subgroup expressions")] SubgroupExpression, #[error("Cannot access the type")] InvalidAccessBase, #[error("Cannot access at the index")] InvalidAccessIndex, #[error("Cannot access with index of type")] InvalidAccessIndexTy, #[error("Constants don't support array length expressions")] ArrayLength, #[error("Cannot cast scalar components of expression `{from}` to type `{to}`")] InvalidCastArg { from: String, to: String }, #[error("Cannot apply the unary op to the argument")] InvalidUnaryOpArg, #[error("Cannot apply the binary op to the arguments")] InvalidBinaryOpArgs, #[error("Cannot apply math function to type")] InvalidMathArg, #[error("{0:?} built-in function expects {1:?} arguments but {2:?} were supplied")] InvalidMathArgCount(crate::MathFunction, usize, usize), #[error("{0} built-in function argument is out of valid range")] InvalidMathArgValue(String), #[error("Cannot apply relational function to type")] InvalidRelationalArg(RelationalFunction), #[error("value of `low` is greater than `high` for clamp built-in function")] InvalidClamp, #[error("Constructor expects {expected} components, found {actual}")] InvalidVectorComposeLength { expected: usize, actual: usize }, #[error("Constructor must only contain vector or scalar arguments")] InvalidVectorComposeComponent, #[error("Splat is defined only on scalar values")] SplatScalarOnly, #[error("Can only swizzle vector constants")] SwizzleVectorOnly, #[error("swizzle component not present in source expression")] SwizzleOutOfBounds, #[error("Type is not constructible")] TypeNotConstructible, #[error("Subexpression(s) are not constant")] SubexpressionsAreNotConstant, #[error("Not implemented as constant expression: {0}")] NotImplemented(String), #[error("{0} operation overflowed")] Overflow(String), #[error( "the concrete type `{to_type}` cannot represent the abstract value `{value}` accurately" )] AutomaticConversionLossy { value: String, to_type: &'static str, }, #[error("Division by zero")] DivisionByZero, #[error("Remainder by zero")] RemainderByZero, #[error("RHS of shift operation is greater than or equal to 32")] ShiftedMoreThan32Bits, #[error(transparent)] Literal(#[from] crate::valid::LiteralError), #[error("Can't use pipeline-overridable constants in const-expressions")] Override, #[error("Unexpected runtime-expression")] RuntimeExpr, #[error("Unexpected override-expression")] OverrideExpr, #[error("Expected boolean expression for condition argument of `select`, got something else")] SelectScalarConditionNotABool, #[error( "Expected vectors of the same size for reject and accept args., got {:?} and {:?}", reject, accept )] SelectVecRejectAcceptSizeMismatch { reject: crate::VectorSize, accept: crate::VectorSize, }, #[error("Expected boolean vector for condition arg., got something else")] SelectConditionNotAVecBool, #[error( "Expected same number of vector components between condition, accept, and reject args., got something else", )] SelectConditionVecSizeMismatch, #[error( "Expected reject and accept args. to be scalars of vectors of the same type, got something else", )] SelectAcceptRejectTypeMismatch, #[error("Cooperative operations can't be constant")] CooperativeOperation, } impl<'a> ConstantEvaluator<'a> { /// Return a [`ConstantEvaluator`] that will add expressions to `module`'s /// constant expression arena. /// /// Report errors according to WGSL's rules for constant evaluation. pub const fn for_wgsl_module( module: &'a mut crate::Module, global_expression_kind_tracker: &'a mut ExpressionKindTracker, layouter: &'a mut crate::proc::Layouter, in_override_ctx: bool, ) -> Self { Self::for_module( Behavior::Wgsl(if in_override_ctx { WgslRestrictions::Override } else { WgslRestrictions::Const(None) }), module, global_expression_kind_tracker, layouter, ) } /// Return a [`ConstantEvaluator`] that will add expressions to `module`'s /// constant expression arena. /// /// Report errors according to GLSL's rules for constant evaluation. pub const fn for_glsl_module( module: &'a mut crate::Module, global_expression_kind_tracker: &'a mut ExpressionKindTracker, layouter: &'a mut crate::proc::Layouter, ) -> Self { Self::for_module( Behavior::Glsl(GlslRestrictions::Const), module, global_expression_kind_tracker, layouter, ) } const fn for_module( behavior: Behavior<'a>, module: &'a mut crate::Module, global_expression_kind_tracker: &'a mut ExpressionKindTracker, layouter: &'a mut crate::proc::Layouter, ) -> Self { Self { behavior, types: &mut module.types, constants: &module.constants, overrides: &module.overrides, expressions: &mut module.global_expressions, expression_kind_tracker: global_expression_kind_tracker, layouter, } } /// Return a [`ConstantEvaluator`] that will add expressions to `function`'s /// expression arena. /// /// Report errors according to WGSL's rules for constant evaluation. pub const fn for_wgsl_function( module: &'a mut crate::Module, expressions: &'a mut Arena, local_expression_kind_tracker: &'a mut ExpressionKindTracker, layouter: &'a mut crate::proc::Layouter, emitter: &'a mut super::Emitter, block: &'a mut crate::Block, is_const: bool, ) -> Self { let local_data = FunctionLocalData { global_expressions: &module.global_expressions, emitter, block, }; Self { behavior: Behavior::Wgsl(if is_const { WgslRestrictions::Const(Some(local_data)) } else { WgslRestrictions::Runtime(local_data) }), types: &mut module.types, constants: &module.constants, overrides: &module.overrides, expressions, expression_kind_tracker: local_expression_kind_tracker, layouter, } } /// Return a [`ConstantEvaluator`] that will add expressions to `function`'s /// expression arena. /// /// Report errors according to GLSL's rules for constant evaluation. pub const fn for_glsl_function( module: &'a mut crate::Module, expressions: &'a mut Arena, local_expression_kind_tracker: &'a mut ExpressionKindTracker, layouter: &'a mut crate::proc::Layouter, emitter: &'a mut super::Emitter, block: &'a mut crate::Block, ) -> Self { Self { behavior: Behavior::Glsl(GlslRestrictions::Runtime(FunctionLocalData { global_expressions: &module.global_expressions, emitter, block, })), types: &mut module.types, constants: &module.constants, overrides: &module.overrides, expressions, expression_kind_tracker: local_expression_kind_tracker, layouter, } } pub const fn to_ctx(&self) -> crate::proc::GlobalCtx<'_> { crate::proc::GlobalCtx { types: self.types, constants: self.constants, overrides: self.overrides, global_expressions: match self.function_local_data() { Some(data) => data.global_expressions, None => self.expressions, }, } } fn check(&self, expr: Handle) -> Result<(), ConstantEvaluatorError> { if !self.expression_kind_tracker.is_const(expr) { log::debug!("check: SubexpressionsAreNotConstant"); return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant); } Ok(()) } fn check_and_get( &mut self, expr: Handle, ) -> Result, ConstantEvaluatorError> { match self.expressions[expr] { Expression::Constant(c) => { // Are we working in a function's expression arena, or the // module's constant expression arena? if let Some(function_local_data) = self.function_local_data() { // Deep-copy the constant's value into our arena. self.copy_from( self.constants[c].init, function_local_data.global_expressions, ) } else { // "See through" the constant and use its initializer. Ok(self.constants[c].init) } } _ => { self.check(expr)?; Ok(expr) } } } /// Try to evaluate `expr` at compile time. /// /// The `expr` argument can be any sort of Naga [`Expression`] you like. If /// we can determine its value at compile time, we append an expression /// representing its value - a tree of [`Literal`], [`Compose`], /// [`ZeroValue`], and [`Swizzle`] expressions - to the expression arena /// `self` contributes to. /// /// If `expr`'s value cannot be determined at compile time, and `self` is /// contributing to some function's expression arena, then append `expr` to /// that arena unchanged (and thus unevaluated). Otherwise, `self` must be /// contributing to the module's constant expression arena; since `expr`'s /// value is not a constant, return an error. /// /// We only consider `expr` itself, without recursing into its operands. Its /// operands must all have been produced by prior calls to /// `try_eval_and_append`, to ensure that they have already been reduced to /// an evaluated form if possible. /// /// [`Literal`]: Expression::Literal /// [`Compose`]: Expression::Compose /// [`ZeroValue`]: Expression::ZeroValue /// [`Swizzle`]: Expression::Swizzle pub fn try_eval_and_append( &mut self, expr: Expression, span: Span, ) -> Result, ConstantEvaluatorError> { match self.expression_kind_tracker.type_of_with_expr(&expr) { ExpressionKind::Const => { let eval_result = self.try_eval_and_append_impl(&expr, span); // We should be able to evaluate `Const` expressions at this // point. If we failed to, then that probably means we just // haven't implemented that part of constant evaluation. Work // around this by simply emitting it as a run-time expression. if self.behavior.has_runtime_restrictions() && matches!( eval_result, Err(ConstantEvaluatorError::NotImplemented(_) | ConstantEvaluatorError::InvalidBinaryOpArgs,) ) { Ok(self.append_expr(expr, span, ExpressionKind::Runtime)) } else { eval_result } } ExpressionKind::Override => match self.behavior { Behavior::Wgsl(WgslRestrictions::Override | WgslRestrictions::Runtime(_)) => { Ok(self.append_expr(expr, span, ExpressionKind::Override)) } Behavior::Wgsl(WgslRestrictions::Const(_)) => { Err(ConstantEvaluatorError::OverrideExpr) } // GLSL specialization constants (constant_id) become Override expressions Behavior::Glsl(GlslRestrictions::Runtime(_)) => { Ok(self.append_expr(expr, span, ExpressionKind::Override)) } Behavior::Glsl(GlslRestrictions::Const) => { Err(ConstantEvaluatorError::OverrideExpr) } }, ExpressionKind::Runtime => { if self.behavior.has_runtime_restrictions() { Ok(self.append_expr(expr, span, ExpressionKind::Runtime)) } else { Err(ConstantEvaluatorError::RuntimeExpr) } } } } /// Is the [`Self::expressions`] arena the global module expression arena? const fn is_global_arena(&self) -> bool { matches!( self.behavior, Behavior::Wgsl(WgslRestrictions::Const(None) | WgslRestrictions::Override) | Behavior::Glsl(GlslRestrictions::Const) ) } const fn function_local_data(&self) -> Option<&FunctionLocalData<'a>> { match self.behavior { Behavior::Wgsl( WgslRestrictions::Runtime(ref function_local_data) | WgslRestrictions::Const(Some(ref function_local_data)), ) | Behavior::Glsl(GlslRestrictions::Runtime(ref function_local_data)) => { Some(function_local_data) } _ => None, } } fn try_eval_and_append_impl( &mut self, expr: &Expression, span: Span, ) -> Result, ConstantEvaluatorError> { log::trace!("try_eval_and_append: {expr:?}"); match *expr { Expression::Constant(c) if self.is_global_arena() => { // "See through" the constant and use its initializer. // This is mainly done to avoid having constants pointing to other constants. Ok(self.constants[c].init) } Expression::Override(_) => Err(ConstantEvaluatorError::Override), Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => { self.register_evaluated_expr(expr.clone(), span) } Expression::Compose { ty, ref components } => { let components = components .iter() .map(|component| self.check_and_get(*component)) .collect::, _>>()?; self.register_evaluated_expr(Expression::Compose { ty, components }, span) } Expression::Splat { size, value } => { let value = self.check_and_get(value)?; self.register_evaluated_expr(Expression::Splat { size, value }, span) } Expression::AccessIndex { base, index } => { let base = self.check_and_get(base)?; self.access(base, index as usize, span) } Expression::Access { base, index } => { let base = self.check_and_get(base)?; let index = self.check_and_get(index)?; let index_val: u32 = self .to_ctx() .get_const_val_from(index, self.expressions) .map_err(|_| ConstantEvaluatorError::InvalidAccessIndexTy)?; self.access(base, index_val as usize, span) } Expression::Swizzle { size, vector, pattern, } => { let vector = self.check_and_get(vector)?; self.swizzle(size, span, vector, pattern) } Expression::Unary { expr, op } => { let expr = self.check_and_get(expr)?; self.unary_op(op, expr, span) } Expression::Binary { left, right, op } => { let left = self.check_and_get(left)?; let right = self.check_and_get(right)?; self.binary_op(op, left, right, span) } Expression::Math { fun, arg, arg1, arg2, arg3, } => { let arg = self.check_and_get(arg)?; let arg1 = arg1.map(|arg| self.check_and_get(arg)).transpose()?; let arg2 = arg2.map(|arg| self.check_and_get(arg)).transpose()?; let arg3 = arg3.map(|arg| self.check_and_get(arg)).transpose()?; self.math(arg, arg1, arg2, arg3, fun, span) } Expression::As { convert, expr, kind, } => { let expr = self.check_and_get(expr)?; match convert { Some(width) => self.cast(expr, crate::Scalar { kind, width }, span), None => Err(ConstantEvaluatorError::NotImplemented( "bitcast built-in function".into(), )), } } Expression::Select { reject, accept, condition, } => { let mut arg = |expr| self.check_and_get(expr); let reject = arg(reject)?; let accept = arg(accept)?; let condition = arg(condition)?; self.select(reject, accept, condition, span) } Expression::Relational { fun, argument } => { let argument = self.check_and_get(argument)?; self.relational(fun, argument, span) } Expression::ArrayLength(expr) => match self.behavior { Behavior::Wgsl(_) => Err(ConstantEvaluatorError::ArrayLength), Behavior::Glsl(_) => { let expr = self.check_and_get(expr)?; self.array_length(expr, span) } }, Expression::Load { .. } => Err(ConstantEvaluatorError::Load), Expression::LocalVariable(_) => Err(ConstantEvaluatorError::LocalVariable), Expression::Derivative { .. } => Err(ConstantEvaluatorError::Derivative), Expression::CallResult { .. } => Err(ConstantEvaluatorError::Call), Expression::WorkGroupUniformLoadResult { .. } => { Err(ConstantEvaluatorError::WorkGroupUniformLoadResult) } Expression::AtomicResult { .. } => Err(ConstantEvaluatorError::Atomic), Expression::FunctionArgument(_) => Err(ConstantEvaluatorError::FunctionArg), Expression::GlobalVariable(_) => Err(ConstantEvaluatorError::GlobalVariable), Expression::ImageSample { .. } | Expression::ImageLoad { .. } | Expression::ImageQuery { .. } => Err(ConstantEvaluatorError::ImageExpression), Expression::RayQueryProceedResult | Expression::RayQueryGetIntersection { .. } | Expression::RayQueryVertexPositions { .. } => { Err(ConstantEvaluatorError::RayQueryExpression) } Expression::SubgroupBallotResult => Err(ConstantEvaluatorError::SubgroupExpression), Expression::SubgroupOperationResult { .. } => { Err(ConstantEvaluatorError::SubgroupExpression) } Expression::CooperativeLoad { .. } | Expression::CooperativeMultiplyAdd { .. } => { Err(ConstantEvaluatorError::CooperativeOperation) } } } /// Splat `value` to `size`, without using [`Splat`] expressions. /// /// This constructs [`Compose`] or [`ZeroValue`] expressions to /// build a vector with the given `size` whose components are all /// `value`. /// /// Use `span` as the span of the inserted expressions and /// resulting types. /// /// [`Splat`]: Expression::Splat /// [`Compose`]: Expression::Compose /// [`ZeroValue`]: Expression::ZeroValue fn splat( &mut self, value: Handle, size: crate::VectorSize, span: Span, ) -> Result, ConstantEvaluatorError> { match self.expressions[value] { Expression::Literal(literal) => { let scalar = literal.scalar(); let ty = self.types.insert( Type { name: None, inner: TypeInner::Vector { size, scalar }, }, span, ); let expr = Expression::Compose { ty, components: vec![value; size as usize], }; self.register_evaluated_expr(expr, span) } Expression::ZeroValue(ty) => { let inner = match self.types[ty].inner { TypeInner::Scalar(scalar) => TypeInner::Vector { size, scalar }, _ => return Err(ConstantEvaluatorError::SplatScalarOnly), }; let res_ty = self.types.insert(Type { name: None, inner }, span); let expr = Expression::ZeroValue(res_ty); self.register_evaluated_expr(expr, span) } _ => Err(ConstantEvaluatorError::SplatScalarOnly), } } fn swizzle( &mut self, size: crate::VectorSize, span: Span, src_constant: Handle, pattern: [crate::SwizzleComponent; 4], ) -> Result, ConstantEvaluatorError> { let mut get_dst_ty = |ty| match self.types[ty].inner { TypeInner::Vector { size: _, scalar } => Ok(self.types.insert( Type { name: None, inner: TypeInner::Vector { size, scalar }, }, span, )), _ => Err(ConstantEvaluatorError::SwizzleVectorOnly), }; match self.expressions[src_constant] { Expression::ZeroValue(ty) => { let dst_ty = get_dst_ty(ty)?; let expr = Expression::ZeroValue(dst_ty); self.register_evaluated_expr(expr, span) } Expression::Splat { value, .. } => { let expr = Expression::Splat { size, value }; self.register_evaluated_expr(expr, span) } Expression::Compose { ty, ref components } => { let dst_ty = get_dst_ty(ty)?; let mut flattened = [src_constant; 4]; // dummy value let len = crate::proc::flatten_compose(ty, components, self.expressions, self.types) .zip(flattened.iter_mut()) .map(|(component, elt)| *elt = component) .count(); let flattened = &flattened[..len]; let swizzled_components = pattern[..size as usize] .iter() .map(|&sc| { let sc = sc as usize; if let Some(elt) = flattened.get(sc) { Ok(*elt) } else { Err(ConstantEvaluatorError::SwizzleOutOfBounds) } }) .collect::>, _>>()?; let expr = Expression::Compose { ty: dst_ty, components: swizzled_components, }; self.register_evaluated_expr(expr, span) } _ => Err(ConstantEvaluatorError::SwizzleVectorOnly), } } fn math( &mut self, arg: Handle, arg1: Option>, arg2: Option>, arg3: Option>, fun: crate::MathFunction, span: Span, ) -> Result, ConstantEvaluatorError> { let expected = fun.argument_count(); let given = Some(arg) .into_iter() .chain(arg1) .chain(arg2) .chain(arg3) .count(); if expected != given { return Err(ConstantEvaluatorError::InvalidMathArgCount( fun, expected, given, )); } // NOTE: We try to match the declaration order of `MathFunction` here. match fun { // comparison crate::MathFunction::Abs => { component_wise_scalar(self, span, [arg], |args| match args { Scalar::AbstractFloat([e]) => Ok(Scalar::AbstractFloat([e.abs()])), Scalar::F32([e]) => Ok(Scalar::F32([e.abs()])), Scalar::F16([e]) => Ok(Scalar::F16([e.abs()])), Scalar::AbstractInt([e]) => Ok(Scalar::AbstractInt([e.wrapping_abs()])), Scalar::I32([e]) => Ok(Scalar::I32([e.wrapping_abs()])), Scalar::U32([e]) => Ok(Scalar::U32([e])), // TODO: just re-use the expression, ezpz Scalar::I64([e]) => Ok(Scalar::I64([e.wrapping_abs()])), Scalar::U64([e]) => Ok(Scalar::U64([e])), }) } crate::MathFunction::Min => { component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| { Ok([e1.min(e2)]) }) } crate::MathFunction::Max => { component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| { Ok([e1.max(e2)]) }) } crate::MathFunction::Clamp => { component_wise_scalar!( self, span, [arg, arg1.unwrap(), arg2.unwrap()], |e, low, high| { if low > high { Err(ConstantEvaluatorError::InvalidClamp) } else { Ok([e.clamp(low, high)]) } } ) } crate::MathFunction::Saturate => component_wise_float(self, span, [arg], |e| match e { Float::F16([e]) => Ok(Float::F16( [e.clamp(f16::from_f32(0.0), f16::from_f32(1.0))], )), Float::F32([e]) => Ok(Float::F32([e.clamp(0., 1.)])), Float::Abstract([e]) => Ok(Float::Abstract([e.clamp(0., 1.)])), }), // trigonometry crate::MathFunction::Cos => { component_wise_float!(self, span, [arg], |e| { Ok([e.cos()]) }) } crate::MathFunction::Cosh => { component_wise_float!(self, span, [arg], |e| { let result = e.cosh(); if result.is_finite() { Ok([result]) } else { Err(ConstantEvaluatorError::Overflow("cosh".into())) } }) } crate::MathFunction::Sin => { component_wise_float!(self, span, [arg], |e| { Ok([e.sin()]) }) } crate::MathFunction::Sinh => { component_wise_float!(self, span, [arg], |e| { let result = e.sinh(); if result.is_finite() { Ok([result]) } else { Err(ConstantEvaluatorError::Overflow("sinh".into())) } }) } crate::MathFunction::Tan => { component_wise_float!(self, span, [arg], |e| { Ok([e.tan()]) }) } crate::MathFunction::Tanh => { component_wise_float!(self, span, [arg], |e| { Ok([e.tanh()]) }) } crate::MathFunction::Acos => { component_wise_float!(self, span, [arg], |e| { if e.abs() <= One::one() { Ok([e.acos()]) } else { Err(ConstantEvaluatorError::InvalidMathArgValue("acos".into())) } }) } crate::MathFunction::Asin => { component_wise_float!(self, span, [arg], |e| { if e.abs() <= One::one() { Ok([e.asin()]) } else { Err(ConstantEvaluatorError::InvalidMathArgValue("asin".into())) } }) } crate::MathFunction::Atan => { component_wise_float!(self, span, [arg], |e| { Ok([e.atan()]) }) } crate::MathFunction::Atan2 => { component_wise_float!(self, span, [arg, arg1.unwrap()], |y, x| { Ok([y.atan2(x)]) }) } crate::MathFunction::Asinh => component_wise_float(self, span, [arg], |e| match e { Float::Abstract([e]) => Ok(Float::Abstract([libm::asinh(e)])), Float::F32([e]) => Ok(Float::F32([(e as f64).asinh() as f32])), Float::F16([e]) => Ok(Float::F16([e.asinh()])), }), crate::MathFunction::Acosh => { component_wise_float!(self, span, [arg], |e| { Ok([e.acosh()]) }) } crate::MathFunction::Atanh => { component_wise_float!(self, span, [arg], |e| { if e.abs() < One::one() { Ok([e.atanh()]) } else { Err(ConstantEvaluatorError::InvalidMathArgValue("atanh".into())) } }) } crate::MathFunction::Radians => { component_wise_float!(self, span, [arg], |e1| { Ok([e1.to_radians()]) }) } crate::MathFunction::Degrees => { component_wise_float!(self, span, [arg], |e| { let result = e.to_degrees(); if result.is_finite() { Ok([result]) } else { Err(ConstantEvaluatorError::Overflow("degrees".into())) } }) } // decomposition crate::MathFunction::Ceil => { component_wise_float!(self, span, [arg], |e| { Ok([e.ceil()]) }) } crate::MathFunction::Floor => { component_wise_float!(self, span, [arg], |e| { Ok([e.floor()]) }) } crate::MathFunction::Round => { component_wise_float(self, span, [arg], |e| match e { Float::Abstract([e]) => Ok(Float::Abstract([libm::rint(e)])), Float::F32([e]) => Ok(Float::F32([libm::rintf(e)])), Float::F16([e]) => { // TODO: `round_ties_even` is not available on `half::f16` yet. // // This polyfill is shamelessly [~~stolen from~~ inspired by `ndarray-image`][polyfill source], // which has licensing compatible with ours. See also // . // // [polyfill source]: https://github.com/imeka/ndarray-ndimage/blob/8b14b4d6ecfbc96a8a052f802e342a7049c68d8f/src/lib.rs#L98 fn round_ties_even(x: f64) -> f64 { let i = x as i64; let f = (x - i as f64).abs(); if f == 0.5 { if i & 1 == 1 { // -1.5, 1.5, 3.5, ... (x.abs() + 0.5).copysign(x) } else { (x.abs() - 0.5).copysign(x) } } else { x.round() } } Ok(Float::F16([(f16::from_f64(round_ties_even(f64::from(e))))])) } }) } crate::MathFunction::Fract => { component_wise_float!(self, span, [arg], |e| { // N.B., Rust's definition of `fract` is `e - e.trunc()`, so we can't use that // here. Ok([e - e.floor()]) }) } crate::MathFunction::Trunc => { component_wise_float!(self, span, [arg], |e| { Ok([e.trunc()]) }) } // exponent crate::MathFunction::Exp => { component_wise_float!(self, span, [arg], |e| { let result = e.exp(); if result.is_finite() { Ok([result]) } else { Err(ConstantEvaluatorError::Overflow("exp".into())) } }) } crate::MathFunction::Exp2 => { component_wise_float!(self, span, [arg], |e| { let result = e.exp2(); if result.is_finite() { Ok([result]) } else { Err(ConstantEvaluatorError::Overflow("exp2".into())) } }) } crate::MathFunction::Log => { component_wise_float!(self, span, [arg], |e| { if e > Zero::zero() { Ok([e.ln()]) } else { Err(ConstantEvaluatorError::InvalidMathArgValue("log".into())) } }) } crate::MathFunction::Log2 => { component_wise_float!(self, span, [arg], |e| { if e > Zero::zero() { Ok([e.log2()]) } else { Err(ConstantEvaluatorError::InvalidMathArgValue("log2".into())) } }) } crate::MathFunction::Pow => { component_wise_float!(self, span, [arg, arg1.unwrap()], |e1, e2| { Ok([e1.powf(e2)]) }) } // computational crate::MathFunction::Sign => { component_wise_signed!(self, span, [arg], |e| { Ok([if e.is_zero() { Zero::zero() } else { e.signum() }]) }) } crate::MathFunction::Fma => { component_wise_float!( self, span, [arg, arg1.unwrap(), arg2.unwrap()], |e1, e2, e3| { Ok([e1.mul_add(e2, e3)]) } ) } crate::MathFunction::Step => { component_wise_float(self, span, [arg, arg1.unwrap()], |x| match x { Float::Abstract([edge, x]) => { Ok(Float::Abstract([if edge <= x { 1.0 } else { 0.0 }])) } Float::F32([edge, x]) => Ok(Float::F32([if edge <= x { 1.0 } else { 0.0 }])), Float::F16([edge, x]) => Ok(Float::F16([if edge <= x { f16::one() } else { f16::zero() }])), }) } crate::MathFunction::Sqrt => { component_wise_float!(self, span, [arg], |e| { if e >= Zero::zero() { Ok([e.sqrt()]) } else { Err(ConstantEvaluatorError::InvalidMathArgValue("sqrt".into())) } }) } crate::MathFunction::InverseSqrt => { component_wise_float(self, span, [arg], |e| match e { Float::Abstract([e]) => Ok(Float::Abstract([1. / e.sqrt()])), Float::F32([e]) => Ok(Float::F32([1. / e.sqrt()])), Float::F16([e]) => Ok(Float::F16([f16::from_f32(1. / f32::from(e).sqrt())])), }) } // bits crate::MathFunction::CountTrailingZeros => { component_wise_concrete_int!(self, span, [arg], |e| { #[allow(clippy::useless_conversion)] Ok([e .trailing_zeros() .try_into() .expect("bit count overflowed 32 bits, somehow!?")]) }) } crate::MathFunction::CountLeadingZeros => { component_wise_concrete_int!(self, span, [arg], |e| { #[allow(clippy::useless_conversion)] Ok([e .leading_zeros() .try_into() .expect("bit count overflowed 32 bits, somehow!?")]) }) } crate::MathFunction::CountOneBits => { component_wise_concrete_int!(self, span, [arg], |e| { #[allow(clippy::useless_conversion)] Ok([e .count_ones() .try_into() .expect("bit count overflowed 32 bits, somehow!?")]) }) } crate::MathFunction::ReverseBits => { component_wise_concrete_int!(self, span, [arg], |e| { Ok([e.reverse_bits()]) }) } crate::MathFunction::FirstTrailingBit => { component_wise_concrete_int(self, span, [arg], |ci| Ok(first_trailing_bit(ci))) } crate::MathFunction::FirstLeadingBit => { component_wise_concrete_int(self, span, [arg], |ci| Ok(first_leading_bit(ci))) } // vector crate::MathFunction::Dot4I8Packed => { self.packed_dot_product(arg, arg1.unwrap(), span, true) } crate::MathFunction::Dot4U8Packed => { self.packed_dot_product(arg, arg1.unwrap(), span, false) } crate::MathFunction::Cross => self.cross_product(arg, arg1.unwrap(), span), crate::MathFunction::Dot => { // https://www.w3.org/TR/WGSL/#dot-builtin let e1 = self.extract_vec(arg, false)?; let e2 = self.extract_vec(arg1.unwrap(), false)?; if e1.len() != e2.len() { return Err(ConstantEvaluatorError::InvalidMathArg); } fn float_dot_checked

(a: &[P], b: &[P]) -> Result where P: num_traits::Float, { let result = a .iter() .zip(b.iter()) .map(|(&aa, &bb)| aa * bb) .fold(P::zero(), |acc, x| acc + x); if result.is_finite() { Ok(result) } else { Err(ConstantEvaluatorError::Overflow("in dot built-in".into())) } } fn int_dot_checked

(a: &[P], b: &[P]) -> Result where P: num_traits::PrimInt + num_traits::CheckedAdd + num_traits::CheckedMul, { a.iter() .zip(b.iter()) .map(|(&aa, bb)| aa.checked_mul(bb)) .try_fold(P::zero(), |acc, x| { if let Some(x) = x { acc.checked_add(&x) } else { None } }) .ok_or(ConstantEvaluatorError::Overflow( "in dot built-in".to_string(), )) } fn int_dot_wrapping

(a: &[P], b: &[P]) -> P where P: num_traits::PrimInt + num_traits::WrappingAdd + num_traits::WrappingMul, { a.iter() .zip(b.iter()) .map(|(&aa, bb)| aa.wrapping_mul(bb)) .fold(P::zero(), |acc, x| acc.wrapping_add(&x)) } let result = match_literal_vector!(match (e1, e2) => Literal { Float => |e1, e2| { float_dot_checked(e1, e2)? }, AbstractInt => |e1, e2 | { int_dot_checked(e1, e2)? }, I32 => |e1, e2| { int_dot_wrapping(e1, e2) }, U32 => |e1, e2| { int_dot_wrapping(e1, e2) }, })?; self.register_evaluated_expr(Expression::Literal(result), span) } crate::MathFunction::Length => { // https://www.w3.org/TR/WGSL/#length-builtin let e1 = self.extract_vec(arg, true)?; fn float_length(e: &[F]) -> F where F: core::ops::Mul, F: num_traits::Float + iter::Sum, { if e.len() == 1 { // Avoids possible overflow in squaring e[0].abs() } else { e.iter().map(|&ei| ei * ei).sum::().sqrt() } } let result = match_literal_vector!(match e1 => Literal { Float => |e1| { float_length(e1) }, })?; self.register_evaluated_expr(Expression::Literal(result), span) } crate::MathFunction::Distance => { // https://www.w3.org/TR/WGSL/#distance-builtin let e1 = self.extract_vec(arg, true)?; let e2 = self.extract_vec(arg1.unwrap(), true)?; if e1.len() != e2.len() { return Err(ConstantEvaluatorError::InvalidMathArg); } fn float_distance(a: &[F], b: &[F]) -> F where F: core::ops::Mul, F: num_traits::Float + iter::Sum + core::ops::Sub, { if a.len() == 1 { // Avoids possible overflow in squaring (a[0] - b[0]).abs() } else { a.iter() .zip(b.iter()) .map(|(&aa, &bb)| aa - bb) .map(|ei| ei * ei) .sum::() .sqrt() } } let result = match_literal_vector!(match (e1, e2) => Literal { Float => |e1, e2| { float_distance(e1, e2) }, })?; self.register_evaluated_expr(Expression::Literal(result), span) } crate::MathFunction::Normalize => { // https://www.w3.org/TR/WGSL/#normalize-builtin let e1 = self.extract_vec(arg, true)?; fn float_normalize(e: &[F]) -> ArrayVec where F: core::ops::Mul, F: num_traits::Float + iter::Sum, { let len = e.iter().map(|&ei| ei * ei).sum::().sqrt(); let mut out = ArrayVec::new(); for &ei in e { out.push(ei / len); } out } let result = match_literal_vector!(match e1 => LiteralVector { Float => |e1| { float_normalize(e1) }, })?; result.register_as_evaluated_expr(self, span) } // unimplemented crate::MathFunction::Modf | crate::MathFunction::Frexp | crate::MathFunction::Ldexp | crate::MathFunction::Outer | crate::MathFunction::FaceForward | crate::MathFunction::Reflect | crate::MathFunction::Refract | crate::MathFunction::Mix | crate::MathFunction::SmoothStep | crate::MathFunction::Inverse | crate::MathFunction::Transpose | crate::MathFunction::Determinant | crate::MathFunction::QuantizeToF16 | crate::MathFunction::ExtractBits | crate::MathFunction::InsertBits | crate::MathFunction::Pack4x8snorm | crate::MathFunction::Pack4x8unorm | crate::MathFunction::Pack2x16snorm | crate::MathFunction::Pack2x16unorm | crate::MathFunction::Pack2x16float | crate::MathFunction::Pack4xI8 | crate::MathFunction::Pack4xU8 | crate::MathFunction::Pack4xI8Clamp | crate::MathFunction::Pack4xU8Clamp | crate::MathFunction::Unpack4x8snorm | crate::MathFunction::Unpack4x8unorm | crate::MathFunction::Unpack2x16snorm | crate::MathFunction::Unpack2x16unorm | crate::MathFunction::Unpack2x16float | crate::MathFunction::Unpack4xI8 | crate::MathFunction::Unpack4xU8 => Err(ConstantEvaluatorError::NotImplemented( format!("{fun:?} built-in function"), )), } } /// Dot product of two packed vectors (`dot4I8Packed` and `dot4U8Packed`) fn packed_dot_product( &mut self, a: Handle, b: Handle, span: Span, signed: bool, ) -> Result, ConstantEvaluatorError> { let Expression::Literal(Literal::U32(a)) = self.expressions[a] else { return Err(ConstantEvaluatorError::InvalidMathArg); }; let Expression::Literal(Literal::U32(b)) = self.expressions[b] else { return Err(ConstantEvaluatorError::InvalidMathArg); }; let result = if signed { Literal::I32( (a & 0xFF) as i8 as i32 * (b & 0xFF) as i8 as i32 + ((a >> 8) & 0xFF) as i8 as i32 * ((b >> 8) & 0xFF) as i8 as i32 + ((a >> 16) & 0xFF) as i8 as i32 * ((b >> 16) & 0xFF) as i8 as i32 + ((a >> 24) & 0xFF) as i8 as i32 * ((b >> 24) & 0xFF) as i8 as i32, ) } else { Literal::U32( (a & 0xFF) * (b & 0xFF) + ((a >> 8) & 0xFF) * ((b >> 8) & 0xFF) + ((a >> 16) & 0xFF) * ((b >> 16) & 0xFF) + ((a >> 24) & 0xFF) * ((b >> 24) & 0xFF), ) }; self.register_evaluated_expr(Expression::Literal(result), span) } /// Vector cross product. fn cross_product( &mut self, a: Handle, b: Handle, span: Span, ) -> Result, ConstantEvaluatorError> { use Literal as Li; let (a, ty) = self.extract_vec_with_size::<3>(a)?; let (b, _) = self.extract_vec_with_size::<3>(b)?; let product = match (a, b) { ( [Li::AbstractInt(a0), Li::AbstractInt(a1), Li::AbstractInt(a2)], [Li::AbstractInt(b0), Li::AbstractInt(b1), Li::AbstractInt(b2)], ) => { // `cross` has no overload for AbstractInt, so AbstractInt // arguments are automatically converted to AbstractFloat. Since // `f64` has a much wider range than `i64`, there's no danger of // overflow here. let p = cross_product( [a0 as f64, a1 as f64, a2 as f64], [b0 as f64, b1 as f64, b2 as f64], ); [ Li::AbstractFloat(p[0]), Li::AbstractFloat(p[1]), Li::AbstractFloat(p[2]), ] } ( [Li::AbstractFloat(a0), Li::AbstractFloat(a1), Li::AbstractFloat(a2)], [Li::AbstractFloat(b0), Li::AbstractFloat(b1), Li::AbstractFloat(b2)], ) => { let p = cross_product([a0, a1, a2], [b0, b1, b2]); [ Li::AbstractFloat(p[0]), Li::AbstractFloat(p[1]), Li::AbstractFloat(p[2]), ] } ([Li::F16(a0), Li::F16(a1), Li::F16(a2)], [Li::F16(b0), Li::F16(b1), Li::F16(b2)]) => { let p = cross_product([a0, a1, a2], [b0, b1, b2]); [Li::F16(p[0]), Li::F16(p[1]), Li::F16(p[2])] } ([Li::F32(a0), Li::F32(a1), Li::F32(a2)], [Li::F32(b0), Li::F32(b1), Li::F32(b2)]) => { let p = cross_product([a0, a1, a2], [b0, b1, b2]); [Li::F32(p[0]), Li::F32(p[1]), Li::F32(p[2])] } ([Li::F64(a0), Li::F64(a1), Li::F64(a2)], [Li::F64(b0), Li::F64(b1), Li::F64(b2)]) => { let p = cross_product([a0, a1, a2], [b0, b1, b2]); [Li::F64(p[0]), Li::F64(p[1]), Li::F64(p[2])] } _ => return Err(ConstantEvaluatorError::InvalidMathArg), }; let p0 = self.register_evaluated_expr(Expression::Literal(product[0]), span)?; let p1 = self.register_evaluated_expr(Expression::Literal(product[1]), span)?; let p2 = self.register_evaluated_expr(Expression::Literal(product[2]), span)?; self.register_evaluated_expr( Expression::Compose { ty, components: vec![p0, p1, p2], }, span, ) } /// Extract the values of a `vecN` from `expr`. /// /// Return the value of `expr`, whose type is `vecN` for some /// vector size `N` and scalar `S`, as an array of `N` [`Literal`] /// values. /// /// Also return the type handle from the `Compose` expression. fn extract_vec_with_size( &mut self, expr: Handle, ) -> Result<([Literal; N], Handle), ConstantEvaluatorError> { let span = self.expressions.get_span(expr); let expr = self.eval_zero_value_and_splat(expr, span)?; let Expression::Compose { ty, ref components } = self.expressions[expr] else { return Err(ConstantEvaluatorError::InvalidMathArg); }; let mut value = [Literal::Bool(false); N]; for (component, elt) in crate::proc::flatten_compose(ty, components, self.expressions, self.types) .zip(value.iter_mut()) { let Expression::Literal(literal) = self.expressions[component] else { return Err(ConstantEvaluatorError::InvalidMathArg); }; *elt = literal; } Ok((value, ty)) } /// Extract the values of a `vecN` from `expr`. /// /// Return the value of `expr`, whose type is `vecN` for some /// vector size `N` and scalar `S`, as an array of `N` [`Literal`] /// values. /// /// Also return the type handle from the `Compose` expression. fn extract_vec( &mut self, expr: Handle, allow_single: bool, ) -> Result { let span = self.expressions.get_span(expr); let expr = self.eval_zero_value_and_splat(expr, span)?; match self.expressions[expr] { Expression::Literal(literal) if allow_single => { Ok(LiteralVector::from_literal(literal)) } Expression::Compose { ty, ref components } => { let mut components_out = ArrayVec::::new(); for expr in crate::proc::flatten_compose(ty, components, self.expressions, self.types) { match self.expressions[expr] { Expression::Literal(l) => components_out.push(l), _ => return Err(ConstantEvaluatorError::InvalidMathArg), } } LiteralVector::from_literal_vec(components_out) } _ => Err(ConstantEvaluatorError::InvalidMathArg), } } fn array_length( &mut self, array: Handle, span: Span, ) -> Result, ConstantEvaluatorError> { match self.expressions[array] { Expression::ZeroValue(ty) | Expression::Compose { ty, .. } => { match self.types[ty].inner { TypeInner::Array { size, .. } => match size { ArraySize::Constant(len) => { let expr = Expression::Literal(Literal::U32(len.get())); self.register_evaluated_expr(expr, span) } ArraySize::Pending(_) => Err(ConstantEvaluatorError::ArrayLengthOverridden), ArraySize::Dynamic => Err(ConstantEvaluatorError::ArrayLengthDynamic), }, _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg), } } _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg), } } fn access( &mut self, base: Handle, index: usize, span: Span, ) -> Result, ConstantEvaluatorError> { match self.expressions[base] { Expression::ZeroValue(ty) => { let ty_inner = &self.types[ty].inner; let components = ty_inner .components() .ok_or(ConstantEvaluatorError::InvalidAccessBase)?; if index >= components as usize { Err(ConstantEvaluatorError::InvalidAccessBase) } else { let ty_res = ty_inner .component_type(index) .ok_or(ConstantEvaluatorError::InvalidAccessIndex)?; let ty = match ty_res { crate::proc::TypeResolution::Handle(ty) => ty, crate::proc::TypeResolution::Value(inner) => { self.types.insert(Type { name: None, inner }, span) } }; self.register_evaluated_expr(Expression::ZeroValue(ty), span) } } Expression::Splat { size, value } => { if index >= size as usize { Err(ConstantEvaluatorError::InvalidAccessBase) } else { Ok(value) } } Expression::Compose { ty, ref components } => { let _ = self.types[ty] .inner .components() .ok_or(ConstantEvaluatorError::InvalidAccessBase)?; crate::proc::flatten_compose(ty, components, self.expressions, self.types) .nth(index) .ok_or(ConstantEvaluatorError::InvalidAccessIndex) } _ => Err(ConstantEvaluatorError::InvalidAccessBase), } } /// Lower [`ZeroValue`] and [`Splat`] expressions to [`Literal`] and [`Compose`] expressions. /// /// [`ZeroValue`]: Expression::ZeroValue /// [`Splat`]: Expression::Splat /// [`Literal`]: Expression::Literal /// [`Compose`]: Expression::Compose fn eval_zero_value_and_splat( &mut self, mut expr: Handle, span: Span, ) -> Result, ConstantEvaluatorError> { // If expr is a Compose expression, eliminate ZeroValue and Splat expressions for // each of its components. if let Expression::Compose { ty, ref components } = self.expressions[expr] { let components = components .clone() .iter() .map(|component| self.eval_zero_value_and_splat(*component, span)) .collect::>()?; expr = self.register_evaluated_expr(Expression::Compose { ty, components }, span)?; } // The result of the splat() for a Splat of a scalar ZeroValue is a // vector ZeroValue, so we must call eval_zero_value_impl() after // splat() in order to ensure we have no ZeroValues remaining. if let Expression::Splat { size, value } = self.expressions[expr] { expr = self.splat(value, size, span)?; } if let Expression::ZeroValue(ty) = self.expressions[expr] { expr = self.eval_zero_value_impl(ty, span)?; } Ok(expr) } /// Lower [`ZeroValue`] expressions to [`Literal`] and [`Compose`] expressions. /// /// [`ZeroValue`]: Expression::ZeroValue /// [`Literal`]: Expression::Literal /// [`Compose`]: Expression::Compose fn eval_zero_value( &mut self, expr: Handle, span: Span, ) -> Result, ConstantEvaluatorError> { match self.expressions[expr] { Expression::ZeroValue(ty) => self.eval_zero_value_impl(ty, span), _ => Ok(expr), } } /// Lower [`ZeroValue`] expressions to [`Literal`] and [`Compose`] expressions. /// /// [`ZeroValue`]: Expression::ZeroValue /// [`Literal`]: Expression::Literal /// [`Compose`]: Expression::Compose fn eval_zero_value_impl( &mut self, ty: Handle, span: Span, ) -> Result, ConstantEvaluatorError> { match self.types[ty].inner { TypeInner::Scalar(scalar) => { let expr = Expression::Literal( Literal::zero(scalar).ok_or(ConstantEvaluatorError::TypeNotConstructible)?, ); self.register_evaluated_expr(expr, span) } TypeInner::Vector { size, scalar } => { let scalar_ty = self.types.insert( Type { name: None, inner: TypeInner::Scalar(scalar), }, span, ); let el = self.eval_zero_value_impl(scalar_ty, span)?; let expr = Expression::Compose { ty, components: vec![el; size as usize], }; self.register_evaluated_expr(expr, span) } TypeInner::Matrix { columns, rows, scalar, } => { let vec_ty = self.types.insert( Type { name: None, inner: TypeInner::Vector { size: rows, scalar }, }, span, ); let el = self.eval_zero_value_impl(vec_ty, span)?; let expr = Expression::Compose { ty, components: vec![el; columns as usize], }; self.register_evaluated_expr(expr, span) } TypeInner::Array { base, size: ArraySize::Constant(size), .. } => { let el = self.eval_zero_value_impl(base, span)?; let expr = Expression::Compose { ty, components: vec![el; size.get() as usize], }; self.register_evaluated_expr(expr, span) } TypeInner::Struct { ref members, .. } => { let types: Vec<_> = members.iter().map(|m| m.ty).collect(); let mut components = Vec::with_capacity(members.len()); for ty in types { components.push(self.eval_zero_value_impl(ty, span)?); } let expr = Expression::Compose { ty, components }; self.register_evaluated_expr(expr, span) } _ => Err(ConstantEvaluatorError::TypeNotConstructible), } } /// Convert the scalar components of `expr` to `target`. /// /// Treat `span` as the location of the resulting expression. pub fn cast( &mut self, expr: Handle, target: crate::Scalar, span: Span, ) -> Result, ConstantEvaluatorError> { use crate::Scalar as Sc; let expr = self.eval_zero_value(expr, span)?; let make_error = || -> Result<_, ConstantEvaluatorError> { let from = format!("{:?} {:?}", expr, self.expressions[expr]); #[cfg(feature = "wgsl-in")] let to = target.to_wgsl_for_diagnostics(); #[cfg(not(feature = "wgsl-in"))] let to = format!("{target:?}"); Err(ConstantEvaluatorError::InvalidCastArg { from, to }) }; use crate::proc::type_methods::IntFloatLimits; let expr = match self.expressions[expr] { Expression::Literal(literal) => { let literal = match target { Sc::I32 => Literal::I32(match literal { Literal::I32(v) => v, Literal::U32(v) => v as i32, Literal::F32(v) => v.clamp(i32::min_float(), i32::max_float()) as i32, Literal::F16(v) => f16::to_i32(&v).unwrap(), //Only None on NaN or Inf Literal::Bool(v) => v as i32, Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => { return make_error(); } Literal::AbstractInt(v) => i32::try_from_abstract(v)?, Literal::AbstractFloat(v) => i32::try_from_abstract(v)?, }), Sc::U32 => Literal::U32(match literal { Literal::I32(v) => v as u32, Literal::U32(v) => v, Literal::F32(v) => v.clamp(u32::min_float(), u32::max_float()) as u32, // max(0) avoids None due to negative, therefore only None on NaN or Inf Literal::F16(v) => f16::to_u32(&v.max(f16::ZERO)).unwrap(), Literal::Bool(v) => v as u32, Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => { return make_error(); } Literal::AbstractInt(v) => u32::try_from_abstract(v)?, Literal::AbstractFloat(v) => u32::try_from_abstract(v)?, }), Sc::I64 => Literal::I64(match literal { Literal::I32(v) => v as i64, Literal::U32(v) => v as i64, Literal::F32(v) => v.clamp(i64::min_float(), i64::max_float()) as i64, Literal::Bool(v) => v as i64, Literal::F64(v) => v.clamp(i64::min_float(), i64::max_float()) as i64, Literal::I64(v) => v, Literal::U64(v) => v as i64, Literal::F16(v) => f16::to_i64(&v).unwrap(), //Only None on NaN or Inf Literal::AbstractInt(v) => i64::try_from_abstract(v)?, Literal::AbstractFloat(v) => i64::try_from_abstract(v)?, }), Sc::U64 => Literal::U64(match literal { Literal::I32(v) => v as u64, Literal::U32(v) => v as u64, Literal::F32(v) => v.clamp(u64::min_float(), u64::max_float()) as u64, Literal::Bool(v) => v as u64, Literal::F64(v) => v.clamp(u64::min_float(), u64::max_float()) as u64, Literal::I64(v) => v as u64, Literal::U64(v) => v, // max(0) avoids None due to negative, therefore only None on NaN or Inf Literal::F16(v) => f16::to_u64(&v.max(f16::ZERO)).unwrap(), Literal::AbstractInt(v) => u64::try_from_abstract(v)?, Literal::AbstractFloat(v) => u64::try_from_abstract(v)?, }), Sc::F16 => Literal::F16(match literal { Literal::F16(v) => v, Literal::F32(v) => f16::from_f32(v), Literal::F64(v) => f16::from_f64(v), Literal::Bool(v) => f16::from_u32(v as u32).unwrap(), Literal::I64(v) => f16::from_i64(v).unwrap(), Literal::U64(v) => f16::from_u64(v).unwrap(), Literal::I32(v) => f16::from_i32(v).unwrap(), Literal::U32(v) => f16::from_u32(v).unwrap(), Literal::AbstractFloat(v) => f16::try_from_abstract(v)?, Literal::AbstractInt(v) => f16::try_from_abstract(v)?, }), Sc::F32 => Literal::F32(match literal { Literal::I32(v) => v as f32, Literal::U32(v) => v as f32, Literal::F32(v) => v, Literal::Bool(v) => v as u32 as f32, Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => { return make_error(); } Literal::F16(v) => f16::to_f32(v), Literal::AbstractInt(v) => f32::try_from_abstract(v)?, Literal::AbstractFloat(v) => f32::try_from_abstract(v)?, }), Sc::F64 => Literal::F64(match literal { Literal::I32(v) => v as f64, Literal::U32(v) => v as f64, Literal::F16(v) => f16::to_f64(v), Literal::F32(v) => v as f64, Literal::F64(v) => v, Literal::Bool(v) => v as u32 as f64, Literal::I64(_) | Literal::U64(_) => return make_error(), Literal::AbstractInt(v) => f64::try_from_abstract(v)?, Literal::AbstractFloat(v) => f64::try_from_abstract(v)?, }), Sc::BOOL => Literal::Bool(match literal { Literal::I32(v) => v != 0, Literal::U32(v) => v != 0, Literal::F32(v) => v != 0.0, Literal::F16(v) => v != f16::zero(), Literal::Bool(v) => v, Literal::AbstractInt(v) => v != 0, Literal::AbstractFloat(v) => v != 0.0, Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => { return make_error(); } }), Sc::ABSTRACT_FLOAT => Literal::AbstractFloat(match literal { Literal::AbstractInt(v) => { // Overflow is forbidden, but inexact conversions // are fine. The range of f64 is far larger than // that of i64, so we don't have to check anything // here. v as f64 } Literal::AbstractFloat(v) => v, _ => return make_error(), }), Sc::ABSTRACT_INT => Literal::AbstractInt(match literal { Literal::AbstractInt(v) => v, _ => return make_error(), }), _ => { log::debug!("Constant evaluator refused to convert value to {target:?}"); return make_error(); } }; Expression::Literal(literal) } Expression::Compose { ty, components: ref src_components, } => { let ty_inner = match self.types[ty].inner { TypeInner::Vector { size, .. } => TypeInner::Vector { size, scalar: target, }, TypeInner::Matrix { columns, rows, .. } => TypeInner::Matrix { columns, rows, scalar: target, }, _ => return make_error(), }; let mut components = src_components.clone(); for component in &mut components { *component = self.cast(*component, target, span)?; } let ty = self.types.insert( Type { name: None, inner: ty_inner, }, span, ); Expression::Compose { ty, components } } Expression::Splat { size, value } => { let value_span = self.expressions.get_span(value); let cast_value = self.cast(value, target, value_span)?; Expression::Splat { size, value: cast_value, } } _ => return make_error(), }; self.register_evaluated_expr(expr, span) } /// Convert the scalar leaves of `expr` to `target`, handling arrays. /// /// `expr` must be a `Compose` expression whose type is a scalar, vector, /// matrix, or nested arrays of such. /// /// This is basically the same as the [`cast`] method, except that that /// should only handle Naga [`As`] expressions, which cannot convert arrays. /// /// Treat `span` as the location of the resulting expression. /// /// [`cast`]: ConstantEvaluator::cast /// [`As`]: crate::Expression::As pub fn cast_array( &mut self, expr: Handle, target: crate::Scalar, span: Span, ) -> Result, ConstantEvaluatorError> { let expr = self.check_and_get(expr)?; let Expression::Compose { ty, ref components } = self.expressions[expr] else { return self.cast(expr, target, span); }; let TypeInner::Array { base: _, size, stride: _, } = self.types[ty].inner else { return self.cast(expr, target, span); }; let mut components = components.clone(); for component in &mut components { *component = self.cast_array(*component, target, span)?; } let first = components.first().unwrap(); let new_base = match self.resolve_type(*first)? { crate::proc::TypeResolution::Handle(ty) => ty, crate::proc::TypeResolution::Value(inner) => { self.types.insert(Type { name: None, inner }, span) } }; let mut layouter = core::mem::take(self.layouter); layouter.update(self.to_ctx()).unwrap(); *self.layouter = layouter; let new_base_stride = self.layouter[new_base].to_stride(); let new_array_ty = self.types.insert( Type { name: None, inner: TypeInner::Array { base: new_base, size, stride: new_base_stride, }, }, span, ); let compose = Expression::Compose { ty: new_array_ty, components, }; self.register_evaluated_expr(compose, span) } fn unary_op( &mut self, op: UnaryOperator, expr: Handle, span: Span, ) -> Result, ConstantEvaluatorError> { let expr = self.eval_zero_value_and_splat(expr, span)?; let expr = match self.expressions[expr] { Expression::Literal(value) => Expression::Literal(match op { UnaryOperator::Negate => match value { Literal::I32(v) => Literal::I32(v.wrapping_neg()), Literal::I64(v) => Literal::I64(v.wrapping_neg()), Literal::F32(v) => Literal::F32(-v), Literal::F16(v) => Literal::F16(-v), Literal::F64(v) => Literal::F64(-v), Literal::AbstractInt(v) => Literal::AbstractInt(v.wrapping_neg()), Literal::AbstractFloat(v) => Literal::AbstractFloat(-v), _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg), }, UnaryOperator::LogicalNot => match value { Literal::Bool(v) => Literal::Bool(!v), _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg), }, UnaryOperator::BitwiseNot => match value { Literal::I32(v) => Literal::I32(!v), Literal::I64(v) => Literal::I64(!v), Literal::U32(v) => Literal::U32(!v), Literal::U64(v) => Literal::U64(!v), Literal::AbstractInt(v) => Literal::AbstractInt(!v), _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg), }, }), Expression::Compose { ty, components: ref src_components, } => { match self.types[ty].inner { TypeInner::Vector { .. } | TypeInner::Matrix { .. } => (), _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg), } let mut components = src_components.clone(); for component in &mut components { *component = self.unary_op(op, *component, span)?; } Expression::Compose { ty, components } } _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg), }; self.register_evaluated_expr(expr, span) } fn binary_op( &mut self, op: BinaryOperator, left: Handle, right: Handle, span: Span, ) -> Result, ConstantEvaluatorError> { let left = self.eval_zero_value_and_splat(left, span)?; let right = self.eval_zero_value_and_splat(right, span)?; // Note: in most cases constant evaluation checks for overflow, but for // i32/u32, it uses wrapping arithmetic. See // . let expr = match (&self.expressions[left], &self.expressions[right]) { (&Expression::Literal(left_value), &Expression::Literal(right_value)) => { if !matches!(op, BinaryOperator::ShiftLeft | BinaryOperator::ShiftRight) && core::mem::discriminant(&left_value) != core::mem::discriminant(&right_value) { return Err(ConstantEvaluatorError::InvalidBinaryOpArgs); } let literal = match op { BinaryOperator::Equal => Literal::Bool(left_value == right_value), BinaryOperator::NotEqual => Literal::Bool(left_value != right_value), BinaryOperator::Less => Literal::Bool(left_value < right_value), BinaryOperator::LessEqual => Literal::Bool(left_value <= right_value), BinaryOperator::Greater => Literal::Bool(left_value > right_value), BinaryOperator::GreaterEqual => Literal::Bool(left_value >= right_value), _ => match (left_value, right_value) { (Literal::I32(a), Literal::I32(b)) => Literal::I32(match op { BinaryOperator::Add => a.wrapping_add(b), BinaryOperator::Subtract => a.wrapping_sub(b), BinaryOperator::Multiply => a.wrapping_mul(b), BinaryOperator::Divide => { if b == 0 { return Err(ConstantEvaluatorError::DivisionByZero); } else { a.wrapping_div(b) } } BinaryOperator::Modulo => { if b == 0 { return Err(ConstantEvaluatorError::RemainderByZero); } else { a.wrapping_rem(b) } } BinaryOperator::And => a & b, BinaryOperator::ExclusiveOr => a ^ b, BinaryOperator::InclusiveOr => a | b, _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs), }), (Literal::I32(a), Literal::U32(b)) => Literal::I32(match op { BinaryOperator::ShiftLeft => { if (if a.is_negative() { !a } else { a }).leading_zeros() <= b { return Err(ConstantEvaluatorError::Overflow("<<".to_string())); } a.checked_shl(b) .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)? } BinaryOperator::ShiftRight => a .checked_shr(b) .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?, _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs), }), (Literal::U32(a), Literal::U32(b)) => Literal::U32(match op { BinaryOperator::Add => a.wrapping_add(b), BinaryOperator::Subtract => a.wrapping_sub(b), BinaryOperator::Multiply => a.wrapping_mul(b), BinaryOperator::Divide => a .checked_div(b) .ok_or(ConstantEvaluatorError::DivisionByZero)?, BinaryOperator::Modulo => a .checked_rem(b) .ok_or(ConstantEvaluatorError::RemainderByZero)?, BinaryOperator::And => a & b, BinaryOperator::ExclusiveOr => a ^ b, BinaryOperator::InclusiveOr => a | b, BinaryOperator::ShiftLeft => a .checked_mul( 1u32.checked_shl(b) .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?, ) .ok_or(ConstantEvaluatorError::Overflow("<<".to_string()))?, BinaryOperator::ShiftRight => a .checked_shr(b) .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?, _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs), }), (Literal::F32(a), Literal::F32(b)) => Literal::F32(match op { BinaryOperator::Add => a + b, BinaryOperator::Subtract => a - b, BinaryOperator::Multiply => a * b, BinaryOperator::Divide => a / b, BinaryOperator::Modulo => a % b, _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs), }), (Literal::AbstractInt(a), Literal::U32(b)) => { Literal::AbstractInt(match op { BinaryOperator::ShiftLeft => { if (if a.is_negative() { !a } else { a }).leading_zeros() <= b { return Err(ConstantEvaluatorError::Overflow( "<<".to_string(), )); } a.checked_shl(b).unwrap_or(0) } BinaryOperator::ShiftRight => a.checked_shr(b).unwrap_or(0), _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs), }) } (Literal::F16(a), Literal::F16(b)) => { let result = match op { BinaryOperator::Add => a + b, BinaryOperator::Subtract => a - b, BinaryOperator::Multiply => a * b, BinaryOperator::Divide => a / b, BinaryOperator::Modulo => a % b, _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs), }; if !result.is_finite() { return Err(ConstantEvaluatorError::Overflow(format!("{op:?}"))); } Literal::F16(result) } (Literal::AbstractInt(a), Literal::AbstractInt(b)) => { Literal::AbstractInt(match op { BinaryOperator::Add => a.checked_add(b).ok_or_else(|| { ConstantEvaluatorError::Overflow("addition".into()) })?, BinaryOperator::Subtract => a.checked_sub(b).ok_or_else(|| { ConstantEvaluatorError::Overflow("subtraction".into()) })?, BinaryOperator::Multiply => a.checked_mul(b).ok_or_else(|| { ConstantEvaluatorError::Overflow("multiplication".into()) })?, BinaryOperator::Divide => a.checked_div(b).ok_or_else(|| { if b == 0 { ConstantEvaluatorError::DivisionByZero } else { ConstantEvaluatorError::Overflow("division".into()) } })?, BinaryOperator::Modulo => a.checked_rem(b).ok_or_else(|| { if b == 0 { ConstantEvaluatorError::RemainderByZero } else { ConstantEvaluatorError::Overflow("remainder".into()) } })?, BinaryOperator::And => a & b, BinaryOperator::ExclusiveOr => a ^ b, BinaryOperator::InclusiveOr => a | b, _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs), }) } (Literal::AbstractFloat(a), Literal::AbstractFloat(b)) => { let result = match op { BinaryOperator::Add => a + b, BinaryOperator::Subtract => a - b, BinaryOperator::Multiply => a * b, BinaryOperator::Divide => a / b, BinaryOperator::Modulo => a % b, _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs), }; if !result.is_finite() { return Err(ConstantEvaluatorError::Overflow(format!("{op:?}"))); } Literal::AbstractFloat(result) } (Literal::Bool(a), Literal::Bool(b)) => Literal::Bool(match op { BinaryOperator::LogicalAnd => a && b, BinaryOperator::LogicalOr => a || b, _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs), }), _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs), }, }; Expression::Literal(literal) } ( &Expression::Compose { components: ref src_components, ty, }, &Expression::Literal(_), ) => { if !is_allowed_compose_literal_op(&self.types[ty].inner, op) { return Err(ConstantEvaluatorError::InvalidBinaryOpArgs); } let mut components = src_components.clone(); for component in &mut components { *component = self.binary_op(op, *component, right, span)?; } Expression::Compose { ty, components } } ( &Expression::Literal(_), &Expression::Compose { components: ref src_components, ty, }, ) => { if !is_allowed_compose_literal_op(&self.types[ty].inner, op) { return Err(ConstantEvaluatorError::InvalidBinaryOpArgs); } let mut components = src_components.clone(); for component in &mut components { *component = self.binary_op(op, left, *component, span)?; } Expression::Compose { ty, components } } ( &Expression::Compose { components: ref left_components, ty: left_ty, }, &Expression::Compose { components: ref right_components, ty: right_ty, }, ) => { // We have to make a copy of the component lists, because the // call to `binary_op_vector` needs `&mut self`, but `self` owns // the component lists. let left_flattened = crate::proc::flatten_compose( left_ty, left_components, self.expressions, self.types, ) .collect::>(); let right_flattened = crate::proc::flatten_compose( right_ty, right_components, self.expressions, self.types, ) .collect::>(); self.binary_op_compose( op, &left_flattened, &right_flattened, left_ty, right_ty, span, )? } _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs), }; return self.register_evaluated_expr(expr, span); fn is_allowed_compose_literal_op(compose_ty: &TypeInner, op: BinaryOperator) -> bool { let is_numeric_vec = matches!( compose_ty, TypeInner::Vector { scalar, .. } if scalar.kind != ScalarKind::Bool ); let is_allowed_vec_scalar_op = matches!( op, BinaryOperator::Add | BinaryOperator::Subtract | BinaryOperator::Multiply | BinaryOperator::Divide | BinaryOperator::Modulo ); let is_mat = matches!(compose_ty, TypeInner::Matrix { .. }); let is_allowed_mat_scalar_op = matches!(op, BinaryOperator::Multiply); is_numeric_vec && is_allowed_vec_scalar_op || is_mat && is_allowed_mat_scalar_op } } fn binary_op_compose( &mut self, op: BinaryOperator, left_components: &[Handle], right_components: &[Handle], left_ty: Handle, right_ty: Handle, span: Span, ) -> Result { match (&self.types[left_ty].inner, &self.types[right_ty].inner) { // Binary operation on vector-vector ( &TypeInner::Vector { size: left_size, .. }, &TypeInner::Vector { size: right_size, .. }, ) if left_size == right_size => self.binary_op_vector( op, left_size, left_components, right_components, left_ty, span, ), // Binary operation on vector-matrix ( &TypeInner::Vector { size, .. }, &TypeInner::Matrix { columns, rows, scalar, }, ) if op == BinaryOperator::Multiply && size == rows => self.multiply_vector_matrix( left_components, right_components, columns, scalar, span, ), // Binary operation on matrix-vector ( &TypeInner::Matrix { columns, rows, scalar, }, &TypeInner::Vector { size, .. }, ) if op == BinaryOperator::Multiply && size == columns => { self.multiply_matrix_vector(left_components, right_components, rows, scalar, span) } // Binary operation on matrix-matrix ( &TypeInner::Matrix { columns: left_columns, rows: left_rows, scalar, }, &TypeInner::Matrix { columns: right_columns, rows: right_rows, .. }, ) => match op { BinaryOperator::Add | BinaryOperator::Subtract if left_columns == right_columns && left_rows == right_rows => { let components = left_components .iter() .zip(right_components) .map(|(&left, &right)| self.binary_op(op, left, right, span)) .collect::, _>>()?; Ok(Expression::Compose { ty: left_ty, components, }) } BinaryOperator::Multiply if left_columns == right_rows => self .multiply_matrix_matrix( left_components, right_components, left_rows, right_columns, scalar, span, ), _ => Err(ConstantEvaluatorError::InvalidBinaryOpArgs), }, _ => Err(ConstantEvaluatorError::InvalidBinaryOpArgs), } } fn binary_op_vector( &mut self, op: BinaryOperator, size: crate::VectorSize, left_components: &[Handle], right_components: &[Handle], left_ty: Handle, span: Span, ) -> Result { let ty = match op { // Relational operators produce vectors of booleans. BinaryOperator::Equal | BinaryOperator::NotEqual | BinaryOperator::Less | BinaryOperator::LessEqual | BinaryOperator::Greater | BinaryOperator::GreaterEqual => self.types.insert( Type { name: None, inner: TypeInner::Vector { size, scalar: crate::Scalar::BOOL, }, }, span, ), // Other operators produce the same type as their left // operand. BinaryOperator::Add | BinaryOperator::Subtract | BinaryOperator::Multiply | BinaryOperator::Divide | BinaryOperator::Modulo | BinaryOperator::And | BinaryOperator::ExclusiveOr | BinaryOperator::InclusiveOr | BinaryOperator::ShiftLeft | BinaryOperator::ShiftRight => left_ty, BinaryOperator::LogicalAnd | BinaryOperator::LogicalOr => { // Not supported on vectors return Err(ConstantEvaluatorError::InvalidBinaryOpArgs); } }; let components = left_components .iter() .zip(right_components) .map(|(&left, &right)| self.binary_op(op, left, right, span)) .collect::, _>>()?; Ok(Expression::Compose { ty, components }) } fn multiply_vector_matrix( &mut self, vec_components: &[Handle], mat_components: &[Handle], mat_columns: crate::VectorSize, scalar: crate::Scalar, span: Span, ) -> Result { let ty = self.types.insert( Type { name: None, inner: TypeInner::Vector { size: mat_columns, scalar, }, }, span, ); let components = mat_components .iter() .map(|&column| { let Expression::Compose { ref components, .. } = self.expressions[column] else { unreachable!() }; self.dot_exprs( vec_components.iter().cloned(), components.clone().into_iter(), span, ) }) .collect::, _>>()?; Ok(Expression::Compose { ty, components }) } fn multiply_matrix_vector( &mut self, mat_components: &[Handle], vec_components: &[Handle], mat_rows: crate::VectorSize, scalar: crate::Scalar, span: Span, ) -> Result { let ty = self.types.insert( Type { name: None, inner: TypeInner::Vector { size: mat_rows, scalar, }, }, span, ); let flatten = self.flatten_matrix(mat_components); let nr = mat_rows as usize; let components = (0..nr) .map(|r| { let row = flatten.iter().skip(r).step_by(nr).cloned(); self.dot_exprs(row, vec_components.iter().cloned(), span) }) .collect::, _>>()?; Ok(Expression::Compose { ty, components }) } fn multiply_matrix_matrix( &mut self, left_components: &[Handle], right_components: &[Handle], left_rows: crate::VectorSize, right_columns: crate::VectorSize, scalar: crate::Scalar, span: Span, ) -> Result { let left_nc = left_components.len(); let left_nr = left_rows as usize; let right_nc = right_columns as usize; let right_nr = left_nc; let mut result = Vec::with_capacity(right_nc); let result_ty = self.types.insert( Type { name: None, inner: TypeInner::Matrix { columns: right_columns, rows: left_rows, scalar, }, }, span, ); let result_column_ty = self.types.insert( Type { name: None, inner: TypeInner::Vector { size: left_rows, scalar, }, }, span, ); let left_flattened = self.flatten_matrix(left_components); let right_flattened = self.flatten_matrix(right_components); for c in 0..right_nc { let result_column = (0..left_nr) .map(|r| { let row = left_flattened.iter().skip(r).step_by(left_nr); let column = right_flattened.iter().skip(c * right_nr).take(right_nr); self.dot_exprs(row.cloned(), column.cloned(), span) }) .collect::, _>>()?; let expr = Expression::Compose { ty: result_column_ty, components: result_column, }; let handle = self.register_evaluated_expr(expr, span)?; result.push(handle); } Ok(Expression::Compose { ty: result_ty, components: result, }) } fn flatten_matrix(&self, columns: &[Handle]) -> ArrayVec, 16> { let mut flattened = ArrayVec::<_, 16>::new(); for &column in columns { let Expression::Compose { ref components, .. } = self.expressions[column] else { unreachable!() }; flattened.extend(components.iter().cloned()); } flattened } fn dot_exprs( &mut self, left: impl Iterator>, right: impl Iterator>, span: Span, ) -> Result, ConstantEvaluatorError> { let mut acc = None; for (l, r) in left.zip(right) { let result = self.binary_op(BinaryOperator::Multiply, l, r, span)?; match acc.as_mut() { Some(acc) => *acc = self.binary_op(BinaryOperator::Add, *acc, result, span)?, None => acc = Some(result), } } Ok(acc.unwrap()) } fn relational( &mut self, fun: RelationalFunction, arg: Handle, span: Span, ) -> Result, ConstantEvaluatorError> { let arg = self.eval_zero_value_and_splat(arg, span)?; match fun { RelationalFunction::All | RelationalFunction::Any => match self.expressions[arg] { Expression::Literal(Literal::Bool(_)) => Ok(arg), Expression::Compose { ty, ref components } if matches!(self.types[ty].inner, TypeInner::Vector { .. }) => { let mut bool_components = ArrayVec::::new(); for component in crate::proc::flatten_compose(ty, components, self.expressions, self.types) { match self.expressions[component] { Expression::Literal(Literal::Bool(val)) => { bool_components.push(val); } _ => { return Err(ConstantEvaluatorError::InvalidRelationalArg(fun)); } } } let components = bool_components; let result = match fun { RelationalFunction::All => components.iter().all(|c| *c), RelationalFunction::Any => components.iter().any(|c| *c), _ => unreachable!(), }; self.register_evaluated_expr(Expression::Literal(Literal::Bool(result)), span) } _ => Err(ConstantEvaluatorError::InvalidRelationalArg(fun)), }, _ => Err(ConstantEvaluatorError::NotImplemented(format!( "{fun:?} built-in function" ))), } } /// Deep copy `expr` from `expressions` into `self.expressions`. /// /// Return the root of the new copy. /// /// This is used when we're evaluating expressions in a function's /// expression arena that refer to a constant: we need to copy the /// constant's value into the function's arena so we can operate on it. fn copy_from( &mut self, expr: Handle, expressions: &Arena, ) -> Result, ConstantEvaluatorError> { let span = expressions.get_span(expr); match expressions[expr] { ref expr @ (Expression::Literal(_) | Expression::Constant(_) | Expression::ZeroValue(_)) => self.register_evaluated_expr(expr.clone(), span), Expression::Compose { ty, ref components } => { let mut components = components.clone(); for component in &mut components { *component = self.copy_from(*component, expressions)?; } self.register_evaluated_expr(Expression::Compose { ty, components }, span) } Expression::Splat { size, value } => { let value = self.copy_from(value, expressions)?; self.register_evaluated_expr(Expression::Splat { size, value }, span) } _ => { log::debug!("copy_from: SubexpressionsAreNotConstant"); Err(ConstantEvaluatorError::SubexpressionsAreNotConstant) } } } /// Returns the total number of components, after flattening, of a vector compose expression. fn vector_compose_flattened_size( &self, components: &[Handle], ) -> Result { components .iter() .try_fold(0, |acc, c| -> Result<_, ConstantEvaluatorError> { let size = match *self.resolve_type(*c)?.inner_with(self.types) { TypeInner::Scalar(_) => 1, // We trust that the vector size of `component` is correct, // as it will have already been validated when `component` // was registered. TypeInner::Vector { size, .. } => size as usize, _ => return Err(ConstantEvaluatorError::InvalidVectorComposeComponent), }; Ok(acc + size) }) } fn register_evaluated_expr( &mut self, expr: Expression, span: Span, ) -> Result, ConstantEvaluatorError> { // It suffices to only check_literal_value() for `Literal` expressions, // since we only register one expression at a time, `Compose` // expressions can only refer to other expressions, and `ZeroValue` // expressions are always okay. if let Expression::Literal(literal) = expr { crate::valid::check_literal_value(literal)?; } // Ensure vector composes contain the correct number of components. We // do so here when each compose is registered to avoid having to deal // with the mess each time the compose is used in another expression. if let Expression::Compose { ty, ref components } = expr { if let TypeInner::Vector { size, scalar: _ } = self.types[ty].inner { let expected = size as usize; let actual = self.vector_compose_flattened_size(components)?; if expected != actual { return Err(ConstantEvaluatorError::InvalidVectorComposeLength { expected, actual, }); } } } Ok(self.append_expr(expr, span, ExpressionKind::Const)) } fn append_expr( &mut self, expr: Expression, span: Span, expr_type: ExpressionKind, ) -> Handle { let h = match self.behavior { Behavior::Wgsl( WgslRestrictions::Runtime(ref mut function_local_data) | WgslRestrictions::Const(Some(ref mut function_local_data)), ) | Behavior::Glsl(GlslRestrictions::Runtime(ref mut function_local_data)) => { let is_running = function_local_data.emitter.is_running(); let needs_pre_emit = expr.needs_pre_emit(); if is_running && needs_pre_emit { function_local_data .block .extend(function_local_data.emitter.finish(self.expressions)); let h = self.expressions.append(expr, span); function_local_data.emitter.start(self.expressions); h } else { self.expressions.append(expr, span) } } _ => self.expressions.append(expr, span), }; self.expression_kind_tracker.insert(h, expr_type); h } /// Resolve the type of `expr` if it is a constant expression. /// /// If `expr` was evaluated to a constant, returns its type. /// Otherwise, returns an error. fn resolve_type( &self, expr: Handle, ) -> Result { use crate::proc::TypeResolution as Tr; use crate::Expression as Ex; let resolution = match self.expressions[expr] { Ex::Literal(ref literal) => Tr::Value(literal.ty_inner()), Ex::Constant(c) => Tr::Handle(self.constants[c].ty), Ex::ZeroValue(ty) | Ex::Compose { ty, .. } => Tr::Handle(ty), Ex::Splat { size, value } => { let Tr::Value(TypeInner::Scalar(scalar)) = self.resolve_type(value)? else { return Err(ConstantEvaluatorError::SplatScalarOnly); }; Tr::Value(TypeInner::Vector { scalar, size }) } _ => { log::debug!("resolve_type: SubexpressionsAreNotConstant"); return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant); } }; Ok(resolution) } fn select( &mut self, reject: Handle, accept: Handle, condition: Handle, span: Span, ) -> Result, ConstantEvaluatorError> { let mut arg = |arg| self.eval_zero_value_and_splat(arg, span); let reject = arg(reject)?; let accept = arg(accept)?; let condition = arg(condition)?; let select_single_component = |this: &mut Self, reject_scalar, reject, accept, condition| { let accept = this.cast(accept, reject_scalar, span)?; if condition { Ok(accept) } else { Ok(reject) } }; match (&self.expressions[reject], &self.expressions[accept]) { (&Expression::Literal(reject_lit), &Expression::Literal(_accept_lit)) => { let reject_scalar = reject_lit.scalar(); let &Expression::Literal(Literal::Bool(condition)) = &self.expressions[condition] else { return Err(ConstantEvaluatorError::SelectScalarConditionNotABool); }; select_single_component(self, reject_scalar, reject, accept, condition) } ( &Expression::Compose { ty: reject_ty, components: ref reject_components, }, &Expression::Compose { ty: accept_ty, components: ref accept_components, }, ) => { let ty_deets = |ty| { let (size, scalar) = self.types[ty].inner.vector_size_and_scalar().unwrap(); (size.unwrap(), scalar) }; let expected_vec_size = { let [(reject_vec_size, _), (accept_vec_size, _)] = [reject_ty, accept_ty].map(ty_deets); if reject_vec_size != accept_vec_size { return Err(ConstantEvaluatorError::SelectVecRejectAcceptSizeMismatch { reject: reject_vec_size, accept: accept_vec_size, }); } reject_vec_size }; let condition_components = match self.expressions[condition] { Expression::Literal(Literal::Bool(condition)) => { vec![condition; (expected_vec_size as u8).into()] } Expression::Compose { ty: condition_ty, components: ref condition_components, } => { let (condition_vec_size, condition_scalar) = ty_deets(condition_ty); if condition_scalar.kind != ScalarKind::Bool { return Err(ConstantEvaluatorError::SelectConditionNotAVecBool); } if condition_vec_size != expected_vec_size { return Err(ConstantEvaluatorError::SelectConditionVecSizeMismatch); } condition_components .iter() .copied() .map(|component| match &self.expressions[component] { &Expression::Literal(Literal::Bool(condition)) => condition, _ => unreachable!(), }) .collect() } _ => return Err(ConstantEvaluatorError::SelectConditionNotAVecBool), }; let evaluated = Expression::Compose { ty: reject_ty, components: reject_components .clone() .into_iter() .zip(accept_components.clone().into_iter()) .zip(condition_components.into_iter()) .map(|((reject, accept), condition)| { let reject_scalar = match &self.expressions[reject] { &Expression::Literal(lit) => lit.scalar(), _ => unreachable!(), }; select_single_component(self, reject_scalar, reject, accept, condition) }) .collect::>()?, }; self.register_evaluated_expr(evaluated, span) } _ => Err(ConstantEvaluatorError::SelectAcceptRejectTypeMismatch), } } } fn first_trailing_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> { // NOTE: Bit indices for this built-in start at 0 at the "right" (or LSB). For example, a value // of 1 means the least significant bit is set. Therefore, an input of `0x[80 00…]` would // return a right-to-left bit index of 0. let trailing_zeros_to_bit_idx = |e: u32| -> u32 { match e { idx @ 0..=31 => idx, 32 => u32::MAX, _ => unreachable!(), } }; match concrete_int { ConcreteInt::U32([e]) => ConcreteInt::U32([trailing_zeros_to_bit_idx(e.trailing_zeros())]), ConcreteInt::I32([e]) => { ConcreteInt::I32([trailing_zeros_to_bit_idx(e.trailing_zeros()) as i32]) } } } #[test] fn first_trailing_bit_smoke() { assert_eq!( first_trailing_bit(ConcreteInt::I32([0])), ConcreteInt::I32([-1]) ); assert_eq!( first_trailing_bit(ConcreteInt::I32([1])), ConcreteInt::I32([0]) ); assert_eq!( first_trailing_bit(ConcreteInt::I32([2])), ConcreteInt::I32([1]) ); assert_eq!( first_trailing_bit(ConcreteInt::I32([-1])), ConcreteInt::I32([0]), ); assert_eq!( first_trailing_bit(ConcreteInt::I32([i32::MIN])), ConcreteInt::I32([31]), ); assert_eq!( first_trailing_bit(ConcreteInt::I32([i32::MAX])), ConcreteInt::I32([0]), ); for idx in 0..32 { assert_eq!( first_trailing_bit(ConcreteInt::I32([1 << idx])), ConcreteInt::I32([idx]) ) } assert_eq!( first_trailing_bit(ConcreteInt::U32([0])), ConcreteInt::U32([u32::MAX]) ); assert_eq!( first_trailing_bit(ConcreteInt::U32([1])), ConcreteInt::U32([0]) ); assert_eq!( first_trailing_bit(ConcreteInt::U32([2])), ConcreteInt::U32([1]) ); assert_eq!( first_trailing_bit(ConcreteInt::U32([1 << 31])), ConcreteInt::U32([31]), ); assert_eq!( first_trailing_bit(ConcreteInt::U32([u32::MAX])), ConcreteInt::U32([0]), ); for idx in 0..32 { assert_eq!( first_trailing_bit(ConcreteInt::U32([1 << idx])), ConcreteInt::U32([idx]) ) } } fn first_leading_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> { // NOTE: Bit indices for this built-in start at 0 at the "right" (or LSB). For example, 1 means // the least significant bit is set. Therefore, an input of 1 would return a right-to-left bit // index of 0. let rtl_to_ltr_bit_idx = |e: u32| -> u32 { match e { idx @ 0..=31 => 31 - idx, 32 => u32::MAX, _ => unreachable!(), } }; match concrete_int { ConcreteInt::I32([e]) => ConcreteInt::I32([{ let rtl_bit_index = if e.is_negative() { e.leading_ones() } else { e.leading_zeros() }; rtl_to_ltr_bit_idx(rtl_bit_index) as i32 }]), ConcreteInt::U32([e]) => ConcreteInt::U32([rtl_to_ltr_bit_idx(e.leading_zeros())]), } } #[test] fn first_leading_bit_smoke() { assert_eq!( first_leading_bit(ConcreteInt::I32([-1])), ConcreteInt::I32([-1]) ); assert_eq!( first_leading_bit(ConcreteInt::I32([0])), ConcreteInt::I32([-1]) ); assert_eq!( first_leading_bit(ConcreteInt::I32([1])), ConcreteInt::I32([0]) ); assert_eq!( first_leading_bit(ConcreteInt::I32([-2])), ConcreteInt::I32([0]) ); assert_eq!( first_leading_bit(ConcreteInt::I32([1234 + 4567])), ConcreteInt::I32([12]) ); assert_eq!( first_leading_bit(ConcreteInt::I32([i32::MAX])), ConcreteInt::I32([30]) ); assert_eq!( first_leading_bit(ConcreteInt::I32([i32::MIN])), ConcreteInt::I32([30]) ); // NOTE: Ignore the sign bit, which is a separate (above) case. for idx in 0..(32 - 1) { assert_eq!( first_leading_bit(ConcreteInt::I32([1 << idx])), ConcreteInt::I32([idx]) ); } for idx in 1..(32 - 1) { assert_eq!( first_leading_bit(ConcreteInt::I32([-(1 << idx)])), ConcreteInt::I32([idx - 1]) ); } assert_eq!( first_leading_bit(ConcreteInt::U32([0])), ConcreteInt::U32([u32::MAX]) ); assert_eq!( first_leading_bit(ConcreteInt::U32([1])), ConcreteInt::U32([0]) ); assert_eq!( first_leading_bit(ConcreteInt::U32([u32::MAX])), ConcreteInt::U32([31]) ); for idx in 0..32 { assert_eq!( first_leading_bit(ConcreteInt::U32([1 << idx])), ConcreteInt::U32([idx]) ) } } /// Trait for conversions of abstract values to concrete types. trait TryFromAbstract: Sized { /// Convert an abstract literal `value` to `Self`. /// /// Since Naga's [`AbstractInt`] and [`AbstractFloat`] exist to support /// WGSL, we follow WGSL's conversion rules here: /// /// - WGSL §6.1.2. Conversion Rank says that automatic conversions /// from [`AbstractInt`] to an integer type are either lossless or an /// error. /// /// - WGSL §15.7.6 Floating Point Conversion says that conversions /// to floating point in constant expressions and override /// expressions are errors if the value is out of range for the /// destination type, but rounding is okay. /// /// - WGSL §17.1.2 i32()/u32() constructors treat AbstractFloat as any /// other floating point type, following the scalar floating point to /// integral conversion algorithm (§15.7.6). There is no automatic /// conversion from AbstractFloat to integer types. /// /// [`AbstractInt`]: crate::Literal::AbstractInt /// [`AbstractFloat`]: crate::Literal::AbstractFloat fn try_from_abstract(value: T) -> Result; } impl TryFromAbstract for i32 { fn try_from_abstract(value: i64) -> Result { i32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy { value: format!("{value:?}"), to_type: "i32", }) } } impl TryFromAbstract for u32 { fn try_from_abstract(value: i64) -> Result { u32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy { value: format!("{value:?}"), to_type: "u32", }) } } impl TryFromAbstract for u64 { fn try_from_abstract(value: i64) -> Result { u64::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy { value: format!("{value:?}"), to_type: "u64", }) } } impl TryFromAbstract for i64 { fn try_from_abstract(value: i64) -> Result { Ok(value) } } impl TryFromAbstract for f32 { fn try_from_abstract(value: i64) -> Result { let f = value as f32; // The range of `i64` is roughly ±18 × 10¹⁸, whereas the range of // `f32` is roughly ±3.4 × 10³⁸, so there's no opportunity for // overflow here. Ok(f) } } impl TryFromAbstract for f32 { fn try_from_abstract(value: f64) -> Result { let f = value as f32; if f.is_infinite() { return Err(ConstantEvaluatorError::AutomaticConversionLossy { value: format!("{value:?}"), to_type: "f32", }); } Ok(f) } } impl TryFromAbstract for f64 { fn try_from_abstract(value: i64) -> Result { let f = value as f64; // The range of `i64` is roughly ±18 × 10¹⁸, whereas the range of // `f64` is roughly ±1.8 × 10³⁰⁸, so there's no opportunity for // overflow here. Ok(f) } } impl TryFromAbstract for f64 { fn try_from_abstract(value: f64) -> Result { Ok(value) } } impl TryFromAbstract for i32 { fn try_from_abstract(value: f64) -> Result { // https://www.w3.org/TR/WGSL/#floating-point-conversion // To convert a floating point scalar value X to an integer scalar type T: // * If X is a NaN, the result is an indeterminate value in T. // * If X is exactly representable in the target type T, then the // result is that value. // * Otherwise, the result is the value in T closest to truncate(X) and // also exactly representable in the original floating point type. // // A rust cast satisfies these requirements apart from "the result // is... exactly representable in the original floating point type". // However, i32::MIN and i32::MAX are exactly representable by f64, so // we're all good. Ok(value as i32) } } impl TryFromAbstract for u32 { fn try_from_abstract(value: f64) -> Result { // As above, u32::MIN and u32::MAX are exactly representable by f64, // so a simple rust cast is sufficient. Ok(value as u32) } } impl TryFromAbstract for i64 { fn try_from_abstract(value: f64) -> Result { // As above, except we clamp to the minimum and maximum values // representable by both f64 and i64. use crate::proc::type_methods::IntFloatLimits; Ok(value.clamp(i64::min_float(), i64::max_float()) as i64) } } impl TryFromAbstract for u64 { fn try_from_abstract(value: f64) -> Result { // As above, this time clamping to the minimum and maximum values // representable by both f64 and u64. use crate::proc::type_methods::IntFloatLimits; Ok(value.clamp(u64::min_float(), u64::max_float()) as u64) } } impl TryFromAbstract for f16 { fn try_from_abstract(value: f64) -> Result { let f = f16::from_f64(value); if f.is_infinite() { return Err(ConstantEvaluatorError::AutomaticConversionLossy { value: format!("{value:?}"), to_type: "f16", }); } Ok(f) } } impl TryFromAbstract for f16 { fn try_from_abstract(value: i64) -> Result { let f = f16::from_i64(value); if f.is_none() { return Err(ConstantEvaluatorError::AutomaticConversionLossy { value: format!("{value:?}"), to_type: "f16", }); } Ok(f.unwrap()) } } fn cross_product(a: [T; 3], b: [T; 3]) -> [T; 3] where T: Copy, T: core::ops::Mul, T: core::ops::Sub, { [ a[1] * b[2] - a[2] * b[1], a[2] * b[0] - a[0] * b[2], a[0] * b[1] - a[1] * b[0], ] } #[cfg(test)] mod tests { use alloc::{vec, vec::Vec}; use crate::{ Arena, BinaryOperator, Constant, Expression, FastHashMap, Handle, Literal, ScalarKind, Type, TypeInner, UnaryOperator, UniqueArena, VectorSize, }; use super::{Behavior, ConstantEvaluator, ExpressionKindTracker, WgslRestrictions}; #[test] fn unary_op() { let mut types = UniqueArena::new(); let mut constants = Arena::new(); let overrides = Arena::new(); let mut global_expressions = Arena::new(); let scalar_ty = types.insert( Type { name: None, inner: TypeInner::Scalar(crate::Scalar::I32), }, Default::default(), ); let vec_ty = types.insert( Type { name: None, inner: TypeInner::Vector { size: VectorSize::Bi, scalar: crate::Scalar::I32, }, }, Default::default(), ); let h = constants.append( Constant { name: None, ty: scalar_ty, init: global_expressions .append(Expression::Literal(Literal::I32(4)), Default::default()), }, Default::default(), ); let h1 = constants.append( Constant { name: None, ty: scalar_ty, init: global_expressions .append(Expression::Literal(Literal::I32(8)), Default::default()), }, Default::default(), ); let vec_h = constants.append( Constant { name: None, ty: vec_ty, init: global_expressions.append( Expression::Compose { ty: vec_ty, components: vec![constants[h].init, constants[h1].init], }, Default::default(), ), }, Default::default(), ); let expr = global_expressions.append(Expression::Constant(h), Default::default()); let expr1 = global_expressions.append(Expression::Constant(vec_h), Default::default()); let expr2 = Expression::Unary { op: UnaryOperator::Negate, expr, }; let expr3 = Expression::Unary { op: UnaryOperator::BitwiseNot, expr, }; let expr4 = Expression::Unary { op: UnaryOperator::BitwiseNot, expr: expr1, }; let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions); let mut solver = ConstantEvaluator { behavior: Behavior::Wgsl(WgslRestrictions::Const(None)), types: &mut types, constants: &constants, overrides: &overrides, expressions: &mut global_expressions, expression_kind_tracker, layouter: &mut crate::proc::Layouter::default(), }; let res1 = solver .try_eval_and_append(expr2, Default::default()) .unwrap(); let res2 = solver .try_eval_and_append(expr3, Default::default()) .unwrap(); let res3 = solver .try_eval_and_append(expr4, Default::default()) .unwrap(); assert_eq!( global_expressions[res1], Expression::Literal(Literal::I32(-4)) ); assert_eq!( global_expressions[res2], Expression::Literal(Literal::I32(!4)) ); let res3_inner = &global_expressions[res3]; match *res3_inner { Expression::Compose { ref ty, ref components, } => { assert_eq!(*ty, vec_ty); let mut components_iter = components.iter().copied(); assert_eq!( global_expressions[components_iter.next().unwrap()], Expression::Literal(Literal::I32(!4)) ); assert_eq!( global_expressions[components_iter.next().unwrap()], Expression::Literal(Literal::I32(!8)) ); assert!(components_iter.next().is_none()); } _ => panic!("Expected vector"), } } #[test] fn matrix_op() { let mut helper = MatrixTestHelper::new(); for nc in 2..=4 { for nr in 2..=4 { // Validates multiplication on vector-matrix. // vecR(0, 1, .., r) * matCxR(0, 1, .., nc * nr) let evaluated = helper.eval_vector_multiply_matrix(nc, nr); let expected = (0..nc) .map(|c| (0..nr).map(|r| (r * (c * nr + r)) as f32).sum()) .collect::>(); assert_eq!(evaluated, expected); // Validates multiplication on matrix-vector. // matCxR(0, 1, .., nc * nr) * vecC(0, 1, .., nc) let evaluated = helper.eval_matrix_multiply_vector(nc, nr); let expected = (0..nr) .map(|r| (0..nc).map(|c| (c * (c * nr + r)) as f32).sum()) .collect::>(); assert_eq!(evaluated, expected); for k in 2..=4 { // Validates multiplication on matrix-matrix. // matKxR(0, 1, .., k * nr) * matCxK(0, 1, .., nc * k) let evaluated = helper.eval_matrix_multiply_matrix(nr, nc, k); let expected = (0..nc) .flat_map(|c| { (0..nr).map(move |r| { (0..k).map(|v| ((v * nr + r) * (c * k + v)) as f32).sum() }) }) .collect::>(); assert_eq!(evaluated, expected); } } } } /// Test fixture providing pre-built f32 vector and matrix constant /// expressions with sequential element values, used to evaluate and verify /// matrix operations. struct MatrixTestHelper { types: UniqueArena, expressions: Arena, /// Vector expressions from [0, 1] to [0, 1, 2, 3]. vec_exprs: FastHashMap>, /// Matrix expressions from [0, .., 3] to [0, .., 15]. mat_exprs: FastHashMap<(usize, usize), Handle>, } impl MatrixTestHelper { fn new() -> Self { let mut types = UniqueArena::new(); let mut expressions = Arena::new(); let span = crate::Span::default(); let (mut vec_tys, mut mat_tys) = (FastHashMap::default(), FastHashMap::default()); for c in 2..=4 { let vec_ty = types.insert( Type { name: None, inner: TypeInner::Vector { size: Self::int_to_vector_size(c), scalar: crate::Scalar::F32, }, }, span, ); vec_tys.insert(c, vec_ty); for r in 2..=4 { let mat_ty = types.insert( Type { name: None, inner: TypeInner::Matrix { columns: Self::int_to_vector_size(c), rows: Self::int_to_vector_size(r), scalar: crate::Scalar::F32, }, }, span, ); mat_tys.insert((c, r), mat_ty); } } let mut lit_exprs = FastHashMap::default(); for i in 0..16 { let expr = expressions.append(Expression::Literal(Literal::F32(i as f32)), span); lit_exprs.insert(i, expr); } let mut vec_exprs = FastHashMap::default(); for c in 2..=4 { let expr = expressions.append( Expression::Compose { ty: *vec_tys.get(&c).unwrap(), components: (0..c) .map(|i| *lit_exprs.get(&i).unwrap()) .collect::>(), }, span, ); vec_exprs.insert(c, expr); } let mut mat_exprs = FastHashMap::default(); for c in 2..=4 { for r in 2..=4 { let mut columns = Vec::with_capacity(c); for cc in 0..c { let start = cc * r; let expr = expressions.append( Expression::Compose { ty: *vec_tys.get(&r).unwrap(), components: (start..start + r) .map(|i| *lit_exprs.get(&i).unwrap()) .collect::>(), }, span, ); columns.push(expr); } let expr = expressions.append( Expression::Compose { ty: *mat_tys.get(&(c, r)).unwrap(), components: columns, }, span, ); mat_exprs.insert((c, r), expr); } } Self { types, expressions, vec_exprs, mat_exprs, } } /// Evaluates vec[0..nr] * mat[0..nc*nr] and returns the result as f32s. fn eval_vector_multiply_matrix(&mut self, nc: usize, nr: usize) -> Vec { let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&self.expressions); let mut solver = ConstantEvaluator { behavior: Behavior::Wgsl(WgslRestrictions::Const(None)), types: &mut self.types, constants: &Arena::new(), overrides: &Arena::new(), expressions: &mut self.expressions, expression_kind_tracker, layouter: &mut crate::proc::Layouter::default(), }; let result = solver .try_eval_and_append( Expression::Binary { op: BinaryOperator::Multiply, left: *self.vec_exprs.get(&nr).unwrap(), right: *self.mat_exprs.get(&(nc, nr)).unwrap(), }, Default::default(), ) .unwrap(); self.flatten(result) } /// Evaluates mat[0..nc*nr] * vec[0..nc] and returns the result as f32s. fn eval_matrix_multiply_vector(&mut self, nc: usize, nr: usize) -> Vec { let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&self.expressions); let mut solver = ConstantEvaluator { behavior: Behavior::Wgsl(WgslRestrictions::Const(None)), types: &mut self.types, constants: &Arena::new(), overrides: &Arena::new(), expressions: &mut self.expressions, expression_kind_tracker, layouter: &mut crate::proc::Layouter::default(), }; let result = solver .try_eval_and_append( Expression::Binary { op: BinaryOperator::Multiply, left: *self.mat_exprs.get(&(nc, nr)).unwrap(), right: *self.vec_exprs.get(&nc).unwrap(), }, Default::default(), ) .unwrap(); self.flatten(result) } /// Evaluates mat[0..k*l_nr] * mat[0..r_nc*k] and returns the result as /// f32s. fn eval_matrix_multiply_matrix(&mut self, l_nr: usize, r_nc: usize, k: usize) -> Vec { let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&self.expressions); let mut solver = ConstantEvaluator { behavior: Behavior::Wgsl(WgslRestrictions::Const(None)), types: &mut self.types, constants: &Arena::new(), overrides: &Arena::new(), expressions: &mut self.expressions, expression_kind_tracker, layouter: &mut crate::proc::Layouter::default(), }; let result = solver .try_eval_and_append( Expression::Binary { op: BinaryOperator::Multiply, left: *self.mat_exprs.get(&(k, l_nr)).unwrap(), right: *self.mat_exprs.get(&(r_nc, k)).unwrap(), }, Default::default(), ) .unwrap(); self.flatten(result) } fn flatten(&self, expr: Handle) -> Vec { let Expression::Compose { ref components, ref ty, } = self.expressions[expr] else { unreachable!() }; match self.types[*ty].inner { TypeInner::Vector { .. } => components .iter() .map(|&comp| { let Expression::Literal(Literal::F32(v)) = self.expressions[comp] else { unreachable!() }; v }) .collect(), TypeInner::Matrix { .. } => components .iter() .flat_map(|&comp| self.flatten(comp)) .collect(), _ => unreachable!(), } } fn int_to_vector_size(int: usize) -> VectorSize { match int { 2 => VectorSize::Bi, 3 => VectorSize::Tri, 4 => VectorSize::Quad, _ => unreachable!(), } } } #[test] fn cast() { let mut types = UniqueArena::new(); let mut constants = Arena::new(); let overrides = Arena::new(); let mut global_expressions = Arena::new(); let scalar_ty = types.insert( Type { name: None, inner: TypeInner::Scalar(crate::Scalar::I32), }, Default::default(), ); let h = constants.append( Constant { name: None, ty: scalar_ty, init: global_expressions .append(Expression::Literal(Literal::I32(4)), Default::default()), }, Default::default(), ); let expr = global_expressions.append(Expression::Constant(h), Default::default()); let root = Expression::As { expr, kind: ScalarKind::Bool, convert: Some(crate::BOOL_WIDTH), }; let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions); let mut solver = ConstantEvaluator { behavior: Behavior::Wgsl(WgslRestrictions::Const(None)), types: &mut types, constants: &constants, overrides: &overrides, expressions: &mut global_expressions, expression_kind_tracker, layouter: &mut crate::proc::Layouter::default(), }; let res = solver .try_eval_and_append(root, Default::default()) .unwrap(); assert_eq!( global_expressions[res], Expression::Literal(Literal::Bool(true)) ); } #[test] fn access() { let mut types = UniqueArena::new(); let mut constants = Arena::new(); let overrides = Arena::new(); let mut global_expressions = Arena::new(); let matrix_ty = types.insert( Type { name: None, inner: TypeInner::Matrix { columns: VectorSize::Bi, rows: VectorSize::Tri, scalar: crate::Scalar::F32, }, }, Default::default(), ); let vec_ty = types.insert( Type { name: None, inner: TypeInner::Vector { size: VectorSize::Tri, scalar: crate::Scalar::F32, }, }, Default::default(), ); let mut vec1_components = Vec::with_capacity(3); let mut vec2_components = Vec::with_capacity(3); for i in 0..3 { let h = global_expressions.append( Expression::Literal(Literal::F32(i as f32)), Default::default(), ); vec1_components.push(h) } for i in 3..6 { let h = global_expressions.append( Expression::Literal(Literal::F32(i as f32)), Default::default(), ); vec2_components.push(h) } let vec1 = constants.append( Constant { name: None, ty: vec_ty, init: global_expressions.append( Expression::Compose { ty: vec_ty, components: vec1_components, }, Default::default(), ), }, Default::default(), ); let vec2 = constants.append( Constant { name: None, ty: vec_ty, init: global_expressions.append( Expression::Compose { ty: vec_ty, components: vec2_components, }, Default::default(), ), }, Default::default(), ); let h = constants.append( Constant { name: None, ty: matrix_ty, init: global_expressions.append( Expression::Compose { ty: matrix_ty, components: vec![constants[vec1].init, constants[vec2].init], }, Default::default(), ), }, Default::default(), ); let base = global_expressions.append(Expression::Constant(h), Default::default()); let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions); let mut solver = ConstantEvaluator { behavior: Behavior::Wgsl(WgslRestrictions::Const(None)), types: &mut types, constants: &constants, overrides: &overrides, expressions: &mut global_expressions, expression_kind_tracker, layouter: &mut crate::proc::Layouter::default(), }; let root1 = Expression::AccessIndex { base, index: 1 }; let res1 = solver .try_eval_and_append(root1, Default::default()) .unwrap(); let root2 = Expression::AccessIndex { base: res1, index: 2, }; let res2 = solver .try_eval_and_append(root2, Default::default()) .unwrap(); match global_expressions[res1] { Expression::Compose { ref ty, ref components, } => { assert_eq!(*ty, vec_ty); let mut components_iter = components.iter().copied(); assert_eq!( global_expressions[components_iter.next().unwrap()], Expression::Literal(Literal::F32(3.)) ); assert_eq!( global_expressions[components_iter.next().unwrap()], Expression::Literal(Literal::F32(4.)) ); assert_eq!( global_expressions[components_iter.next().unwrap()], Expression::Literal(Literal::F32(5.)) ); assert!(components_iter.next().is_none()); } _ => panic!("Expected vector"), } assert_eq!( global_expressions[res2], Expression::Literal(Literal::F32(5.)) ); } #[test] fn compose_of_constants() { let mut types = UniqueArena::new(); let mut constants = Arena::new(); let overrides = Arena::new(); let mut global_expressions = Arena::new(); let i32_ty = types.insert( Type { name: None, inner: TypeInner::Scalar(crate::Scalar::I32), }, Default::default(), ); let vec2_i32_ty = types.insert( Type { name: None, inner: TypeInner::Vector { size: VectorSize::Bi, scalar: crate::Scalar::I32, }, }, Default::default(), ); let h = constants.append( Constant { name: None, ty: i32_ty, init: global_expressions .append(Expression::Literal(Literal::I32(4)), Default::default()), }, Default::default(), ); let h_expr = global_expressions.append(Expression::Constant(h), Default::default()); let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions); let mut solver = ConstantEvaluator { behavior: Behavior::Wgsl(WgslRestrictions::Const(None)), types: &mut types, constants: &constants, overrides: &overrides, expressions: &mut global_expressions, expression_kind_tracker, layouter: &mut crate::proc::Layouter::default(), }; let solved_compose = solver .try_eval_and_append( Expression::Compose { ty: vec2_i32_ty, components: vec![h_expr, h_expr], }, Default::default(), ) .unwrap(); let solved_negate = solver .try_eval_and_append( Expression::Unary { op: UnaryOperator::Negate, expr: solved_compose, }, Default::default(), ) .unwrap(); let pass = match global_expressions[solved_negate] { Expression::Compose { ty, ref components } => { ty == vec2_i32_ty && components.iter().all(|&component| { let component = &global_expressions[component]; matches!(*component, Expression::Literal(Literal::I32(-4))) }) } _ => false, }; if !pass { panic!("unexpected evaluation result") } } #[test] fn splat_of_constant() { let mut types = UniqueArena::new(); let mut constants = Arena::new(); let overrides = Arena::new(); let mut global_expressions = Arena::new(); let i32_ty = types.insert( Type { name: None, inner: TypeInner::Scalar(crate::Scalar::I32), }, Default::default(), ); let vec2_i32_ty = types.insert( Type { name: None, inner: TypeInner::Vector { size: VectorSize::Bi, scalar: crate::Scalar::I32, }, }, Default::default(), ); let h = constants.append( Constant { name: None, ty: i32_ty, init: global_expressions .append(Expression::Literal(Literal::I32(4)), Default::default()), }, Default::default(), ); let h_expr = global_expressions.append(Expression::Constant(h), Default::default()); let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions); let mut solver = ConstantEvaluator { behavior: Behavior::Wgsl(WgslRestrictions::Const(None)), types: &mut types, constants: &constants, overrides: &overrides, expressions: &mut global_expressions, expression_kind_tracker, layouter: &mut crate::proc::Layouter::default(), }; let solved_compose = solver .try_eval_and_append( Expression::Splat { size: VectorSize::Bi, value: h_expr, }, Default::default(), ) .unwrap(); let solved_negate = solver .try_eval_and_append( Expression::Unary { op: UnaryOperator::Negate, expr: solved_compose, }, Default::default(), ) .unwrap(); let pass = match global_expressions[solved_negate] { Expression::Compose { ty, ref components } => { ty == vec2_i32_ty && components.iter().all(|&component| { let component = &global_expressions[component]; matches!(*component, Expression::Literal(Literal::I32(-4))) }) } _ => false, }; if !pass { panic!("unexpected evaluation result") } } #[test] fn splat_of_zero_value() { let mut types = UniqueArena::new(); let constants = Arena::new(); let overrides = Arena::new(); let mut global_expressions = Arena::new(); let f32_ty = types.insert( Type { name: None, inner: TypeInner::Scalar(crate::Scalar::F32), }, Default::default(), ); let vec2_f32_ty = types.insert( Type { name: None, inner: TypeInner::Vector { size: VectorSize::Bi, scalar: crate::Scalar::F32, }, }, Default::default(), ); let five = global_expressions.append(Expression::Literal(Literal::F32(5.0)), Default::default()); let five_splat = global_expressions.append( Expression::Splat { size: VectorSize::Bi, value: five, }, Default::default(), ); let zero = global_expressions.append(Expression::ZeroValue(f32_ty), Default::default()); let zero_splat = global_expressions.append( Expression::Splat { size: VectorSize::Bi, value: zero, }, Default::default(), ); let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions); let mut solver = ConstantEvaluator { behavior: Behavior::Wgsl(WgslRestrictions::Const(None)), types: &mut types, constants: &constants, overrides: &overrides, expressions: &mut global_expressions, expression_kind_tracker, layouter: &mut crate::proc::Layouter::default(), }; let solved_add = solver .try_eval_and_append( Expression::Binary { op: BinaryOperator::Add, left: zero_splat, right: five_splat, }, Default::default(), ) .unwrap(); let pass = match global_expressions[solved_add] { Expression::Compose { ty, ref components } => { ty == vec2_f32_ty && components.iter().all(|&component| { let component = &global_expressions[component]; matches!(*component, Expression::Literal(Literal::F32(5.0))) }) } _ => false, }; if !pass { panic!("unexpected evaluation result") } } } naga-29.0.3/src/proc/emitter.rs000064400000000000000000000020151046102023000143740ustar 00000000000000use crate::arena::Arena; /// Helper class to emit expressions #[derive(Default, Debug)] pub struct Emitter { start_len: Option, } impl Emitter { pub fn start(&mut self, arena: &Arena) { if self.start_len.is_some() { unreachable!("Emitting has already started!"); } self.start_len = Some(arena.len()); } pub const fn is_running(&self) -> bool { self.start_len.is_some() } #[must_use] pub fn finish( &mut self, arena: &Arena, ) -> Option<(crate::Statement, crate::span::Span)> { let start_len = self.start_len.take().unwrap(); if start_len != arena.len() { let mut span = crate::span::Span::default(); let range = arena.range_from(start_len); for handle in range.clone() { span.subsume(arena.get_span(handle)) } Some((crate::Statement::Emit(range), span)) } else { None } } } naga-29.0.3/src/proc/index.rs000064400000000000000000000637001046102023000140420ustar 00000000000000/*! Definitions for index bounds checking. */ use core::iter::{self, zip}; use crate::arena::{Handle, HandleSet, UniqueArena}; use crate::{valid, FastHashSet}; /// How should code generated by Naga do bounds checks? /// /// When a vector, matrix, or array index is out of bounds—either negative, or /// greater than or equal to the number of elements in the type—WGSL requires /// that some other index of the implementation's choice that is in bounds is /// used instead. (There are no types with zero elements.) /// /// Similarly, when out-of-bounds coordinates, array indices, or sample indices /// are presented to the WGSL `textureLoad` and `textureStore` operations, the /// operation is redirected to do something safe. /// /// Different users of Naga will prefer different defaults: /// /// - When used as part of a WebGPU implementation, the WGSL specification /// requires the `Restrict` behavior for array, vector, and matrix accesses, /// and either the `Restrict` or `ReadZeroSkipWrite` behaviors for texture /// accesses. /// /// - When used by the `wgpu` crate for native development, `wgpu` selects /// `ReadZeroSkipWrite` as its default. /// /// - Naga's own default is `Unchecked`, so that shader translations /// are as faithful to the original as possible. /// /// Sometimes the underlying hardware and drivers can perform bounds checks /// themselves, in a way that performs better than the checks Naga would inject. /// If you're using native checks like this, then having Naga inject its own /// checks as well would be redundant, and the `Unchecked` policy is /// appropriate. #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] pub enum BoundsCheckPolicy { /// Replace out-of-bounds indexes with some arbitrary in-bounds index. /// /// (This does not necessarily mean clamping. For example, interpreting the /// index as unsigned and taking the minimum with the largest valid index /// would also be a valid implementation. That would map negative indices to /// the last element, not the first.) Restrict, /// Out-of-bounds reads return zero, and writes have no effect. /// /// When applied to a chain of accesses, like `a[i][j].b[k]`, all index /// expressions are evaluated, regardless of whether prior or later index /// expressions were in bounds. But all the accesses per se are skipped /// if any index is out of bounds. ReadZeroSkipWrite, /// Naga adds no checks to indexing operations. Generate the fastest code /// possible. This is the default for Naga, as a translator, but consumers /// should consider defaulting to a safer behavior. Unchecked, } /// Policies for injecting bounds checks during code generation. #[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] #[cfg_attr(feature = "deserialize", serde(default))] pub struct BoundsCheckPolicies { /// How should the generated code handle array, vector, or matrix indices /// that are out of range? pub index: BoundsCheckPolicy, /// How should the generated code handle array, vector, or matrix indices /// that are out of range, when those values live in a [`GlobalVariable`] in /// the [`Storage`] or [`Uniform`] address spaces? /// /// Some graphics hardware provides "robust buffer access", a feature that /// ensures that using a pointer cannot access memory outside the 'buffer' /// that it was derived from. In Naga terms, this means that the hardware /// ensures that pointers computed by applying [`Access`] and /// [`AccessIndex`] expressions to a [`GlobalVariable`] whose [`space`] is /// [`Storage`] or [`Uniform`] will never read or write memory outside that /// global variable. /// /// When hardware offers such a feature, it is probably undesirable to have /// Naga inject bounds checking code for such accesses, since the hardware /// can probably provide the same protection more efficiently. However, /// bounds checks are still needed on accesses to indexable values that do /// not live in buffers, like local variables. /// /// So, this option provides a separate policy that applies only to accesses /// to storage and uniform globals. When depending on hardware bounds /// checking, this policy can be `Unchecked` to avoid unnecessary overhead. /// /// When special hardware support is not available, this should probably be /// the same as `index_bounds_check_policy`. /// /// [`GlobalVariable`]: crate::GlobalVariable /// [`space`]: crate::GlobalVariable::space /// [`Restrict`]: crate::proc::BoundsCheckPolicy::Restrict /// [`ReadZeroSkipWrite`]: crate::proc::BoundsCheckPolicy::ReadZeroSkipWrite /// [`Access`]: crate::Expression::Access /// [`AccessIndex`]: crate::Expression::AccessIndex /// [`Storage`]: crate::AddressSpace::Storage /// [`Uniform`]: crate::AddressSpace::Uniform pub buffer: BoundsCheckPolicy, /// How should the generated code handle image texel loads that are out /// of range? /// /// This controls the behavior of [`ImageLoad`] expressions when a coordinate, /// texture array index, level of detail, or multisampled sample number is out of range. /// /// There is no corresponding policy for [`ImageStore`] statements. All the /// platforms we support already discard out-of-bounds image stores, /// effectively implementing the "skip write" part of [`ReadZeroSkipWrite`]. /// /// [`ImageLoad`]: crate::Expression::ImageLoad /// [`ImageStore`]: crate::Statement::ImageStore /// [`ReadZeroSkipWrite`]: BoundsCheckPolicy::ReadZeroSkipWrite pub image_load: BoundsCheckPolicy, /// How should the generated code handle binding array indexes that are out of bounds. pub binding_array: BoundsCheckPolicy, } /// The default `BoundsCheckPolicy` is `Unchecked`. impl Default for BoundsCheckPolicy { fn default() -> Self { BoundsCheckPolicy::Unchecked } } impl BoundsCheckPolicies { /// Determine which policy applies to `base`. /// /// `base` is the "base" expression (the expression being indexed) of a `Access` /// and `AccessIndex` expression. This is either a pointer, a value, being directly /// indexed, or a binding array. /// /// See the documentation for [`BoundsCheckPolicy`] for details about /// when each policy applies. pub fn choose_policy( &self, base: Handle, types: &UniqueArena, info: &valid::FunctionInfo, ) -> BoundsCheckPolicy { let ty = info[base].ty.inner_with(types); if let crate::TypeInner::BindingArray { .. } = *ty { return self.binding_array; } match ty.pointer_space() { Some(crate::AddressSpace::Storage { access: _ } | crate::AddressSpace::Uniform) => { self.buffer } // This covers other address spaces, but also accessing vectors and // matrices by value, where no pointer is involved. _ => self.index, } } /// Return `true` if any of `self`'s policies are `policy`. pub fn contains(&self, policy: BoundsCheckPolicy) -> bool { self.index == policy || self.buffer == policy || self.image_load == policy } } /// An index that may be statically known, or may need to be computed at runtime. /// /// This enum lets us handle both [`Access`] and [`AccessIndex`] expressions /// with the same code. /// /// [`Access`]: crate::Expression::Access /// [`AccessIndex`]: crate::Expression::AccessIndex #[derive(Clone, Copy, Debug)] pub enum GuardedIndex { Known(u32), Expression(Handle), } /// Build a set of expressions used as indices, to cache in temporary variables when /// emitted. /// /// Given the bounds-check policies `policies`, construct a `HandleSet` containing the handle /// indices of all the expressions in `function` that are ever used as guarded indices /// under the [`ReadZeroSkipWrite`] policy. The `module` argument must be the module to /// which `function` belongs, and `info` should be that function's analysis results. /// /// Such index expressions will be used twice in the generated code: first for the /// comparison to see if the index is in bounds, and then for the access itself, should /// the comparison succeed. To avoid computing the expressions twice, the generated code /// should cache them in temporary variables. /// /// Why do we need to build such a set in advance, instead of just processing access /// expressions as we encounter them? Whether an expression needs to be cached depends on /// whether it appears as something like the [`index`] operand of an [`Access`] expression /// or the [`level`] operand of an [`ImageLoad`] expression, and on the index bounds check /// policies that apply to those accesses. But [`Emit`] statements just identify a range /// of expressions by index; there's no good way to tell what an expression is used /// for. The only way to do it is to just iterate over all the expressions looking for /// relevant `Access` expressions --- which is what this function does. /// /// Simple expressions like variable loads and constants don't make sense to cache: it's /// no better than just re-evaluating them. But constants are not covered by `Emit` /// statements, and `Load`s are always cached to ensure they occur at the right time, so /// we don't bother filtering them out from this set. /// /// Fortunately, we don't need to deal with [`ImageStore`] statements here. When we emit /// code for a statement, the writer isn't in the middle of an expression, so we can just /// emit declarations for temporaries, initialized appropriately. /// /// None of these concerns apply for SPIR-V output, since it's easy to just reuse an /// instruction ID in two places; that has the same semantics as a temporary variable, and /// it's inherent in the design of SPIR-V. This function is more useful for text-based /// back ends. /// /// [`ReadZeroSkipWrite`]: BoundsCheckPolicy::ReadZeroSkipWrite /// [`index`]: crate::Expression::Access::index /// [`Access`]: crate::Expression::Access /// [`level`]: crate::Expression::ImageLoad::level /// [`ImageLoad`]: crate::Expression::ImageLoad /// [`Emit`]: crate::Statement::Emit /// [`ImageStore`]: crate::Statement::ImageStore pub fn find_checked_indexes( module: &crate::Module, function: &crate::Function, info: &valid::FunctionInfo, policies: BoundsCheckPolicies, ) -> HandleSet { use crate::Expression as Ex; let mut guarded_indices = HandleSet::for_arena(&function.expressions); // Don't bother scanning if we never need `ReadZeroSkipWrite`. if policies.contains(BoundsCheckPolicy::ReadZeroSkipWrite) { for (_handle, expr) in function.expressions.iter() { // There's no need to handle `AccessIndex` expressions, as their // indices never need to be cached. match *expr { Ex::Access { base, index } => { if policies.choose_policy(base, &module.types, info) == BoundsCheckPolicy::ReadZeroSkipWrite && access_needs_check( base, GuardedIndex::Expression(index), module, &function.expressions, info, ) .is_some() { guarded_indices.insert(index); } } Ex::ImageLoad { coordinate, array_index, sample, level, .. } => { if policies.image_load == BoundsCheckPolicy::ReadZeroSkipWrite { guarded_indices.insert(coordinate); if let Some(array_index) = array_index { guarded_indices.insert(array_index); } if let Some(sample) = sample { guarded_indices.insert(sample); } if let Some(level) = level { guarded_indices.insert(level); } } } _ => {} } } } guarded_indices } /// Determine whether `index` is statically known to be in bounds for `base`. /// /// If we can't be sure that the index is in bounds, return the limit within /// which valid indices must fall. /// /// The return value is one of the following: /// /// - `Some(Known(n))` indicates that `n` is the largest valid index. /// /// - `Some(Computed(global))` indicates that the largest valid index is one /// less than the length of the array that is the last member of the /// struct held in `global`. /// /// - `None` indicates that the index need not be checked, either because it /// is statically known to be in bounds, or because the applicable policy /// is `Unchecked`. /// /// This function only handles subscriptable types: arrays, vectors, and /// matrices. It does not handle struct member indices; those never require /// run-time checks, so it's best to deal with them further up the call /// chain. /// /// This function assumes that any relevant overrides have fully-evaluated /// constants as their values (as arranged by [`process_overrides`], for /// example). /// /// [`process_overrides`]: crate::back::pipeline_constants::process_overrides /// /// # Panics /// /// - If `base` is not an indexable type, panic. /// /// - If `base` is an override-sized array, but the override's value is not a /// fully-evaluated constant expression, panic. pub fn access_needs_check( base: Handle, mut index: GuardedIndex, module: &crate::Module, expressions: &crate::Arena, info: &valid::FunctionInfo, ) -> Option { let base_inner = info[base].ty.inner_with(&module.types); // Unwrap safety: `Err` here indicates unindexable base types and invalid // length constants, but `access_needs_check` is only used by back ends, so // validation should have caught those problems. let length = base_inner.indexable_length_resolved(module).unwrap(); index.try_resolve_to_constant(expressions, module); if let (&GuardedIndex::Known(index), &IndexableLength::Known(length)) = (&index, &length) { if index < length { // Index is statically known to be in bounds, no check needed. return None; } }; Some(length) } /// Items returned by the [`bounds_check_iter`] iterator. #[cfg_attr(not(feature = "msl-out"), allow(dead_code))] pub(crate) struct BoundsCheck { /// The base of the [`Access`] or [`AccessIndex`] expression. /// /// [`Access`]: crate::Expression::Access /// [`AccessIndex`]: crate::Expression::AccessIndex pub base: Handle, /// The index being accessed. pub index: GuardedIndex, /// The length of `base`. pub length: IndexableLength, } /// Returns an iterator of accesses within the chain of `Access` and /// `AccessIndex` expressions starting from `chain` that may need to be /// bounds-checked at runtime. /// /// Items are yielded as [`BoundsCheck`] instances. /// /// Accesses through a struct are omitted, since you never need a bounds check /// for accessing a struct field. /// /// If `chain` isn't an `Access` or `AccessIndex` expression at all, the /// iterator is empty. pub(crate) fn bounds_check_iter<'a>( mut chain: Handle, module: &'a crate::Module, function: &'a crate::Function, info: &'a valid::FunctionInfo, ) -> impl Iterator + 'a { iter::from_fn(move || { let (next_expr, result) = match function.expressions[chain] { crate::Expression::Access { base, index } => { (base, Some((base, GuardedIndex::Expression(index)))) } crate::Expression::AccessIndex { base, index } => { // Don't try to check indices into structs. Validation already took // care of them, and access_needs_check doesn't handle that case. let mut base_inner = info[base].ty.inner_with(&module.types); if let crate::TypeInner::Pointer { base, .. } = *base_inner { base_inner = &module.types[base].inner; } match *base_inner { crate::TypeInner::Struct { .. } => (base, None), _ => (base, Some((base, GuardedIndex::Known(index)))), } } _ => return None, }; chain = next_expr; Some(result) }) .flatten() .filter_map(|(base, index)| { access_needs_check(base, index, module, &function.expressions, info).map(|length| { BoundsCheck { base, index, length, } }) }) } /// Returns all the types which we need out-of-bounds locals for; that is, /// all of the types which the code might attempt to get an out-of-bounds /// pointer to, in which case we yield a pointer to the out-of-bounds local /// of the correct type. pub fn oob_local_types( module: &crate::Module, function: &crate::Function, info: &valid::FunctionInfo, policies: BoundsCheckPolicies, ) -> FastHashSet> { let mut result = FastHashSet::default(); if policies.index != BoundsCheckPolicy::ReadZeroSkipWrite { return result; } for statement in &function.body { // The only situation in which we end up actually needing to create an // out-of-bounds pointer is when passing one to a function. // // This is because pointers are never baked; they're just inlined everywhere // they're used. That means that loads can just return 0, and stores can just do // nothing; functions are the only case where you actually *have* to produce a // pointer. if let crate::Statement::Call { function: callee, ref arguments, .. } = *statement { // Now go through the arguments of the function looking for pointers which need bounds checks. for (arg_info, &arg) in zip(&module.functions[callee].arguments, arguments) { match module.types[arg_info.ty].inner { crate::TypeInner::ValuePointer { .. } => { // `ValuePointer`s should only ever be used when resolving the types of // expressions, since the arena can no longer be modified at that point; things // in the arena should always use proper `Pointer`s. unreachable!("`ValuePointer` found in arena") } crate::TypeInner::Pointer { base, .. } => { if bounds_check_iter(arg, module, function, info) .next() .is_some() { result.insert(base); } } _ => continue, }; } } } result } impl GuardedIndex { /// Make a `GuardedIndex::Known` from a `GuardedIndex::Expression` if possible. /// /// Return values that are already `Known` unchanged. pub(crate) fn try_resolve_to_constant( &mut self, expressions: &crate::Arena, module: &crate::Module, ) { if let GuardedIndex::Expression(expr) = *self { *self = GuardedIndex::from_expression(expr, expressions, module); } } pub(crate) fn from_expression( expr: Handle, expressions: &crate::Arena, module: &crate::Module, ) -> Self { match module.to_ctx().get_const_val_from(expr, expressions) { Ok(value) => Self::Known(value), Err(_) => Self::Expression(expr), } } } #[derive(Clone, Copy, Debug, thiserror::Error, PartialEq)] pub enum IndexableLengthError { #[error("Type is not indexable, and has no length (validation error)")] TypeNotIndexable, #[error(transparent)] ResolveArraySizeError(#[from] super::ResolveArraySizeError), #[error("Array size is still pending")] Pending(crate::ArraySize), } impl crate::TypeInner { /// Return the length of a subscriptable type. /// /// The `self` parameter should be a handle to a vector, matrix, or array /// type, a pointer to one of those, or a value pointer. Arrays may be /// fixed-size, dynamically sized, or sized by a specializable constant. /// This function does not handle struct member references, as with /// `AccessIndex`. /// /// The value returned is appropriate for bounds checks on subscripting. /// /// Return an error if `self` does not describe a subscriptable type at all. pub fn indexable_length( &self, module: &crate::Module, ) -> Result { use crate::TypeInner as Ti; let known_length = match *self { Ti::Vector { size, .. } => size as _, Ti::Matrix { columns, .. } => columns as _, Ti::Array { size, .. } | Ti::BindingArray { size, .. } => { return size.to_indexable_length(module); } Ti::ValuePointer { size: Some(size), .. } => size as _, Ti::Pointer { base, .. } => { // When assigning types to expressions, ResolveContext::Resolve // does a separate sub-match here instead of a full recursion, // so we'll do the same. let base_inner = &module.types[base].inner; match *base_inner { Ti::Vector { size, .. } => size as _, Ti::Matrix { columns, .. } => columns as _, Ti::Array { size, .. } | Ti::BindingArray { size, .. } => { return size.to_indexable_length(module) } _ => return Err(IndexableLengthError::TypeNotIndexable), } } _ => return Err(IndexableLengthError::TypeNotIndexable), }; Ok(IndexableLength::Known(known_length)) } /// Return the length of `self`, assuming overrides are yet to be supplied. /// /// Return the number of elements in `self`: /// /// - If `self` is a runtime-sized array, then return /// [`IndexableLength::Dynamic`]. /// /// - If `self` is an override-sized array, then assume that override values /// have not yet been supplied, and return [`IndexableLength::Dynamic`]. /// /// - Otherwise, the type simply tells us the length of `self`, so return /// [`IndexableLength::Known`]. /// /// If `self` is not an indexable type at all, return an error. /// /// The difference between this and `indexable_length_resolved` is that we /// treat override-sized arrays and dynamically-sized arrays both as /// [`Dynamic`], on the assumption that our callers want to treat both cases /// as "not yet possible to check". /// /// [`Dynamic`]: IndexableLength::Dynamic pub fn indexable_length_pending( &self, module: &crate::Module, ) -> Result { let length = self.indexable_length(module); if let Err(IndexableLengthError::Pending(_)) = length { return Ok(IndexableLength::Dynamic); } length } /// Return the length of `self`, assuming overrides have been resolved. /// /// Return the number of elements in `self`: /// /// - If `self` is a runtime-sized array, then return /// [`IndexableLength::Dynamic`]. /// /// - If `self` is an override-sized array, then assume that the override's /// value is a fully-evaluated constant expression, and return /// [`IndexableLength::Known`]. Otherwise, return an error. /// /// - Otherwise, the type simply tells us the length of `self`, so return /// [`IndexableLength::Known`]. /// /// If `self` is not an indexable type at all, return an error. /// /// The difference between this and `indexable_length_pending` is /// that if `self` is override-sized, we require the override's /// value to be known. pub fn indexable_length_resolved( &self, module: &crate::Module, ) -> Result { let length = self.indexable_length(module); // If the length is override-based, then try to compute its value now. if let Err(IndexableLengthError::Pending(size)) = length { if let IndexableLength::Known(computed) = size.resolve(module.to_ctx())? { return Ok(IndexableLength::Known(computed)); } } length } } /// The number of elements in an indexable type. /// /// This summarizes the length of vectors, matrices, and arrays in a way that is /// convenient for indexing and bounds-checking code. #[derive(Debug)] pub enum IndexableLength { /// Values of this type always have the given number of elements. Known(u32), /// The number of elements is determined at runtime. Dynamic, } impl crate::ArraySize { pub const fn to_indexable_length( self, _module: &crate::Module, ) -> Result { match self { Self::Constant(length) => Ok(IndexableLength::Known(length.get())), Self::Pending(_) => Err(IndexableLengthError::Pending(self)), Self::Dynamic => Ok(IndexableLength::Dynamic), } } } naga-29.0.3/src/proc/keyword_set.rs000064400000000000000000000124021046102023000152630ustar 00000000000000use core::{fmt, hash}; use crate::racy_lock::RacyLock; use crate::FastHashSet; /// A case-sensitive set of strings, /// for use with [`Namer`][crate::proc::Namer] to avoid collisions with keywords and other reserved /// identifiers. /// /// This is currently implemented as a hash table. /// Future versions of Naga may change the implementation based on speed and code size /// considerations. #[derive(Clone, Debug, Default, Eq, PartialEq)] pub struct KeywordSet(FastHashSet<&'static str>); impl KeywordSet { /// Returns a new mutable empty set. pub fn new() -> Self { Self::default() } /// Returns a reference to the empty set. pub fn empty() -> &'static Self { static EMPTY: RacyLock = RacyLock::new(Default::default); &EMPTY } /// Returns whether the set contains the given string. #[inline] pub fn contains(&self, identifier: &str) -> bool { self.0.contains(identifier) } } impl Default for &'static KeywordSet { fn default() -> Self { KeywordSet::empty() } } impl FromIterator<&'static str> for KeywordSet { fn from_iter>(iter: T) -> Self { Self(iter.into_iter().collect()) } } /// Accepts double references so that `KeywordSet::from_iter(&["foo"])` works. impl<'a> FromIterator<&'a &'static str> for KeywordSet { fn from_iter>(iter: T) -> Self { Self::from_iter(iter.into_iter().copied()) } } impl Extend<&'static str> for KeywordSet { #[expect( clippy::useless_conversion, reason = "doing .into_iter() sooner reduces distinct monomorphizations" )] fn extend>(&mut self, iter: T) { self.0.extend(iter.into_iter()) } } /// Accepts double references so that `.extend(&["foo"])` works. impl<'a> Extend<&'a &'static str> for KeywordSet { fn extend>(&mut self, iter: T) { self.extend(iter.into_iter().copied()) } } /// A case-insensitive, ASCII-only set of strings, /// for use with [`Namer`][crate::proc::Namer] to avoid collisions with keywords and other reserved /// identifiers. /// /// This is currently implemented as a hash table. /// Future versions of Naga may change the implementation based on speed and code size /// considerations. #[derive(Clone, Debug, Default, Eq, PartialEq)] pub struct CaseInsensitiveKeywordSet(FastHashSet>); impl CaseInsensitiveKeywordSet { /// Returns a new mutable empty set. pub fn new() -> Self { Self::default() } /// Returns a reference to the empty set. pub fn empty() -> &'static Self { static EMPTY: RacyLock = RacyLock::new(Default::default); &EMPTY } /// Returns whether the set contains the given string, with comparison /// by [`str::eq_ignore_ascii_case()`]. #[inline] pub fn contains(&self, identifier: &str) -> bool { self.0.contains(&AsciiUniCase(identifier)) } } impl Default for &'static CaseInsensitiveKeywordSet { fn default() -> Self { CaseInsensitiveKeywordSet::empty() } } impl FromIterator<&'static str> for CaseInsensitiveKeywordSet { fn from_iter>(iter: T) -> Self { Self( iter.into_iter() .inspect(debug_assert_ascii) .map(AsciiUniCase) .collect(), ) } } /// Accepts double references so that `CaseInsensitiveKeywordSet::from_iter(&["foo"])` works. impl<'a> FromIterator<&'a &'static str> for CaseInsensitiveKeywordSet { fn from_iter>(iter: T) -> Self { Self::from_iter(iter.into_iter().copied()) } } impl Extend<&'static str> for CaseInsensitiveKeywordSet { fn extend>(&mut self, iter: T) { self.0.extend( iter.into_iter() .inspect(debug_assert_ascii) .map(AsciiUniCase), ) } } /// Accepts double references so that `.extend(&["foo"])` works. impl<'a> Extend<&'a &'static str> for CaseInsensitiveKeywordSet { fn extend>(&mut self, iter: T) { self.extend(iter.into_iter().copied()) } } /// A string wrapper type with an ascii case insensitive Eq and Hash impl #[derive(Clone, Copy)] struct AsciiUniCase + ?Sized>(S); impl> fmt::Debug for AsciiUniCase { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.0.as_ref().fmt(f) } } impl> PartialEq for AsciiUniCase { #[inline] fn eq(&self, other: &Self) -> bool { self.0.as_ref().eq_ignore_ascii_case(other.0.as_ref()) } } impl> Eq for AsciiUniCase {} impl> hash::Hash for AsciiUniCase { #[inline] fn hash(&self, hasher: &mut H) { for byte in self .0 .as_ref() .as_bytes() .iter() .map(|b| b.to_ascii_lowercase()) { hasher.write_u8(byte); } } } fn debug_assert_ascii(s: &&'static str) { debug_assert!(s.is_ascii(), "{s:?} not ASCII") } naga-29.0.3/src/proc/layouter.rs000064400000000000000000000235041046102023000145750ustar 00000000000000use core::{fmt::Display, num::NonZeroU32, ops}; use crate::{ arena::{Handle, HandleVec}, valid::MAX_TYPE_SIZE, }; /// A newtype struct where its only valid values are powers of 2 #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] pub struct Alignment(NonZeroU32); impl Alignment { pub const ONE: Self = Self(NonZeroU32::new(1).unwrap()); pub const TWO: Self = Self(NonZeroU32::new(2).unwrap()); pub const FOUR: Self = Self(NonZeroU32::new(4).unwrap()); pub const EIGHT: Self = Self(NonZeroU32::new(8).unwrap()); pub const SIXTEEN: Self = Self(NonZeroU32::new(16).unwrap()); pub const MIN_UNIFORM: Self = Self::SIXTEEN; pub const fn new(n: u32) -> Option { if n.is_power_of_two() { // Value can't be 0 since we just checked if it's a power of 2. Some(Self(NonZeroU32::new(n).unwrap())) } else { None } } /// # Panics /// If `width` is not a power of 2 pub const fn from_width(width: u8) -> Self { Self::new(width as u32).unwrap() } /// Returns whether or not `n` is a multiple of this alignment. pub const fn is_aligned(&self, n: u32) -> bool { // equivalent to: `n % self.0.get() == 0` but much faster n & (self.0.get() - 1) == 0 } /// Round `n` up to the nearest alignment boundary. pub const fn round_up(&self, n: u32) -> u32 { // equivalent to: // match n % self.0.get() { // 0 => n, // rem => n + (self.0.get() - rem), // } let mask = self.0.get() - 1; (n + mask) & !mask } } impl Display for Alignment { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { self.0.get().fmt(f) } } impl ops::Mul for Alignment { type Output = u32; fn mul(self, rhs: u32) -> Self::Output { self.0.get() * rhs } } impl ops::Mul for Alignment { type Output = Alignment; fn mul(self, rhs: Alignment) -> Self::Output { // Both lhs and rhs are powers of 2, the result will be a power of 2. Self(NonZeroU32::new(self.0.get() * rhs.0.get()).unwrap()) } } impl From for Alignment { fn from(size: crate::VectorSize) -> Self { match size { crate::VectorSize::Bi => Alignment::TWO, crate::VectorSize::Tri => Alignment::FOUR, crate::VectorSize::Quad => Alignment::FOUR, } } } impl From for Alignment { fn from(size: crate::CooperativeSize) -> Self { Self(NonZeroU32::new(size as u32).unwrap()) } } /// Size and alignment information for a type. #[derive(Clone, Copy, Debug, Hash, PartialEq)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] pub struct TypeLayout { pub size: u32, pub alignment: Alignment, } impl TypeLayout { /// Produce the stride as if this type is a base of an array. pub const fn to_stride(&self) -> u32 { self.alignment.round_up(self.size) } } /// Helper processor that derives the sizes of all types. /// /// `Layouter` uses the default layout algorithm/table, described in /// [WGSL §4.3.7, "Memory Layout"] /// /// A `Layouter` may be indexed by `Handle` values: `layouter[handle]` is the /// layout of the type whose handle is `handle`. /// /// [WGSL §4.3.7, "Memory Layout"](https://gpuweb.github.io/gpuweb/wgsl/#memory-layouts) #[derive(Debug, Default)] pub struct Layouter { /// Layouts for types in an arena. layouts: HandleVec, } impl ops::Index> for Layouter { type Output = TypeLayout; fn index(&self, handle: Handle) -> &TypeLayout { &self.layouts[handle] } } /// Errors generated by the `Layouter`. /// /// All of these errors can be produced when validating an arbitrary module. /// When processing WGSL source, only the `TooLarge` error should be /// produced by the `Layouter`, as the front-end should not produce IR /// that would result in the other errors. #[derive(Clone, Copy, Debug, PartialEq, thiserror::Error)] pub enum LayoutErrorInner { #[error("Array element type {0:?} doesn't exist")] InvalidArrayElementType(Handle), #[error("Struct member[{0}] type {1:?} doesn't exist")] InvalidStructMemberType(u32, Handle), #[error("Type width must be a power of two")] NonPowerOfTwoWidth, #[error("Size exceeds limit of {MAX_TYPE_SIZE} bytes")] TooLarge, } #[derive(Clone, Copy, Debug, PartialEq, thiserror::Error)] #[error("Error laying out type {ty:?}: {inner}")] pub struct LayoutError { pub ty: Handle, pub inner: LayoutErrorInner, } impl LayoutErrorInner { const fn with(self, ty: Handle) -> LayoutError { LayoutError { ty, inner: self } } } impl Layouter { /// Remove all entries from this `Layouter`, retaining storage. pub fn clear(&mut self) { self.layouts.clear(); } #[expect(rustdoc::private_intra_doc_links)] /// Extend this `Layouter` with layouts for any new entries in `gctx.types`. /// /// Ensure that every type in `gctx.types` has a corresponding [TypeLayout] /// in [`Self::layouts`]. /// /// Some front ends need to be able to compute layouts for existing types /// while module construction is still in progress and new types are still /// being added. This function assumes that the `TypeLayout` values already /// present in `self.layouts` cover their corresponding entries in `types`, /// and extends `self.layouts` as needed to cover the rest. Thus, a front /// end can call this function at any time, passing its current type and /// constant arenas, and then assume that layouts are available for all /// types. #[allow(clippy::or_fun_call)] pub fn update(&mut self, gctx: super::GlobalCtx) -> Result<(), LayoutError> { use crate::TypeInner as Ti; for (ty_handle, ty) in gctx.types.iter().skip(self.layouts.len()) { let size = ty .inner .try_size(gctx) .ok_or_else(|| LayoutErrorInner::TooLarge.with(ty_handle))?; let layout = match ty.inner { Ti::Scalar(scalar) | Ti::Atomic(scalar) => { let alignment = Alignment::new(scalar.width as u32) .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?; TypeLayout { size, alignment } } Ti::Vector { size: vec_size, scalar, } => { let alignment = Alignment::new(scalar.width as u32) .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?; TypeLayout { size, alignment: Alignment::from(vec_size) * alignment, } } Ti::Matrix { columns: _, rows, scalar, } => { let alignment = Alignment::new(scalar.width as u32) .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?; TypeLayout { size, alignment: Alignment::from(rows) * alignment, } } Ti::CooperativeMatrix { columns: _, rows, scalar, role: _, } => { let alignment = Alignment::new(scalar.width as u32) .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?; TypeLayout { size, alignment: Alignment::from(rows) * alignment, } } Ti::Pointer { .. } | Ti::ValuePointer { .. } => TypeLayout { size, alignment: Alignment::ONE, }, Ti::Array { base, stride: _, size: _, } => TypeLayout { size, alignment: if base < ty_handle { self[base].alignment } else { return Err(LayoutErrorInner::InvalidArrayElementType(base).with(ty_handle)); }, }, Ti::Struct { span, ref members } => { let mut alignment = Alignment::ONE; for (index, member) in members.iter().enumerate() { alignment = if member.ty < ty_handle { alignment.max(self[member.ty].alignment) } else { return Err(LayoutErrorInner::InvalidStructMemberType( index as u32, member.ty, ) .with(ty_handle)); }; } TypeLayout { size: span, alignment, } } Ti::Image { .. } | Ti::Sampler { .. } | Ti::AccelerationStructure { .. } | Ti::RayQuery { .. } | Ti::BindingArray { .. } => TypeLayout { size, alignment: Alignment::ONE, }, }; debug_assert!(size <= layout.size); self.layouts.insert(ty_handle, layout); } Ok(()) } } naga-29.0.3/src/proc/mod.rs000064400000000000000000001075341046102023000135160ustar 00000000000000/*! [`Module`](super::Module) processing functionality. */ mod constant_evaluator; mod emitter; pub mod index; mod keyword_set; mod layouter; mod namer; mod overloads; mod terminator; mod type_methods; mod typifier; pub use constant_evaluator::{ ConstantEvaluator, ConstantEvaluatorError, ExpressionKind, ExpressionKindTracker, }; pub use emitter::Emitter; pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, IndexableLengthError}; pub use keyword_set::{CaseInsensitiveKeywordSet, KeywordSet}; pub use layouter::{Alignment, LayoutError, LayoutErrorInner, Layouter, TypeLayout}; pub use namer::{EntryPointIndex, ExternalTextureNameKey, NameKey, Namer}; pub use overloads::{Conclusion, MissingSpecialType, OverloadSet, Rule}; pub use terminator::ensure_block_returns; use thiserror::Error; pub use type_methods::{ concrete_int_scalars, min_max_float_representable_by, vector_size_str, vector_sizes, }; pub use typifier::{compare_types, ResolveContext, ResolveError, TypeResolution}; use crate::non_max_u32::NonMaxU32; impl From for super::Scalar { fn from(format: super::StorageFormat) -> Self { use super::{ScalarKind as Sk, StorageFormat as Sf}; let kind = match format { Sf::R8Unorm => Sk::Float, Sf::R8Snorm => Sk::Float, Sf::R8Uint => Sk::Uint, Sf::R8Sint => Sk::Sint, Sf::R16Uint => Sk::Uint, Sf::R16Sint => Sk::Sint, Sf::R16Float => Sk::Float, Sf::Rg8Unorm => Sk::Float, Sf::Rg8Snorm => Sk::Float, Sf::Rg8Uint => Sk::Uint, Sf::Rg8Sint => Sk::Sint, Sf::R32Uint => Sk::Uint, Sf::R32Sint => Sk::Sint, Sf::R32Float => Sk::Float, Sf::Rg16Uint => Sk::Uint, Sf::Rg16Sint => Sk::Sint, Sf::Rg16Float => Sk::Float, Sf::Rgba8Unorm => Sk::Float, Sf::Rgba8Snorm => Sk::Float, Sf::Rgba8Uint => Sk::Uint, Sf::Rgba8Sint => Sk::Sint, Sf::Bgra8Unorm => Sk::Float, Sf::Rgb10a2Uint => Sk::Uint, Sf::Rgb10a2Unorm => Sk::Float, Sf::Rg11b10Ufloat => Sk::Float, Sf::R64Uint => Sk::Uint, Sf::Rg32Uint => Sk::Uint, Sf::Rg32Sint => Sk::Sint, Sf::Rg32Float => Sk::Float, Sf::Rgba16Uint => Sk::Uint, Sf::Rgba16Sint => Sk::Sint, Sf::Rgba16Float => Sk::Float, Sf::Rgba32Uint => Sk::Uint, Sf::Rgba32Sint => Sk::Sint, Sf::Rgba32Float => Sk::Float, Sf::R16Unorm => Sk::Float, Sf::R16Snorm => Sk::Float, Sf::Rg16Unorm => Sk::Float, Sf::Rg16Snorm => Sk::Float, Sf::Rgba16Unorm => Sk::Float, Sf::Rgba16Snorm => Sk::Float, }; let width = match format { Sf::R64Uint => 8, _ => 4, }; super::Scalar { kind, width } } } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum HashableLiteral { F64(u64), F32(u32), F16(u16), U32(u32), I32(i32), U64(u64), I64(i64), Bool(bool), AbstractInt(i64), AbstractFloat(u64), } impl From for HashableLiteral { fn from(l: crate::Literal) -> Self { match l { crate::Literal::F64(v) => Self::F64(v.to_bits()), crate::Literal::F32(v) => Self::F32(v.to_bits()), crate::Literal::F16(v) => Self::F16(v.to_bits()), crate::Literal::U32(v) => Self::U32(v), crate::Literal::I32(v) => Self::I32(v), crate::Literal::U64(v) => Self::U64(v), crate::Literal::I64(v) => Self::I64(v), crate::Literal::Bool(v) => Self::Bool(v), crate::Literal::AbstractInt(v) => Self::AbstractInt(v), crate::Literal::AbstractFloat(v) => Self::AbstractFloat(v.to_bits()), } } } impl crate::Literal { pub const fn new(value: u8, scalar: crate::Scalar) -> Option { match (value, scalar.kind, scalar.width) { (value, crate::ScalarKind::Float, 8) => Some(Self::F64(value as _)), (value, crate::ScalarKind::Float, 4) => Some(Self::F32(value as _)), (value, crate::ScalarKind::Float, 2) => { Some(Self::F16(half::f16::from_f32_const(value as _))) } (value, crate::ScalarKind::Uint, 4) => Some(Self::U32(value as _)), (value, crate::ScalarKind::Sint, 4) => Some(Self::I32(value as _)), (value, crate::ScalarKind::Uint, 8) => Some(Self::U64(value as _)), (value, crate::ScalarKind::Sint, 8) => Some(Self::I64(value as _)), (1, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(true)), (0, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(false)), (value, crate::ScalarKind::AbstractInt, 8) => Some(Self::AbstractInt(value as _)), (value, crate::ScalarKind::AbstractFloat, 8) => Some(Self::AbstractFloat(value as _)), _ => None, } } pub const fn zero(scalar: crate::Scalar) -> Option { Self::new(0, scalar) } pub const fn one(scalar: crate::Scalar) -> Option { Self::new(1, scalar) } pub const fn minus_one(scalar: crate::Scalar) -> Option { match (scalar.kind, scalar.width) { (crate::ScalarKind::Float, 8) => Some(Self::F64(-1.0)), (crate::ScalarKind::Float, 4) => Some(Self::F32(-1.0)), (crate::ScalarKind::Float, 2) => Some(Self::F16(half::f16::from_f32_const(-1.0))), (crate::ScalarKind::Sint, 8) => Some(Self::I64(-1)), (crate::ScalarKind::Sint, 4) => Some(Self::I32(-1)), (crate::ScalarKind::AbstractInt, 8) => Some(Self::AbstractInt(-1)), _ => None, } } pub const fn width(&self) -> crate::Bytes { match *self { Self::F64(_) | Self::I64(_) | Self::U64(_) => 8, Self::F32(_) | Self::U32(_) | Self::I32(_) => 4, Self::F16(_) => 2, Self::Bool(_) => crate::BOOL_WIDTH, Self::AbstractInt(_) | Self::AbstractFloat(_) => crate::ABSTRACT_WIDTH, } } pub const fn scalar(&self) -> crate::Scalar { match *self { Self::F64(_) => crate::Scalar::F64, Self::F32(_) => crate::Scalar::F32, Self::F16(_) => crate::Scalar::F16, Self::U32(_) => crate::Scalar::U32, Self::I32(_) => crate::Scalar::I32, Self::U64(_) => crate::Scalar::U64, Self::I64(_) => crate::Scalar::I64, Self::Bool(_) => crate::Scalar::BOOL, Self::AbstractInt(_) => crate::Scalar::ABSTRACT_INT, Self::AbstractFloat(_) => crate::Scalar::ABSTRACT_FLOAT, } } pub const fn scalar_kind(&self) -> crate::ScalarKind { self.scalar().kind } pub const fn ty_inner(&self) -> crate::TypeInner { crate::TypeInner::Scalar(self.scalar()) } } impl TryFrom for u32 { type Error = ConstValueError; fn try_from(value: crate::Literal) -> Result { match value { crate::Literal::U32(value) => Ok(value), crate::Literal::I32(value) => value.try_into().map_err(|_| ConstValueError::Negative), _ => Err(ConstValueError::InvalidType), } } } impl TryFrom for bool { type Error = ConstValueError; fn try_from(value: crate::Literal) -> Result { match value { crate::Literal::Bool(value) => Ok(value), _ => Err(ConstValueError::InvalidType), } } } impl super::AddressSpace { pub fn access(self) -> crate::StorageAccess { use crate::StorageAccess as Sa; match self { crate::AddressSpace::Function | crate::AddressSpace::Private | crate::AddressSpace::WorkGroup => Sa::LOAD | Sa::STORE, crate::AddressSpace::Uniform => Sa::LOAD, crate::AddressSpace::Storage { access } => access, crate::AddressSpace::Handle => Sa::LOAD, crate::AddressSpace::Immediate => Sa::LOAD, // TaskPayload isn't always writable, but this is checked for elsewhere, // when not using multiple payloads and matching the entry payload is checked. crate::AddressSpace::TaskPayload => Sa::LOAD | Sa::STORE, crate::AddressSpace::RayPayload | crate::AddressSpace::IncomingRayPayload => { Sa::LOAD | Sa::STORE } } } } impl super::MathFunction { pub const fn argument_count(&self) -> usize { match *self { // comparison Self::Abs => 1, Self::Min => 2, Self::Max => 2, Self::Clamp => 3, Self::Saturate => 1, // trigonometry Self::Cos => 1, Self::Cosh => 1, Self::Sin => 1, Self::Sinh => 1, Self::Tan => 1, Self::Tanh => 1, Self::Acos => 1, Self::Asin => 1, Self::Atan => 1, Self::Atan2 => 2, Self::Asinh => 1, Self::Acosh => 1, Self::Atanh => 1, Self::Radians => 1, Self::Degrees => 1, // decomposition Self::Ceil => 1, Self::Floor => 1, Self::Round => 1, Self::Fract => 1, Self::Trunc => 1, Self::Modf => 1, Self::Frexp => 1, Self::Ldexp => 2, // exponent Self::Exp => 1, Self::Exp2 => 1, Self::Log => 1, Self::Log2 => 1, Self::Pow => 2, // geometry Self::Dot => 2, Self::Dot4I8Packed => 2, Self::Dot4U8Packed => 2, Self::Outer => 2, Self::Cross => 2, Self::Distance => 2, Self::Length => 1, Self::Normalize => 1, Self::FaceForward => 3, Self::Reflect => 2, Self::Refract => 3, // computational Self::Sign => 1, Self::Fma => 3, Self::Mix => 3, Self::Step => 2, Self::SmoothStep => 3, Self::Sqrt => 1, Self::InverseSqrt => 1, Self::Inverse => 1, Self::Transpose => 1, Self::Determinant => 1, Self::QuantizeToF16 => 1, // bits Self::CountTrailingZeros => 1, Self::CountLeadingZeros => 1, Self::CountOneBits => 1, Self::ReverseBits => 1, Self::ExtractBits => 3, Self::InsertBits => 4, Self::FirstTrailingBit => 1, Self::FirstLeadingBit => 1, // data packing Self::Pack4x8snorm => 1, Self::Pack4x8unorm => 1, Self::Pack2x16snorm => 1, Self::Pack2x16unorm => 1, Self::Pack2x16float => 1, Self::Pack4xI8 => 1, Self::Pack4xU8 => 1, Self::Pack4xI8Clamp => 1, Self::Pack4xU8Clamp => 1, // data unpacking Self::Unpack4x8snorm => 1, Self::Unpack4x8unorm => 1, Self::Unpack2x16snorm => 1, Self::Unpack2x16unorm => 1, Self::Unpack2x16float => 1, Self::Unpack4xI8 => 1, Self::Unpack4xU8 => 1, } } } impl crate::Expression { /// Returns true if the expression is considered emitted at the start of a function. pub const fn needs_pre_emit(&self) -> bool { match *self { Self::Literal(_) | Self::Constant(_) | Self::Override(_) | Self::ZeroValue(_) | Self::FunctionArgument(_) | Self::GlobalVariable(_) | Self::LocalVariable(_) => true, _ => false, } } /// Return true if this expression is a dynamic array/vector/matrix index, /// for [`Access`]. /// /// This method returns true if this expression is a dynamically computed /// index, and as such can only be used to index matrices when they appear /// behind a pointer. See the documentation for [`Access`] for details. /// /// Note, this does not check the _type_ of the given expression. It's up to /// the caller to establish that the `Access` expression is well-typed /// through other means, like [`ResolveContext`]. /// /// [`Access`]: crate::Expression::Access /// [`ResolveContext`]: crate::proc::ResolveContext pub const fn is_dynamic_index(&self) -> bool { match *self { Self::Literal(_) | Self::ZeroValue(_) | Self::Constant(_) => false, _ => true, } } } impl crate::Function { /// Return the global variable being accessed by the expression `pointer`. /// /// Assuming that `pointer` is a series of `Access` and `AccessIndex` /// expressions that ultimately access some part of a `GlobalVariable`, /// return a handle for that global. /// /// If the expression does not ultimately access a global variable, return /// `None`. pub fn originating_global( &self, mut pointer: crate::Handle, ) -> Option> { loop { pointer = match self.expressions[pointer] { crate::Expression::Access { base, .. } => base, crate::Expression::AccessIndex { base, .. } => base, crate::Expression::GlobalVariable(handle) => return Some(handle), crate::Expression::LocalVariable(_) => return None, crate::Expression::FunctionArgument(_) => return None, // There are no other expressions that produce pointer values. _ => unreachable!(), } } } } impl crate::SampleLevel { pub const fn implicit_derivatives(&self) -> bool { match *self { Self::Auto | Self::Bias(_) => true, Self::Zero | Self::Exact(_) | Self::Gradient { .. } => false, } } } impl crate::Binding { pub const fn to_built_in(&self) -> Option { match *self { crate::Binding::BuiltIn(built_in) => Some(built_in), Self::Location { .. } => None, } } } impl super::SwizzleComponent { pub const XYZW: [Self; 4] = [Self::X, Self::Y, Self::Z, Self::W]; pub const fn index(&self) -> u32 { match *self { Self::X => 0, Self::Y => 1, Self::Z => 2, Self::W => 3, } } pub const fn from_index(idx: u32) -> Self { match idx { 0 => Self::X, 1 => Self::Y, 2 => Self::Z, _ => Self::W, } } } impl super::ImageClass { pub const fn is_multisampled(self) -> bool { match self { crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => multi, crate::ImageClass::Storage { .. } => false, crate::ImageClass::External => false, } } pub const fn is_mipmapped(self) -> bool { match self { crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => !multi, crate::ImageClass::Storage { .. } => false, crate::ImageClass::External => false, } } pub const fn is_depth(self) -> bool { matches!(self, crate::ImageClass::Depth { .. }) } } impl crate::Module { pub const fn to_ctx(&self) -> GlobalCtx<'_> { GlobalCtx { types: &self.types, constants: &self.constants, overrides: &self.overrides, global_expressions: &self.global_expressions, } } pub fn compare_types(&self, lhs: &TypeResolution, rhs: &TypeResolution) -> bool { compare_types(lhs, rhs, &self.types) } } #[derive(Debug)] pub enum ConstValueError { NonConst, Negative, InvalidType, } impl From for ConstValueError { fn from(_: core::convert::Infallible) -> Self { unreachable!() } } #[derive(Clone, Copy)] pub struct GlobalCtx<'a> { pub types: &'a crate::UniqueArena, pub constants: &'a crate::Arena, pub overrides: &'a crate::Arena, pub global_expressions: &'a crate::Arena, } impl GlobalCtx<'_> { /// Try to evaluate the expression in `self.global_expressions` using its `handle` /// and return it as a `T: TryFrom`. /// /// This currently only evaluates scalar expressions. If adding support for vectors, /// consider changing `valid::expression::validate_constant_shift_amounts` to use that /// support. #[cfg_attr( not(any( feature = "glsl-in", feature = "spv-in", feature = "wgsl-in", glsl_out, hlsl_out, msl_out, wgsl_out )), allow(dead_code) )] pub(super) fn get_const_val( &self, handle: crate::Handle, ) -> Result where T: TryFrom, E: Into, { self.get_const_val_from(handle, self.global_expressions) } pub(super) fn get_const_val_from( &self, handle: crate::Handle, arena: &crate::Arena, ) -> Result where T: TryFrom, E: Into, { fn get( gctx: GlobalCtx, handle: crate::Handle, arena: &crate::Arena, ) -> Option { match arena[handle] { crate::Expression::Literal(literal) => Some(literal), crate::Expression::ZeroValue(ty) => match gctx.types[ty].inner { crate::TypeInner::Scalar(scalar) => crate::Literal::zero(scalar), _ => None, }, _ => None, } } let value = match arena[handle] { crate::Expression::Constant(c) => { get(*self, self.constants[c].init, self.global_expressions) } _ => get(*self, handle, arena), }; match value { Some(v) => v.try_into().map_err(Into::into), None => Err(ConstValueError::NonConst), } } pub fn compare_types(&self, lhs: &TypeResolution, rhs: &TypeResolution) -> bool { compare_types(lhs, rhs, self.types) } } #[derive(Error, Debug, Clone, Copy, PartialEq)] pub enum ResolveArraySizeError { #[error("array element count must be positive (> 0)")] ExpectedPositiveArrayLength, #[error("internal: array size override has not been resolved")] NonConstArrayLength, } impl crate::ArraySize { /// Return the number of elements that `size` represents, if known at code generation time. /// /// If `size` is override-based, return an error unless the override's /// initializer is a fully evaluated constant expression. You can call /// [`pipeline_constants::process_overrides`] to supply values for a /// module's overrides and ensure their initializers are fully evaluated, as /// this function expects. /// /// [`pipeline_constants::process_overrides`]: crate::back::pipeline_constants::process_overrides pub fn resolve(&self, gctx: GlobalCtx) -> Result { match *self { crate::ArraySize::Constant(length) => Ok(IndexableLength::Known(length.get())), crate::ArraySize::Pending(handle) => { let Some(expr) = gctx.overrides[handle].init else { return Err(ResolveArraySizeError::NonConstArrayLength); }; let length = gctx.get_const_val(expr).map_err(|err| match err { ConstValueError::NonConst => ResolveArraySizeError::NonConstArrayLength, ConstValueError::Negative | ConstValueError::InvalidType => { ResolveArraySizeError::ExpectedPositiveArrayLength } })?; if length == 0 { return Err(ResolveArraySizeError::ExpectedPositiveArrayLength); } Ok(IndexableLength::Known(length)) } crate::ArraySize::Dynamic => Ok(IndexableLength::Dynamic), } } } /// Return an iterator over the individual components assembled by a /// `Compose` expression. /// /// Given `ty` and `components` from an `Expression::Compose`, return an /// iterator over the components of the resulting value. /// /// Normally, this would just be an iterator over `components`. However, /// `Compose` expressions can concatenate vectors, in which case the i'th /// value being composed is not generally the i'th element of `components`. /// This function consults `ty` to decide if this concatenation is occurring, /// and returns an iterator that produces the components of the result of /// the `Compose` expression in either case. pub fn flatten_compose<'arenas>( ty: crate::Handle, components: &'arenas [crate::Handle], expressions: &'arenas crate::Arena, types: &'arenas crate::UniqueArena, ) -> impl Iterator> + 'arenas { // Returning `impl Iterator` is a bit tricky. We may or may not // want to flatten the components, but we have to settle on a // single concrete type to return. This function returns a single // iterator chain that handles both the flattening and // non-flattening cases. let (size, is_vector) = if let crate::TypeInner::Vector { size, .. } = types[ty].inner { (size as usize, true) } else { (components.len(), false) }; /// Flatten `Compose` expressions if `is_vector` is true. fn flatten_compose<'c>( component: &'c crate::Handle, is_vector: bool, expressions: &'c crate::Arena, ) -> &'c [crate::Handle] { if is_vector { if let crate::Expression::Compose { ty: _, components: ref subcomponents, } = expressions[*component] { return subcomponents; } } core::slice::from_ref(component) } /// Flatten `Splat` expressions if `is_vector` is true. fn flatten_splat<'c>( component: &'c crate::Handle, is_vector: bool, expressions: &'c crate::Arena, ) -> impl Iterator> { let mut expr = *component; let mut count = 1; if is_vector { if let crate::Expression::Splat { size, value } = expressions[expr] { expr = value; count = size as usize; } } core::iter::repeat_n(expr, count) } // Expressions like `vec4(vec3(vec2(6, 7), 8), 9)` require us to // flatten up to two levels of `Compose` expressions. // // Expressions like `vec4(vec3(1.0), 1.0)` require us to flatten // `Splat` expressions. Fortunately, the operand of a `Splat` must // be a scalar, so we can stop there. components .iter() .flat_map(move |component| flatten_compose(component, is_vector, expressions)) .flat_map(move |component| flatten_compose(component, is_vector, expressions)) .flat_map(move |component| flatten_splat(component, is_vector, expressions)) .take(size) } impl super::ShaderStage { pub const fn compute_like(self) -> bool { match self { Self::Vertex | Self::Fragment => false, Self::Compute | Self::Task | Self::Mesh => true, Self::RayGeneration | Self::AnyHit | Self::ClosestHit | Self::Miss => false, } } /// Mesh or task shader pub const fn mesh_like(self) -> bool { match self { Self::Task | Self::Mesh => true, _ => false, } } } #[test] fn test_matrix_size() { let module = crate::Module::default(); assert_eq!( crate::TypeInner::Matrix { columns: crate::VectorSize::Tri, rows: crate::VectorSize::Tri, scalar: crate::Scalar::F32, } .size(module.to_ctx()), 48, ); } impl crate::Module { /// Extracts mesh shader info from a mesh output global variable. Used in frontends /// and by validators. This only validates the output variable itself, and not the /// vertex and primitive output types. /// /// The output contains the extracted mesh stage info, with overrides unset, /// and then the overrides separately. This is because the overrides should be /// treated as expressions elsewhere, but that requires mutably modifying the /// module and the expressions should only be created at parse time, not validation /// time. #[allow(clippy::type_complexity)] pub fn analyze_mesh_shader_info( &self, gv: crate::Handle, ) -> ( crate::MeshStageInfo, [Option>; 2], Option>, ) { use crate::span::AddSpan; use crate::valid::EntryPointError; #[derive(Default)] struct OutError { pub inner: Option, } impl OutError { pub fn set(&mut self, err: EntryPointError) { if self.inner.is_none() { self.inner = Some(err); } } } // Used to temporarily initialize stuff let null_type = crate::Handle::new(NonMaxU32::new(0).unwrap()); let mut output = crate::MeshStageInfo { topology: crate::MeshOutputTopology::Triangles, max_vertices: 0, max_vertices_override: None, max_primitives: 0, max_primitives_override: None, vertex_output_type: null_type, primitive_output_type: null_type, output_variable: gv, }; // Stores the error to output, if any. let mut error = OutError::default(); let r#type = &self.types[self.global_variables[gv].ty].inner; let mut topology = output.topology; // Max, max override, type let mut vertex_info = (0, None, null_type); let mut primitive_info = (0, None, null_type); match r#type { &crate::TypeInner::Struct { ref members, .. } => { let mut builtins = crate::FastHashSet::default(); for member in members { match member.binding { Some(crate::Binding::BuiltIn(crate::BuiltIn::VertexCount)) => { // Must have type u32 if self.types[member.ty].inner.scalar() != Some(crate::Scalar::U32) { error.set(EntryPointError::BadMeshOutputVariableField); } // Each builtin should only occur once if builtins.contains(&crate::BuiltIn::VertexCount) { error.set(EntryPointError::BadMeshOutputVariableType); } builtins.insert(crate::BuiltIn::VertexCount); } Some(crate::Binding::BuiltIn(crate::BuiltIn::PrimitiveCount)) => { // Must have type u32 if self.types[member.ty].inner.scalar() != Some(crate::Scalar::U32) { error.set(EntryPointError::BadMeshOutputVariableField); } // Each builtin should only occur once if builtins.contains(&crate::BuiltIn::PrimitiveCount) { error.set(EntryPointError::BadMeshOutputVariableType); } builtins.insert(crate::BuiltIn::PrimitiveCount); } Some(crate::Binding::BuiltIn( crate::BuiltIn::Vertices | crate::BuiltIn::Primitives, )) => { let ty = &self.types[member.ty].inner; // Analyze the array type to determine size and vertex/primitive type let (a, b, c) = match ty { &crate::TypeInner::Array { base, size, .. } => { let ty = base; let (max, max_override) = match size { crate::ArraySize::Constant(a) => (a.get(), None), crate::ArraySize::Pending(o) => (0, Some(o)), crate::ArraySize::Dynamic => { error.set(EntryPointError::BadMeshOutputVariableField); (0, None) } }; (max, max_override, ty) } _ => { error.set(EntryPointError::BadMeshOutputVariableField); (0, None, null_type) } }; if matches!( member.binding, Some(crate::Binding::BuiltIn(crate::BuiltIn::Primitives)) ) { // Primitives require special analysis to determine topology primitive_info = (a, b, c); match self.types[c].inner { crate::TypeInner::Struct { ref members, .. } => { for member in members { match member.binding { Some(crate::Binding::BuiltIn( crate::BuiltIn::PointIndex, )) => { topology = crate::MeshOutputTopology::Points; } Some(crate::Binding::BuiltIn( crate::BuiltIn::LineIndices, )) => { topology = crate::MeshOutputTopology::Lines; } Some(crate::Binding::BuiltIn( crate::BuiltIn::TriangleIndices, )) => { topology = crate::MeshOutputTopology::Triangles; } _ => (), } } } _ => (), } // Each builtin should only occur once if builtins.contains(&crate::BuiltIn::Primitives) { error.set(EntryPointError::BadMeshOutputVariableType); } builtins.insert(crate::BuiltIn::Primitives); } else { vertex_info = (a, b, c); // Each builtin should only occur once if builtins.contains(&crate::BuiltIn::Vertices) { error.set(EntryPointError::BadMeshOutputVariableType); } builtins.insert(crate::BuiltIn::Vertices); } } _ => error.set(EntryPointError::BadMeshOutputVariableType), } } output = crate::MeshStageInfo { topology, max_vertices: vertex_info.0, max_vertices_override: None, vertex_output_type: vertex_info.2, max_primitives: primitive_info.0, max_primitives_override: None, primitive_output_type: primitive_info.2, ..output } } _ => error.set(EntryPointError::BadMeshOutputVariableType), } ( output, [vertex_info.1, primitive_info.1], error .inner .map(|a| a.with_span_handle(self.global_variables[gv].ty, &self.types)), ) } pub fn uses_mesh_shaders(&self) -> bool { let binding_uses_mesh = |b: &crate::Binding| { matches!( b, crate::Binding::BuiltIn( crate::BuiltIn::MeshTaskSize | crate::BuiltIn::CullPrimitive | crate::BuiltIn::PointIndex | crate::BuiltIn::LineIndices | crate::BuiltIn::TriangleIndices | crate::BuiltIn::VertexCount | crate::BuiltIn::Vertices | crate::BuiltIn::PrimitiveCount | crate::BuiltIn::Primitives, ) | crate::Binding::Location { per_primitive: true, .. } ) }; for (_, ty) in self.types.iter() { match ty.inner { crate::TypeInner::Struct { ref members, .. } => { for binding in members.iter().filter_map(|m| m.binding.as_ref()) { if binding_uses_mesh(binding) { return true; } } } _ => (), } } for ep in &self.entry_points { if matches!( ep.stage, crate::ShaderStage::Mesh | crate::ShaderStage::Task ) { return true; } for binding in ep .function .arguments .iter() .filter_map(|arg| arg.binding.as_ref()) .chain( ep.function .result .iter() .filter_map(|res| res.binding.as_ref()), ) { if binding_uses_mesh(binding) { return true; } } } if self .global_variables .iter() .any(|gv| gv.1.space == crate::AddressSpace::TaskPayload) { return true; } false } } impl crate::MeshOutputTopology { pub const fn to_builtin(self) -> crate::BuiltIn { match self { Self::Points => crate::BuiltIn::PointIndex, Self::Lines => crate::BuiltIn::LineIndices, Self::Triangles => crate::BuiltIn::TriangleIndices, } } } impl crate::AddressSpace { pub const fn is_workgroup_like(self) -> bool { matches!(self, Self::WorkGroup | Self::TaskPayload) } } naga-29.0.3/src/proc/namer.rs000064400000000000000000000406111046102023000140310ustar 00000000000000use alloc::{ borrow::Cow, format, string::{String, ToString}, vec::Vec, }; use crate::{ arena::Handle, proc::{keyword_set::CaseInsensitiveKeywordSet, KeywordSet}, FastHashMap, }; pub type EntryPointIndex = u16; const SEPARATOR: char = '_'; /// A component of a lowered external texture. /// /// Whereas the WGSL backend implements [`ImageClass::External`] /// images directly, most other Naga backends lower them to a /// collection of ordinary textures that represent individual planes /// (as received from a video decoder, perhaps), together with a /// struct of parameters saying how they should be cropped, sampled, /// and color-converted. /// /// This lowering means that individual globals and function /// parameters in Naga IR must be split out by the backends into /// collections of globals and parameters of simpler types. /// /// A value of this enum serves as a name key for one specific /// component in the lowered representation of an external texture. /// That is, these keys are for variables/parameters that do not exist /// in the Naga IR, only in its lowered form. /// /// [`ImageClass::External`]: crate::ir::ImageClass::External #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] pub enum ExternalTextureNameKey { Plane(usize), Params, } impl ExternalTextureNameKey { const ALL: &[(&str, ExternalTextureNameKey)] = &[ ("_plane0", ExternalTextureNameKey::Plane(0)), ("_plane1", ExternalTextureNameKey::Plane(1)), ("_plane2", ExternalTextureNameKey::Plane(2)), ("_params", ExternalTextureNameKey::Params), ]; } #[derive(Debug, Eq, Hash, PartialEq)] pub enum NameKey { Constant(Handle), Override(Handle), GlobalVariable(Handle), Type(Handle), StructMember(Handle, u32), Function(Handle), FunctionArgument(Handle, u32), FunctionLocal(Handle, Handle), /// A local variable used by ReadZeroSkipWrite bounds-check policy /// when it needs to produce a pointer-typed result for an OOB access. /// These are unique per accessed type, so the second element is a /// type handle. See docs for [`crate::back::msl`]. FunctionOobLocal(Handle, Handle), EntryPoint(EntryPointIndex), EntryPointLocal(EntryPointIndex, Handle), EntryPointArgument(EntryPointIndex, u32), /// Entry point version of `FunctionOobLocal`. EntryPointOobLocal(EntryPointIndex, Handle), /// A global variable holding a component of a lowered external texture. /// /// See [`ExternalTextureNameKey`] for details. ExternalTextureGlobalVariable(Handle, ExternalTextureNameKey), /// A function argument holding a component of a lowered external /// texture. /// /// See [`ExternalTextureNameKey`] for details. ExternalTextureFunctionArgument(Handle, u32, ExternalTextureNameKey), } /// This processor assigns names to all the things in a module /// that may need identifiers in a textual backend. #[derive(Default)] pub struct Namer { /// The last numeric suffix used for each base name. Zero means "no suffix". unique: FastHashMap, keywords: &'static KeywordSet, builtin_identifiers: &'static KeywordSet, keywords_case_insensitive: &'static CaseInsensitiveKeywordSet, reserved_prefixes: Vec<&'static str>, } impl Namer { /// Return a form of `string` suitable for use as the base of an identifier. /// /// - Drop leading digits. /// - Retain only alphanumeric and `_` characters. /// - Avoid prefixes in [`Namer::reserved_prefixes`]. /// - Replace consecutive `_` characters with a single `_` character. /// /// The return value is a valid identifier prefix in all of Naga's output languages, /// and it never ends with a `SEPARATOR` character. /// It is used as a key into the unique table. fn sanitize<'s>(&self, string: &'s str) -> Cow<'s, str> { let string = string .trim_start_matches(|c: char| c.is_numeric()) .trim_end_matches(SEPARATOR); let base = if !string.is_empty() && !string.contains("__") && string .chars() .all(|c: char| c.is_ascii_alphanumeric() || c == '_') { Cow::Borrowed(string) } else { let mut filtered = string.chars().fold(String::new(), |mut s, c| { let c = match c { // Make several common characters in C++-ish types become snake case // separators. ':' | '<' | '>' | ',' => '_', c => c, }; let had_underscore_at_end = s.ends_with('_'); if had_underscore_at_end && c == '_' { return s; } if c.is_ascii_alphanumeric() || c == '_' { s.push(c); } else { use core::fmt::Write as _; if !s.is_empty() && !had_underscore_at_end { s.push('_'); } write!(s, "u{:04x}_", c as u32).unwrap(); } s }); let stripped_len = filtered.trim_end_matches(SEPARATOR).len(); filtered.truncate(stripped_len); if filtered.is_empty() { filtered.push_str("unnamed"); } else if filtered.starts_with(|c: char| c.is_ascii_digit()) { unreachable!( "internal error: invalid identifier starting with ASCII digit {:?}", filtered.chars().nth(0) ) } Cow::Owned(filtered) }; for prefix in &self.reserved_prefixes { if base.starts_with(prefix) { return format!("gen_{base}").into(); } } base } /// Return a new identifier based on `label_raw`. /// /// The result: /// - is a valid identifier even if `label_raw` is not /// - conflicts with no keywords listed in `Namer::keywords`, and /// - is different from any identifier previously constructed by this /// `Namer`. /// /// Guarantee uniqueness by applying a numeric suffix when necessary. If `label_raw` /// itself ends with digits, separate them from the suffix with an underscore. pub fn call(&mut self, label_raw: &str) -> String { use core::fmt::Write as _; // for write!-ing to Strings let base = self.sanitize(label_raw); debug_assert!(!base.is_empty() && !base.ends_with(SEPARATOR)); // This would seem to be a natural place to use `HashMap::entry`. However, `entry` // requires an owned key, and we'd like to avoid heap-allocating strings we're // just going to throw away. The approach below double-hashes only when we create // a new entry, in which case the heap allocation of the owned key was more // expensive anyway. match self.unique.get_mut(base.as_ref()) { Some(count) => { *count += 1; // Add the suffix. This may fit in base's existing allocation. let mut suffixed = base.into_owned(); write!(suffixed, "{}{}", SEPARATOR, *count).unwrap(); suffixed } None => { let mut suffixed = base.to_string(); if base.ends_with(char::is_numeric) || self.keywords.contains(base.as_ref()) || self.keywords_case_insensitive.contains(base.as_ref()) || self.builtin_identifiers.contains(base.as_ref()) { suffixed.push(SEPARATOR); } debug_assert!(!self.keywords.contains(&suffixed)); // `self.unique` wants to own its keys. This allocates only if we haven't // already done so earlier. self.unique.insert(base.into_owned(), 0); suffixed } } } pub fn call_or(&mut self, label: &Option, fallback: &str) -> String { self.call(match *label { Some(ref name) => name, None => fallback, }) } /// Enter a local namespace for things like structs. /// /// Struct member names only need to be unique amongst themselves, not /// globally. This function temporarily establishes a fresh, empty naming /// context for the duration of the call to `body`. fn namespace(&mut self, capacity: usize, body: impl FnOnce(&mut Self)) { let empty_unique = FastHashMap::with_capacity_and_hasher(capacity, Default::default()); let saved_unique = core::mem::replace(&mut self.unique, empty_unique); let saved_builtin_identifiers = core::mem::take(&mut self.builtin_identifiers); body(self); self.unique = saved_unique; self.builtin_identifiers = saved_builtin_identifiers; } pub fn reset( &mut self, module: &crate::Module, reserved_keywords: &'static KeywordSet, builtin_identifiers: &'static KeywordSet, reserved_keywords_case_insensitive: &'static CaseInsensitiveKeywordSet, reserved_prefixes: &[&'static str], output: &mut FastHashMap, ) { self.reserved_prefixes.clear(); self.reserved_prefixes.extend(reserved_prefixes.iter()); self.unique.clear(); self.keywords = reserved_keywords; self.builtin_identifiers = builtin_identifiers; self.keywords_case_insensitive = reserved_keywords_case_insensitive; // Choose fallback names for anonymous entry point return types. let mut entrypoint_type_fallbacks = FastHashMap::default(); for ep in &module.entry_points { if let Some(ref result) = ep.function.result { if let crate::Type { name: None, inner: crate::TypeInner::Struct { .. }, } = module.types[result.ty] { let label = match ep.stage { crate::ShaderStage::Vertex => "VertexOutput", crate::ShaderStage::Fragment => "FragmentOutput", crate::ShaderStage::Compute => "ComputeOutput", crate::ShaderStage::Task | crate::ShaderStage::Mesh | crate::ShaderStage::RayGeneration | crate::ShaderStage::ClosestHit | crate::ShaderStage::AnyHit | crate::ShaderStage::Miss => unreachable!(), }; entrypoint_type_fallbacks.insert(result.ty, label); } } } let mut temp = String::new(); for (ty_handle, ty) in module.types.iter() { // If the type is anonymous, check `entrypoint_types` for // something better than just `"type"`. let raw_label = match ty.name { Some(ref given_name) => given_name.as_str(), None => entrypoint_type_fallbacks .get(&ty_handle) .cloned() .unwrap_or("type"), }; let ty_name = self.call(raw_label); output.insert(NameKey::Type(ty_handle), ty_name); if let crate::TypeInner::Struct { ref members, .. } = ty.inner { // struct members have their own namespace, because access is always prefixed self.namespace(members.len(), |namer| { for (index, member) in members.iter().enumerate() { let name = namer.call_or(&member.name, "member"); output.insert(NameKey::StructMember(ty_handle, index as u32), name); } }) } } for (ep_index, ep) in module.entry_points.iter().enumerate() { let ep_name = self.call(&ep.name); output.insert(NameKey::EntryPoint(ep_index as _), ep_name); for (index, arg) in ep.function.arguments.iter().enumerate() { let name = self.call_or(&arg.name, "param"); output.insert( NameKey::EntryPointArgument(ep_index as _, index as u32), name, ); } for (handle, var) in ep.function.local_variables.iter() { let name = self.call_or(&var.name, "local"); output.insert(NameKey::EntryPointLocal(ep_index as _, handle), name); } } for (fun_handle, fun) in module.functions.iter() { let fun_name = self.call_or(&fun.name, "function"); output.insert(NameKey::Function(fun_handle), fun_name); for (index, arg) in fun.arguments.iter().enumerate() { let name = self.call_or(&arg.name, "param"); output.insert(NameKey::FunctionArgument(fun_handle, index as u32), name); if matches!( module.types[arg.ty].inner, crate::TypeInner::Image { class: crate::ImageClass::External, .. } ) { let base = arg.name.as_deref().unwrap_or("param"); for &(suffix, ext_key) in ExternalTextureNameKey::ALL { let name = self.call(&format!("{base}_{suffix}")); output.insert( NameKey::ExternalTextureFunctionArgument( fun_handle, index as u32, ext_key, ), name, ); } } } for (handle, var) in fun.local_variables.iter() { let name = self.call_or(&var.name, "local"); output.insert(NameKey::FunctionLocal(fun_handle, handle), name); } } for (handle, var) in module.global_variables.iter() { let name = self.call_or(&var.name, "global"); output.insert(NameKey::GlobalVariable(handle), name); if matches!( module.types[var.ty].inner, crate::TypeInner::Image { class: crate::ImageClass::External, .. } ) { let base = var.name.as_deref().unwrap_or("global"); for &(suffix, ext_key) in ExternalTextureNameKey::ALL { let name = self.call(&format!("{base}_{suffix}")); output.insert( NameKey::ExternalTextureGlobalVariable(handle, ext_key), name, ); } } } for (handle, constant) in module.constants.iter() { let label = match constant.name { Some(ref name) => name, None => { use core::fmt::Write; // Try to be more descriptive about the constant values temp.clear(); write!(temp, "const_{}", output[&NameKey::Type(constant.ty)]).unwrap(); &temp } }; let name = self.call(label); output.insert(NameKey::Constant(handle), name); } for (handle, override_) in module.overrides.iter() { let label = match override_.name { Some(ref name) => name, None => { use core::fmt::Write; // Try to be more descriptive about the override values temp.clear(); write!(temp, "override_{}", output[&NameKey::Type(override_.ty)]).unwrap(); &temp } }; let name = self.call(label); output.insert(NameKey::Override(handle), name); } } } #[test] fn test() { let mut namer = Namer::default(); assert_eq!(namer.call("x"), "x"); assert_eq!(namer.call("x"), "x_1"); assert_eq!(namer.call("x1"), "x1_"); assert_eq!(namer.call("__x"), "_x"); assert_eq!(namer.call("1___x"), "_x_1"); } naga-29.0.3/src/proc/overloads/any_overload_set.rs000064400000000000000000000071151046102023000202640ustar 00000000000000//! Dynamically dispatched [`OverloadSet`]s. use crate::common::DiagnosticDebug; use crate::ir; use crate::proc::overloads::{list, regular, OverloadSet, Rule}; use crate::proc::{GlobalCtx, TypeResolution}; use alloc::vec::Vec; use core::fmt; macro_rules! define_any_overload_set { { $( $module:ident :: $name:ident, )* } => { /// An [`OverloadSet`] that dynamically dispatches to concrete implementations. #[derive(Clone)] pub(in crate::proc::overloads) enum AnyOverloadSet { $( $name ( $module :: $name ), )* } $( impl From<$module::$name> for AnyOverloadSet { fn from(concrete: $module::$name) -> Self { AnyOverloadSet::$name(concrete) } } )* impl OverloadSet for AnyOverloadSet { fn is_empty(&self) -> bool { match *self { $( AnyOverloadSet::$name(ref x) => x.is_empty(), )* } } fn min_arguments(&self) -> usize { match *self { $( AnyOverloadSet::$name(ref x) => x.min_arguments(), )* } } fn max_arguments(&self) -> usize { match *self { $( AnyOverloadSet::$name(ref x) => x.max_arguments(), )* } } fn arg( &self, i: usize, ty: &ir::TypeInner, types: &crate::UniqueArena, ) -> Self { match *self { $( AnyOverloadSet::$name(ref x) => AnyOverloadSet::$name(x.arg(i, ty, types)), )* } } fn concrete_only(self, types: &crate::UniqueArena) -> Self { match self { $( AnyOverloadSet::$name(x) => AnyOverloadSet::$name(x.concrete_only(types)), )* } } fn most_preferred(&self) -> Rule { match *self { $( AnyOverloadSet::$name(ref x) => x.most_preferred(), )* } } fn overload_list(&self, gctx: &GlobalCtx<'_>) -> Vec { match *self { $( AnyOverloadSet::$name(ref x) => x.overload_list(gctx), )* } } fn allowed_args(&self, i: usize, gctx: &GlobalCtx<'_>) -> Vec { match *self { $( AnyOverloadSet::$name(ref x) => x.allowed_args(i, gctx), )* } } fn for_debug(&self, types: &crate::UniqueArena) -> impl fmt::Debug { DiagnosticDebug((self, types)) } } impl fmt::Debug for DiagnosticDebug<(&AnyOverloadSet, &crate::UniqueArena)> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let (set, types) = self.0; match *set { $( AnyOverloadSet::$name(ref x) => DiagnosticDebug((x, types)).fmt(f), )* } } } } } define_any_overload_set! { list::List, regular::Regular, } naga-29.0.3/src/proc/overloads/constructor_set.rs000064400000000000000000000124671046102023000201750ustar 00000000000000//! A set of type constructors, represented as a bitset. use crate::ir; use crate::proc::overloads::one_bits_iter::OneBitsIter; bitflags::bitflags! { /// A set of type constructors. #[derive(Copy, Clone, Debug, PartialEq)] pub(crate) struct ConstructorSet: u16 { const SCALAR = 1 << 0; const VEC2 = 1 << 1; const VEC3 = 1 << 2; const VEC4 = 1 << 3; const MAT2X2 = 1 << 4; const MAT2X3 = 1 << 5; const MAT2X4 = 1 << 6; const MAT3X2 = 1 << 7; const MAT3X3 = 1 << 8; const MAT3X4 = 1 << 9; const MAT4X2 = 1 << 10; const MAT4X3 = 1 << 11; const MAT4X4 = 1 << 12; const VECN = Self::VEC2.bits() | Self::VEC3.bits() | Self::VEC4.bits(); } } impl ConstructorSet { /// Return the single-member set containing `inner`'s constructor. pub const fn singleton(inner: &ir::TypeInner) -> ConstructorSet { use ir::TypeInner as Ti; use ir::VectorSize as Vs; match *inner { Ti::Scalar(_) => Self::SCALAR, Ti::Vector { size, scalar: _ } => match size { Vs::Bi => Self::VEC2, Vs::Tri => Self::VEC3, Vs::Quad => Self::VEC4, }, Ti::Matrix { columns, rows, scalar: _, } => match (columns, rows) { (Vs::Bi, Vs::Bi) => Self::MAT2X2, (Vs::Bi, Vs::Tri) => Self::MAT2X3, (Vs::Bi, Vs::Quad) => Self::MAT2X4, (Vs::Tri, Vs::Bi) => Self::MAT3X2, (Vs::Tri, Vs::Tri) => Self::MAT3X3, (Vs::Tri, Vs::Quad) => Self::MAT3X4, (Vs::Quad, Vs::Bi) => Self::MAT4X2, (Vs::Quad, Vs::Tri) => Self::MAT4X3, (Vs::Quad, Vs::Quad) => Self::MAT4X4, }, _ => Self::empty(), } } pub const fn is_singleton(self) -> bool { self.bits().is_power_of_two() } /// Return an iterator over this set's members. /// /// Members are produced as singleton, in order from most general to least. pub fn members(self) -> impl Iterator { OneBitsIter::new(self.bits() as u64).map(|bit| Self::from_bits(bit as u16).unwrap()) } /// Return the size of the sole element of `self`. /// /// # Panics /// /// Panic if `self` is not a singleton. pub fn size(self) -> ConstructorSize { use ir::VectorSize as Vs; use ConstructorSize as Cs; match self { ConstructorSet::SCALAR => Cs::Scalar, ConstructorSet::VEC2 => Cs::Vector(Vs::Bi), ConstructorSet::VEC3 => Cs::Vector(Vs::Tri), ConstructorSet::VEC4 => Cs::Vector(Vs::Quad), ConstructorSet::MAT2X2 => Cs::Matrix { columns: Vs::Bi, rows: Vs::Bi, }, ConstructorSet::MAT2X3 => Cs::Matrix { columns: Vs::Bi, rows: Vs::Tri, }, ConstructorSet::MAT2X4 => Cs::Matrix { columns: Vs::Bi, rows: Vs::Quad, }, ConstructorSet::MAT3X2 => Cs::Matrix { columns: Vs::Tri, rows: Vs::Bi, }, ConstructorSet::MAT3X3 => Cs::Matrix { columns: Vs::Tri, rows: Vs::Tri, }, ConstructorSet::MAT3X4 => Cs::Matrix { columns: Vs::Tri, rows: Vs::Quad, }, ConstructorSet::MAT4X2 => Cs::Matrix { columns: Vs::Quad, rows: Vs::Bi, }, ConstructorSet::MAT4X3 => Cs::Matrix { columns: Vs::Quad, rows: Vs::Tri, }, ConstructorSet::MAT4X4 => Cs::Matrix { columns: Vs::Quad, rows: Vs::Quad, }, _ => unreachable!("ConstructorSet was not a singleton"), } } } /// The sizes a member of [`ConstructorSet`] might have. #[derive(Clone, Copy)] pub enum ConstructorSize { /// The constructor is [`SCALAR`]. /// /// [`SCALAR`]: ConstructorSet::SCALAR Scalar, /// The constructor is `VECN` for some `N`. Vector(ir::VectorSize), /// The constructor is `MATCXR` for some `C` and `R`. Matrix { columns: ir::VectorSize, rows: ir::VectorSize, }, } impl ConstructorSize { /// Construct a [`TypeInner`] for a type with this size and the given `scalar`. /// /// [`TypeInner`]: ir::TypeInner pub const fn to_inner(self, scalar: ir::Scalar) -> ir::TypeInner { match self { Self::Scalar => ir::TypeInner::Scalar(scalar), Self::Vector(size) => ir::TypeInner::Vector { size, scalar }, Self::Matrix { columns, rows } => ir::TypeInner::Matrix { columns, rows, scalar, }, } } } macro_rules! constructor_set { ( $( $constr:ident )|* ) => { { use $crate::proc::overloads::constructor_set::ConstructorSet; ConstructorSet::empty() $( .union(ConstructorSet::$constr) )* } } } pub(in crate::proc::overloads) use constructor_set; naga-29.0.3/src/proc/overloads/list.rs000064400000000000000000000142531046102023000157030ustar 00000000000000//! An [`OverloadSet`] represented as a vector of rules. //! //! [`OverloadSet`]: crate::proc::overloads::OverloadSet use crate::common::{DiagnosticDebug, ForDebug, ForDebugWithTypes}; use crate::ir; use crate::proc::overloads::one_bits_iter::OneBitsIter; use crate::proc::overloads::Rule; use crate::proc::{GlobalCtx, TypeResolution}; use alloc::rc::Rc; use alloc::vec::Vec; use core::fmt; /// A simple list of overloads. /// /// Note that this type is not quite as general as it looks, in that /// the implementation of `most_preferred` doesn't work for arbitrary /// lists of overloads. See the documentation for [`List::rules`] for /// details. #[derive(Clone)] pub(in crate::proc::overloads) struct List { /// A bitmask of which elements of `rules` are included in the set. members: u64, /// A list of type rules that are members of the set. /// /// These must be listed in order such that every rule in the list /// is always more preferred than all subsequent rules in the /// list. If there is no such arrangement of rules, then you /// cannot use `List` to represent the overload set. rules: Rc>, } impl List { pub(in crate::proc::overloads) fn from_rules(rules: Vec) -> List { List { members: len_to_full_mask(rules.len()), rules: Rc::new(rules), } } fn members(&self) -> impl Iterator { OneBitsIter::new(self.members).map(|mask| { let index = mask.trailing_zeros() as usize; (mask, &self.rules[index]) }) } fn filter(&self, mut pred: F) -> List where F: FnMut(&Rule) -> bool, { let mut filtered_members = 0; for (mask, rule) in self.members() { if pred(rule) { filtered_members |= mask; } } List { members: filtered_members, rules: self.rules.clone(), } } } impl crate::proc::overloads::OverloadSet for List { fn is_empty(&self) -> bool { self.members == 0 } fn min_arguments(&self) -> usize { self.members() .fold(None, |best, (_, rule)| { // This is different from `max_arguments` because // `