pax_global_header00006660000000000000000000000064151437175620014525gustar00rootroot0000000000000052 comment=2aa7f58440a06b15352a2cbce01fa4c26f824969 e3nn-0.6.0/000077500000000000000000000000001514371756200123735ustar00rootroot00000000000000e3nn-0.6.0/.github/000077500000000000000000000000001514371756200137335ustar00rootroot00000000000000e3nn-0.6.0/.github/CHANGELOG.md000066400000000000000000000304151514371756200155470ustar00rootroot00000000000000# CHANGELOG All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] ## [0.5.2] - 2024-07 ### Added - `o3.experimental.FullTensorProductv2` for compatibility with `torch.compile(..., fulgraph=True)` - enable `pip` caching in CI - Optional scalar bias term in `_batchnorm.py` ### Changed - refactor to use `pyproject.toml` for packaging - refactor `gh` community files - move `pylint`, `coverage` and `flake8` configuration to `pyproject.toml` ### Fixed - Fix TorchScript warning "doesn't support instance-level annotations" (#437) ## [0.5.1] - 2022-12-12 ### Added - L=12 spherical harmonics ### Fixed - `TensorProduct.visualize` now works even if the TP is on the GPU. - Github actions only trigger a push to coveralls if the corresponding token is set in github secrets. - Batchnorm ## [0.5.0] - 2022-04-13 ### Added - Sparse Voxel Convolution - Clebsch-Gordan coefficients are computed via a change of basis from the complex to real basis. - `o3`, `nn` and `io` are accessible through `e3nn`. For instance `e3nn.o3.rand_axis_angle`. ### Changed - Since now the code is no more tested against `torch==1.8.0`, only tested against `torch>=1.10.0` ### Fixed - `wigner_3j` now _always_ returns a contiguous copy regardless of dtype or device ## [0.4.4] - 2021-12-15 ### Fixed - Remove `CartesianTensor._rtp`. Instead recompute the `ReducedTensorProduct` everytime. The user can save the `ReducedTensorProduct` to avoid creating it each time. - `*equivariance_error` no longer keeps around unneeded autograd graphs - `CartesianTensor` builds `ReducedTensorProduct` with correct device/dtype when called without one ### Added - Created module for reflected imports allowing for nice syntax for creating `irreps`, e.g. `from e3nn.o3.irreps import l3o # same as Irreps("o3")` - Add `uvu= 1.8.0 rather than 1.8.1 - Changed `o3.legendre` into a module `o3.Legendre` ### Removed - Removed `e3nn.util.codegen.eval_code` in favor of `torch.fx` ## [0.2.8] - 2021-04-21 ### Added - `squared` option to `o3.Norm` - `e3nn.nn.models.v2104.voxel_convolution.Convolution` made to be resolution agnostic - `TensorProduct.visualize` keyword argument `aspect_ratio` ### Changed - `ReducedTensorProducts` is a (scriptable) `torch.nn.Module` - e3nn now requires the latest stable PyTorch, >=1.8.1 - `TensorProduct.visualize`: color of paths based on `w.pow(2).mean()` instead of `w.sum().sign() * w.abs().sum()` ### Fixed - No more NaN gradients of `o3.Norm`/`nn.NormActivation` at zero when using `epsilon` - Modules with `@compile_mode('trace')` can now be compiled when their dtype and the current default dtype are different - Fix errors in `ReducedTensorProducts` and add new tests ## [0.2.7] - 2021-04-14 ### Added - `uuu` connection mode in `o3.TensorProduct` now has specialized code ### Fixed - Fixed an issue with `Activation` (used by `Gate`). It was only applying the first activation function provided. `Activation('0e+0e', [act1, act2])` was equivalent to `Activation('2x0e', [act1])`. Solved by removing the `.simplify()` applied to `self.irreps_in`. - `Gate` will not accept non-scalar `irreps_gates` or `irreps_scalars` ## [0.2.6] - 2021-04-12 ### Added - `e3nn.util.test.random_irreps` convinience function for writing tests ### Changed - `o3.Linear` now has more efficient specialized code ### Fixed - Fixed a problem with temporary files on windows ## [0.2.5] - 2021-04-07 ### Added - Added `e3nn.set_optimization_defaults()` and `e3nn.get_optimization_defaults()` - Constructors for empty `Irreps`: `Irreps()` and `Irreps("")` - Additional tests, docs, and refactoring for `Irrep` and `Irreps`. - Added `TensorProduct.weight_views()` and `TensorProduct.weight_view_for_instruction()` - Fix Docs for ExtractIr ### Changed - Renamed `o3.TensorProduct` arguments in `irreps_in1`, `irreps_in2` and `irreps_out` - Renamed `o3.spherical_harmonics` arguement `xyz` into `x` - Renamed `math.soft_one_hot_linspace` argument `endpoint` into `cutoff`, `cutoff = not endpoint` - Variances are now provided to `o3.TensorProduct` through explicit `in1_var`, `in2_var`, `out_var` parameters - Submodules define `__all__`; documentation uses shorter module names for the classes/methods. ### Fixed - Enabling/disabling einsum optimization no longer affects PyTorch RNG state. ### Removed - Variances can no longer be provided to `o3.TensorProduct` in the list-of-tuple format for `irreps_in1`, etc. ## [0.2.4] - 2021-03-23 ### Added - `basis='smooth_finite'` option to `math.soft_one_hot_linspace` - `math.soft_unit_step` function - `nn.model.v2103` generic message passing model + examples of networks using it. - `o3.TensorProduct`: is jit scriptable - `o3.TensorProduct`: also broadcast the `weight` argument - simple e3nn models can be saved/loaded with `torch.save()`/`torch.load()` - JITable `o3.SphericalHarmonics` module version of `o3.spherical_harmonics` - `in_place` option for `e3nn.util.jit` compilation functions - New `@compile_mode("unsupported")` for modules that do not support TorchScript - flake8 settings have been added to `setup.cfg` for improved code style - `TensorProduct.visualize()` can now plot weights - `basis='bessel'` option to `math.soft_one_hot_linspace` - Optional optimization of `TensorProduct` if [`opt_einsum_fx`](https://github.com/Linux-cpp-lisp/opt_einsum_fx) is installed ### Changed - `o3.TensorProduct` now uses `torch.fx` to generate it's code - e3nn now requires the latest stable PyTorch, >=1.8.0 - in `soft_one_hot_linspace` the argument `base` is renamed into `basis` - `Irreps.slices()`, do `zip(irreps.slices(), irreps)` to retrieve the old behavior - `math.soft_one_hot_linspace` very small change in the normalization of `fourier` basis - `normalize2mom` is now a `torch.nn.Module` - rename arguments `set_ir_...` into `filter_ir_...` - Renamed `e3nn.nn.Gate` argument `irreps_nonscalars` to `irreps_gated` - Renamed `e3nn.o3.TensorProduct` arguments `x1, x2` to `x, y` ### Fixed - `nn.Gate` was crashing when the number of scalars or gates was zero - `device` edge cases for `Gate` and `SphericalHarmonics` ## [0.2.3] - 2021-02-23 ### Added - Add argument `basis` into `math.soft_one_hot_linspace` that can take values `gaussian`, `cosine` and `fourier` - `io.SphericalTensor.sum_of_diracs` - Optional arguments `function(..., device=None, dtype=None)` for many functions - `e3nn.nn.models.gate_points_2102` using node attributes along the length embedding to feed the radial network - `Irreps.slices()` - Module `Extract` (and `ExtractIr`) to extract subsets of irreps tensors - Recursive TorchScript compiler `e3nn.util.jit` - TorchScript support for `TensorProduct` and subclasses, `NormActivation`, `Gate`, `FullyConnectedNet`, and `gate_points_2101.Network` ### Changed - in `o3.TensorProduct.instructions`: renamed `weight_shape` in `path_shape` and is now set even if `has_weight` is `False` - `o3.TensorProduct` weights are now flattened tensors - rename `io.SphericalTensor.from_geometry_adjusted` into `io.SphericalTensor.with_peaks_at` - in `ReducedTensorProducts`, `ElementwiseTensorProduct` and `FullTensorProduct`: rename `irreps_out` argument into `set_ir_out` to not confuse it with `o3.Irreps` ### Removed - `io.SphericalTensor.from_geometry_global_rescale` - `e3nn.math.reduce.reduce_tensor` in favor of `e3nn.o3.ReducedTensorProducts` - swish, use `torch.nn.functional.silu` instead - `"cartesian_vectors"` for equivariance testing — since the 0.2.2 Euler angle convention change, L=1 irreps are equivalent ### Fixed - `io.SphericalTensor.from_samples_on_s2` manage batch dimension - Modules that generate code now clean up their temporary files - `NormActivation` now works on GPU ## [0.2.2] - 2021-02-09 ### Changed - Euler angle convention from ZYZ to YXY - `TensorProduct.weight_shapes` content put into `TensorProduct.instructions` ### Added - Better TorchScript support e3nn-0.6.0/.github/CODE_OF_CONDUCT.md000066400000000000000000000125351514371756200165400ustar00rootroot00000000000000 # Contributor Covenant Code of Conduct ## Our Pledge We as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community. ## Our Standards Examples of behavior that contributes to a positive environment for our community include: * Demonstrating empathy and kindness toward other people * Being respectful of differing opinions, viewpoints, and experiences * Giving and gracefully accepting constructive feedback * Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience * Focusing on what is best not just for us as individuals, but for the overall community Examples of unacceptable behavior include: * The use of sexualized language or imagery, and sexual attention or advances of any kind * Trolling, insulting or derogatory comments, and personal or political attacks * Public or private harassment * Publishing others' private information, such as a physical or email address, without their explicit permission * Other conduct which could reasonably be considered inappropriate in a professional setting ## Enforcement Responsibilities Community leaders are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any behavior that they deem inappropriate, threatening, offensive, or harmful. Community leaders have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, and will communicate reasons for moderation decisions when appropriate. ## Scope This Code of Conduct applies within all community spaces, and also applies when an individual is officially representing the community in public spaces. Examples of representing our community include using an official e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. ## Enforcement Instances of abusive, harassing, or otherwise unacceptable behavior may be reported to the community leaders responsible for enforcement at support@e3nn.org. All complaints will be reviewed and investigated promptly and fairly. All community leaders are obligated to respect the privacy and security of the reporter of any incident. ## Enforcement Guidelines Community leaders will follow these Community Impact Guidelines in determining the consequences for any action they deem in violation of this Code of Conduct: ### 1. Correction **Community Impact**: Use of inappropriate language or other behavior deemed unprofessional or unwelcome in the community. **Consequence**: A private, written warning from community leaders, providing clarity around the nature of the violation and an explanation of why the behavior was inappropriate. A public apology may be requested. ### 2. Warning **Community Impact**: A violation through a single incident or series of actions. **Consequence**: A warning with consequences for continued behavior. No interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, for a specified period of time. This includes avoiding interactions in community spaces as well as external channels like social media. Violating these terms may lead to a temporary or permanent ban. ### 3. Temporary Ban **Community Impact**: A serious violation of community standards, including sustained inappropriate behavior. **Consequence**: A temporary ban from any sort of interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Violating these terms may lead to a permanent ban. ### 4. Permanent Ban **Community Impact**: Demonstrating a pattern of violation of community standards, including sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals. **Consequence**: A permanent ban from any sort of public interaction within the community. ## Attribution This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.0, available at [https://www.contributor-covenant.org/version/2/0/code_of_conduct.html][v2.0]. Community Impact Guidelines were inspired by [Mozilla's code of conduct enforcement ladder][Mozilla CoC]. For answers to common questions about this code of conduct, see the FAQ at [https://www.contributor-covenant.org/faq][FAQ]. Translations are available at [https://www.contributor-covenant.org/translations][translations]. [homepage]: https://www.contributor-covenant.org [v2.0]: https://www.contributor-covenant.org/version/2/0/code_of_conduct.html [Mozilla CoC]: https://github.com/mozilla/diversity [FAQ]: https://www.contributor-covenant.org/faq [translations]: https://www.contributor-covenant.org/translations e3nn-0.6.0/.github/CONTRIBUTING.md000066400000000000000000000015351514371756200161700ustar00rootroot00000000000000# Contribute For the docstrings we use the [numpy style](https://numpydoc.readthedocs.io/en/latest/format.html). You can install some of the commonly used development tools by using e3nn's 'dev' extra: ``` pip install -e '.[dev]' ``` To have atomic code style checks performed at each commit, you can install the pre-commit hook using: ``` pre-commit install ``` These checks are automatically run on any commit made to the github repository but the pre-commit hook allows you to see if there are any problems locally. Additionally, you may want to run the tests locally before pushing to remote. This can be done with (from the root e3nn directory): ``` pytest tests ``` For formatting we use the [black](https://black.readthedocs.io/en/stable/index.html) library. It can be installed with: ``` pip install black ``` and run with: ``` black . ``` e3nn-0.6.0/.github/ISSUE_TEMPLATE/000077500000000000000000000000001514371756200161165ustar00rootroot00000000000000e3nn-0.6.0/.github/ISSUE_TEMPLATE/bug-report.md000066400000000000000000000023631514371756200205320ustar00rootroot00000000000000--- name: Bug report about: Create a report to help us improve title: "\U0001F41B [BUG]" labels: bug assignees: '' --- **Describe the bug** A clear and concise description of what the bug is. **To Reproduce** Minimal code to reproduce the behavior. Please be try to isolate the code producing the error code from code specific to your task but not necessarily relevant to the error (e.g. replacing input data with random inputs instead of data from files). **Expected behavior** A clear and concise description of what you expected to happen. **Environment (please complete the following information):** - OS: [e.g. iOS, Ubuntu, Windows] - python version (`python --version`) - python environment (commands are given for python interpreter): - e3nn version (`import e3nn; e3nn.__version__`) - pytorch version (`import torch; torch.__version__`) - pytorch_geometric version (`import torch_geometric; torch_geometric.__version__`) - (if relevant) GPU support with CUDA - cuda Version according to nvcc (`nvcc --version`) - cuda version according to pyTorch (`import torch; torch.version.cuda`) **Screenshots** If applicable, add screenshots to help explain your problem. **Additional context** Add any other context about the problem here. e3nn-0.6.0/.github/ISSUE_TEMPLATE/config.yml000066400000000000000000000002471514371756200201110ustar00rootroot00000000000000blank_issues_enabled: false contact_links: - name: Question url: https://github.com/e3nn/e3nn/discussions about: Please ask questions on the Discussions tab e3nn-0.6.0/.github/ISSUE_TEMPLATE/feature_request.md000066400000000000000000000011601514371756200216410ustar00rootroot00000000000000--- name: Feature request about: Suggest an idea for this project title: "\U0001F31F [FEATURE]" labels: enhancement assignees: '' --- **Is your feature request related to a problem? Please describe.** A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] **Describe the solution you'd like** A clear and concise description of what you want to happen. **Describe alternatives you've considered** A clear and concise description of any alternative solutions or features you've considered. **Additional context** Add any other context or screenshots about the feature request here. e3nn-0.6.0/.github/PULL_REQUEST_TEMPLATE.md000066400000000000000000000022421514371756200175340ustar00rootroot00000000000000 ## Description ## Motivation and Context Resolves: #??? ## How Has This Been Tested? ## Checklist: - [ ] I have read the [**CONTRIBUTING**](https://github.com/e3nn/e3nn/blob/main/.github/CONTRIBUTING.md) document. - [ ] My code follows the code style of this project. - [ ] I have updated the documentation (if relevant). - [ ] I have added tests that cover my changes (if relevant). - [ ] The modified code is cuda compatible (github tests don't test cuda) (if relevant). - [ ] I have updated the [CHANGELOG](https://github.com/e3nn/e3nn/blob/main/.github/CHANGELOG.md). e3nn-0.6.0/.github/workflows/000077500000000000000000000000001514371756200157705ustar00rootroot00000000000000e3nn-0.6.0/.github/workflows/release.yml000066400000000000000000000012031514371756200201270ustar00rootroot00000000000000name: Upload Python Package on: release: types: [created] jobs: deploy: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v4 with: python-version: '3.11' - name: Install dependencies run: | python -m pip install --upgrade pip pip install setuptools wheel twine build - name: Build and publish env: TWINE_USERNAME: __token__ TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD_V2 }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} run: | python -m build twine upload dist/*e3nn-0.6.0/.github/workflows/style.yaml000066400000000000000000000012321514371756200200120ustar00rootroot00000000000000name: Run Style Check on: push: branches: ["main"] pull_request: branches: ["main"] jobs: style: runs-on: ubuntu-latest strategy: fail-fast: false steps: - uses: actions/checkout@v4 - name: Set up Python uses: actions/setup-python@v3 with: python-version: "3.11" - name: Install dependencies run: | python -m pip install --upgrade pip python -m pip install ruff==0.8.2 - name: Style check run: | ruff check . --ignore E741 e3nn-0.6.0/.github/workflows/tests.yml000066400000000000000000000037771514371756200176730ustar00rootroot00000000000000name: Check Syntax and Run Tests on: push: branches: - main pull_request: branches: - main jobs: build: runs-on: ubuntu-latest strategy: matrix: python-version: ["3.11"] torch-version: ["2.4.0"] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} cache: "pip" cache-dependency-path: | pyproject.toml - name: Install dependencies env: TORCH: "${{ matrix.torch-version }}" CUDA: "cpu" GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} run: | python -m pip install --upgrade pip pip install wheel pip install torch==${TORCH} torchvision --index-url https://download.pytorch.org/whl/cpu pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html pip install . pip install plotly - name: Install pytest run: | pip install pytest pytest-cov pip install coveralls - name: Test with pytest run: | coverage run --source=e3nn -m pytest --doctest-modules --ignore=docs/ --ignore-glob='**/experimental*' tests examples - name: Upload to coveralls env: COVERALLS_TOKEN: ${{ secrets.COVERALLS_TOKEN }} # Only send to coveralls if the token has been set and the user pushed if: env.COVERALLS_TOKEN != null && github.event_name == 'push' run: | COVERALLS_REPO_TOKEN=${{ secrets.COVERALLS_TOKEN }} coveralls e3nn-0.6.0/.gitignore000066400000000000000000000003311514371756200143600ustar00rootroot00000000000000notebook .vscode .coverage .pytest_cache *.egg-info __pycache__ QM9 wandb *.pk *json _build build dist *.ipynb .ipynb_checkpoints *.so examples/s2cnn/mnist/MNIST_data/MNIST/raw examples/s2cnn/mnist/s2_mnist.gz .idea e3nn-0.6.0/.pre-commit-config.yaml000066400000000000000000000027711514371756200166630ustar00rootroot00000000000000exclude: &exclude_files > (?x)^( docs/.*| tests/.*| .github/.*| LICENSE| )$ repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v2.5.0 hooks: - id: mixed-line-ending - id: trailing-whitespace - repo: https://github.com/PyCQA/pylint rev: pylint-2.5.2 hooks: - id: pylint language: system args: [ '--disable=protected-access', '--disable=no-else-return', '--disable=raise-missing-from', '--disable=invalid-name', '--disable=duplicate-code', '--disable=import-outside-toplevel', '--disable=missing-docstring', '--disable=bad-continuation', '--disable=locally-disabled', '--disable=too-few-public-methods', '--disable=too-many-arguments', '--disable=too-many-instance-attributes', '--disable=too-many-local-variables', '--disable=too-many-locals', '--disable=too-many-branches', '--disable=too-many-statements', '--disable=too-many-return-statements', '--disable=redefined-builtin', '--disable=redefined-outer-name', '--disable=line-too-long', '--disable=fixme', ] exclude: *exclude_files - repo: https://github.com/PyCQA/flake8 rev: 6.1.0 hooks: - id: flake8 name: Check PEP8 additional_dependencies: [Flake8-pyproject] e3nn-0.6.0/.readthedocs.yml000066400000000000000000000010211514371756200154530ustar00rootroot00000000000000# .readthedocs.yml # Read the Docs configuration file # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details # Required version: 2 # Set the version of Python and other tools you might need build: os: ubuntu-22.04 tools: python: "3.11" # Build documentation in the docs/ directory with Sphinx sphinx: configuration: docs/conf.py # Optionally declare the Python requirements required to build your docs python: install: - requirements: docs/requirements.txt - method: pip path: . e3nn-0.6.0/CITATION.bib000066400000000000000000000044061514371756200142670ustar00rootroot00000000000000@misc{https://doi.org/10.48550/arxiv.2207.09453, doi = {10.48550/ARXIV.2207.09453}, url = {https://arxiv.org/abs/2207.09453}, author = {Geiger, Mario and Smidt, Tess}, title = {e3nn: Euclidean Neural Networks}, publisher = {arXiv}, year = {2022}, copyright = {Creative Commons Attribution 4.0 International} } @misc{thomas2018tensorfieldnetworks, title={Tensor field networks: Rotation- and translation-equivariant neural networks for 3D point clouds}, author={Nathaniel Thomas and Tess Smidt and Steven Kearnes and Lusann Yang and Li Li and Kai Kohlhoff and Patrick Riley}, year={2018}, eprint={1802.08219}, archivePrefix={arXiv}, primaryClass={cs.LG}, url={https://arxiv.org/abs/1802.08219} } @misc{weiler20183dsteerablecnns, title={3D Steerable CNNs: Learning Rotationally Equivariant Features in Volumetric Data}, author={Maurice Weiler and Mario Geiger and Max Welling and Wouter Boomsma and Taco Cohen}, year={2018}, eprint={1807.02547}, archivePrefix={arXiv}, primaryClass={cs.LG}, url={https://arxiv.org/abs/1807.02547} } @misc{kondor2018clebschgordannets, title={Clebsch-Gordan Nets: a Fully Fourier Space Spherical Convolutional Neural Network}, author={Risi Kondor and Zhen Lin and Shubhendu Trivedi}, year={2018}, eprint={1806.09231}, archivePrefix={arXiv}, primaryClass={stat.ML}, url={https://arxiv.org/abs/1806.09231} } @software{e3nn_software, author = {Mario Geiger and Tess Smidt and Alby M. and Benjamin Kurt Miller and Wouter Boomsma and Bradley Dice and Kostiantyn Lapchevskyi and Maurice Weiler and Michał Tyszkiewicz and Simon Batzner and Dylan Madisetti and Martin Uhrin and Jes Frellsen and Nuri Jung and Sophia Sanborn and Mingjian Wen and Josh Rackers and Marcel Rød and Michael Bailey}, title = {Euclidean neural networks: e3nn}, month = apr, year = 2022, publisher = {Zenodo}, version = {0.5.0}, doi = {10.5281/zenodo.6459381}, url = {https://doi.org/10.5281/zenodo.6459381} }e3nn-0.6.0/INSTALL.md000066400000000000000000000015571514371756200140330ustar00rootroot00000000000000# Install ## Dependencies ### PyTorch e3nn requires PyTorch >=2.2.0. For installation instructions, please see the [PyTorch homepage](https://pytorch.org/). ### optional: torch_geometric First you have to install [pytorch_geometric](https://github.com/rusty1s/pytorch_geometric). For `torch` 2.2 and no CUDA support: ```bash CUDA=cpu TORCH=2.2.0 pip install --upgrade --force-reinstall torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html pip install --upgrade --force-reinstall torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html pip install torch-geometric ``` See [here](https://github.com/rusty1s/pytorch_geometric#installation) to get cuda support or newer versions. ## e3nn ### Stable (PyPI) ```bash $ pip install e3nn ``` ### Unstable (Git) ```bash $ git clone https://github.com/e3nn/e3nn.git $ cd e3nn/ $ pip install . ``` e3nn-0.6.0/LICENSE000066400000000000000000000025551514371756200134070ustar00rootroot00000000000000MIT License Euclidean neural networks (e3nn) Copyright (c) 2020, The Regents of the University of California, through Lawrence Berkeley National Laboratory (subject to receipt of any required approvals from the U.S. Dept. of Energy), Ecole Polytechnique Federale de Lausanne (EPFL), Free University of Berlin and Kostiantyn Lapchevskyi. All rights reserved. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. e3nn-0.6.0/README.md000066400000000000000000000113011514371756200136460ustar00rootroot00000000000000# Euclidean neural networks [![Coverage Status](https://coveralls.io/repos/github/e3nn/e3nn/badge.svg?branch=main)](https://coveralls.io/github/e3nn/e3nn?branch=main) [![DOI](https://zenodo.org/badge/237431920.svg)](https://zenodo.org/badge/latestdoi/237431920) **[Documentation](https://docs.e3nn.org)** | **[Code](https://github.com/e3nn/e3nn)** | **[CHANGELOG](https://github.com/e3nn/e3nn/blob/main/.github/CHANGELOG.md)** | **[Colab](https://colab.research.google.com/drive/1Gps7mMOmzLe3Rt_b012xsz4UyuexTKAf?usp=sharing)** The aim of this library is to help the development of [E(3)](https://en.wikipedia.org/wiki/Euclidean_group) equivariant neural networks. It contains fundamental mathematical operations such as [tensor products](https://docs.e3nn.org/en/stable/api/o3/o3_tp.html) and [spherical harmonics](https://docs.e3nn.org/en/stable/api/o3/o3_sh.html). ![](https://user-images.githubusercontent.com/333780/79220728-dbe82c00-7e54-11ea-82c7-b3acbd9b2246.gif) ```python import torch from e3nn import o3 # Create a random array made of scalar (0e) and a vector (1o) irreps_in = o3.Irreps("0e + 1o") x = irreps_in.randn(-1) # Apply a linear layer irreps_out = o3.Irreps("2x0e + 2x1o") linear = o3.Linear(irreps_in=irreps_in, irreps_out=irreps_out) y = linear(x) # Compute a tensor product with itself tp = o3.FullTensorProduct(irreps_in1=irreps_in, irreps_in2=irreps_in) z = tp(x, x) # Optionally compile the tensor product tp_pt2 = torch.compile(tp, fullgraph=True) z_pt2 = tp_pt2(x, x) # Warning: First few calls might be slow due to compilation torch.testing.assert_close(z, z_pt2) ``` ## Installation **Important:** install pytorch and only then run the command ``` pip install --upgrade pip pip install --upgrade e3nn ``` For details and optional dependencies, see [INSTALL.md](https://github.com/e3nn/e3nn/blob/main/INSTALL.md) ### Breaking changes e3nn is under development. It is recommended to install using pip. The main branch is considered as unstable. The second version number is incremented every time a breaking change is made to the code. ``` 0.(increment when backwards incompatible release).(increment for backwards compatible release) ``` ## Help We are happy to help! The best way to get help on `e3nn` is to submit a [Question](https://github.com/e3nn/e3nn/issues/new?assignees=&labels=question&template=question.md&title=%E2%9D%93+%5BQUESTION%5D) or [Bug Report](https://github.com/e3nn/e3nn/issues/new?assignees=&labels=bug&template=bug-report.md&title=%F0%9F%90%9B+%5BBUG%5D). ## Want to get involved? Great! If you want to get involved in and contribute to the development, improvement, and application of `e3nn`, introduce yourself in the [discussions](https://github.com/e3nn/e3nn/discussions/new). ## Code of conduct Our community abides by the [Contributor Covenant Code of Conduct](./github/CODE_OF_CONDUCT.md). ## Citing If you use e3nn in your research, please cite the following papers: ### Euclidean Neural Networks: - N. Thomas et al., "Tensor field networks: Rotation- and translation-equivariant neural networks for 3D point clouds" (2018). [arXiv:1802.08219](https://arxiv.org/abs/1802.08219) - M. Weiler et al., "3D Steerable CNNs: Learning Rotationally Equivariant Features in Volumetric Data" (2018). [arXiv:1807.02547](https://arxiv.org/abs/1807.02547) - R. Kondor et al., "Clebsch-Gordan Nets: a Fully Fourier Space Spherical Convolutional Neural Network" (2018). [arXiv:1806.09231](https://arxiv.org/abs/1806.09231) ### e3nn: - M. Geiger and T. Smidt, "e3nn: Euclidean Neural Networks" (2022). [arXiv:2207.09453](https://arxiv.org/abs/2207.09453) - M. Geiger et al., "Euclidean neural networks: e3nn" (2022). [Zenodo](https://doi.org/10.5281/zenodo.6459381) For BibTeX entries, please refer to the [CITATION.bib](CITATION.bib) file in this repository. ### Copyright Euclidean neural networks (e3nn) Copyright (c) 2020, The Regents of the University of California, through Lawrence Berkeley National Laboratory (subject to receipt of any required approvals from the U.S. Dept. of Energy), Ecole Polytechnique Federale de Lausanne (EPFL), Free University of Berlin and Kostiantyn Lapchevskyi. All rights reserved. If you have questions about your rights to use or distribute this software, please contact Berkeley Lab's Intellectual Property Office at IPO@lbl.gov. NOTICE. This Software was developed under funding from the U.S. Department of Energy and the U.S. Government consequently retains certain rights. As such, the U.S. Government has been granted for itself and others acting on its behalf a paid-up, nonexclusive, irrevocable, worldwide license in the Software to reproduce, distribute copies to the public, prepare derivative works, and perform publicly and display publicly, and to permit others to do so. e3nn-0.6.0/docs/000077500000000000000000000000001514371756200133235ustar00rootroot00000000000000e3nn-0.6.0/docs/Makefile000066400000000000000000000011721514371756200147640ustar00rootroot00000000000000# Minimal makefile for Sphinx documentation # # You can set these variables from the command line, and also # from the environment for the first two. SPHINXOPTS ?= SPHINXBUILD ?= sphinx-build SOURCEDIR = . BUILDDIR = _build # Put it first so that "make" without argument is like "make help". help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) .PHONY: help Makefile # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) e3nn-0.6.0/docs/api/000077500000000000000000000000001514371756200140745ustar00rootroot00000000000000e3nn-0.6.0/docs/api/e3nn.rst000066400000000000000000000001531514371756200154700ustar00rootroot00000000000000e3nn API ======== .. toctree:: :maxdepth: 2 o3/o3 nn/nn io/io math/math util/utile3nn-0.6.0/docs/api/io/000077500000000000000000000000001514371756200145035ustar00rootroot00000000000000e3nn-0.6.0/docs/api/io/cartesian_tensor.rst000066400000000000000000000001571514371756200206030ustar00rootroot00000000000000Cartesian Tensor ================ .. autoclass:: e3nn.io.CartesianTensor :members: :show-inheritance: e3nn-0.6.0/docs/api/io/io.rst000066400000000000000000000002771514371756200156520ustar00rootroot00000000000000io == This submodule contains subclasses of `e3nn.o3.Irreps` for specialized representations. .. rubric:: Overview .. toctree:: :maxdepth: 1 spherical_tensor cartesian_tensor e3nn-0.6.0/docs/api/io/spherical_tensor.rst000066400000000000000000000060621514371756200206050ustar00rootroot00000000000000Spherical Tensor ================ There exists 4 types of function on the sphere depending on how the parity affects it. The representation of the coefficients are affected by this choice: .. jupyter-execute:: import torch from e3nn.io import SphericalTensor print(SphericalTensor(lmax=2, p_val=1, p_arg=1)) print(SphericalTensor(lmax=2, p_val=1, p_arg=-1)) print(SphericalTensor(lmax=2, p_val=-1, p_arg=1)) print(SphericalTensor(lmax=2, p_val=-1, p_arg=-1)) .. jupyter-execute:: import plotly.graph_objects as go def plot(traces): traces = [go.Surface(**d) for d in traces] fig = go.Figure(data=traces) fig.show() In the following graph we show the four possible behavior under parity for a function on the sphere. #. This first ball shows :math:`f(x)` unaffected by the parity #. Then ``p_val=1`` but ``p_arg=-1`` so we see the signal flipped over the sphere but the colors are unchanged #. For ``p_val=-1`` and ``p_arg=1`` only the value of the signal flips its sign #. For ``p_val=-1`` and ``p_arg=-1`` both in the same time, the signal flips over the sphere and the value flip its sign .. jupyter-execute:: :hide-code: axis = dict( showbackground=False, showticklabels=False, showgrid=False, zeroline=False, title='', nticks=3, ) layout = dict( width=680, height=260, scene=dict( xaxis=dict( **axis, range=[-4, 4] ), yaxis=dict( **axis, range=[-1, 1] ), zaxis=dict( **axis, range=[-1, 1] ), aspectmode='manual', aspectratio=dict(x=4, y=1, z=1), camera=dict( up=dict(x=0, y=0, z=1), center=dict(x=0, y=0, z=0), eye=dict(x=0, y=-5, z=0), projection=dict(type='orthographic'), ), ), paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(0,0,0,0)", margin=dict(l=0, r=0, t=0, b=0) ) cmap_bwr = [[0, 'rgb(0,50,255)'], [0.5, 'rgb(200,200,200)'], [1, 'rgb(255,50,0)']] def plot(traces): cmax = max(abs(d['surfacecolor']).max() for d in traces) traces = [go.Surface(**d, colorscale=cmap_bwr, cmin=-cmax, cmax=cmax) for d in traces] fig = go.Figure(data=traces, layout=layout) fig.show() .. jupyter-execute:: lmax = 1 x = torch.tensor([0.8] + [0.0, 0.0, 1.0]) parity = -torch.eye(3) x = torch.stack([ SphericalTensor(lmax, p_val, p_arg).D_from_matrix(parity) @ x for p_val in [+1, -1] for p_arg in [+1, -1] ]) centers = torch.tensor([ [-3.0, 0.0, 0.0], [-1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [3.0, 0.0, 0.0], ]) st = SphericalTensor(lmax, 1, 1) # p_val and p_arg set arbitrarily here plot(st.plotly_surface(x, centers=centers, radius=False)) .. autoclass:: e3nn.io.SphericalTensor :members: :show-inheritance: e3nn-0.6.0/docs/api/math/000077500000000000000000000000001514371756200150255ustar00rootroot00000000000000e3nn-0.6.0/docs/api/math/math.rst000066400000000000000000000003511514371756200165070ustar00rootroot00000000000000math ==== .. autofunction:: e3nn.math.direct_sum .. autofunction:: e3nn.math.orthonormalize .. autofunction:: e3nn.math.complete_basis .. autofunction:: e3nn.math.soft_one_hot_linspace .. autofunction:: e3nn.math.soft_unit_step e3nn-0.6.0/docs/api/nn/000077500000000000000000000000001514371756200145075ustar00rootroot00000000000000e3nn-0.6.0/docs/api/nn/models/000077500000000000000000000000001514371756200157725ustar00rootroot00000000000000e3nn-0.6.0/docs/api/nn/models/gate_points_2101.rst000066400000000000000000000030251514371756200215030ustar00rootroot00000000000000Model Gate of January 2021 ========================== Multipurpose equivariant neural network for point-clouds. Made with `e3nn.o3.TensorProduct` for the linear part and `e3nn.nn.Gate` for the nonlinearities. .. rubric:: Convolution The linear part, module ``Convolution``, is inspired from the ``Depth wise Separable Convolution`` idea. The main operation of the Convolution module is ``tp``. It makes the atoms interact with their neighbors but does not mix the channels. To mix the channels, it is sandwiched between ``lin1`` and ``lin2``. .. literalinclude:: ../../../../e3nn/nn/models/gate_points_2101.py :lines: 22-120 .. rubric:: Network The network is a simple succession of ``Convolution`` and `e3nn.nn.Gate`. The activation function is ReLU when dealing with even scalars and tanh of abs when dealing with even scalars. When the parities (``p`` in `e3nn.o3.Irrep`) are provided, network is equivariant to ``O(3)``. To relax this constraint and make it equivariant to ``SO(3)`` only, one can simply pass all the ``irreps`` parameters to be even (``p=1`` in `e3nn.o3.Irrep`). This is why ``irreps_sh`` is a parameter of the class ``Network``, one can use specific ``l`` of the spherical harmonics with the correct parity ``p=(-1)^l`` (one can use `e3nn.o3.Irreps.spherical_harmonics` for that) or consider that ``p=1`` in order to **not** be equivariant to parity. .. literalinclude:: ../../../../e3nn/nn/models/gate_points_2101.py :lines: 156-336 .. automodule:: e3nn.nn.models.gate_points_2101 :members: :show-inheritance: e3nn-0.6.0/docs/api/nn/models/graph.svg000066400000000000000000007523311514371756200176270ustar00rootroot00000000000000 image/svg+xmle3nn-0.6.0/docs/api/nn/models/models.rst000066400000000000000000000001551514371756200200100ustar00rootroot00000000000000nn - Models =========== .. rubric:: Overview .. toctree:: :maxdepth: 1 v2103 gate_points_2101 e3nn-0.6.0/docs/api/nn/models/v2103.rst000066400000000000000000000053101514371756200172760ustar00rootroot00000000000000Models of March 2021 ==================== Simple Network -------------- Let's create a simple network and evaluate it on random data. .. jupyter-execute:: import torch from e3nn.nn.models.v2103.gate_points_networks import SimpleNetwork net = SimpleNetwork( irreps_in="3x0e + 2x1o", irreps_out="1x1o", max_radius=2.0, num_neighbors=3.0, num_nodes=5.0 ) pos = torch.randn(5, 3) x = net.irreps_in.randn(5, -1) net({ 'pos': pos, 'x': x }) If we rotate the inputs, .. jupyter-execute:: from e3nn import o3 rot = o3.matrix_x(torch.tensor(3.14 / 3.0)) rot .. jupyter-execute:: net({ 'pos': pos @ rot.T, 'x': x @ net.irreps_in.D_from_matrix(rot).T }) it gives the same result as rotating the outputs. .. jupyter-execute:: net({ 'pos': pos, 'x': x }) @ net.irreps_out.D_from_matrix(rot).T Network for a graph with node/edge attributes --------------------------------------------- .. image:: graph.svg :height: 200px :width: 200px :scale: 100 % :alt: graph of the data :align: left A graph is made of nodes and edges. The nodes and edges can have attributes. Usually their only attributes are the positions of the nodes :math:`\vec r_i` and the relative positions of the edges :math:`\vec r_i - \vec r_j`. We typically don't use the node positions because they change with the global translation of the graph. The nodes and edges can have other attributes like for instance atom type or bond type and so on. The attributes defines the graph properties. They don't change layer after layer (in this example). The data (``node_input``) flow through this graph layer after layer. In the following network, the edges attributes are the spherical harmonics :math:`Y^l(\vec r_i - \vec r_j)` plus the extra attributes provided by the user. .. jupyter-execute:: from e3nn.nn.models.v2103.gate_points_networks import NetworkForAGraphWithAttributes from torch_cluster import radius_graph max_radius = 3.0 net = NetworkForAGraphWithAttributes( irreps_node_input="0e+1e", irreps_node_attr="0e+1e", irreps_edge_attr="0e+1e", # attributes in extra of the spherical harmonics irreps_node_output="0e+1e", max_radius=max_radius, num_neighbors=4.0, num_nodes=5.0, ) num_nodes = 5 pos = torch.randn(num_nodes, 4) edge_index = radius_graph(pos, max_radius) num_edges = edge_index.shape[1] net({ 'pos': pos, 'edge_index': edge_index, 'node_input': torch.randn(num_nodes, 4), 'node_attr': torch.randn(num_nodes, 4), 'edge_attr': torch.randn(num_edges, 4), }) e3nn-0.6.0/docs/api/nn/nn.rst000066400000000000000000000002121514371756200156470ustar00rootroot00000000000000nn == .. rubric:: Overview .. toctree:: :maxdepth: 1 nn_gate nn_fc nn_bn nn_s2act nn_normact models/models e3nn-0.6.0/docs/api/nn/nn_bn.rst000066400000000000000000000001571514371756200163360ustar00rootroot00000000000000Batch Normalization =================== .. autoclass:: e3nn.nn.BatchNorm :members: :show-inheritance: e3nn-0.6.0/docs/api/nn/nn_fc.rst000066400000000000000000000002151514371756200163220ustar00rootroot00000000000000Fully Connected Neural Network ============================== .. autoclass:: e3nn.nn.FullyConnectedNet :members: :show-inheritance: e3nn-0.6.0/docs/api/nn/nn_gate.rst000066400000000000000000000002241514371756200166520ustar00rootroot00000000000000Gate ==== .. autoclass:: e3nn.nn.Activation :members: :show-inheritance: .. autoclass:: e3nn.nn.Gate :members: :show-inheritance: e3nn-0.6.0/docs/api/nn/nn_normact.rst000066400000000000000000000001701514371756200173750ustar00rootroot00000000000000Norm-Based Activation ===================== .. autoclass:: e3nn.nn.NormActivation :members: :show-inheritance: e3nn-0.6.0/docs/api/nn/nn_s2act.rst000066400000000000000000000001641514371756200167510ustar00rootroot00000000000000Spherical Activation ==================== .. autoclass:: e3nn.nn.S2Activation :members: :show-inheritance: e3nn-0.6.0/docs/api/o3/000077500000000000000000000000001514371756200144155ustar00rootroot00000000000000e3nn-0.6.0/docs/api/o3/o3.rst000066400000000000000000000005441514371756200154730ustar00rootroot00000000000000o3 == All functions in this module are accessible via the ``o3`` submodule: .. jupyter-execute:: from e3nn import o3 R = o3.rand_matrix(10) D = o3.Irreps.spherical_harmonics(4).D_from_matrix(R) .. rubric:: Overview .. toctree:: :maxdepth: 1 o3_rotation o3_irreps o3_tp o3_sh o3_reduce o3_s2grid o3_wigner e3nn-0.6.0/docs/api/o3/o3_irreps.rst000066400000000000000000000036611514371756200170620ustar00rootroot00000000000000.. _Irreducible representations: Irreps ====== A group representation :math:`(D,V)` describe the action of a group :math:`G` on a vector space :math:`V` .. math:: D : G \longrightarrow \text{linear map on } V. The irreducible representations, in short *irreps* (definition of irreps_) are the "smallest" representations. - Any representation can be decomposed via a change of basis into a direct sum of irreps - Any physical quantity, under the action of :math:`O(3)`, transforms with a representation of :math:`O(3)` The irreps of :math:`SO(3)` are called the wigner_ matrices :math:`D^L`. The irreps of the group of inversion (:math:`\{e, I\}`) are the trivial_ representation :math:`\sigma_+` and the sign representation :math:`\sigma_-` .. math:: \sigma_p(g) = \left \{ \begin{array}{l} 1 \text{ if } g = e \\ p \text{ if } g = I \end{array} \right.. The group :math:`O(3)` is the direct_ product of :math:`SO(3)` and inversion .. math:: g = r i, \quad r \in SO(3), i \in \text{inversion}. The irreps of :math:`O(3)` are the product of the irreps of :math:`SO(3)` and inversion. An instance of the class `e3nn.o3.Irreps` represent a direct sum of irreps of :math:`O(3)`: .. math:: g = r i \mapsto \bigoplus_{j=1}^n m_j \times \sigma_{p_j}(i) D^{L_j}(r) where :math:`(m_j \in \mathbb{N}, p_j = \pm 1, L_j = 0,1,2,3,\dots)_{j=1}^n` defines the `e3nn.o3.Irreps`. Irreps of :math:`O(3)` are often confused with the spherical harmonics, the relation between the irreps and the spherical harmonics is explained at :ref:`Spherical Harmonics`. .. _direct: https://en.wikipedia.org/wiki/Direct_product_of_groups .. _trivial: https://en.wikipedia.org/wiki/Trivial_representation .. _irreps: https://en.wikipedia.org/wiki/Irreducible_representation .. _wigner: https://en.wikipedia.org/wiki/Wigner_D-matrix .. autoclass:: e3nn.o3.Irrep :members: :show-inheritance: .. autoclass:: e3nn.o3.Irreps :members: :show-inheritance: e3nn-0.6.0/docs/api/o3/o3_reduce.rst000066400000000000000000000002211514371756200170120ustar00rootroot00000000000000Reduction of Tensors in Irreps ============================== .. autoclass:: e3nn.o3.ReducedTensorProducts :members: :show-inheritance: e3nn-0.6.0/docs/api/o3/o3_rotation.rst000066400000000000000000000031531514371756200174110ustar00rootroot00000000000000.. _Rotation functions: Parametrization of Rotations ============================ Matrix Parametrization ---------------------- .. autofunction:: e3nn.o3.rand_matrix .. autofunction:: e3nn.o3.matrix_x .. autofunction:: e3nn.o3.matrix_y .. autofunction:: e3nn.o3.matrix_z Euler Angles Parametrization ---------------------------- .. autofunction:: e3nn.o3.identity_angles .. autofunction:: e3nn.o3.rand_angles .. autofunction:: e3nn.o3.compose_angles .. autofunction:: e3nn.o3.inverse_angles Quaternion Parametrization -------------------------- .. autofunction:: e3nn.o3.identity_quaternion .. autofunction:: e3nn.o3.rand_quaternion .. autofunction:: e3nn.o3.compose_quaternion .. autofunction:: e3nn.o3.inverse_quaternion Axis-Angle Parametrization -------------------------- .. autofunction:: e3nn.o3.rand_axis_angle .. autofunction:: e3nn.o3.compose_axis_angle Convertions ----------- .. autofunction:: e3nn.o3.angles_to_matrix .. autofunction:: e3nn.o3.matrix_to_angles .. autofunction:: e3nn.o3.angles_to_quaternion .. autofunction:: e3nn.o3.matrix_to_quaternion .. autofunction:: e3nn.o3.axis_angle_to_quaternion .. autofunction:: e3nn.o3.quaternion_to_axis_angle .. autofunction:: e3nn.o3.matrix_to_axis_angle .. autofunction:: e3nn.o3.angles_to_axis_angle .. autofunction:: e3nn.o3.axis_angle_to_matrix .. autofunction:: e3nn.o3.quaternion_to_matrix .. autofunction:: e3nn.o3.quaternion_to_angles .. autofunction:: e3nn.o3.axis_angle_to_angles Convertions to point on the sphere ---------------------------------- .. autofunction:: e3nn.o3.angles_to_xyz .. autofunction:: e3nn.o3.xyz_to_angles e3nn-0.6.0/docs/api/o3/o3_s2grid.rst000066400000000000000000000005341514371756200167440ustar00rootroot00000000000000Grid Signal on the Sphere ========================= .. autofunction:: e3nn.o3.s2_grid .. autofunction:: e3nn.o3.spherical_harmonics_s2_grid .. autofunction:: e3nn.o3.rfft .. autofunction:: e3nn.o3.irfft .. autoclass:: e3nn.o3.ToS2Grid :members: :show-inheritance: .. autoclass:: e3nn.o3.FromS2Grid :members: :show-inheritance:e3nn-0.6.0/docs/api/o3/o3_sh.rst000066400000000000000000000134241514371756200161660ustar00rootroot00000000000000.. _Spherical Harmonics: Spherical Harmonics =================== The spherical harmonics :math:`Y^l(x)` are functions defined on the sphere :math:`S^2`. They form a basis of the space on function on the sphere: .. math:: \mathcal{F} = \{ S^2 \longrightarrow \mathbb{R} \} On this space it is natural how the group :math:`O(3)` acts, Given :math:`p_a, p_v` two scalar representations: .. math:: [L(g) f](x) = p_v(g) f(p_a(g) R(g)^{-1} x), \quad \forall f \in \mathcal{F}, x \in S^2 :math:`L` is representation of :math:`O(3)`. But :math:`L` is not irreducible. It can be decomposed via a change of basis into a sum of irreps, In a handwavey notation we can write: .. math:: Y^T L(g) Y = 0 \oplus 1 \oplus 2 \oplus 3 \oplus \dots where the change of basis are the spherical harmonics! This notation is handwavey because :math:`x` is a continuous variable, and therefore the change of basis :math:`Y` is not a matrix. As a consequence, the spherical harmonics are equivariant, .. math:: Y^l(R(g) x) = D^l(g) Y^l(x) .. jupyter-execute:: :hide-code: import torch import math from e3nn import o3 import plotly.graph_objects as go axis = dict( showbackground=False, showticklabels=False, showgrid=False, zeroline=False, title='', nticks=3, ) layout = dict( width=690, height=160, scene=dict( xaxis=dict( **axis, range=[-8, 8] ), yaxis=dict( **axis, range=[-2, 2] ), zaxis=dict( **axis, range=[-2, 2] ), aspectmode='manual', aspectratio=dict(x=8, y=2, z=2), camera=dict( up=dict(x=0, y=0, z=1), center=dict(x=0, y=0, z=0), eye=dict(x=0, y=-5, z=5), projection=dict(type='orthographic'), ), ), paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(0,0,0,0)", margin=dict(l=0, r=0, t=0, b=0) ) cmap_bwr = [[0, 'rgb(0,50,255)'], [0.5, 'rgb(200,200,200)'], [1, 'rgb(255,50,0)']] def s2_grid(): betas = torch.linspace(0, math.pi, 40) alphas = torch.linspace(0, 2 * math.pi, 80) beta, alpha = torch.meshgrid(betas, alphas, indexing='ij') return o3.angles_to_xyz(alpha, beta) def trace(r, f, c, radial_abs: bool = True): if radial_abs: a = f.abs() else: a = 1 return dict( x=a * r[..., 0] + c[0], y=a * r[..., 1] + c[1], z=a * r[..., 2] + c[2], surfacecolor=f ) def plot(data, radial_abs: bool = True): r = s2_grid() n = data.shape[-1] traces = [ trace(r, data[..., i], torch.tensor([2.0 * i - (n - 1.0), 0.0, 0.0]), radial_abs=radial_abs) for i in range(n) ] cmax = max(d['surfacecolor'].abs().max().item() for d in traces) traces = [go.Surface(**d, colorscale=cmap_bwr, cmin=-cmax, cmax=cmax) for d in traces] fig = go.Figure(data=traces, layout=layout) fig.show() .. jupyter-execute:: r = s2_grid() ``r`` is a grid on the sphere. .. jupyter-execute:: :hide-code: fig = go.Figure( data=[ go.Scatter3d( x=r[..., 0].flatten(), y=r[..., 1].flatten(), z=r[..., 2].flatten(), mode='markers', marker=dict( size=1, ), ) ], layout=dict( width=500, height=300, scene=dict( xaxis=dict( **axis, range=[-1, 1] ), yaxis=dict( **axis, range=[-1, 1] ), zaxis=dict( **axis, range=[-1, 1] ), aspectmode='manual', aspectratio=dict(x=3, y=3, z=3), camera=dict( up=dict(x=0, y=0, z=1), center=dict(x=0, y=0, z=0), eye=dict(x=0, y=-5, z=5), projection=dict(type='orthographic'), ), ), paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(0,0,0,0)", margin=dict(l=0, r=0, t=0, b=0) ) ) fig.show() Each point on the sphere has 3 components. If we plot the value of each of the 3 component separately we obtain the following figure: .. jupyter-execute:: plot(r, radial_abs=False) x, y and z are represented as 3 scalar fields on 3 different spheres. To obtain a nicer figure (that looks like the spherical harmonics shown on Wikipedia) we can deform the spheres into a shape that has its radius equal to the absolute value of the plotted quantity: .. jupyter-execute:: plot(r) :math:`Y^1` is the identity function. Now let's compute :math:`Y^2`, for this we take the tensor product :math:`r \otimes r` and extract the :math:`L=2` part of it. .. jupyter-execute:: tp = o3.ElementwiseTensorProduct("1o", "1o", ['2e'], irrep_normalization='norm') y2 = tp(r, r) plot(y2) Similarly, the next spherical harmonic function :math:`Y^3` is the :math:`L=3` part of :math:`r \otimes r \otimes r`: .. jupyter-execute:: tp = o3.ElementwiseTensorProduct("2e", "1o", ['3o'], irrep_normalization='norm') y3 = tp(y2, r) plot(y3) The functions below are more efficient versions not using `e3nn.o3.ElementwiseTensorProduct`: .. rubric:: Details .. autofunction:: e3nn.o3.spherical_harmonics .. autofunction:: e3nn.o3.spherical_harmonics_alpha_beta .. autofunction:: e3nn.o3.Legendre e3nn-0.6.0/docs/api/o3/o3_tp.rst000066400000000000000000000100271514371756200161730ustar00rootroot00000000000000Tensor Product ============== All tensor products --- denoted :math:`\otimes` --- share two key characteristics: #. The tensor product is *bilinear*: :math:`(\alpha x_1 + x_2) \otimes y = \alpha x_1 \otimes y + x_2 \otimes y` and :math:`x \otimes (\alpha y_1 + y_2) = \alpha x \otimes y_1 + x \otimes y_2` #. The tensor product is *equivariant*: :math:`(D x) \otimes (D y) = D (x \otimes y)` where :math:`D` is the representation of some symmetry operation from :math:`E(3)` (sorry for the very loose notation) The class `e3nn.o3.TensorProduct` implements all possible tensor products between finite direct sums of irreducible representations (`e3nn.o3.Irreps`). While `e3nn.o3.TensorProduct` provides maximum flexibility, a number of sublcasses provide various typical special cases of the tensor product: * `e3nn.o3.FullTensorProduct`: .. jupyter-execute:: :hide-code: from e3nn import o3 .. jupyter-execute:: tp = o3.FullTensorProduct( irreps_in1='2x0e + 3x1o', irreps_in2='5x0e + 7x1e' ) print(tp) tp.visualize(); The full tensor product is the "natural" one. Every possible output --- each output irrep for every pair of input irreps --- is created and returned independently. The outputs are not mixed with each other. Note how the multiplicities of the outputs are the product of the multiplicities of the respective inputs. * `e3nn.o3.FullyConnectedTensorProduct` .. jupyter-execute:: tp = o3.FullyConnectedTensorProduct( irreps_in1='5x0e + 5x1e', irreps_in2='6x0e + 4x1e', irreps_out='15x0e + 3x1e' ) print(tp) tp.visualize(); In a fully connected tensor product, all paths that lead to any of the irreps specified in ``irreps_out`` are created. Unlike `e3nn.o3.FullTensorProduct`, each output is a learned weighted sum of compatible paths. This allows `e3nn.o3.FullyConnectedTensorProduct` to produce outputs with any multiplicity; note that the example above has :math:`5 \times 6 + 5 \times 4 = 50` ways of creating scalars (``0e``), but the specified ``irreps_out`` has only 15 scalars, each of which is a learned weighted combination of those 50 possible scalars. The blue color in the visualization indicates that the path has these learnable weights. All possible output irreps do **not** need to be included in ``irreps_out`` of a `e3nn.o3.FullyConnectedTensorProduct`: ``o3.FullyConnectedTensorProduct(irreps_in1='5x1o', irreps_in2='3x1o', irreps_out='20x0e')`` will only compute inner products between its inputs, since ``1e``, the output irrep of a vector cross product, is not present in ``irreps_out``. Note also in this example that there are 20 output scalars, even though the given inputs can produce only 15 unique scalars --- this is again allowed because each output is a learned linear combination of those 15 scalars, placing no restrictions on how many or how few outputs can be requested. * `e3nn.o3.ElementwiseTensorProduct` .. jupyter-execute:: tp = o3.ElementwiseTensorProduct( irreps_in1='5x0e + 5x1e', irreps_in2='4x0e + 6x1e' ) print(tp) tp.visualize(); In the elementwise tensor product, the irreps are multiplied one-by-one. Note in the visualization how the inputs have been split and that the multiplicities of the outputs match with the multiplicities of the input. * `e3nn.o3.TensorSquare` .. jupyter-execute:: tp = o3.TensorSquare("5x1e + 2e") print(tp) tp.visualize(); The tensor square operation only computes the non-zero entries of a tensor times itself. It also applies different normalization rules taking into account that a tensor time itself is statistically different from the product of two independent tensors. .. autoclass:: e3nn.o3.TensorProduct :members: :show-inheritance: .. autoclass:: e3nn.o3.FullyConnectedTensorProduct :members: :show-inheritance: .. autoclass:: e3nn.o3.FullTensorProduct :members: :show-inheritance: .. autoclass:: e3nn.o3.ElementwiseTensorProduct :members: :show-inheritance: .. autoclass:: e3nn.o3.TensorSquare :members: :show-inheritance: e3nn-0.6.0/docs/api/o3/o3_wigner.rst000066400000000000000000000001531514371756200170420ustar00rootroot00000000000000Wigner Functions ================ .. autofunction:: e3nn.o3.wigner_D .. autofunction:: e3nn.o3.wigner_3j e3nn-0.6.0/docs/api/util/000077500000000000000000000000001514371756200150515ustar00rootroot00000000000000e3nn-0.6.0/docs/api/util/jit.rst000066400000000000000000000002031514371756200163640ustar00rootroot00000000000000JIT - wrappers for TorchScript =============================== .. automodule:: e3nn.util.jit :members: :show-inheritance: e3nn-0.6.0/docs/api/util/test.rst000066400000000000000000000002051514371756200165570ustar00rootroot00000000000000test - helpers for unit testing =============================== .. automodule:: e3nn.util.test :members: :show-inheritance: e3nn-0.6.0/docs/api/util/util.rst000066400000000000000000000001441514371756200165570ustar00rootroot00000000000000util ==== Helper functions. .. rubric:: Overview .. toctree:: :maxdepth: 1 jit test e3nn-0.6.0/docs/conf.py000066400000000000000000000075431514371756200146330ustar00rootroot00000000000000# Configuration file for the Sphinx documentation builder. # # This file only contains a selection of the most common options. For a full # list see the documentation: # https://www.sphinx-doc.org/en/master/usage/configuration.html # -- Path setup -------------------------------------------------------------- # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. import os import sys from e3nn import __version__, __file__ sys.path.insert(0, os.path.abspath("../")) # -- Project information ----------------------------------------------------- project = "e3nn" copyright = "2020, e3nn Developers" author = "e3nn Developers" # The full version, including alpha/beta/rc tags release = "0.5.1" # -- General configuration --------------------------------------------------- # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ "autodocsumm", "myst_parser", "sphinx.ext.autodoc", "sphinx.ext.autosummary", "sphinx.ext.intersphinx", "sphinx.ext.linkcode", "sphinx.ext.mathjax", "sphinx.ext.napoleon", "jupyter_sphinx", ] # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] intersphinx_mapping = { "python": ("https://docs.python.org/3", None), "numpy": ("https://numpy.org/doc/stable/", None), "pytorch": ("https://pytorch.org/docs/stable/", None), "torch_geometric": ("https://pytorch-geometric.readthedocs.io/en/latest/", None), "ase": ("https://wiki.fysik.dtu.dk/ase/", None), } autodoc_default_options = { "inherited-members": False, "show-inheritance": True, "autosummary": False, } # The reST default role to use for all documents. default_role = "any" # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # html_theme = "sphinx_rtd_theme" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". # html_static_path = ['_static'] myst_update_mathjax = False # Resolve function for the linkcode extension. # Thanks to https://github.com/Lasagne/Lasagne/blob/master/docs/conf.py def linkcode_resolve(domain, info): def find_source(): # try to find the file and line number, based on code from numpy: # https://github.com/numpy/numpy/blob/master/doc/source/conf.py#L286 obj = sys.modules[info["module"]] for part in info["fullname"].split("."): obj = getattr(obj, part) import inspect import os fn = inspect.getsourcefile(obj) fn = os.path.relpath(fn, start=os.path.dirname(__file__)) source, lineno = inspect.getsourcelines(obj) return fn, lineno, lineno + len(source) - 1 if domain != "py" or not info["module"]: return None try: rel_path, line_start, line_end = find_source() # __file__ is imported from e3nn filename = f"e3nn/{rel_path}#L{line_start}-L{line_end}" except Exception: # no need to be relative to core here as module includes full path. filename = info["module"].replace(".", "/") + ".py" tag = __version__ return f"https://github.com/e3nn/e3nn/blob/{tag}/{filename}" e3nn-0.6.0/docs/examples/000077500000000000000000000000001514371756200151415ustar00rootroot00000000000000e3nn-0.6.0/docs/examples/examples.rst000066400000000000000000000002401514371756200175050ustar00rootroot00000000000000Examples ======== The two examples are models made to classify the toy dataset *tetris*. .. toctree:: :maxdepth: 1 tetris_polynomial tetris_gate e3nn-0.6.0/docs/examples/tetris_gate.rst000066400000000000000000000004531514371756200202070ustar00rootroot00000000000000Tetris Gate Example =================== Build on top of `tetris_polynomial`, the following is added: * `soft_one_hot_linspace` * `e3nn.nn.Gate` .. rubric:: code .. literalinclude:: ../../examples/tetris_gate.py Full code `here `_ e3nn-0.6.0/docs/examples/tetris_polynomial.rst000066400000000000000000000015151514371756200214520ustar00rootroot00000000000000.. _tetris_poly: Tetris Polynomial Example ========================= In this example we create an *equivariant polynomial* to classify tetris. We use the following feature of e3nn: * `e3nn.o3.Irreps` * `o3.spherical_harmonics` * `e3nn.o3.FullyConnectedTensorProduct` And the following features of `pytorch_geometric `_ * `radius_graph `_ * `scatter `_ .. rubric:: the model .. literalinclude:: ../../examples/tetris_polynomial.py :lines: 60-107 .. rubric:: training .. literalinclude:: ../../examples/tetris_polynomial.py :lines: 111-127 :dedent: 4 Full code `here `_ e3nn-0.6.0/docs/guide/000077500000000000000000000000001514371756200144205ustar00rootroot00000000000000e3nn-0.6.0/docs/guide/change_of_basis.rst000066400000000000000000000036201514371756200202450ustar00rootroot00000000000000Change of Basis =============== In the release ``0.2.2``, the euler angle convention changed from the standard ZYZ to YXY. This amounts to a change of basis for e3nn. This change of basis means that the real spherical harmonics have been rotated from the "standard" real spherical harmonics (see this table of standard real spherical harmonics from Wikipedia_). If your network has outputs of L=0 only, this has no effect. If your network has outputs of L=1, the components are now ordered x,y,z as opposed to the "standard" y,z,x. If, however, your network has outputs of L=2 or greater, things are a little trickier. In this case there is no simple permutation of spherical harmonic indices that will get you back to the standard real spherical harmonics. In this case you have two options (1) apply the change of basis to your inputs or (2) apply the change of basis to your outputs. 1. If the only inputs you have are scalars and positions, you can just permute the indices of your coordinates. You just need to permute from ``y,z,x`` to ``x,y,z``. If you choose this method, be careful. You must keep the permuted coordinates for all subsequent analysis calculations. 2. If you want to apply the change of basis more generally, for higher L, you can grab the appropriate rotation matrices, like this example for L=2: .. jupyter-execute:: import torch from e3nn import o3 import matplotlib.pyplot as plt change_of_coord = torch.tensor([ # this specifies the change of basis yzx -> xyz [0., 0., 1.], [1., 0., 0.], [0., 1., 0.] ]) D = o3.Irrep(2, 1).D_from_matrix(change_of_coord) plt.imshow(D, cmap="RdBu", vmin=-1, vmax=1) plt.colorbar(); Of course, you can apply the rotation method to either the inputs or the outputs -- you will get the same result. .. _Wikipedia: https://en.wikipedia.org/wiki/Table_of_spherical_harmonics#Real_spherical_harmonics e3nn-0.6.0/docs/guide/convolution.rst000066400000000000000000000167431514371756200175440ustar00rootroot00000000000000.. _conv guide: Convolution =========== In this document we will implement an equivariant convolution with ``e3nn``. We will implement this formula: .. math:: f'_i = \frac{1}{\sqrt{z}} \sum_{j \in \partial(i)} \; f_j \; \otimes\!(h(\|x_{ij}\|)) \; Y(x_{ij} / \|x_{ij}\|) where - :math:`f_j, f'_i` are the nodes input and output - :math:`z` is the average `degree`_ of the nodes - :math:`\partial(i)` is the set of neighbors of the node :math:`i` - :math:`x_{ij}` is the relative vector - :math:`h` is a multi layer perceptron - :math:`Y` is the spherical harmonics - :math:`x \; \otimes\!(w) \; y` is a tensor product of :math:`x` with :math:`y` parametrized by some weights :math:`w` Boilerplate imports .. jupyter-execute:: import torch from torch_cluster import radius_graph from torch_scatter import scatter from e3nn import o3, nn from e3nn.math import soft_one_hot_linspace import matplotlib.pyplot as plt Let's first define the irreps of the input and output features. .. jupyter-execute:: irreps_input = o3.Irreps("10x0e + 10x1e") irreps_output = o3.Irreps("20x0e + 10x1e") And create a random graph using random positions and edges when the relative distance is smaller than ``max_radius``. .. jupyter-execute:: # create node positions num_nodes = 100 pos = torch.randn(num_nodes, 3) # random node positions # create edges max_radius = 1.8 edge_src, edge_dst = radius_graph(pos, max_radius, max_num_neighbors=num_nodes - 1) print(edge_src.shape) edge_vec = pos[edge_dst] - pos[edge_src] # compute z num_neighbors = len(edge_src) / num_nodes num_neighbors ``edge_src`` and ``edge_dst`` contain the indices of the nodes for each edge. And we can also create some random input features. .. jupyter-execute:: f_in = irreps_input.randn(num_nodes, -1) Note that out data is generated with a normal distribution. We will take care of having all the data following the ``component`` normalization (see :ref:`norm guide`). .. jupyter-execute:: f_in.pow(2).mean() # should be close to 1 Let's start with .. math:: Y(x_{ij} / \|x_{ij}\|) .. jupyter-execute:: irreps_sh = o3.Irreps.spherical_harmonics(lmax=2) print(irreps_sh) sh = o3.spherical_harmonics(irreps_sh, edge_vec, normalize=True, normalization='component') # normalize=True ensure that x is divided by |x| before computing the sh sh.pow(2).mean() # should be close to 1 Now we need to compute :math:`\otimes(w)` and :math:`h`. Let's create the tensor product first, it will tell us how many weights it needs. .. jupyter-execute:: tp = o3.FullyConnectedTensorProduct(irreps_input, irreps_sh, irreps_output, shared_weights=False) print(f"{tp} needs {tp.weight_numel} weights") tp.visualize(); in this particual choice of irreps we can see that the l=1 component of the spherical harmonics cannot be used in the tensor product. In this example it's the equivariance to inversion that prohibit the use of l=1. If we don't want the equivariance to inversion we can declare all irreps to be even (``irreps_sh = Irreps("0e + 1e + 2e")``). To implement :math:`h` that has to map the relative distances to the weights of the tensor product we will embed the distances using a basis function and then feed this embedding to a neural network. Let's create that embedding. Here is the base functions we will use: .. jupyter-execute:: num_basis = 10 x = torch.linspace(0.0, 2.0, 1000) y = soft_one_hot_linspace( x, start=0.0, end=max_radius, number=num_basis, basis='smooth_finite', cutoff=True, ) plt.plot(x, y); Note that this set of functions are all smooth and are strictly zero beyond ``max_radius``. This is useful to get a convolution that is smooth although the sharp cutoff at ``max_radius``. Let's use this embedding for the edge distances and normalize it properly (``component`` i.e. second moment close to 1). .. jupyter-execute:: edge_length_embedding = soft_one_hot_linspace( edge_vec.norm(dim=1), start=0.0, end=max_radius, number=num_basis, basis='smooth_finite', cutoff=True, ) edge_length_embedding = edge_length_embedding.mul(num_basis**0.5) print(edge_length_embedding.shape) edge_length_embedding.pow(2).mean() # the second moment Now we can create a MLP and feed it .. jupyter-execute:: fc = nn.FullyConnectedNet([num_basis, 16, tp.weight_numel], torch.relu) weight = fc(edge_length_embedding) print(weight.shape) print(len(edge_src), tp.weight_numel) # For a proper notmalization, the weights also need to be mean 0 print(weight.mean(), weight.std()) # should close to 0 and 1 Now we can compute the term .. math:: f_j \; \otimes\!(h(\|x_{ij}\|)) \; Y(x_{ij} / \|x_{ij}\|) The idea is to compute this quantity per edges, so we will need to "lift" the input feature to the edges. For that we use ``edge_src`` that contains, for each edge, the index of the source node. .. jupyter-execute:: summand = tp(f_in[edge_src], sh, weight) print(summand.shape) print(summand.pow(2).mean()) # should be close to 1 Only the sum over the neighbors is remaining .. math:: f'_i = \frac{1}{\sqrt{z}} \sum_{j \in \partial(i)} \; f_j \; \otimes\!(h(\|x_{ij}\|)) \; Y(x_{ij} / \|x_{ij}\|) .. jupyter-execute:: f_out = scatter(summand, edge_dst, dim=0, dim_size=num_nodes) f_out = f_out.div(num_neighbors**0.5) f_out.pow(2).mean() # should be close to 1 Now we can put everything into a function .. jupyter-execute:: def conv(f_in, pos): edge_src, edge_dst = radius_graph(pos, max_radius, max_num_neighbors=len(pos) - 1) edge_vec = pos[edge_dst] - pos[edge_src] sh = o3.spherical_harmonics(irreps_sh, edge_vec, normalize=True, normalization='component') emb = soft_one_hot_linspace(edge_vec.norm(dim=1), 0.0, max_radius, num_basis, basis='smooth_finite', cutoff=True).mul(num_basis**0.5) return scatter(tp(f_in[edge_src], sh, fc(emb)), edge_dst, dim=0, dim_size=num_nodes).div(num_neighbors**0.5) Now we can check the equivariance .. jupyter-execute:: rot = o3.rand_matrix() D_in = irreps_input.D_from_matrix(rot) D_out = irreps_output.D_from_matrix(rot) # rotate before f_before = conv(f_in @ D_in.T, pos @ rot.T) # rotate after f_after = conv(f_in, pos) @ D_out.T torch.allclose(f_before, f_after, rtol=1e-4, atol=1e-4) The tensor product dominates the execution time: .. jupyter-execute:: import time wall = time.perf_counter() edge_src, edge_dst = radius_graph(pos, max_radius, max_num_neighbors=len(pos) - 1) edge_vec = pos[edge_dst] - pos[edge_src] print(time.perf_counter() - wall); wall = time.perf_counter() sh = o3.spherical_harmonics(irreps_sh, edge_vec, normalize=True, normalization='component') print(time.perf_counter() - wall); wall = time.perf_counter() emb = soft_one_hot_linspace(edge_vec.norm(dim=1), 0.0, max_radius, num_basis, basis='smooth_finite', cutoff=True).mul(num_basis**0.5) print(time.perf_counter() - wall); wall = time.perf_counter() weight = fc(emb) print(time.perf_counter() - wall); wall = time.perf_counter() summand = tp(f_in[edge_src], sh, weight) print(time.perf_counter() - wall); wall = time.perf_counter() scatter(summand, edge_dst, dim=0, dim_size=num_nodes).div(num_neighbors**0.5) print(time.perf_counter() - wall); wall = time.perf_counter() .. _degree: https://en.wikipedia.org/wiki/Degree_(graph_theory) e3nn-0.6.0/docs/guide/equivar_testing.rst000066400000000000000000000060361514371756200203700ustar00rootroot00000000000000Equivariance Testing ==================== In `e3nn.util.test`, the library provides some tools for confirming that functions are equivariant. The main tool is `equivariance_error`, which computes the largest absolute change in output between the function applied to transformed arguments and the transform applied to the function: .. jupyter-execute:: import e3nn.o3 from e3nn.util.test import equivariance_error tp = e3nn.o3.FullyConnectedTensorProduct("2x0e + 3x1o", "2x0e + 3x1o", "2x1o") equivariance_error( tp, args_in=[tp.irreps_in1.randn(1, -1), tp.irreps_in2.randn(1, -1)], irreps_in=[tp.irreps_in1, tp.irreps_in2], irreps_out=[tp.irreps_out] ) The keys in the output indicate the type of random transformation (``(parity, did_translation)``) and the values are the maximum componentwise error. For convenience, the wrapper function `assert_equivariant` is provided: .. jupyter-execute:: from e3nn.util.test import assert_equivariant assert_equivariant(tp) For typical e3nn operations `assert_equivariant` can optionally infer the input and output `e3nn.o3.Irreps`, generate random inputs when no inputs are provided, and check the error against a threshold appropriate to the current ``torch.get_default_dtype()``. In addition to `e3nn.o3.Irreps`-like objects, ``irreps_in`` can also contain two special values: * ``'cartesian_points'``: ``(N, 3)`` tensors containing XYZ points in real space that are equivariant under rotations *and* translations * ``None``: any input or output that is invariant and should be left alone These can be used to test models that operate on full graphs that include position information: .. jupyter-execute:: :hide-code: kwargs = dict( irreps_in="3x0e + 2x1o", irreps_out="4x0e + 1x1o", max_radius=2.0, num_neighbors=3.0, num_nodes=5.0 ) .. jupyter-execute:: import torch from torch_geometric.data import Data from e3nn.nn.models.v2103.gate_points_networks import SimpleNetwork from e3nn.util.test import assert_equivariant # kwargs = ... f = SimpleNetwork(**kwargs) def wrapper(pos, x): data = dict(pos=pos, x=x) return f(data) assert_equivariant( wrapper, irreps_in=['cartesian_points', f.irreps_in], irreps_out=[f.irreps_out], ) To test equivariance on a specific graph, ``args_in`` can be used: .. jupyter-execute:: :hide-code: my_pos = torch.randn(3, 3) my_x = f.irreps_in.randn(3, -1) .. jupyter-execute:: assert_equivariant( wrapper, irreps_in=['cartesian_points', f.irreps_in], args_in=[my_pos, my_x], irreps_out=[f.irreps_out], ) Logging ------- ``assert_equivariant`` also logs the equivariance error to the ``e3nn.util.test`` logger with level ``INFO`` regardless of whether the test fails. When running in pytest, these logs can be seen using the `"Live Logs" feature `_: .. code:: pytest tests/ --log-cli-level infoe3nn-0.6.0/docs/guide/guide.rst000066400000000000000000000004621514371756200162510ustar00rootroot00000000000000.. _user_guide: User Guide ========== .. rubric:: Beginner .. toctree:: :maxdepth: 1 installation irreps convolution normalization .. rubric:: Advanced .. toctree:: :maxdepth: 1 periodic_boundary_conditions transformer equivar_testing jit change_of_basis e3nn-0.6.0/docs/guide/installation.md000066400000000000000000000000421514371756200174370ustar00rootroot00000000000000```{include} ../../INSTALL.md ``` e3nn-0.6.0/docs/guide/irreps.rst000066400000000000000000000041331514371756200164570ustar00rootroot00000000000000.. _irreps guide: Irreducible representations =========================== This page is a beginner introduction to the main object of ``e3nn`` library: `e3nn.o3.Irreps`. All the core component of ``e3nn`` can be found in ``e3nn.o3``. ``o3`` stands for the group of 3d orthogonal matrices, which is equivalently the group of rotation and inversion. .. jupyter-execute:: from e3nn.o3 import Irreps An instance of `e3nn.o3.Irreps` describe how some data behave under rotation. The mathematical description of irreps can be found in the API :ref:`Irreducible representations`. .. jupyter-execute:: irreps = Irreps("1o") irreps ``irreps`` does not contain any data. Under the hood it is simply a tuple of made of other tuples and ints. .. jupyter-execute:: # Tuple[Tuple[int, Tuple[int, int]]] # ((multiplicity, (l, p)), ...) print(len(irreps)) mul_ir = irreps[0] # a tuple print(mul_ir) print(len(mul_ir)) mul = mul_ir[0] # an int ir = mul_ir[1] # another tuple print(mul) print(ir) # print(len(ir)) ir is a tuple of 2 ints but __len__ has been disabled since it is always 2 l = ir[0] p = ir[1] print(l, p) Our ``irreps`` means "transforms like a vector". ``irreps`` is able to provide the matrix to transform the data under a rotation .. jupyter-execute:: import torch t = torch.tensor # show the transformation matrix corresponding to the inversion irreps.D_from_angles(alpha=t(0.0), beta=t(0.0), gamma=t(0.0), k=t(1)) .. jupyter-execute:: # a small rotation around the y axis irreps.D_from_angles(alpha=t(0.1), beta=t(0.0), gamma=t(0.0), k=t(0)) In this example .. jupyter-execute:: irreps = Irreps("7x0e + 3x0o + 5x1o + 5x2o") the ``irreps`` tell us how 7 scalars, 3 pseudoscalars, 5 vectors and 5 odd representation of ``l=2`` transforms. They all transforms independently, this can be seen by visualizing the matrix .. jupyter-execute:: from e3nn import o3 rot = -o3.rand_matrix() D = irreps.D_from_matrix(rot) import matplotlib.pyplot as plt plt.imshow(D, cmap='bwr', vmin=-1, vmax=1); e3nn-0.6.0/docs/guide/jit.rst000066400000000000000000000203761514371756200157500ustar00rootroot00000000000000======================= TorchScript JIT Support ======================= PyTorch provides two ways to compile code into TorchScript: `tracing and scripting `_. Tracing follows the tensor operations on an example input, allowing complex Python control flow if that control flow does not depend on the data itself. Scripting compiles a subset of Python directly into TorchScript, allowing data-dependent control flow but only limited Python features. This is a problem for e3nn, where many modules --- such as `e3nn.o3.TensorProduct` --- use significant Python control flow based on ``e3nn.o3.Irreps`` as well as features like inheritance that are incompatible with scripting. Other modules like ``e3nn.nn.Gate``, however, contain important but simple data-dependent control flow. Thus ``e3nn.nn.Gate`` needs to be scripted, even though it contains a `e3nn.o3.TensorProduct` that has to be traced. To hide this complexity from the user and prevent difficult-to-understand errors, ``e3nn`` implements a wrapper for ``torch.jit`` --- `e3nn.util.jit <../api/util/jit.rst>`_ --- that recursively and automatically compiles submodules according to directions they provide. Using the ``@compile_mode`` decorator, modules can indicate whether they should be scripted, traced, or left alone. Simple Example: Scripting ========================= We define a simple module that includes data-dependent control flow: .. jupyter-execute:: import torch from e3nn.o3 import Norm, Irreps class MyModule(torch.nn.Module): def __init__(self, irreps_in) -> None: super().__init__() self.norm = Norm(irreps_in) def forward(self, x): norm = self.norm(x) if torch.any(norm > 7.): return norm else: return norm * 0.5 irreps = Irreps("2x0e + 1x1o") mod = MyModule(irreps) To compile it to TorchScript, we can try to use ``torch.jit.script``: .. jupyter-execute:: try: mod_script = torch.jit.script(mod) except: print("Compilation failed!") This fails because ``Norm`` is a subclass of `e3nn.o3.TensorProduct` and TorchScript doesn't support inheritance. If we use ``e3nn.util.jit.script``, on the other hand, it works: .. jupyter-execute:: from e3nn.util.jit import script, trace mod_script = script(mod) Internally, ``e3nn.util.jit.script`` recurses through the submodules of ``mod``, compiling each in accordance with its ``@e3nn.util.jit.compile_mode`` decorator if it has one. In particular, ``Norm`` and other `e3nn.o3.TensorProduct` s are marked with ``@compile_mode('trace')``, so ``e3nn.util.jit`` constructs an example input for ``mod.norm``, traces it, and replaces it with the traced TorchScript module. Then when the parent module ``mod`` is compiled inside ``e3nn.util.jit.script`` with ``torch.jit.script``, the submodule ``mod.norm`` has already been compiled and is integrated without issue. As expected, the scripted module and the original give the same results: .. jupyter-execute:: x = irreps.randn(2, -1) assert torch.allclose(mod(x), mod_script(x)) Mixing Tracing and Scripting ============================ Say we define: .. jupyter-execute:: from e3nn.util.jit import compile_mode @compile_mode('script') class MyModule(torch.nn.Module): def __init__(self, irreps_in) -> None: super().__init__() self.norm = Norm(irreps_in) def forward(self, x): norm = self.norm(x) for row in norm: if torch.any(row > 0.1): return row return norm class AnotherModule(torch.nn.Module): def __init__(self, irreps_in) -> None: super().__init__() self.mymod = MyModule(irreps_in) def forward(self, x): return self.mymod(x) + 3. And trace an instance of ``AnotherModule`` using `e3nn.util.jit.trace`: .. jupyter-execute:: mod2 = AnotherModule(irreps) example_inputs = (irreps.randn(3, -1),) mod2_traced = trace( mod2, example_inputs ) Note that we marked ``MyModule`` with ``@compile_mode('script')`` because it contains control flow, and that the control flow is preserved even when called from the traced ``AnotherModule``: .. jupyter-execute:: print(mod2_traced(torch.zeros(2, irreps.dim))) print(mod2_traced(irreps.randn(3, -1))) We can confirm that the submodule ``mymod`` was compiled as a script, but that ``mod2`` was traced: .. jupyter-execute:: print(type(mod2_traced)) print(type(mod2_traced.mymod)) Customizing Tracing Inputs ========================== Submodules can also be compiled automatically using tracing if they are marked with ``@compile_mode('trace')``. When submodules are compiled by tracing it must be possible to generate plausible input examples on the fly. These example inputs can be generated automatically based on the ``irreps_in`` of the module (the specifics are the same as for ``assert_equivariant``). If this is not possible or would yield incorrect results, a module can define a ``_make_tracing_inputs`` method that generates example inputs of correct shape and type. .. jupyter-execute:: @compile_mode('trace') class TracingModule(torch.nn.Module): def forward(self, x: torch.Tensor, indexes: torch.LongTensor): return x[indexes].sum() # Because this module has no `irreps_in`, and because # `irreps_in` can't describe indexes, since it's a LongTensor, # we impliment _make_tracing_inputs def _make_tracing_inputs(self, n: int): import random # The compiler asks for n example inputs --- # this is only a suggestion, the only requirement # is that at least one be returned. return [ { 'forward': ( torch.randn(5, random.randint(1, 3)), torch.arange(3) ) } for _ in range(n) ] To recursively compile this module and its submodules in accordance with their ``@compile_mode``s, we can use ``e3nn.util.jit.compile`` directly. This can be useful if the module you are compiling is annotated with ``@compile_mode`` and you don't want to override that annotation by using ``trace`` or ``script``: .. jupyter-execute:: from e3nn.util.jit import compile mod3 = TracingModule() mod3_traced = compile(mod3) print(type(mod3_traced)) Deciding between ``'script'`` and ``'trace'`` ============================================= The easiest way to decide on a compile mode for your module is to try both. Tracing will usually generate warnings if it encounters dynamic control flow that it cannot fully capture, and scripting will raise compiler errors for features it does not support. In general, any module that uses inheritance or control flow based on ``e3nn.o3.Irreps`` in ``forward()`` will have to be traced. Testing ======= A helper function is provided to unit test that auto-JITable modules (those annotated with ``@compile_mode``) can be compiled: .. jupyter-execute:: from e3nn.util.test import assert_auto_jitable assert_auto_jitable(mod2) By default, ``assert_auto_jitable`` will test traced modules to confirm that they reject input shapes that are likely incorrect. Specifically, it changes ``x.shape[-1]`` on the assumption that the final dimension is a network architecture constant. If this heuristic is wrong for your module (like it is for ``TracedModule`` above), it can be disabled: .. jupyter-execute:: assert_auto_jitable(mod3, strict_shapes=False) Compile mode ``"unsupported"`` ============================== Sometimes you may write modules that use features unsupported by TorchScript regardless of whether you trace or script. To avoid cryptic errors from TorchScript if someone tries to compile a model containing such a module, the module can be marked with ``@compile_mode("unsupported")``: .. jupyter-execute:: :raises: @compile_mode('unsupported') class ChildMod(torch.nn.Module): pass class Supermod(torch.nn.Module): def __init__(self) -> None: super().__init__() self.child = ChildMod() mod = Supermod() script(mod)e3nn-0.6.0/docs/guide/normalization.rst000066400000000000000000000036461514371756200200510ustar00rootroot00000000000000.. _norm guide: Normalization ============= .. jupyter-execute:: :hide-code: import torch We define two kind of normalizations: ``component`` and ``norm``. Definition ---------- component """"""""" ``component`` normalization refers to tensors with each component of value around 1. More precisely, the second moment of each component is 1. .. math:: \langle x_i^2 \rangle = 1 Examples: * ``[1.0, -1.0, -1.0, 1.0]`` * ``[1.0, 1.0, 1.0, 1.0]`` the mean **don't** need to be zero * ``[0.0, 2.0, 0.0, 0.0]`` this is still fine because :math:`\|x\|^2 = n` .. jupyter-execute:: torch.randn(10) norm """" ``norm`` normalization refers to tensors of norm close to 1. .. math:: \|x\| \approx 1 Examples: * ``[0.5, -0.5, -0.5, 0.5]`` * ``[0.5, 0.5, 0.5, 0.5]`` the mean **don't** need to be zero * ``[0.0, 1.0, 0.0, 0.0]`` .. jupyter-execute:: torch.randn(10) / 10**0.5 There is just a factor :math:`\sqrt{n}` between the two normalizations. Motivation ---------- Assuming that the weights distribution obey .. math:: \langle w_i \rangle = 0 \langle w_i w_j \rangle = \sigma^2 \delta_{ij} It imply that the two first moments of :math:`x \cdot w` (and therefore mean and variance) are only function of the second moment of :math:`x` .. math:: \langle x \cdot w \rangle &= \sum_i \langle x_i w_i \rangle = \sum_i \langle x_i \rangle \langle w_i \rangle = 0 \langle (x \cdot w)^2 \rangle &= \sum_{i} \sum_{j} \langle x_i w_i x_j w_j \rangle &= \sum_{i} \sum_{j} \langle x_i x_j \rangle \langle w_i w_j \rangle &= \sigma^2 \sum_{i} \langle x_i^2 \rangle Testing ------- You can use ``e3nn.util.test.assert_normalized`` to check whether a function or module is normalized at initialization: .. code:: from e3nn.util.test import assert_normalized from e3nn import o3 assert_normalized(o3.Linear("10x0e", "10x0e")) e3nn-0.6.0/docs/guide/periodic_boundary_conditions.rst000066400000000000000000000231221514371756200231040ustar00rootroot00000000000000Point inputs with periodic boundary conditions ================================================================ This example shows how to give point inputs with periodic boundary conditions (e.g. crystal data) to a Euclidean neural network built with ``e3nn``. For a specific application, this code should be modified with a more tailored network design. .. jupyter-execute:: import torch import e3nn import ase import ase.neighborlist import torch_geometric import torch_geometric.data default_dtype = torch.float64 torch.set_default_dtype(default_dtype) Example crystal structures ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ First, we create some crystal structures which have periodic boundary conditions. .. jupyter-execute:: # A lattice is a 3 x 3 matrix # The first index is the lattice vector (a, b, c) # The second index is a Cartesian index over (x, y, z) # Polonium with Simple Cubic Lattice po_lattice = torch.eye(3) * 3.340 # Cubic lattice with edges of length 3.34 AA po_coords = torch.tensor([[0., 0., 0.,]]) po_types = ['Po'] # Silicon with Diamond Structure si_lattice = torch.tensor([ [0. , 2.734364, 2.734364], [2.734364, 0. , 2.734364], [2.734364, 2.734364, 0. ] ]) si_coords = torch.tensor([ [1.367182, 1.367182, 1.367182], [0. , 0. , 0. ] ]) si_types = ['Si', 'Si'] po = ase.Atoms(symbols=po_types, positions=po_coords, cell=po_lattice, pbc=True) si = ase.Atoms(symbols=si_types, positions=si_coords, cell=si_lattice, pbc=True) Create and store periodic graph data ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ We use the `ase.neighborlist.neighbor_list` algorithm and a ``radial_cutoff`` distance to define which edges to include in the graph to represent interactions with neighboring atoms. Note that for a convolutional network, the number of layers determines the receptive field, i.e. how “far out” any given atom can see. So even if a we use a ``radial_cutoff = 3.5``, a two layer network effectively sees ``2 * 3.5 = 7`` distance units (in this case Angstroms) away and a three layer network ``3 * 3.5 = 10.5`` distance units. We then store our data in `torch_geometric.data.Data` objects that we will batch with `torch_geometric.data.DataLoader` below. .. jupyter-execute:: radial_cutoff = 3.5 # Only include edges for neighboring atoms within a radius of 3.5 Angstroms. type_encoding = {'Po': 0, 'Si': 1} type_onehot = torch.eye(len(type_encoding)) dataset = [] dummy_energies = torch.randn(2, 1, 1) # dummy energies for example for crystal, energy in zip([po, si], dummy_energies): # edge_src and edge_dst are the indices of the central and neighboring atom, respectively # edge_shift indicates whether the neighbors are in different images / copies of the unit cell edge_src, edge_dst, edge_shift = ase.neighborlist.neighbor_list("ijS", a=crystal, cutoff=radial_cutoff, self_interaction=True) data = torch_geometric.data.Data( pos=torch.tensor(crystal.get_positions()), lattice=torch.tensor(crystal.cell.array).unsqueeze(0), # We add a dimension for batching x=type_onehot[[type_encoding[atom] for atom in crystal.symbols]], # Using "dummy" inputs of scalars because they are all C edge_index=torch.stack([torch.LongTensor(edge_src), torch.LongTensor(edge_dst)], dim=0), edge_shift=torch.tensor(edge_shift, dtype=default_dtype), energy=energy # dummy energy (assumed to be normalized "per atom") ) dataset.append(data) print(dataset) The first `torch_geometric.data.Data` object is for simple cubic Polonium which has 7 edges: 6 for nearest neighbors and 1 as a “self” edge, ``6 + 1 = 7``. The second `torch_geometric.data.Data` object is for diamond Silicon which has 10 edges: 4 nearest neighbors for each of the two atoms and 2 “self” edges, one for each atom, ``4 * 2 + 1 * 2 = 10``. The lattice of each structure has a shape of ``[1, 3, 3]`` such that when we batch examples, the batched lattices will have shape ``[batch_size, 3, 3]``. Graph Batches ~~~~~~~~~~~~~ `torch_geometric.data.DataLoader` create batches of differently sized structures and produces `torch_geometric.data.Data` objects containing a batch when iterated over. .. jupyter-execute:: batch_size = 2 dataloader = torch_geometric.data.DataLoader(dataset, batch_size=batch_size) for data in dataloader: print(data) print(data.batch) print(data.pos) print(data.x) ``data.batch`` is the batch index which is tensor of shape ``[batch_size]`` that stores which points or “atoms” belong to which example. In this case, since we only have two examples in our batch, the batch tensor only contains the numbers ``0`` and ``1``. The batch index is often passed to ``scatter`` `operations to aggregate per examples values `__, e.g. the total energy for a single crystal structure. For more details on batching with ``torch_geometric``, please see `this page `__. Relative distance vectors of edges with periodic boundaries ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ To calculate the vectors associated with each edge for a given `torch_geometric.data.Data` object representing a single example, we use the following expression: .. parsed-literal:: edge_src, edge_dst = data['edge_index'][0], data['edge_index'][1] edge_vec = (data['pos'][edge_dst] - data['pos'][edge_src] + torch.einsum('ni,nij->nj', data['edge_shift'], data['lattice'])) The first line in the definition of ``edge_vec`` is simply how one normally computes relative distance vectors given two points. The second line adds the contribution to the relative distance vector due to crossing unit cell boundaries i.e. if atoms belong to different images of the unit cell. As we will see below, we can modify this expression to also include the ``data['batch']`` tensor when handling batched data. One Approach: Adding a Preprocessing Method to the Network ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ While ``edge_vec`` can be stored in the `torch_geometric.data.Data` object, it can also be calculated by adding a preprocessing method to the Network. For this example, we create a modified version of the example network ``SimpleNetwork`` `documented here `__ with `source code here `__. ``SimpleNetwork`` is a good starting point to check your data pipeline but should be replaced with a more tailored network for your specific application. .. jupyter-execute:: from e3nn.nn.models.v2103.gate_points_networks import SimpleNetwork from typing import Dict, Union import torch_scatter class SimplePeriodicNetwork(SimpleNetwork): def __init__(self, **kwargs) -> None: """The keyword `pool_nodes` is used by SimpleNetwork to determine whether we sum over all atom contributions per example. In this example, we want use a mean operations instead, so we will override this behavior. """ self.pool = False if kwargs['pool_nodes'] == True: kwargs['pool_nodes'] = False kwargs['num_nodes'] = 1. self.pool = True super().__init__(**kwargs) # Overwriting preprocess method of SimpleNetwork to adapt for periodic boundary data def preprocess(self, data: Union[torch_geometric.data.Data, Dict[str, torch.Tensor]]) -> torch.Tensor: if 'batch' in data: batch = data['batch'] else: batch = data['pos'].new_zeros(data['pos'].shape[0], dtype=torch.long) edge_src = data['edge_index'][0] # Edge source edge_dst = data['edge_index'][1] # Edge destination # We need to compute this in the computation graph to backprop to positions # We are computing the relative distances + unit cell shifts from periodic boundaries edge_batch = batch[edge_src] edge_vec = (data['pos'][edge_dst] - data['pos'][edge_src] + torch.einsum('ni,nij->nj', data['edge_shift'], data['lattice'][edge_batch])) return batch, data['x'], edge_src, edge_dst, edge_vec def forward(self, data: Union[torch_geometric.data.Data, Dict[str, torch.Tensor]]) -> torch.Tensor: # if pool_nodes was set to True, use scatter_mean to aggregate output = super().forward(data) if self.pool == True: return torch_scatter.scatter_mean(output, data.batch, dim=0) # Take mean over atoms per example else: return output We define and run the network. .. jupyter-execute:: net = SimplePeriodicNetwork( irreps_in="2x0e", # One hot scalars (L=0 and even parity) on each atom to represent atom type irreps_out="1x0e", # Single scalar (L=0 and even parity) to output (for example) energy max_radius=radial_cutoff, # Cutoff radius for convolution num_neighbors=10.0, # scaling factor based on the typical number of neighbors pool_nodes=True, # We pool nodes to predict total energy ) When we apply the network to our data, we get one scalar per example. .. jupyter-execute:: for data in dataloader: print(net(data).shape) # One scalar per example e3nn-0.6.0/docs/guide/transformer.png000066400000000000000000006060301514371756200174750ustar00rootroot00000000000000PNG  IHDR\(sBIT|dtEXtSoftwaregnome-screenshot>-tEXtCreation TimeSat 20 Mar 2021 01:18:08 PM CETAm8 IDATxwtnMlz%Tz =. H"`HGE ""MzW5t!ls$ln ~9Νoc1c1X!Vtc1c1{8!c1c1ƪN1c1c*b1c1cJc1c1R8!c1c1ƪN1c1c*b1c1cJc1c1R8!c1c1ƪN1c1c*b1c1cJc1c1R8!c1c1ƪN1c1c*b1c1cJc1c1R8!c1c1ƪN1c1c*b1c1cJc1c1R8!c1c1ƪN1c1c*b1c1cJc1c1R8!c14wp;JrBGwHL:5̵̤'!^SRcS(OqMETl*Uǖ:$!I JB [1YS [QH =R!*wqs[|¯e_c+{*Ji%vCӱ;p=3*.!wG*7۩F/8Tc;~3* )`i^%PNPDTJ)6r 珝ĵJɕW&+Z~tZ58v׊;q1p}a\OmQ>>>\݇q1u3;~nJBP͡ ><I6 0OBpq]+~$)cH _PJݠuEb~XBc]B8~8®Wj~]`r==a[uW2v6wu] vq,vpuwK qlVu,˭r%W ;Q0|{0q#/p~@ ~T ABQ&gesbSql 봼% SOæ$j9"iKz g Ehq${7"/+`E~H-zt~ E+:|b?m5Zȯ ^ IK4ŗe='nI/<$}I7yy(܎!Tݔrd:ZE#*1F4 ۑB%>Dް' =lSXZ!gDig=;v!e7}y[KIe=*k ҏ4 G EGpNSRhKOl lwn:zU@Svd!d~*=G`=m?NUh#ֲ_:۔\A]~C|[S.䃞4nŮK>:CȮ~#9eZ#Q}Ց&п+{TklEϜ_$J:-7龻r9186zTHB.~FM#l&=e^@NZlr<=xeMhHPړL0!!atvrM[ߏ^Xji֔zdz4X å SUX+ 2VtuY$ى4Hf8+kBLO{3;`IGX"-,aܢ=& ^ lNC0U+!/b`g^טQ4F#65^s$i-jS j8' %=X<#63ï[=Ъ<!i{5\q/Oŋd{'?49~hA.Ni`ѳvcB@|P9B6F^/aʵk>Ws=j݈$bVaKj M?"X:D5^o_j|Ph8k4D;>nYxUVj mhݼ/^TG n?BdA eBkyQi'2ldnO#P&nJ(mc MVc#@OÈ켄{txe6v>{K=:ou ==slT*$ FAdVg4y %{d Mgͣ[ Ή"  CiQi?mXʼnT!ܒ>JGh LJRO0Ř_.Ќ@GOodw޺(G*Z5YFsF7%w8PᔑŘKhάY46 }u͚5f͚Ms7\#'[ڽ;~׍EIjG}@*?h_!u v\uԴ1*|iA4dp+ڥ7^ZK'.UkѼOLN=6f}Kt͞If͢Sw.k.ӯ"Y_տWj(ޥ:ِOux/}wjF1| 6Zv }J-r=x9}ͥ 4qZd" 5zьH:|)߷O:ȷ#~JJxSH3oR }^55vXKkOumi )=MfE~ hsd_;?P/H]?}h5m ZKA2J>B̤~Ickie靦BGt^)jiAG̦Yؒ РΣx4o,5k}J) WGWw>5DDzMڀ|{n5G%d('Ъ_>>>2R+}u%3ߴ^OIQ (TyZKFs~\8'5 y(&ё Gg͢ߡ. mmV_ΝgHK~@PS_ѯkVӒ)BN5/YGSyՠm>U|F/zQ07SBeBl).ξ7MvݡEMA?yxs26Zy9/5up6Mjt) t HnAKKdצydه>ym* kC`+C5֔ńZ191)4˞#o s=w|FY}(HғǁB7f Ʊ4 Ĝ/V[;/㻺>4`]bDiw@VM}f埵"\Zz.y(fM ?ҲuiɴF ?SxFMy2WWڄyȳYO8tԥ7}/Jbjlm#hE?WѲoS77y*?9}=l5=؋4NHF86*W#=Z A s0'~+GqL՝tGO& bj>Y0+!&‹οA)D|zQr)6rZ릶j2M7_t yϧ[FUg'lЁ8^C&֫@՞bxB Й5D3"- LQ %tpKs{B4(6Ld&Ĕ:E-@nx62_@O?FOڒTH޵= +*gnMԜ̄Gcϣ/0 Cv4(PRW?hGC{oDy8C(UB~RdV|;%ȷϮ-*ZAHC;DFJ{Kդ3}+'IK-,H?+|{R*P+-J3{/;!7Nmi+^ TA[e,Di Y2SHo9DwyȤV@Y! 2HG?Y9'ㄘ=$c:?1-7KsN7#7Є=^9ǖ&Pw '4ioDQ347CB𷄹Dj<6g64i Hd>:KԪi#nM6bkz{SD$ᇖdՒ^)lThE'GbS\1T_c͌LLj_Yd5[!>vuGMy@Kee&jƖ13:^hQWJ]FC <^J/ 1|L;Dؙ 2j=-tN't!7Um:畖Ljty+7%1$@-zc S m~ : ak4A~n>zvc;-4d(Ưɓ0ay${3fpgY<k8'/aX`('{[§~p [V߀Z(%á#黏Ⴙ/eI s F<cZvC1]UtPk^%!i18-P]{[㘟myto|I8{[Æ+Tn-PHx}Wew8_ l_ < ?~Cfb%!WoC4{8~(.hd*` j&/Z#!f(~o?Ӗf~l)Bf2ܼ?n.qO vGH(l#y"* #:ص3:un /9}-dE ZM FI4z^A$HE^D8:Am$D!=BGK%<-0ulL#>Ө^5hg3^݆݊p'45l WD b VJ~y[@}~Ww<뗒 m-=1F܀046eO @2:c1fǠEoAF]/bDS' '<'ގ@jDT'Q>3*<ڌWm૴8x>7asFWثbiJ$(I /x5HMvE~ᢤB7Vcݸ(UHOQlio@7Eܙ0ep} 3;@m2q!;Z-@Vb!.]ǔ. !yhs|:QWr ȠV'(5S曫TF>hd4zmE#5Yd˞iXcbc7mLVBм;ΜIac 0"2‘QZn(QTH1*u@x!%]"('DF!2{J;j6f܏#q^hSޭبU@B,gy L`,6`jϿ^ q/aݙ9l[{[B}:B[CTH{U_!~)u`[P5,T"Gc|`dRVd>Ԍ3欳A X!se-KX/QgoY)xt0ٲ ͏ČZ=HD|l1 I Zjhbr]b @ {mQPcl>#`٘z{'2)S:Ny`ʍ /sDaHH! IDATM[~Tsܫ/‘! *`kU2T5}Fb2@Om8&:Qo\l6iqX9c>hvŴzE8, ~a*K{sOD6aV_m^w!<ڴCP B ;ݺUL}}"Q6槱/h5j_ ;4=$hLlz?@dpl欸 fߕ 4, -?#e63p1$'M&@; SoHe0";B{,Ŏ NI))Ȁ*<_6GA͋cSM*1>2[Q1@EƭqNlOFͷ۠G[Xk]n@Ѷ }yP\<$c˲%a/۵#QKׁۤgA ;HHz 6&Jؠkkmޭ@T(DZQyޓ&'ljys0װ:C,l`kt\*;He\8m.h Wpn}y}$F_nX8W DD=Z?v,ѪQq^)w/ aÛ m -6⷟NJ%tn?aVCcoCzyG()>֌IH9HTGݎՍՉDdAJYgahx$@[xeR?oĶw|엀[Ǖmv~97 DP _܅Fǹs-3hE>!RP4HV-" 4H-SFƜbdt^1sW쇢kj4d7!FBVl?#@:ǽo+t|{ ڽ!:_W:8[bJ1 (4,퀖5rRݹ@]r_2؃uSLyExb9[p.='~ h/\ȑ95BD5z5wmM\K=bWZFw~ňag%_ g9 "|xx>ERh V8 lÚ9k2bcz0ڮ zyQMѷz?7~/Pz6.K{Cj 4Vk4m!1Z~Te^yc,\c:62U[V1A-B}o=6]h&ϫ k R m>Ff--q56H+w!U JaʱP!3 ܢ-iCjQ r6*74 @ 2iG!u$$#2kH{&ԲuC򳗓j:Z憌V,vh bA@vA.U& JVD7i*5sv+W‘u4wp09 l̜bPO"Y hw[ȥKtnrֳ:3ڙg6P)u K͏hٱp|ݾoN۶O R~۷;]}hz Ʀ}]t6MB"2} ;E6w !ll'+cyr'u0{ >~.\!sXjK \L/ m\Ia!+pR}iľ§8&Ձ%uVHW]?LD 9xOGt`x_CN",[)r&*uhbL= + ~/fN^$Zy4O]]ɲL/ҵ]r1x.+~>Ts䑜 \#Nάjˇ Ag6ھ};znIU/)d:~.EOLkRF}{pH4p6yM`X 6{.m:N-OfޱM9%Joy u|j+<؏P8=Eܑy\G_U홿 O;ci k7'tE 0z0NZzm՟҉.O3R~HM噱¥5oך(I&dF==2[O ˻ 2WmF۷fӸSq+R 2iJU'G@6hMc/K{InEhMr"4̙S]OhnU07r~wOpTRXJC@v}?҉YF-+COr.V2˄ ړ\&OZԙ\JFiƷM?i+t*s isϢ7)vM6OȣXZz"Dsz"o`j UX3cb=̉ɲH;!LCD4>MFiLF?0ĖbM5]oKIfSwȨWHGtmWrȬv My.C&Zr[Z\̓3{V35 d2E:?-:^-ݘJ.Zg_.ѥsҪ/PMԦ7gcVALIiT"/2:wMPg+̲2+/ T!3:T! A xo˾JD^iI^.]G["0Bbҝ<@ND K}-؊Y I;纸HGikUNPo14/3tgT%ZSN쿤v S'Q"%jd'Wj+˖џ]LI/[i|rgZ%ړ;˹Eԯ%XQ99>lB^o8㽜U5H"zOC\-H@̕ \Fs-ҋ~̥aA{' \%zLgN6" Һ6u>.xb*Н>^vY da@G%6"gyc ~NDYE1dn{ߤnܝj|\bDzJ<9FzJ[LGsJ~DB+5z%Mc]^U>alr9CA^љR)_%Јe dcג-=3]-t:vrڻX(Vw*O2SGi8bGk$I;vZYdjq"gjBe`z H.BjAջŗ(-^+^\Fj4E;z-J~YȵWǜs4(ԕ(ؒ >¼8OOhC V %e&K2 '14/,+&뎕ͱpR< $\;" ]eRBo ZO_d^Qq@̍F s Pe~5'0!F)أ$ bSMta%zF }eQ:]x ~}uD 9 1"hhɒ\kLS' w8RE-*ɬyQ,;;gr{k-:gtf0jeM2!:dفFx:׹Qr㇑H"F4H >CJDe({{g>FxaLDD%}`=~ՌLjxE۠M@ǐ.ff$)p͓hΔy/UuhDD,`u)~R 2Vnp}B]%hkrb@Z$P{&R bE!^foWs=Rc]5M%ѽ8HNEf(eUGyɪAp?"ml}ݏA'åO2.\ŎG! b.v~TؒRKџ䇸-; )) ~~a1{৕-p{?RkaXn3 dcE=N|Goo;‹-^8fA }'$I*횋G1c1ɓHLL45 u91BcXf I6m୷ނ e*d* N1S$6CVB 3Xeh`ٳ+웄tj ˖-{%''C@TBPKɰJ*:3c/,Idd2lll e:tѽDNLIhp=ddd@BVdZ-vJb1JKxD\GFo^jxCakQ-d1=BBBpL2*o ,@JJn%cONÞ={yf\r@vիBBBʤcÆ 8|! =zSN~^?Vz4^72N6A$̛7 . a"VEbb7%1&""##q ##HMMEbb"A@zФIrX)NE8!cOW>==Ecpt v@P W(N>3gEr$0+Aqu$&&BEXYYaԘ%RRRk.;v #F=%%ˋlדJqB1ƞ"D!"o6̽$iQ'c0;DQߠ3V^|ɰ֬Yggg۷Duرd/믿_]:XP!GxEEE!!!Z: xZ-z=++\TJE`+g?{dYv]~:jժzݞD\~r 2d:CDx!6mڄCB$R-Ro$c=EF }/W>*1\0>n݂VEDDΝ;/B(qyYp9$%%+w^THMM5kЦM cϪ-[\6..G50DEE\~֭fXqA@RR"""R`oo[[[T*(J^:兇bƍ׬Yc6XZZNT*XYY"g;v 999ߗAwޅB DQN3=bM$ʕ+h*MR4I{J /x0_9M]ʨ51Vn߾ݻwc߾}/Q~}ӱm6:v숟~~~~::v숎;BѣG(ǏVqe]'O_&M… 8x ٰ7kżFID$cϊo۶ 3g45fYYbΈSz)))qlmmaii XYYA&AhKR* 666BXXn:jz0L&\.t:KQ]WkfHli4$%%F%%%!==iiih4HOOGFF!I Yt:ܺu ~-WQVZAR$+/y1czkK=&ķ91VwwwtgϞxUtt4N}Ho IDAT?,z#cL+f$pB1 jxaUbb"гgOXZZ?~[[ySB3g`Ioph$ Į]o1_; mll`cc`˖-Dtt47oބ%Q-P(}vNr=9:uH։'jSjjj& g6s Q \nH%&&"""gϞ5<닺uBTP(IjѡClݺIII(V}Ņ 0,<<>/JǏ#!!!ף8{,bbb`kk ZE||<@Dh4EϟG>}fI_'N1SFu9{#11* HII5nܸOOOԨQ@f2lٰ@޽Z F+W}/W^^Ǐ$&&/_Fbb"4ib/f߾}Xt)"BJJJ#&'g*?{eO D aU-CkmE*BU[g`82d=or?x{I>뾒{s<|%55U6@`2hii9cpÆ \M;v`hnn޶Z>}l2^/sOP_~?ٳCU[sssY|9vٳgI:! inn&߲f|hZKٳ'}ɓ'BE1e]ƚ5k |yH"oKK jD06Vl6CEEFFʒ멮%KBA(" QVVƔ)SXbYYYt-E@ d4zx-.C0VclPSSҥ {%66>}fƎ+G=zOjj*[l_~a̘1$&&ʞ={p:ߢ~ƴiؼy3cǎ%%G䭷bϞ=455_ר BK.rJd$ ZM\\r%\>DMxRA 8>s{9Yx1| / +Θ1o?|x<fϞ}B/ϒ%KX~=W\q)))h^Tr]wq/rJZZZxhM^/+VMU**Jؒ~}}=]^{8bbbȐSٳg*J;NRRRkwL3`0p8hhhpxp:ݻ&zùKIHH  bZq:9HY3 T*^/L<˗Ga^hAL . ] )ri1]&]0@p6֭ӧOK~QXXyM:U~饗RXX(Z( PT 4}6 ̙3ÁBh41ÔDnU~Çn:&MnI{&V^Ͱa(((Խ0j(&8ka}]x +/11,Fͧ~ʶmx<3zh>뮻9sYr%^y浉r\,X;wr7rM71}tSO=ŗ_~ɺuhmmE1x`&N(w}k1{l-Znʳ@p"E xߏN#11Q"r HrXjk֬o;$66L<zj\RTFJbI(\..{r9r$Ç[nǙdz AJ'EI)ofmm-&MbQ ! EB%qDy?1^Ez~jf@CBZDV3vب Fict,w/&''\ I*--ntҥKbcc v\1DEF d ˛oɦM?>ݻw'55UZ[[ B!.+1һ,%%шbݎfc֭FNϞ=t466bjrdp8hiipȩ~+_пɓ'n:rssǏ{@p֨(ۅy) @|@CLD?a9ѱkwd}5{Q[[BW3f``Ȑ!yW3`;}N?20~8RSS֭1cPXXHJJ :NݒƂ|>$''p8x/1ǓK~~~T&1IIBBpJV+v0`iiiCVV)))t҅lrssӧ#F`ȑvt:@ZZZ7n Ӣ'"BL ."sPNPgM ( L&)O@ X9QP8҂ꫯCr饗RYYʕ+ILLG'=^0СCR(**TJb|>rrrhmm=sf裏~~smuJV(w}_'DP0c z= ?8N|I ҥKz̜9>p@yz!~i<쳴>CO>$''S^^."0*Jj5 t. 児Bjjj|>ڵw vhhii={HNNfܸqhW(rڤRa2tzصk[lh0vYSTٺu+"#k,&v|w3 %LB0 ׋:od_YYIee%={oFvq\z ,~"0nL33qD EEETWW 2 ̟?*~0eʔS:NNN7pK.믿nfVXRdѢE <̇~ի)Cy?> ,{k= ~xDAٳg| FNG^^t֍jkkihh %\13#G2uTetdzJJJHII!//Villwބa9_nݺq%PTTD]] 6 N&*YPzV+WH4߿? 6EYt:]笔JZ\\|A AL .B3$ftiR]fePi$1N{)pr'#Ǘ?>|K/oӗ4Zv-=CVSTTDQQ)E3f_LwCaX`#G|"[pX~( nv Zlx䈫p8,Xz^W._BS:f,?v UUU|7<\r%~֯_ρUU$:[B`ԨQ̚5jLMM N;g]vcVy:YfOȟ'rrrXx1ga̟?UV1rN8,Y%Kl23gՂxXv-*p8㡩Ӊl.ZW8&H{\yh4B\2gL r:J%Pݎ$33nwxj5e˖1vX獔(UÌ"=1cbb|>8@jj*^z| {z~򓟜ßL@ :`ODR3j#H { :3l00]v<)Ql/S[[СCXC}1#',"0bݺu,Yd,YrF37UVo ),,䩧"..4t:O<s9cahZ~aRSSlf…Ĝ~'Jv;H*- DRU=UW]`@RbI>|/}T*Q*x^pX h"233e1L1zF%y<N'vHZZٳG>/xǏoFF@pp;TP[f牞\&@Jv >-~H  c4\.0LdggVd2|>jz1LĐl6R:!fY~=vJ TVVbX֭ww΢EHNN>ǗA^3g/^N;+[JYf/ϢEHMM=} .}8p ?8&S[׳j*oߎ!>>sW=ׯ[nTTT~ u{<t:/Xnt: ɓxzRRRXx1w< GQVVF0$##\w?D^/GKbISvɰg?IIIŵ0d z\.+J Gb޽?J4P(PTM_P3 dgg$ r{/1o@ 8aL.sڻ>}X]Iԝ qFJKK(,,dԨQDl2*޿ۘf⨩a֭>|@VVI&dȐ!$&&ҵkW6l {@ a6yk2p@ 6m/gΜ9NZlJg555k駟sNn&nVnvꫯ駟2~x~ӟ2w\fϞ͒%Kx<<#mYl7|3wqG+ǎKRR<o޽{ P]]駟k׮6  O^(--2ɗ_%Vׯݻwd2a2d3zχůȨHaMnB,RTrK@ ' ]{E3I*]6ZjrdqNJjf4y;W@ 8Tk&ٹUuH, ~CQ]]rQSSVˀ0ͼ曤G8fQWWGcc#ʷ&PUU@vv6`Z->#G;.is|uR=z4fobJ# Ζo⤵ٳg"ٹs'۷oodܹ[ti1,O?Ν;뮻;ᮻ˗mp8իQ(̝;{S,]7|Scׯ?<@I]]s̡đ Hcy%7jZj9J2̏QRO~`p\=1)JǤIۥ:$233޽;]v+EJm۶QSS&53xs ȿ{mNҥKinn9tώe˖u )))L6 Nu|VIMM%77W^,B2)EE^F~ 0 G^%66yرCO:ԧTMtBcc#f"{IEť޽{SO1sLj|ڻ/Ü*"eR g#6r_ Y=ϙ@ !Mv;͆RdĈ1f#>>Bjj* ..o2(,,fh:u*?J`߿?zBr|`Zx1 89%rǓYP*hNY&#G jlٲP(bD O>p?;S}|r$hlld߾}8&.׿f֬Y#ӦMaB}~{'عw>o]o~Ov]{<(ĩŰȴFs +f0#11ɄB~b)]>ͶOB#f͚ d2zjx" B|7l6yx@{(JnO~V$8?p8:,XSblIII?Y~=p,˺Nj*=ک}8nN .T*wy'gŊPJ(d`lt:e0ݎ"OY<änn7.K~IX޽ig2BĖ/__Nzz@ "p;: }xY"//'x)--ȑ#13>}vm=0LdeesI8E0ɀ=[negu} B09zy@ " NE'51vXFyP^xwy~x%MbI%Ҫq8 y(`0P*p.2}#rb}}?ȱG2:t-Zu~9̜9{QpqFi{"y7:K/Ċ+:~Q>bqtMxQ۷mC=>*~v=Y8|p*J233;u|ŇhdObd7th.)Qb'J$ ʋ=fS .+j_I܀v#$1t:q:XVZ[[X,|>NJ RDLʨ7f@p6NEhp—^WZ-p,2E2/ذaAeX˗ƻKcc#Vˁ/|| F^KRYY)OV+Gaڵ,]5kDE!H`0Ȗ-[O|rer;.]ᅬl[z}Ԙf37n}.Ue]\sM2eJ* y6o6*~6w:scƌ9+Q [n5k0m4bbbtGI|Zvl6vՊft\$M |(b߾}<㴴>yl6xdz}("8bXSS<сc+W@=8tPdf sϞ=LvٮR$55<p Y\HJJ۷p"VٱcG ?7|s+:}1cZw|9r/I&1{ln!yWFn:Ǹqy=CIs-t9ŋy׈BeeeX,V,9(A,zQb:)B!|>s|ʥ$`EaBXKK MMMfN'111̝;Цr\sE%֮]+_fy@ kaiyIG nvZvɊ+">pm6{%33]ed8i&M6Knn.p,ťXK޽lR~GvvѣGIHH`ȑѧO}$J%z|( ޽;jNDbbbp\ן+֭1Ne2ZkעRPTĐDJJ '`Ւ-W$45@ 8 zMbb"ÇGS__n_lܸ$Ej xv*++tdddp&MD8+pQFEbpQ<Э[7bbb׿EUU=zhw5>pa*++&>>4ƏOee%< }EP0QZ0GRVVFAAY ^zkj*>c,cդsrr2e _~9.9z( .GJb޼y7?;wRǍ5\#WdgΜ9<A:RSUY~=>(nk{Rĉ)((?dݺu|>5\ȑ#VϟÇ뮻4ik֬aʕTTTǢ箾jL" 茸#A^Eڊ^GVT*INN`0P[[㑷%1)m/~eV ;Q5ȗ-&d`xP*dddпe:y$i%KvvD}}=͘ԩS),,j;gΜv']v\.W_%//O~oXXr%ɲhq֫+zLi:ߨ.tgiT@yZ-#Fh4ҳgOf3z_~B!t:EEEWqq1F1*RDTFZZÇF իyyyAFEKK QCFF* Mnn.IIIh4wV%++݉D(rQXXnGPPTT^'>> &hHII!S\\5`0Hkk+W^y%GFjj*~;v<5k桇bС>Jbȑ9@ hOJOOO~;n1111tP(9s:+jkk;w. ̛7ONkk|AٓPàp8ݻ㏉%zϧǖ-[h4zG0"S&%H0JV#SJb d rWǃ锣BZӫW/ `CYY{ʿ)v6h1M~𒱩A{H+Ǘhhhd21lذ(ALROFARVQTQ8+ȾvZZZڴn3ތ wȇBFg0drWv|Zۡ} duɉS |Ƅ c"_~zь=Cѣ:%MDBI&O.]TГJƎiZ)..&vj]t?| 6{nZZZHII((4B*/\qz!~_p-tzVO@h6q /j4hsΥtߏbi^T~o8th"ܦTCTĜ+VRٓ &зo_J\eĉl޼m۶jQ*=?20d/ N@h&o,iXv؁JBizKLL 呒BJJ *JR>׭[Ǒ#GkMG0k.AlƍmL<0ׯ緿&B pZP.]D5FC0D>|_~3g2JJJHNNF 6LݻÓVb%a7{{( tF5Dvzx9995M $ڠP('>>"Zcc#;v/?`ܸqqg{V`P*݋RhcwQ2za+Z{= #++ Aii)]v+҅a?ԩS)**ه444i&Q*9i@7w_XZZZؼy3|g2j(9KR˖->W8~ir=OImc$x8ټ+KRaX9s& h4FEIYjZDINN*=޽{l\.OR"Vss3@b֭['b%I*ST2d'^S%TM׋^-[i& qqqtޝ]dMVjbrɮx'2n88"x.R{B]q#&YGNDtEWL9u]mԨQdeeɑT !1yd&MD]]#$oH|>T)ȏ##?|E :Ekk+G! Fvr0@RSS0a?0`< B}v, 111QPP O )--.#;;,***̎srr߿1'|µ^KϞ=磏>bqrJe]FMMPRRB0$//iӦzj9B o߾x^;Xرc9ry).6T*wq={d}̛7#Gz-z=7|3۶mc…ƒɪUHLLdƌdƍ/_aX^z|y(Y@B *tr-I(ttǂ|>\t IHHw|W466Rd/$YX,<_WL4W_t:ijj 11J` yGVzqh4YRT- Bl6[vڅ!==lRRRHKK/:thi 2)QM'0{9U=7QtE T{n6o֭[lT*&L*=dyu㇆B 33Lƍ6}h(..fĈB@ 8 9t萜2n8.ye۷swc0())a̘1Q ~-;wp1 ~~?fAΝ;Yb_l^e4 & )**BVxؼy3fRѣҥbX8rǏGѰyf<{͛GM׮]qva/F_gڴi̘1=7|-¶mxg`$''SO;x;/~ /ԩSOp8믿{GYpIXZ"_z^Nh410Ábfɂlf޽EEuI!$N9\.>ٳ'yyyL>׋la2p:tH^rʽ锏ȢX0Ф ]t!//:, V [$$$VYt)>`EbhR{Ȋ' )O3=t> eζY%Ay15551c ~?A1rHg=Ayf|M|M)..olc/# >K/? ZSFF&olݺcjƍGbb"1cPRR¾}hjj"--:ƎK `ӦM~***={RSS@^/~^?z(gڵ 843ݺu_dɒ%̟?Pz-F#Z%KPPPEd{ŋxOx7;vlT׎ ի=z4Ͽ #玓-( Z[[)//'%%NJ,I~_f& Q%%%RdԿ=D'ۍᐫPvMh*T*ʢݎV`2%L}}IhJE}}=ՊtQɳgPeW! +jGmf[a5%?bClPsa1L"_P|Cz,tOXHp,=ᮻ⮻͛7e>Cn=<@ 萞K^z<tXVrrr4hp] HOOJۗ&#Ay"2m4T*&qa2L2OBB~_o[oET4hlb|M7t:娳 QB!O=󑗗|gP($$$PVV֦X8&55 b4q:' `2X`׿?w… ;\`޽Bz-Ȃ j5GGz={6:gpz[2uTN\L c4H1 S VAi)))a޽-U'#B^%Wz|dggGyEi4a߾}x^4 Vd ᆪTFɃ,PWW믿_O.]HMMa۩&|>grZALEA+%꾷M( EWfK;Ba~91L"Q]jǠ3(%%%vm?bG* ,EL]]c߾}r%P(#%LPTn~?'zfҽ;2LJ$%%Gkk+pOyy9~ѵkW)++ѣ8cH '---r-F~?JRNUWWQ'΢ʤS]f{S?Ti 9`!(UUU\իWr$77|G{7ngϞL:qơ$++Av~ F`nff JaYYYrJ"ػw/Z={RdJVS]]}] e˖KQ7nɪlKh6lcD aذaK,X~2tP?Q[[wߥo߾\q߿k{}J%{/:e˖1h >$4~_3|p7o̞=' A)Hގ `0f&kvލf2P_h#¤dzJ=ȔHQ)҃RlIKKl6喖8@8lt:)))TTTR|OS93ȕ*ȴ oDW$%vMԏC bv֭ ;v0oeUA{gժUÆ bsN ⢣ɃM|w:x]z&q6K[O~BVVΏdOέʖ-[Xr%>|BqF $$$B ))=z'~Z-W#KLLlw ǧDZ~_IddY{m"-((,&`w[n}bb"v[P(L> 6P^^m0/2&g}}v/kZiiiARE RM^/))).BJ$%*RrX0Jфc +,/=käqJ>gCdd/BBBdeeѻwov;TWWG+JKxp 0 BAb5ZOU%ďؙ :RQ=ѣϵZ-wqyE5QFQSSCfffs/?4"Nj_'kv"3f9-d2n%fJ%))) {Vft:qqqx^jkk"vBMMMTVVRTTԩ zmmmUNFCCnӟTaN68N @ *K*x߿?~#Gȩ}WRvQfR &I$|i4j,I"$NS(<'v;ѥKT*wEaa!NC0dsr &k?LD3FlԺ['a^ujkk1bD &h/Rwޡ3fvF%?^"'8fz)S0`F6l؀fVo `ΝrwʓZt:+Wjrӭ[7l {aݺu][v,G"#^QZZx^V+&I6?^Ȓ$`0ȶm0H||mll+O#*Byxx0uT._ʕ+ٸq#m۶5tAk$Qv4=~SNagg ZubffѣGu4l 7oW~:'Oӓ *ٷo%K֭[xyy%$$k׮o>UL1!{6mʾ}2koo>iذ!?3qq+Ҵi m':`Rۛcccv \>PPɘRpqq͍`nܸǏ g.WP/Lbcci۶-.\ޒNj122:+N]2]|6qqqi25S III 6f͚e1ٚW6{$HlDO2x}쌍C,nV9y$uIsX Η+ ot}'''k 6664iJB_~cjj^$I|'TZcccpwwGT(RšWV-̪PKƍܶe˖YŽ֭[vwn[xqT6A'''}}FTE]PLᥛB4Z-O>\tI_^tK0=\rs9IHH&R nZnJYRt2R/ :{,j EXD@L]b*c.%D]R)ؤ(?&Qƾ*[0EM`̞͛=w믿RX>7oĭ[&..N>K;N>}$IᕘHBBXXXWN]p,uXꀘn],u LTf(" &PkGSC2_HGPT)ؕ'7xx%HIn-vs~JJ 'k֬̌ѣGSh\oWȿYp!6mbԨQxyyK J|嗹60O=0yݔd"""ؼyk1 {allLtt Yfoo?d(:;;Ӻuk:v옦]f)JFIɒ%ٺu+O>̌ƍӭ[7Q?LxLwpp_~̝;YSSgW2,,,%998޽$ITaHL/4h$I ĒYeIII$$$?jiZJ%jcNIHLLD_t4VK*UpppH]fUM ͛Ӹqcj讼W4oC2+^%8@e"%4$ RLyS޽{s-Vot‡Ke:ŋSNSd +ؾ};ׯ$Q*ԭ[޽{caK9q](dEBBW^%11[[[{V˵k׈ԔeˊRBesIDPPWҥKSdY֭[ܿR_DROe^RNʀe͛7#2vvvX[[ckk BJdPPaaaʃRP(HLL$&&FwjZ^ٟ4iÆ Oϙ" &>)IGBqYL€ϻFxx8 .㸹1a<== -A>rtn5 dYի޽/ =zG 0M@X`~$I5;v 22;;;,--155L?211pݻwQTJNN&666MWrrr(S+T~]}S&46+Fth"Ft;E Wo[\,!\2X=Rx-{{{L¡C1555tA B!2lA ~ʗ/7<}wrmdYٳgXxY?u+BZx1HK{JEJJ IIIio111j  bҥƢRRtte0c <==} LGĄWH6uؙ4ZT&JT([V۔*U*"&$I4hЀ:u6c@AA d#QdI<<ٳ댥˕>W@ [,)RIHHHЯ0[2999M@,>>Xj5eʔw2f}=3]{$kj:*BYرc)YW?d'JK{SL-D0Ø1c2dO<1twu7o'  3t+(f˖M&M )))iTL… p4YŒ$ajjJbň!66-::ZoTTőDӦMQ(T\m۶1vXlll#%%uROtpp`Ŋ[rvv/b =xs9w$N:W_}ҥKE ^x[AG gUTah4uS%e1B j*_yj,Yp(}Lodd$XYYQR%},,,۷/_~tҘcffFJJ XZZʘ1c8p;wNSD?UVW*1EN>jL65jK{bŊ4lؐ͛7s}&L/)  *Iٳ'w(J4 $b$'aeYqT*ɉLMMWjEFFMbb"Z;biivvvѬY3BBBŋ~dA%2)O d͛7?~<,YDÄebb˜1c/9w   [la…& ϋ/^ӧg\\\P*LԴZ-*nݺV+IFHH~IZMҥ{mv ggg<<2~"&&F.@fbbBRի˗ϗQTCq*80[!8i.JxÇ3c \]]0`#L&OB`ܹ888K   LX|| sss7oNs>gBB%QKIL9+uN8ɓE0L4ivvv̝;W,f"   tdYfѢEر޽{gj{AAA"3!NoOTp G \qP,YСCE0LΟ?τ `"(&  B֯_ώ;h֬]vD0,㒓8W21?n|w"&{+Vߟ2eʈ ]AAArM.O?/~m #)|H >>>   F HvXdωh7>ZpauL r9UJY-rTӑUuX̒ J +W0sL/ΤIPudj\ Ga˖-  P ǿ.ϿZмysM?Ŗg4u:sN|Gx$%/[>IؐMr [cHff|f|gx\,QR689`o]z=W*T7<ȍPV6N~nKinnnsKƞݜwˀؓ'O8q"L>KKKCwIx"7R]ꍠsOΝ;QՆ dV? ""]_P 8X;P擡3gwps/n$=S+B+ה>N4sA̙3رx>;w./_fԬYӐ]2XY[<|2-06w Zfĉh4Mɲ̩S9|#կ_?Ǐ; Sҁ ąj[&Aƻ{Q Gvb!zfkG41~]6ٗF9y~`ϭ cZq(EVX}%1Rxo?)E[k /22>Pߕс}{LY' s:1mWf҆ +)wpB IDAT}qK+#wyocIIV e&3a38M;\ #`hO%2E?,鳁 ~Ȗ4V׋L4 ON[=yi(}{O 蛽_1^$;6MYɣН]rϵ6~CoIo Wbq6O_ so-Z?dZȨ&6c&&˚#|_e՟AO:]wHl۶ǏT*cǎia$$QBw50TQY1,lMwccct邥%Ν;Zj%:={ݻY֭k. h|XJY%m^t3}L-qakQ8l?+HliY@,l9? fMMJ6b;-q[`gcU($ZvQ?71?Vٿ?>C%x̞PEɸI>YLR-gYgl92Ju\6m^\J`Sۥ,5ɜٸ8z0kP|Jg_ :|{gSZγUD:NgqTWtsD>'ҬNjXT)1NXHCt`( UűyV6O.˔I*\go$NO~#pssVmG:Pj EGvz,1YB7ϸOV1f]Z!(聃L"Wc*|:)$%pu*,7!<, ,l3s+aic x _+)1̍[[4D{˦Gcx?3.y5W3J.t-$ v-ݻߵkGGǼ 0AAAL0{{{~Gܛ[Yr͚5ѻwo Ç? .(('''q.0Yc*0JʖŒ}. _f=eqlꉵ3|DKi0{-0|Rt$Q(0J-{.bs$<ȉs\V;νV^dg5Zb"bU$\Ԯ+y!Q(-upĒB ߌ%=1g j\@UbED H9C%9['__I Jh59/C?.BI(H8=nؐ^Qd02z IANAna {dʘraoeW,\\X7o(YVߍac#..+}hOUc Ere^."qo:&xS^"8HTK$aaE c*OC[@~m֙̕Kڹ-~tza` rf-39tSڂի\<mPDc.ݻٽ!@ݽݻgጭéQ2.wC Q;-W ,jTà Rfiԛ@L~5Όb d… }6cǎMi9Ge/_&ҥKڢTf %JPti-[FũXT >ڱg$Ħje \JUJT'ɬ=gWgAt>Xf吩\3}nj:s2"+`t3 Z%#b÷Ci%iCeKݔ9?V}9E|:Ƙݽ17r,<Jb ] [(Csf:` 4fW؇Qcs#sBZ cOŏ}Y@vZBٍq:yqy9P6^猵BpgDFKyyȚڇ(PPȧ1Պ)2nF9R\0W:E=9j95-,iߴ"M])ٟ*1K&c4 Џ]:gaSo#'hAeGL?=3ndS OARjU<ب{f(Skǣ Y^#Oi&OGYiͿdC;7F0P~ct@B2.+F;뗶E}zMFP:Qoz˶]r,ԵxyNc6l(7lP|r^7/ڑ#G ʋ/Zj億ѣFZ+]vZmC%""Bҥܹsg9&&)pEV'9c[Cw[H+'g}픲ݫ_OMwCdxpRm&oYOMrۤ t(SQ6CXv+/)7#U2J-Pr9;V-,Ww0% YRȞ5rm]eW ˝6S,_&;FshX+"uק=gɃMx~jhJ/^ JGBKRa…TT:u$ISLΝ;XSS\mK֠Wz=~8kfٲe_|ٳ=f͚1bCwyÅ}앀}s|׿ic yں9ڵO(R-kyLH=E\,߼җ6Ę8Z-$d{WEV= Gvp5$?&XvJ_bSWY{CAA(y!R8'CW~/)u|VRUMMrunݺ F\\/M4e>{eyTjuAA>0z\Z~N#7՛߉ Y9(fh&Ѹ>𺷆x\\\hٲ%WFP P3N[T4AAA["d),l޼ ZbԬx$0Ӯp޺͹s爍^z"&\|DRc  ! c{eҥl߾=Sm[pno<9rIĢBE\\o$k H2AA &|*88KIǎ3_Lx"zveYøRD,_Btt4}eJ4%AAxߩcHw|He͛GRRG(K}ɗMw ,Lr3+1\ӪƧL=FXB"EfTw;t\|/ L:h)Rs~;EoؒswJHUr:ki`DI6t_r B>tΞ=K3ߡj[g\v$Ilܸd훭钲,sErr2˖-#,,]yGGGϟ:|7olneƃWd9 ƺ33/}gԩs 3`Pm)~UV68. aIftyY/~ћ݇M>; 3ʴI Pq8.OFX/E̩@n~ mr"RФFM{|iܢ9k9E A^sN\]]ҥK Wg(5#G$99X6芏I~ Ǽ?~kV!5t*rx G\m?¦ײַlnSPˉ<2Y"AJAM Yyv&MUl%2AxNom ݍ\!2!5kSNE|ZibJ#QF&O찫W2g  ꐽ I$11#ϜTjUVƍE^:$QS3,z%#'nf1?ej|z$+wÇM,?Cw Ma1%}RJ]ڌJk.rj\sb;N~Vtwh5 '0G|beUbgIѰab:Np%+ՠW(C툽u ^޸XSh=KO᭨^֎xV׈IaumgXֳ&^68vezy,v=_TRo%h>:,sC3>~[Z!=fa$Uo'%ٳFlBB1AȇqssҾօ?H.dƬYغu+=]vERP($Iˌ3fPR%ڵk>Olll3f k׮… {O>tsssCw%߱u6ׯ*_-kwbY!_^.(5IT &J /ebP:hm{5~$KO7Xƅ{ҶXZٚTۅqn4(툝fҏ2Ү|Q>ʼnݩW;+gJWd&|%ȱ&tR(ſ1z}g *S4#*@Y"[Sl\HcK8Hw6v|ff;!-%>%3Q59"k}f|זf㗱W>atw?]@)'npjCwvY=noCT_T<=쉬u-873mV?<ƽb3Io Wbq6O_h5Yj!Zh؜.IW8vgoKђK ҡWyELLV'81ƫfs'N`ܹS oooԩ?k֬Rvz%KdɒFedNŸq*oT=+؋`Oixxxq^n45F{R5N#[ŵ0sڟfɾo%]^~\k=j,Kʳ6X܏tp 7U/kVǼi5yüR~s{b1|㉊c_Lj* D\Ц\=FIJj t?I h_4C2^i<[odθ2NE}85ңFՓy:e{;O$HKyeϣR.ب6~e yh-[@P8?<ĸgΘDr$= Bh ex ڎPȢ;x4|7!К^-fs5cVԇ%4v k>ʒqQ|;?>o"=6oJMp<__۔ r49q9GG<&ҕ%2#ΰkv9LT+Nׂg ZN}m4gw-1nֆ;iCk ňK򉸸8LMM,ȍ{kv/Iξi4֭[GժUc/ڵ+:tٳg-Z]> fTiNlx"E@ZB!afiK)<,Q*Ebi B76oRNaZOvݕ9iG>1eg.|7trţ`\b}KYS IDAT>!OhtV&0*Ihc~nz0-3v/G|,`i;J0aKQfL gzQ:ԍ^qtQE|] e;3xhWvBZ64g*̓;3<)(I/ /P`!5N2Ѭ:{=x B\F\e*&Qjo &KZ7}Wr֍S*}/U0X'&ґ;l,bIsYrhaJ@̱ƾ׷|~ZYܺd>ӞL`|ڧCb@2=>_ I <-$aIh摆L#c/ߗl KM;̖nJ}@b>,>-jSӆSqxJFy" &IJe˸~:˖-8F&J+ͮɗ2FYnMؿ?:uB,˜;wYQ*`ff7@LLLۖ/_^!%$$pyQ(j ,,7n`ll,˘QlYwTXBѣG%)ek| 䝴s4N:{-~H;8#@=ism Nuˁ|3`aU>SϩHa_Vxt IbZN"걆JM\Ûε _,+_ҥ6VEosGeĮS&SryRI?{(GV+c8Jtg.D@&:ޜNsW|: (e)1k F/werxd+>UպrHM[rLκ2?KFAa5(݊a"}BR+S"NCHp**S-mɇN~8]}{V UiW 95m )bܹs{ҤIltw"&,?>FI~ʕ+v?tBBBvիWOspp0 ͛DEEPfMn݊$I4oޜ(LMM䯿|k"O]tΝ;Bbb"$ѣG~7N:EHJJ‚O>5kPPw‚J*q1u떡 9gϞɓ 4]YIQkIQkPL(bn [{;4D8TslJ(݃^}DrK?r(JQ{jؾ98,(Q[XJ ̭b9HľؙP={numV;Djƍ5kb龩~Ʉj"WNrL˜K\딧|]Sd40W@@&SJ=Y Pȇ~ǀC=gfkfjTD 慐0"¢PnGTbckhbBw⬰g v8.c,IAl3xIO*AV\J/ȑIDo+RX mo^m{aOckeB]x͛7 )ܻƿpxMGE7~*ȬM1*WgVZ-[Ѫe]-@?Â(bşd cw̹Lk {f9nHP h\L"L%6dH"""B#c={%@ssɛxdQE?NQu֣8 L-wK8GSFcyͻ%UOTӾ,;0j$s]^փ$Qur}YI)yox*G[ãem-uk\Ovc H8Ԅߔ vB2@Ο?ӧi߾=9v\IhukBۼzkM9ۂBni~ʗ/ρݗJ%66i/ wQTǿ $JBI衆H쨠Q)" A"(iJQDtK $T;f!@I<fwvf6̽PX1^xf̘#{e?p,X/ʕ+d2>xZ`xу}Yd 㫯bŊ̚5 OOO\]]- 3rHnݺ^ՕXRS;PH-Z!J` Vs?L3UrhUN1S$)+a'i@`ZlMSnE׻̿ױW䗫՛ѐL,&] GZty=[X&>=&=sWvi\ lٮb$EIOǡ #XOMuLC9;c!_TS\\)i $70Wpoi1iKɯ>1 IASFT[e S_Ĵ}5*\fCg/ah}Qήk78u+[neKD 7}VnƱkq|#[%rt;u.g|0l5[Rr!oOƭ:șK8'=~4=9"/H !Xp!...t-_Qە.#ܿ6.ƓVGofȐ!tؑ=z@xx8ڵo߾|w8;;?V QQQ899=aN㭷ٳL2sʕ+iРCGxx#c8t+&SL4 77nckk˨Qr W_~TVիWS|soݺETTT!EH"dff8N)w֭gl!i&TξM$5)/FTX)śdжDqqO;m[n!#Mӥ|åUL)ngF:^,?ڽvdҰ$P9eecا1\շ'fO pbϤ=>*WVc=*]Q|O ةNQ'8c 6veAtQFgF4~olk=8~+_giߓms7,o1 ~Ù:/RB i8_֥m' ZKg%1Prj ?`çrP+==|a|4f05_!M?8/>ސV߿'_1&L0C5Z9Ӛ<ʽi+^Yo\ӛk@v3>|a}>/2?CƱM.㯟ޠ9>C&>`(TjҏZS$IV";zjk&LApwmDJľ}k3g"%%ۙL&ѢE ggge˖:^vBBBܹs̙3o֬Yb"999G ˗V8c_އp={&hBND /_p×*ĥ<^ADDovSGՋBѻjmƉա{,ּ(WPEԿ(oE8m"y]T3ru\X'>Q-"'?PAEThs}6. }/yN"SpFTM$5xV=lE>sHX('ms{"S.NcLwz|CޚLsxů=Uz~|G"%+iHs1ÿ_JUĭfb棷C'aWo;|hSJz[$=k)ks/Ep͈$IVH=8}v( 8)S>K,?Ν;e&`Æ b˖-KLLܾ}6mPLlә1cfBK/ҥK)Y*:t7nо}{̄ 2@֭[wS)eܹ̞=5j<,3?1#᫞ztnIzs-KQMD^Eu/˓϶TSns# ʕa8Vq.SbilJōhr)7wV2O'-wi_OP$.&u6(Sf_<$I$2220a~*UtR4iܹseɓ曼ktڕ>gҶm[-[fJMMᆪdɒ9rggg̙%EUU-[رc>|8eʔO>!11>9shuV\˗9p͛7gرh4jÕ+Wboo_yBDD*Uv(чw{R]ݺlF$I$I$I6YT_eҤI>|b֥Kh޼9#F =UN]UUbccٱcݻwˋɓ'ӬY3~Gصk#GF݋N>ё۷o?~^$**%K0tP~m>s|M(,Yٳg@tt4 ",,,pwم PL3GGǧ&p=(=+d(I$I$IzɄ$YAFFׯ_v9ְaCN8СC"wѠAQn]?nqKhذ!.!!!}:uֵv8r!z6=ѣGkKB0av)Vn{t]RN 111j&UU)Uʕ{Yf-ܹs`z=BlmmY&aaaX$qtt_~ܾ}9s`2P~WX!C8C[JUa'O\=d1I$I$I $+0`Xf)))9ٳg[^n].]Jڵ{ŋiѢVʳPVb= 6mʰaHMMxtM?99+?h@)%$$xb6lH-%Yӷ*\L.$I$I#.I*`/^… ꫅2w/͚5+jɓz_fUVW^̛7'A޽{iԨOJ*ѣG|2;v\xJ& =d͚5Võöz?;S$I$IQ$ףԩCyb/"OO>m3f͚qK?c3==UUɓ4k ;J)R 6_0emL2CZ, 򟭭-/2gΜիVóp,W ^~K$I$IRNoϒTٳg-[hv̛7۷[$sWӧOd2ɓٹs}sBUU{=v۹q]tɳYMUUMƋ/|`_;wFQ6nh93Us$I$I$&bTbccPoCs۷ٳ|Gm :-[jE5j̙3qrr1F#L>~YfgL&o&F^RJѤIoNJJUbP/`j-JcX8_K$I$IȢTo9sgѦMݻ7nnsppoFgȑ|kVNT*'^Øh/zPƧh>G%I$I$I$)__[ުU+~bdmOݺu^.16+'bv-@(Qޥ`$I$IgLI֭[g}FTTv3ӧOwL9Iziəܸx4L&tEl)[(.ry$I$INFz=C+! 33EQH&$L&zð(ի-no֬M6hч/_3ѕS$I$IȡMׯ3fLR(; !XlQQQ 8@b2!&I >>^zѧO^z%kcBOGllc=^$I99ȶyJ!߫.s2͛ 8#GR^B9Chd8::2x`z}$#G@ŭ(B׮]9wÑlEGG~zFC$Iym4DUU\֟UU~ SNѦM ͎;HOO'11 gWff&cǎB tȤI裏Yf 1Igiii?~&M"AQwNv$Yhܸ1p7-IdM8'Nd=z3gҬY3ڴi믿Μ9s$)Jff&G͛3͛7ms/ϟ9z(j%cyj*Ο?ϰatDBLQʖ-΄ 3|@FF͛7v(Ok I\]]Y&JIg͚5رiӦѳgO"""mE_fĉZIZj!x&F? dTZEQ0Ly_k!U:@I",,np믿Nҥ6mZ'_(Igǎ???k"IR4k֌¬$I)!,XvQJpqqخZj4k ɄNW{$EQbŴhтJ*`^Na08}k.;J&$)׏YfhP 7Z;9ѤI@.$z,YDx ETTb;NGhh(888PJEI)!Xl)))tI;3ߤW||<.]nݺyýE/ɤQ]lR !8u{ |{/ 55#Gxb.]D޽Yx19̞=)S<ϧQFјoL=#c:ɿ9ۮEjRC{RJkQkٓxk"=<==i֬Ŋv($=bccٺu+e˖Zj%'PUmYv Iz111/^c4ٽ{7mڴ+O8%Jg sr~gs0'˙e]|RUEA֭sp̙0c 7o΂ h߾=͚5#<<\{NGn8<ϗcLzדƑu\;Gdos-'c82ۮq'*ڡJ[t)l۶ڡH0EQ0a;vv($=6oLjj*͚5Ύ4hmc0댛uLIe>gGBBu ( 'NֲAUUL&鄄d*III̜9+Vd8̙3 |sQUC!))'rƍ0= IDATn# --ɓ'sʕlgܸqp رck:;,**'NгgOz=pmt::u"==GGG:uLJOt:/\2_2!V N׉|x+Z27P@IR4j^)bQ3tؑ޽{h%I$)oW޽EQhҤ F bcccmFFgϞaÆZ\lԨQݻZOG*L&p sFFsΥk׮@EAA`` B|@~-C{{{mʕc8qsΝ;QUӼq…9r_m>pНӁxf9:NKeM$ۣ+WÇѣIJ&>>^x!ߛ[ɓE!%%]( (Q]va22!V]x1ZIF?ejժVQV-=7|`n;w.uVNz,[[; I"zbwoa2ÁXf :tvVY*?ŕǏ쌷7?%ͶoNxx8񄅅qW\a߾},Y$bȪN:ٳ#ƌ. |͛7sʔ)C۶mׯlٲiįףہ7mvUTt,?U\GGG*T1cիF3gxbFɸqruIh3Q!Vdff_d17I`0+ŋTZU^GvvvL27x=z@PPM6eĈ;;;;+G*FإK,jI0|=ʷ~ˍ7 קu֔/_!n9r4Xt)ꫯh߾Œ{u*Yd1)ӧQUJ*ѧO6lܹs?A\SUEQHNN&22[[[+MΨʙ3g0)Rsή]tL}:ZKEԩS\z@۶msԡy( 5jԠbŊ\|uQZKPfݻQU2eиqcڵkg4>>M6D"Ehذ!}Tr9(4mڔMjUX+eooO%}wL!cdÓ~&ػؠȄXeƌtڡ<7>}:8s 5bdffZ1Ջxbvmp,EG&$9dZ$ʔ)üyhܸV+]QUܹsZ2,Xx߱-Z-Z|hfddd(( /_EQ0! bȐ!KϞ-?|F4iu̺m@@s̙3iHNNFofm6-!\Pw/(:N8{nL&\2 %KZ]zz:k-GYyIfF )6- KzUYw1Df B`c`j<חvY;瞋 筷?:=zz1yd}$IOرU:g[3aR nnnkKL&7jղ8bŊpB=#uԱ!xwiժ3J.ҥKuk]eagٳ~: ,x-53B"##5j+Wqc޼y* 8PWf*T?I666uABgƌCժU6l5k$((#Fyf9<ŋ_Ύ!C/°a 4f[[[ʕ+ǭ[d29d>:t޽{9uSFl;fWRBBp.Ϻ"ElԃS*Utvb27nW\hѢԭ[T֬Yɓ'y&GUVm}z.׮]vvvZ`+իL4 2uTBP~}Zn9tL6RJѣG]]>Zg,=!$''ShQͅf͚LRR7nСCk-oMw^Nv:;;~W)zуU«v1l,?4 4,AˮTW*b8*S;nܗ 79GqR < :g񹺺믿aJ.~j׮O?/K23?x c2ғO?~Ð$qHŊ{aqv_,^bŊ;|xO~%eBl`ZHr cfgwߤk^LX~>} ݻ~O>ɳ6UUپ};/?hѢZ>ѣGl29zJ+VWZ; I Hjjj3]TQEvf(KrѣЦMl#!V"55gggyYxq U͙c0<b0?--W^yW^ymbbb EQUoƭ[(WE-e!&fIca׼ys:|ӧ]̏qš5k85|o_xdBLcs 3Sӱ*Q +\~m^x>s֮] @RR{fڵ,\r=q9r$6lо0kNdA}q*V γ>BsUV՚\ke-~~~9~L\\tIII!((5j`gg˗INNf͚r "##9G\=;XCn> 1)e~! TiT;yz-[km_ɒ%Yz5+W_~ZcsٳgT} ܊`htRڷo-4lڴ EQ~<8J*eh$Io%??n*BmV$h$88*Ud;KQGrr2ݺu˶S1G!33zY-r{2;00PkD /;Nʽ8\-wo)"5\ҢDPPG'(=bŊ:b(Z(n0wz5OBWaQ[zz wdX Eݏ?hfmcO8Abv)Zj9s{ݻӱcG?7}w!""[n<7ǍFs݈+˄R<{?`Iz6lJ"E۫)"˔%bP҉poO"^iqߌ3Qe0xxxЩSy2Rhذ!ϧrEBɒ%=111׃與GX*Y{dbkw…lHj_%88]渃]||<111qE\\>zx"))Y_{;%''s%n7BݤYffՒZ!̝;C?w;6?簪8q:uh<\]]y뭷8uZfo^x+Ҿ}{t:AAARLJ)RD9<׈˚KrL'Bf2/u(BϞ=i۶-| ~-9MӦp[q,wU5Җ$&&b >C/QӧH"QIm___ΣYdd$))).﻽f͚9n#Ά t IDATD3ߋ/ƆrqqC6m,?riiif[_LUUVZoF-f?h_~\~^^d2i6Y|9nnn9NcaUT)BCCIJJBU+%5c.\@-CQΟ?%Yfa~w,X[m(6{V~_ɅڵkSZ5^~eEAUUL&˗k>G)Rvvvpx >cGҥ0amt(Bll,iii~?3gB6m?׭nݺ%F#7n`0XˤB"b9,.";ٳgݻ7={wc2 τ0s%==3g/w޽>}{nݵdb-V e;9'L0? *X>vXߋZlu_Z]D.['V1")xB G2~ǝLZܦ*-Ҿ( FQ.w$Iz !hڴŀ|Ν;>66ӧOsܹ !ؿE1`EQpss㥗^~~T}%Jh_|ׯ(DGGOݺuپ};ժU{乮( 6sLNN8(R'9:OJJ 鏽o'zbR+Q{8a.Tx !0wFA^e/ h$r dXIzt:1b}[lgϞb2Xp!k֬!>>EQpuuK.|zRSSYz}֭nnnڲ BՑ1Nx>g#Ǐ?'!!]v{n\\\HKK#**  !55Uk71o%g?YըQCLJk?|aҼԫ{ʊ+⊊XJwf̘Afʹ@-997o"J*VY( -[ߟ˗oͻ{J'''>C vtEkR\9>-[F>}r܄ۛÇu5'ßhȄg 7.lNNW8"f 8T.>>b7-&Ė.]hӓW^yE"88X&ĞRQF9ڡHTEuִkbVX\\3f믿ft:ի,ZHFŋfuԡk׮$ 00bŊ}I7/´i8~VTQn߾ڵk-c0С}.1碝O|T\p9#44zQBfϞ_EΝqwwgoٲe 5kdرVu222t4nKEQaovm۶_F 'X;'̓-96)))DEEQT<߷LI\uWbo$}=9W+VϚ[2/Ʉسtɧ檀HM{;233Y~={˗O?=a.,IC1rH9~6pٶm\zUU0a۷ҥK 4M6A@@Bĉs=x63L>};@_Yyyy1{lY~=vҊu !\r4mڔN:Qr'N !8y$Bʗ/OşhvvvԫWpԩOng|g$&&~ ݻwגNNNZ9sPT)z|UPY...ڲ<{9y]-s{1VZ(ŋOv zyoޠvykI9z{9'+PxKۙG_J^կ@(##C)ꫯwϙM6I͚5xR 1Iz( =3f`˖-Zuihɇ%KҨQ#n݊ZlѣqwwƕX@z,on׮]#==///ZNDFFR\9+[UUnݺEll,k&Y^Oƍ1Lڵ ;;;~'<<<2 yh4rt:}W9x .]"22ҥKXQ-ʤIHHH 55^yws:EboooYhݺuJ'%K$<tkEޠA=u":q=||8L>Pt:IIInݚQFgĽ?QzuوBխ{!/_&!! e˖(Q۷o{nuı>ʽz\]]quu6₋S,ի?NGΝO皟7oˏL̟qJxx8>>>)SLN:t|?( 4}˄tzS;QM"3NO .tFJ^{O\GA<< FH.ZRRTjnZZkT]E[Zj ! dA"L21 <6ɝܒw%&***;VB:T5}"8p =*Tښeee% q鰲e˖4k֌ݻw~zvMbb"/_&5556NNN PJAvK3Fѽ{wcڽ{7]+HHH`̚5m۶1gyN8Aҥ :Cj*Xn}GX[[3b~wׯOn|t 15>څWrk !oK0O,[h6rn,wdՍ&++ ȟ\4#GAVVVNd;w.WNNcǎz}^իl2F)Ŏ;·b=7+W\\IIIy5}=sDFFqF-Z pRJâEW^}B4zɲe?2dH<aggǤI,Mpp0nݚf͚aoo_^ZZ6l~1$!& f̘A޽XNN:um'Měoit>|@+**#G] \r^/___^ĭ4]jQI1 Vfϟ֥ ;v\r&-1z;;;4iҥKoo .Pvvڅi8;; **T\]] *p8qttZjeǎ#11WW׼D\d?~E;R rrrر4w鮝]ޗ~:WrܿƍW̵m~ׯ'99>}fnŒriogff nY$a4~~~^fIFY5Mͅ&svTsH6Ջ={`kk[8B s~WJѳgO~2331 >kF~ [s TR׮]nY?,;wb0^:vvvW  mӦMzj׮mwξ_eAchooϐ!C8|0ݺu[[T'O#))QFѸqܹx{{uRnˋӧ3|p~m̙c &Lɓ'0aBfŽO4}Y̺uh߾},)VJ[\ 佷988Gj֬IӦMK/ 1!(7K7 %lڴPt%^ժUťw*q7 x%ѱKWycۤ.VB4:#GW/2μyreفRJEQhFPPO>$KN&7oN׮]>7~ѢE|'t:KBCC4D;@zn^lN:UnRvE333Yn]0˗/͍v{W)<q4ŋսۻw/7`0޽{E=C%>>>̚5'tR^yB 88}&jժEXXz%SJJ͚5`0t Ԯ]__ߛ>~;v >>SҤIy^އ4Mf͚mۖՍ4常8~m5ks=۱QV)W^f;vn7n̝;7*3Ą(r^|ҡ* .d…&ݦ@a|CBB/Hbۿ'kqAk>s"""x饗Nl222xWݻ7C t8B2db .^:tW^4MC׳|r zk7nX[[޲YZZGț= ȫV^δhт-ZŎ(n<QH_)Ebb"mY`iiixxxЧO+2^fΜycccԩStL81oݪu=i8::߲̓>X/?/H߾}Yl}5{B_TZSNo>V_ egg3o<:uNIfV\Sg+]"MӘ7oGa˖-TVvI%#wW'0VY.]tGɻ] ++ {{{:vHDDqY,Xit҅[#%%BBB ɓ$%%TxRԬYWWW~gdI}nݺlٲaÆ@jj*_}˗/à ̬ c4ewܙ4w>ލKo5w7f4MRJ 6yѦM*VhcaooOyӧK,a޼yO?1zh\]]z\˗yW>FejSR*ri|5r 8XQߙP1TALLL%Ip*W|Ͼٱ|r6lȉ'0 ݛ۷F.IIIxyyY8!DYN7.zGXl>}XfΜɥKҥgl=z@DDD^c9r^OxxxY"""ptt$!!9so`kk{ϾG޽;̘1?t =bŊ%>ʘYxcRn2f枱DuQvuޝ[2m4&O\`ٙG2bjժE믿̞=˛뉋㫯bҤIwewI a&gRXy4)KbV`fPu9 ',ҥK:tPرcf_~|wf髬P+WiӦB.]ضmO!wƧBN7zqvvfŊ|RF zA:uMyHH]t/4%<5k֐̙3 P)/VVV1{samm7vvvf@rB;b ۶;g" LwkUϟ/􊎎6zt:UV+{'%HYRvm;w\wެ^ G'Df !Ӈ}%t:!턅[otU*Uxoyۚ5k-Ӕ$MӴ'...ݴHBXNٙӧꫯCͲa^%\a„ ԫW}޵N7xdX~֊5N*ʻ7E; m.n[n[_#F0sL G&$55MdFD^lܬ~f.Ht:]owq!K4 ___f̘ aL)ҥKcȐ!wmXb0(ͤdXUVʸ{;|ۜ ʕ+&} gddԿѣGsA-ZYQ?#3`W7yB!/=|||O͍3s++L}s4M[7OBgI>>lْUVY:!BQӦM̙3fWրQDٙb'nw[[[BCC $ iD <3gd]@Ǐ L6n'?$""L)~Ā={6ǎ#00! ! v֭[;vEdmcBCgBsg~ݮ(:wlt[^_[IHH?_ /Y:uŅBQz).r-Dt;|4X=+Ra{!DʭmHwRZ5zܹsR oCB!D)x7Xd ?|иo k;Xۇ=RzafHaN111X8!1._ѣdZ:,!B ==1c0c qfΜY"1Ϛ/!SbPQ'"[,&&//n!-(}駟СƍߟcGNNB!-dggoAʕa}]bB=͟ fӌoҋFK *!9(8|0ժUt(B۳rJ-[̙3?~|ަ%R@2///:w /@X؝WU߿YfvZ.\@ff]TQ\J)~nӦ ÇK.XYYq%!&D1u~69lwbȬt\*##???jժeP&O>Ӈݻwoq-J)9qfb̙oߞS}y۶m5͛7ceeEV !///vJ15+Gb{#wlұ6.D&Jg]W^ᣏ>*ሄB!JX>S>3lllXnu͛ر# 6ʕ+[0b!DY! 1!™+lY~#-Okuijxr {YfM G$BQ:߿6m`0b06nHPPB!dg?L^x6ZȌ0QHZZu%&&x!gkkk6oLƍvB!iݺ58p*U`ooϦM$&02)D ^d8}!Va9|0J)rrr nC׮]9y򤅣B!Zj1i$߿?III̜9SaB"b&JSԣ4lwg$(߿?+VdԩEQ7[&;; oy#;w9sлwojԨapBqMZZ$''Oll,ֲWt1Rfj67G9<Ӹɥ O\l>cKQ\FLLLoBy̝;7={0`9JءC,B!,ё 裏J2LQd3Bƕlb/qtٙz$d^ #55xK"ʰA1bĈ-[ƻk} $$‘!F˗‘!2IAdXlos!>1daÆxzzZ:Q}t1dzxb Fto۷o~~~8;;[:!BܠW^t‘!2!v6s"6u%|]~B&Mޞ͛7S~} GvoILLO>ݛ!CX:!B܄^/KBBfFfZ6b 8yBXreތÌ {1Ykf[liӦD!"0!DqIB6NJmͰ;IIHRbBYti^xvJz֘K:u8p E!BQBdmYsSZ 0SDr 'NzE#ΝСC~ӧ=Y0*!`P?q3dg4윭슃B!be(~C}SNe̘1X:qx>|x?'N`D~'vw"/exN{3GR8}8c{1v9I+U˺B|D,q8+C,yL&+!2&-4{;ePKbaeS2[C@&Mr t(GѶmۼO?p~cL匮 ~aٔ]ǦX*dqгkIN$3։S |r{.T{j3~kk8ZWw / 5vdXZz1*yhY ~qw|zd?ԗ`dq.2?'Ʀt]M#˛Ї":6Œps?g`L.Q\>G9[Z!v\=$D}n$ݫ/w(lp:}Ѹqc4M#22:uX:qfɒ%4jԈ#G0$UC?)>r[Ʋun,-=>vإVvv6EJaך\I46v'4x`d~b"+;oϯZI,d]k]r1a_$i;ĪѮs'zS%/HgɢG 5NC=Zܬ\9ofŭ78?>%|<|g{E.qtx6xvt4tOVt~`e|⿍܋u{{|(݉ 22RfE`ժUQ 3ɰؔ^9]lؿ?ÇgƍL9=dX؝織qBI;o,;7?ӋҖv<c/?Q%A Uxdr|O4t8$3 "$$^KRI~c|;46 5%u~4xv-i\=o KK $m mO-<=.8wng`jlџBRF0m˥"=I=֡΁F@:Xw?)ZS fWyĿ5#.? f9;w%g*<]5>4.MATtu\o-/Qe=P'1LfKPY$k֬˗/'11`K#1\Ie&fhOl\QJ)"##U8y(Lz=BHߣk]eC՘1n%^aYluzM ࢵym~t`zf=FR] ^Qv*X܃OC2?1y?xʠ::@* s.*KyʷmA R퓶UTJߌs7\(^"2r^P\Y 'а/m˰`|0mިΟSrtFLYiHMylFeVGoDi3\|ݡ"e쇌C2Vɚ^\:c7~V87agvExWBX9#˧< ?š49^is)T'2gceSxfG~RVZ7tfع4t9?_LeRݜ39 OPP$vt0f{Qh =x}쯼Pˡuz32wx‡LjuaLޕ_C65w$!SONt6/p L}7k+}Bf#ԋa._{1TY"Ό9 W~筆u-O[爜1sZFр1|s6zh?h j +@GPG7͖aD璭:yh;֥qؼ "\ٔI7nd#)p eH`7;hT6/u.4F9=8|n]F2yBx.Cω^ܻdrPl-c S7-TUqD%CCx$Yx\ۥ<~(ƿ٭,2Q_𔗙')qGIu_R|mԿcUfZ`t=BVuq Y,}7,*::ZiF-[ҡ)~3VƥXn2/[}+Ly[!5UÔ/o}զ'񃆩i^5%U)hMϪ\#_d9K8(GZ1S!}N~ʺ!^(Yjw"5>CsǮ Evvr=óꇝtRvQi:ꈻ"oثzT#bPg7OyWYNE4Ug׹ÿjἱ0$SgTdz_U*5]N[qjMvlrw8=~\Uԭ;P(WڨqǮ^5PuDS}6jnKJ=5| *I]SRÙy*uQ)T:<&XX{ɐz fJ_vʛj>Imyzp>uƠ~V=~r6m^tqKSAAR|BfJ)9T_Q>=SƨبfQzAuOnTט^y+|)u5lY{>{PR<ԣ۪eY2iJN<.wבuw0M%n:EFlVl[q/]j:-[Z:2%;egDY$Z7;ϋ>'32 . f2;K?q߳C*!b|7ݰ6ri^<AJO%< ?tI:5 C$ R!ؔNS6V'3 'l V"5BLUL` UjͶpjp$pJ4mv2/Reα$oH9m&э}Ś*zSoTrwGlb_Vۅq^~7[?a!C*fNJA{ Y#8ҼW p2* RXJjx%V!"HkG߿uĔRЪU+J[Mg:0EC->ih56Y 9Ƒ]7 ZҴ76Y\,JQعv[N`JytAp‡?s9iǻ8o Ԛ O煆-1pPV W )Wn7E:+.M ( m{o^y݇ܛTr%J?Q^nӈj6$L6#ۅ;j+ąs@xݭyu8ҳB+rD?G޳֑X<.(R}{89Q> IVvt2 \ĉwXkd|*1RŧC[e>~:/eXޙ6}?Suf7E%!fMQυ~.(g+4+4lmM-e+Dai83m.4Mc 2ST6&g'&`C#jӰyOjZ`&2_ |Y>`p+oÑe`U J%?3ZwB`G/'F^6\._n!@EvE傋' ɤ%r (|^Z*V-{qG2Mg :bg?gi==eO;Ћ~v?[׭ixͽH&4MV'0aK.OxbK"vXgleT2]}wtb Q@ċv u ]HYѺJKZI{{l&.~i`+o=zX IDAT:>|'tpwZIy%mR^+KXjR^ -^,=oSs8͸:Ad^MPƥɤM4;oҖW]͉ /YrWugrw:_ r|U~nBBLGcZU2ֲb4mAKCȈo#|-~%^<3[evK!fFw4|EW@s%F}5pꇍv5E2&Ԁ 8EN'/Gþy=\4BiĆ,KbҨMvd\w <vvv Yڤ'@Cwr6N[Y6=궢iwCz&v61›Ҁl@R5=5gcGhcC/p6S)J s|ֽ:*֡0` J,B'\]]qqq_~!'vI aJ! }w3vNE0feBGSF};&=֝3 ]𩚍O:{f:u༉)i^իY.z5WΎWKcœOy|MO2sѓmG1T U5UZ&Vu J=ʌahQx˿ _DQX!p6/T=ʾO0vFoc{9ӞW֡rz6K"9g3qQ,ԧٵuuƎ灩ݞ)dZ|_v.マZt|Bgvq^LW8c Wfͯ9@={_r,E?K`Lǣ_[gWwx,| t0q}ݳ|z#sÇwnd;n[pKÃHNqֲ>6/zYif%}>ahF~a舂5Όjgak# 'vyEۢwR>6)zN畆3M*&;8MiV hKX&XEHd!k\K6(k-P .x.JX˫g^єh,E?Y<1UjKLjgn:3|l I1ˉ<c;bV#G]/ G7<~!ԩ/a@4a(?BڵkU6mԆ ,.ElʼQ &2YԔ]eҠ.Pu]J*jw}_*G/5r)j{4k?[ rw ;Miةv3.+2яTTmW{*ʅj^ A+~w--߮r_TsR>( M:7tf8\{8PyZk 4e::}/UwT^Q讦uvbZ:n:7Eپ6n/xߝ5~7o2zƪYtrQgwWr[fSF]YIUp=czGS`<|[gD]&֑*Wy_mJeGC\WS6Aߩ>vXgKS-mJgUA ~Rt* w4RJ.휩oZI讎}zp&/}=Ux_}E {5<*4UxV}O3gzAYi(4+V4{JTe 7?/TOMQF9X=c*v4~菨Cspm l< ROQxɊܴ7~J)eȉQFW5\^o+QWz6M)K !)++'xʕ+3c K#뾈b{tۺ7#*?}Y^uڷoopʴc{e4v;gkwU( _ r81ѐS W IM0_L_9D q5slF>vψ{eH;=nf9ngO|]zғr^Ieo3,eғv4R(W.wiW9hM~9f;׷`[cە(C*')w\WEfy} I>*8}6 .V],.'IyPm*vμ i >q4.܇!mZE}y $Lಝ7JU~.ҹsg.\Hll,AA%Q@O bϯ'8۴ ~v{z EU\@DDC)C8ϝ 8@G?\$&,ʅU]os@ST}7cicyL]i8X:U5Wo8U & 7[չяΩ"~wۙ)t>:XgۋA&e=}նM{J2L!(tҬFիgT,c(LKlIs<T掣..'epP232`ebo57|dB!(dB{Vrr2K,_~8999q9#! qі~8>>yaoo/ A "I 227xæ &$IA6C Y8/СCڵkev|u)Ij8~ su^z%&Mzc=FJJ |z/T@ y0`O=Ten]gquwˎ$k,H ``3zh41ёW^yիWo>[A lٲ۷[ ff֬YfCEihQ_{hWMAP_~;v0c ! A-#Iߟwy+ R_Lvg٤jlk~z{=N8ams:A@[Egw >nf deel2hqVYySPN~ )SpuV\imS!1A#Rs|]%d|kβ$]H_bm3m{{{|Mõk׬m@`ؿ^b!]uf]; Nl@ TNvv6&MBј7[]{Ǹql:)@`+(Baa!g&::Enn.5d@ T(\v~Ho޼|Co`y>JNsҴiSgmsŅ(^yN`TΝKÆ ߿i4tO.4;S(h4#WFccc;w.~~~,[ '''u[O6jʖrk׮dҥT}S*;~~~co_qhё-[J>}$^NNNlٲ{{{vzv)> j'2MÌ$NI\݂ɓ'[Ve瓔ٳiժ}96eƴEQeV8MWVo!nhna'Ett4YYY_29w1 <3&InٳgYt)T&M?O?3[G$x~mѶmֶ$=z֭[\f!55Z@bZ<Ì\>r->6mQQQde QpF#^]@ (ώ;0 fyw=B-Hnn.1{/5 EQ'dΝdff2i$}LBvv6O?<8qM6݇1DDDΝ;VG A#˖_Am L!nhQ(DI[Űf͚pB<==m@ lEQ8z($ѼysmA`0o3uT6l(h4~mt:[3%,uV^yeŋn&=6oތS榆;v IСp~h4h5QBDD$##u )))!>>B B#(Bu=yd^'?nN37{{2i}e…899YJ1`Nדq;LJ%t6ghٲ%YYYpe]D ǎcƍ 6L)}}}t111YDA5OLL AAA4jԨF%2W^Ep ;vTd]Vvt EQرcooo4h@vv6.\-dpd)'2(-[Jֵ0ܰsQ3{r.Iz;IKY%v%İ;"R2t2²=| mH jm&lc@E18́(\^zo>uّ#GhҤ ȲLJJ ֚}( K.'t qNhӦ-Tǎ#'';vE?`̙w}̚5IPPkW_eΜ95.uYٵkG69cp vMcooO[j\>ͮp&j91 '{R_̳UG$B4yDR>ND aՈ(^Z,QQQ,^XawIGٳ,⮗2.st%pz+X(X ,sE! @1Ĭ,=JΝYz5Æ S~. yyWk6SP#Gp"""#,, ^ϩS֭p|W| IDATrrrY,dff"˲ps5e@DDD$I3Njj* 4P+WHPPPP4iBLL CoaųsT$jl8%rښ9lE$I]^{c[bGpt쏝êӧO_s%k#wwwk!EQH.FQ n%jKC] Yf9$ nd_+ lE%244j悗4WjՊ9soO3gΜwun`f6/0L2-X x {<ŋs9A`Ν;#0uTdͲ,c0̘1VZ˳>EQ(..f̘1ׯvnϷ~kmРAy饗ظqMCrssyWX|96GPMȲBt \:%¹$Ihт̮͛]ؽ{'wfϞ=ӪU+U#44oooOhѢ2I{k6)S?`Q4& vݻIKKUV4n'''Zl#wf׮]ٳNGhh*yzz;wՕ`4hP-HdQeҞUWdY&!!q1g?θqO裏6l'**'r (\xC2~2=r^^^,\+VаaCt:ۜ~:.]bΝ e쌍EQ6n?O9sf/_#G$裏rEjO>}طoddd0bĈrǨEܹGGGGxxۯ*p sd6lH-pttD$ ǧ֓pYډbe@QL?Nl)SɺZ@Va#NeK̰ nK˖-Y|9saŊt֍&MX,A=!%%^{ .O0dk$&Ғr(,(\8A@[$11"a$Ib̘10o( 999_Red< /ڗN#**]vn:f͚Ua.?͛s]wY6x'Xx1Æ ZuFO<@vv6[lӧ_|^Oll,7&o7o3<dҪU+.]7gϮcrwwSJHH`ԩѴiSVXڮ[n׏ɓ'sIx Kkڵ>|ѣG}@\\ӇGrQygǚʕ+hZ^|2a߄7|?Nhhh{{/_nE4i.cǎ(݃F!22qF^J>}ʉ=...>}s2tP h%&&bgg4v-x{{(Y=EIIռqe+ɶKL`xxxpB-Z$0A3<õkx7͈ 6sHͯ&k[Sӯ_2y~wuIh4tIK8˗~g:vzp&''#GV>ߟB֯_oU/1FROIIQ;;;UUcXQQ7o'HU?222$W=7$IT3׾j.Yf^[F[nͤIHMMeѢEjR}YRSS)..flٲY9z(ݻwG$Z-l۶PZ`U5H^^:>,K,!''?L1EQˈƶF_~hZ~gш#Ts}СC˴MҀBM4AQ._lǗYV <==뺠zyd fc+&i*y!888J2Lnz/ZWWW>Cnk$fŕTY&e܈T7n,3`1V3./))aÆ eA׮]cϞ=e<^KR }8*]Xp3tܙ;v؄Axxjټys9rD9tPubbbHOOg&m;==EQpvvHHNNFQ*yݻWqqqQi7 $1|ph$;͛Gǎ9r$,I޽h4L0///rrr4hg J_ңG?FQ=Jpp0888Nc۶mero믿ҫWj 잲tt֍[V(ծ];5gX,3o<"$I"55#Eqq1ׯ_7h-e$xf[QVCb6"[~sq '66}Y o}Q\jC 0M'O$>>^˗)g۶meD_Nǀb(ݻ͛ӠA uϟϊ+r |嗼;:tT9}4V`)*(^={ȠaÆ okזgƍoߞ`o܎%Ea֭ ʎտ۴ispOOOի௿RiZ^xx zz=sxzz_xzZ N> IJJbɒ%̟?_4i޼9K,_W^HDJJ }EЪU+x0`pRSSDձ,] r1ϛ7͛7WKޢN>[znMFK.WM7nՕ_~Y p7,#26mU8PMHHEW>V^M^HIIg锔>}:mٲ%;w4:zIT:p3}ݧzom׭[G6mx衇v{%;;[-:{nV5iohnB}EQطoK.߿R/ӧOCn1ӸO'O,S(pӭӬ6}/{!%%ŋj~#44q1tPS= 2/^T[?j*UًtyF̓>XKNnx549‹/HϞ=Y|9p#ztt4Fb۶m(r۰[QTT`V;YpͣE[UJ\32Agh_2f̘29.\[oO?Ͱabƌlٲvکm7*pCpjģWh,WO>? Y~=?8Atf͚!˲\BCC ٳ|駤|rջ"0 h45Ԙ1 ڂ8g JϦxi4zԩSYl_}?8`{j^oAr4yx:Y?NA&Jz4 Zܽh1|֭[_|Ixꩧ1bD5Z.`ɒ%ݻu_pڝ0w'Gi"}Eaذa /5 {{{F'|o /})..EJ~oοo+ydL,?tP|}}-'"LOOtG͘1c͂XoU[EHUAä8p tڕ( ֭cĈl۶޽{v+^l b,_1w\uK/k**ӫ畭QTTāx'UeYfɒ% <8~x9Rv舳"gرxxxhx"z} c{uQRRRhZ ̱cx m[k.9,tԉ7|FD+13%dReUH+ux5rnoooV:IxGwիQzl`j5kfok))‰,.Ȥx7!~qr5ϕYpӵkW}*}h ;w> ''#1siϘ@P]hт> O?Df&<<{w/زe ;w2CLL ZΝ;xn)^١hzSٳfJoI%IbРA{"2֭cΝ7778w$I"!!͛7)Sʲ\劔cL:'&ξm[EQ&\]b)9s???\]]m'&&A짢(eBA-:)](飠j,^͛7ǼyĸEM($'ꕗZaS & ?z( <ꇻhb z=wm'q͛UʂIm + 2kע5kNdd$\~IXht٬^M aӷo_{=z=,S\\z9r*ӹsg|||n NS+!<(~:iiijՊDС*>sN RnHT.XE]UV"J >?$𣏕#G~z$I:u&ݛ~ EQ(((`̘1fW4Wj >`MF||<~! T>ܯ_?-[FZZׯ_СCDDD_m6t:ƍ#88X( x_XhڵU/1=JII Ȳ̾}| "##~:ˋݻ;}prr{&WC ?qlI3\ͺyf͛رc>}:n (:jпNfrxy'lD@'MZT\[ppb^ŰR>[#ͪ4,˼ۤ1m4/= ^{I‰ sK*\+<}a֭CMnׯK,Q+s**>>>\rE]ψ^G… lݺ{h U_9vC )˗$f͚UocƌaĈzN3eBXX͚5ҥ#F(3&I7[%g,b7QH+QU$I",,^zٳgo`a̙ /@aa! , **Yfɓ9sw͍ٳg߲o{{{v}Yw^/_΃>x[OF#:u⭷bΜ9̘1CnPk׮b vg}Vĩ$Iݛ˗/3m4ZiE4nY̴x_g۶mkC޽?UV%\pV[FW+WosQE̙3닟pKibccD6mDRR֭ϏH.]Jtt4~߳o߾2^(Y25)"۠At:ļyر#?w]ښDmsFhc[^{Ҭ;~a iшceX A/ _}$9mf4 /7f+j A`0m6vimS6|0o΍D ?! l Pc*NNN1ك"ڶmKrr2P!;;+W,lܸQ $8[x MLLɉ-ZToZVV8::bgg/ +kݺ5[C#\FVo$йsgv޽{ԮK.,\FttڕдiS^}U233 dUxڶm[ؗQusΪZ~='N`ҥUqssc֬Yر#G+PPPe˘1cF"66+Vмy $ )**bFo.}b.7aaa|8::WI&,VDRRe~8~8K,r_Νj5X///xbmeY&>>_xG)**ɉ3fгgO8}4g{!((sUNwww)F IDAT$Im$$I";;EQL*R\\̂ e^-R*Kʲ\-4Ex$IE 1_4-\iS.ͼ$7Yr%k֬a 2^xZ>.(߿?d:vH޽%;{-h4CNBhJ@Wk"TBBBp{wM61b*&ҥ dgg.bԨQݻS2}t\²eXfZVEQ 3[m Iw~-۷~I5j/fȑ )K*i4^bccٺu+={&FCٸq#+'NGQ5UV|8::b0駟x( $&&2l0_@ڵk_ŕ+WС]taƍܹs˄1J(W_}v  aҤI<899UzLsΡji߾=h"ȣ>ziiiCK'NԩSk^Ŷ2}&c=MB%{8qO?8<<<>}:C_EVv1 І860/@ Veʕ?a7*uЁ]vh4ӧG3VٳU!>3vXJJJt( sNkҸqcoP;ѣUs^^^(zܹsUE$\\\?~<ƍCe5Ѻ( 6ӧO3a|2 bpCqqZ8&MēO>?oڧ``޼ydee1h |IzEdd$=IBCC2e ۷ogɒ%CJ]sǎe˖V>!bggǠAؽ{7SL1;L$ΐ!C&ڼy3EEE̡PͯV:lN$~ax}*i4|*W {nn9|0%%%jIpssѱcG.]DZZ^zz: ̘1^tF;;۲ILL M4)iheQP[lwĖ5k0fVcggISbb"$Ѳez?fRf}&:BQ91nƍŋ/XnYuX,yN:ĉo>|a k}2G]na>@O h4Z--$[aÆ 8 6{I=foomSmIj 2ڍ-$IuRD^7#I7*ٕ M5Vp;wԩSjhE۷H'$$zI8<Mɻ^ Eɇ9;UbN"ISL!))Xߋz/)ST"❎׍JYYY'JWQ$IbĈl2C%%%((8ppuӓ=zKX"""d{!,, WW20rHR  *gϞ%==]+C$3f _};Hq4(_ܬL0FLԩk׮%66_|9VGҵkWΝ;LNN| >JnѢ[򀵔auE$IRƊ[߅0#F Ձ:e8Շ)dee!Igʔ)1Q={ˏa.SmܥSY;Q-6Amb&b޼yޑe+VĨQ,pwwGQNP)}G = Bvv6sLm۶ӧOZuaKb0((( ;;[Mm*Udj<?oո ӧ_SN:y$yyy#I׮]̙3oTΝ;G߾}/9sӾ}{s9::2sL"""kʕZ+W޽{iժI$/LPPsRQEQpppQQ5I*m߾ÇpB|fUMޟZV-P geu޽jEتC9j޼9mѾ5Ξ=(( n U@Ppok#9U /Ǝ7l;;;Fٸq#6l8᭠zٷoׯ'66QB~̘152 7>,w$۲U}Vh47ճgO\|}}Ce.]d fs,84nX͛ꫯG}g_Nժu]އC O?e̘1eTL/{{{UpL(Q]vIƏ;F\\s1oڴ)+V(㩧hXlY|h4ΛN+aŋUKnT{*Q(@Mݧ5 ]t;v^ bǏzhV#&1n`A׬Y3˄ &$I"';/>UΌ7￿ܲ<ܹsb V^ٳg;v,#FPYepDAE$5jdM nBҢE Ξ=i߾ េkuxU?Oqq1ܲ d߾}ӧjS6m/Tx뭷6$ &pa^u .7z7+-ݮAS]y\]]+_$5Z}uw}Ǒ#Gy$,qrrR|k<(\x4zarJC,)))Zu (1ҍB3ͻ&'M+/w]Tol2|}}9r$ kׯF.yW1RP }"44gϒdWc%QZ*]{XKի]vWWWu$<<9R+±`Ceʕ+'??9sЩS#&&FCϞ=MgCp,ӦMz.!$I"[cwhT6-"ҥM?x ĴiӰcŊ?ŋs kvG֭[>}:<ejcr*%wFV  c)"׏HNãL͍D_^,lذ]vgj*xppp **ln\{E$kr{[ZM ' 1!I-;{Ѱ#)3IOJU#x?H9r$wGalٲ///Znmm$EEElذ{r1Eݝ1cX^g<]5!;:( {X]) Ʈ]?;vlm'??/oEޞz)[rrrh֬YˡCP`\GbooO֭qtt-BT; ȸR@IptI\P5$I+].jJ.z:-PI/$IDDD-``DDDн{w<<#d[ffNGѨQ#imՐf߶a5Z$;NGv툉!>>=zݗ*#Fb iժUoXsoǏXS܈0xCUQիWӬY3"##.a*x*QΝ;J!Nc0}+ % 2S2Ofղ;f͚vZ.\ѣz*~)&Mlº©*N>cƌ33gaʭesb6 ZuR 51uLLEܩAu3vX<==Yvmy@FH,!Im۶-Sl&9uǎ㡇e50ȲLII A6e׮]TXQ z*EEEU~nPߚ1EJd~ȵ 粺E/,Խ&M܁hZ:vHǎ:uZ^@fnJhh(Z"44`ǨWtUILLdҵkWtBu:z4%T9;ӻA]g,]u`t AL IdɼL0jRSSY|9={dƪc}h""##ڵkn_|A˖-3f `( ј擴zj/m۶̟?FBTvӺukdY.WD@QǼyHII!::ƍW#Gǀptt,S%AL`gbVY +=iKCv@$ .à g~wF#6mZu322pvv6{Rrssٳgiii>}SN pvvVݻ7ZįMCBb. to:gAx{{ӱcG:ĥKr[IԔ y 5j?#} ,W^͖-[T#G@7rʞ?={OWᆀ}v-[fQ\4 >>>&?~<nTB$"REi4"|**UEP,tM4^S g|$YҙM!#޽N{Ν;Ehܸ1FM1||8|}}bbb0e#*UmۢbŊY_ :.DxusK=+9·,8dQNZhb֭֭I'[NNNX"N>Acsss?o8Zj|TNDvvvB:u0n8a׮]شi1|;6#!!!ӧOGǎQ^pAQ㬎{kPزe 1sLՠus}1hР|}YΝCXXj֬i8Az;A7oDXXSL4ɖݻ7G -ZcJX:tWƜ9s;|0N ;D@@@@ s-Zh_fffh۶-[?B۶mj*l߾cǎͱ]v~k~/](ԪU +Vkװi&8::"(((ۺwK.&(?ER\FB_ V(ЂiԥKX˖-CM:U#Kܾ} DzGÍ3D#}iWFv24i&L2eʘܞ 11/ܹsQzuL0ϴn0o>4o<Ƕ֭[W^y=0k,z8x ͛;;;;wvʱ88qR |}}]ę6nx{{?[uߺu [l?///Ca}(g=+a""""zt:xyym۶زe nܸarz]t? ;Yy@DYYYYaԩ8unZɓx|7n<%K`ꫯY[ox"޽gy&ۇ+: GE.]?=z:t0N@paf{'ƍ9^y( ߏ;w`ΝChh(֯_YiTUˡѫWb04LQlo Q߿?0o޼챓RvB$I L>5jsu:qƎ;Bkꏈ ..۷oG||<Zj9;w<@5sΡF`޽19xyyD9sCE-гgO\~dlʕ+ѽ{wf}<f,p)pqqAzj,Ѧt*w?{:+`dyM6N`iiiw}!CرիWشiVXUU!"m,{nXYYbŊ-[u'|UUqI|7믿Ѹ]8̞='Nڵk@FвeK$''#""?^eX[[&(<*2_e8q8""""*pSN ´iӌ'DT<4k ^SݻwsѣG+~XXX 88+VĢEpq92K~ʾ IDAT)SEQp)˝~ժU;w.vލ2e`ԩ,^˖-sbΜ9>|e˖1K.eeRDpiYFuJ~„ t fɟ՛+f˼Uvb).F?tE(]6̙WWW<!|u^NRg~Z~̮LnqZuVV "ǎ탂5jfffh֬( t:6lh|Ƚ?ユMBQT^033CZ`aa???,[ ڵ+Ej׮+WI&PjՂӧOgϞƫEAbb"F0t-uzyR32IR?LQ?cfCҺf扈( <==gaܸqq4"( ЧO2f+btP͚53d/r:푘jժծ]kО񁣣#t:*WիU`0`…Wܮj/rCwTqTBm|f]=#""""J9ׯ>$''uw|5U ʕ˵̙3ѥKuٳk׮G}RJ[%(RRtTܿ'p!vⅉDDDTtD$_3(קصkjԨ[[l?u=/ONW: ܊"wn#`Д3x$$#c`'C0 Zd6_;:w^t(ֱ1S9-Znz^ʓ v!o\<=}$!;fE*󊳸xp=]#87~+_ga1%{MDb~=!Oxpn/\g [Cǎة+^oV2ˇC=?8ڗFg^;?Bܶg}cOzh;8­\Ctxndz/L,y9>U["r;V]MLIؤR_2NP2\F 6_=dAp {|ze*D{|IDn= !&r7+&`;5Q3pV[^Ӳ}Uu9._ m>:~Ӽ^O1 p/_GḼ֑a7d?oSvcT-L趘-ljxiLcpcV;Tmm3Kt;:cӅ|r[>+ʢF71e[x>'Pرk55«S""*rʌ#C 1bE4,ӭ82>jENrhcaX1ZBlpu4}UP wkk+OxC-TpKym;?A"| lVeYz"3pV, '7TCC\pgDr1h[)FnOX[1]C*i~>^\͡wBMN=1 EV7&恕Pnz 5Rb$qE89ӿ!ZSC9>oQ0w(*:ɽ*{2Ƭ_Z5-رt)<^[cϥ;V3،5wrpz})uòh-PE >.V:F:<~+ҮK/T0Ɩ˼I'_qLe*1{toX!ϢIE4']̣^G-1|aor~xާ)ڮ=?{c^[㙴f!ِ]-&RbVH_3W}S;#%dCyeeRj$8Q P1w'+(P gumi3WQ@ J"SZ 6h*NgO:S2^;;VBa&aJf-H|y:gUVv:Jf_6 RjIIe˕ lyI[GWN,J?ʒ_6eN`Yc95y/ so'ÿZ(˖+&.pg\23IN,tUdoSdS3Ƚmdƴi2mhJ"ӦMiӦ_HbZl{)e2O2 񀓄L:a,ANP(V-^>QMdX1@*^΋7~V8Y2diپzLVO*ӦMw0fYNrWuRjeˑ-L*MnHa- ɼdUneŌ!z*uh.oNI,\9ʄ)+ !^dw ?~l5ַ\@|-e ˨b _dU v Zew5ks2|JU zo)JޞP/ZƷk+=|T/YޟT0QzV=R3oܖ_x]~&"b5į*4a92HQ[>y&Ⱦ\4^YvdE:rhRU7]M\Fو5JשSզ) .dټң8L9k`MoI-x5"S~Z* 08CKl%݇2g㗥DDcK_4i-'$X+bv kW);q|l,b֋+^:R\VeTȟ J[ 'b@R̥s䒈$ɱe"󊙽ڊb&Z"07 xIf*UESCe-U9/jA<|ɩǔ3ǫy2(N!5`& ;RR[ UtXUaXe+i2m4{Ev6֙hTc(~?^zT.SZˉRlQFjGPiӵtZڏ%iGJky4:}!c3Cc M6ED5L()ARe,c!s- @)s\67ɸJ)AL7LM--:/1"Xi@+ w)OK,*K%XXq%K f,܍itXDY-S[-dؾG-%LDmu+:7<[Ɠ?\-)*^p_6S|^⁑N,}ޒ)cm|\L{q[ L8.IvAHHZ_}% SД$Rwq*weKeq6X(%uGLh]2-Ƅ)Ӳ}S㾖6!@ק͟D*1g\siTgnQB;C|C%E[b+][H POߞ*]ٔ}i!bvEesOi_SOT)GWEIIN)VQ ].2MKĻou4ϋ xFzqS%Y· @,P]2H: J_G L;;SZQZ~yYw~lBL$Aǔ3%*K ӂd-lю)"=CyұIjcr:K-Øu3Sc(Qɦ]1EZghWkT?RI*ʶlOPZ5-4aj\؃H"( N-avilx v^;ճ–y+1boj6Y FmuŕqnS~(ZN–pe^&/4QpeK:-_vlޮ4\Mn%dl1 PUC#X<ɺYK8 [ HB7tnL^_t~o،$g+ekgc9rhl 2~/ǹo1 # y;n.^3c}(㝩Ӿ/_SWG`wl<c*z7W7`n|5w*y| ~S6L/ןCp|.hC\-F}zFk.*ϥ8ÿ.f8=< CǑ]pE'!6=;F薑g.b'0T/}6*86{{zx?z>.o]_$%},C@G*7"3^zx3ya3QmHlv$h[w_~"پ*vaJMu^c4,;~*ߵiGZ'7:Tlhw_FG,>+ >#^Xcr9*7j ,̆clL{M[wD9QƟƑլx\:;f [5'chqbո ΁mvo?GXt{4x?cVM-FVe8[rBEw18x֘/. m,qŌ=ZeOy8)*nm[Q6ZǺ ֱWc(m.`ؔWӵk,/ΦCcL=@ǃ0DUv 6{iSD ^MM\26禠e8T+7 ;U*S rDV`//k ! n?wğݪ o[GaԬN{c?7zT3C)gԨ$>]AJt/[Qr{ҏ^P{*ૃ&l,Wry""3n@D=2?0Ǥ1, SPV+FBZ\ _z.٬*[LBǟZi'w@A}#ރkb0~ł_K'P':r Al_E;l=3 ,qU4OynÀKٙ4CԱDʸ ' zPkh>㦌E׹ڔC/GO'b ϰ(l>mߏБ q }%̨iŵ@~YJDߋ}#V0#:<6a pΨtL_n@8\>з4F)k@ى[ :Vʄ.[c.+1#6̵6Q`R9W-XPsAM$q1l]4a4o;m 8OP9IaVX>]/N],J#HFdd,T(]Rul#F0_XNAC\s(87\y"a'@,a(%2"£6im˨4+K<П1xtwugMMi `uAV֔"m9~ܮ 롃 1RilmtGw`?ɽʯ˅*$NuC𤥘0>߆# vw<|b禗 CYxXƈa1gbwkgu:>]:&^Fl ;>ddvw̫r&OY/n,X<\5< l z ƌ/oR?Ø7ˤN*%4 IDATt:=F&O~#2<^"(pd-t&ӻ tV[ǝ; Pf^wvmLtd3W)1;; :Y#͉P=o\Q3Xd6`5ҡD&\}v6TD`?ÛqDHO)D@@]v~\k1n_zc6oߘQlbKM1Vkh)cݓWcv=?;9VZ+0% Ң&i}ي_L Sɸua tbgqG$eѦy H)R苾}.uan:/P>wp^klP˄: '-͢ݮ}5Ѩ-#[*}0Eم0rNX7gqfF'4 ՚»?ڒ-YfN-(:R6AhT>w"`zQA%7G±2qA'AWX03lUjK z9==ܑ{lL־U`ii M!hm5֧d1~~0F&R s½h5.V)Z> !͐`!d\\җXNk8o;5@ӷ":xM{wK*v?i_Xly?X1{L=qL8 PQhî8~,\аuMXAtUy틾}_G5Dz>t,V _ ~0z4Lodcx0C 8:eĖ c V -_-;;&njZ=>CDcAScLy4ڿXUqOAE#L>p%rZiݯk qn{iV$O.[dO~-Se/u֓Jf@G(_]C?8̼՗,I{zʃMsX2be=K~R`-5t/(]>]_"JI!=t;RfRLYu|Ծ≱%Au+ywٿrN}b !,3.O HJOɣ;eú#6Aʷ!:ՔW>Sv?.,k+ei6]?yc9rrfZOA\Ȭc[e0574iھrMY#k֬IUkȚ5dIIdTi['t9}B!}d}/WgG6mfqI5Wy∗rzfm=qP.@^*k#du09HYę[RRDzh19uj5TיK3wDdѮrb M,8Cmb)v&wd^fm!""7ʺwj5*J9e͚5v~1n⣣$Ryѓrt2.MfԺ}]A \&Oc"ۇHm_Odr|rMlP['6׺J3KiKW/7ɒK%T~[&HS㫊=%l/q8ঈY#@K+wlikW{X#@B;%'o{T~X!J 2KɬJWzJݲW zHn r4Rf C%دIA-f2#wuKZl5ʨdeR+RxCd3iH{ZNI)Tܔg*M|\@QʱB[&PQto.Jzlr)+D%|,5wO^.mqAWs7QR'DDԫ+eBlez9t/I=(<^r$7*/J-{}:)-5_HƵWtҍjɗ;JK)@ [^kh)bДR$E=%N:8| kҭ,.wYRT/d|chbKz68ˤ<;Y^*[J D9JS价e!Ikl_C\y O1RZxtqm/.fyBA!zSiHASdM@HsyKyRg3legT6GԊ/xm_eskig˲Y]ΐ4ȣ]K5-}:Z^)DAd\]4q77UK# 19( ^ޟ<.Hw*drW*r$Z' (UEO# CBL~`N3K*˝rT{{G4It(oڮ8]q _)/PD:`\)ɪ뉏EʲZ;H |9@-k:Cmrq}YnDK rEp=*A 1ʥ۹=00qD]|8K[bx(.]7ee.{3>\^xL@n] fQBS9$e|76?OAƖ7Kds214)y򱠉崍uڿ+NCi_v Q_I;䙡dzKLiخi#*~Q.>!$eJ-@5YIGADk ̥OɒKBdIg4g1'e~SKvk/kB|cZJӮ?ȔKeP[6m爊 12YǖD%&K~2`TJB7]AQSNw5rE?:QyvSk K|YB4ABeL_HSO ;6Chzwut]4.;uIp7CGWGTDN;"^>p#^ۻGqQwy:[eõ2:$x> љ\GT)",IC+ݭ#J& xe K.)vgKz#4 S\\֯5oQw2D`ߜ)XyAB(;8> hآ6|SDʼn=˰l$\/R.EݥM!ŖD8 w_}KoTmVkQTbDDDDDDDDT""""""""*Q#""""""" 1""""""""*Q#""""""" 1""""""""*Q#""""""" 1""""""""*Q#""""""" 1""""""""*Q#""""""" 1""""""""*Q#""""""" 1""""""""*Q#""""""" 1""""""""*Q#""""""" 1"\9;sY+җZ_a$SYUUK,Ap߅]_^3 Xv-ڴiCUU B+ѓ„v؁/'N30!"hѢ5k#G"** u bDDA@iWjorr2  C2_VWZ}H+l0a^|ETV ::zt:SNs1튳흜lYQwUUi&XZZ :ΘHJJB||<ի$>|666%?( 5jEQz"$$''E,`Μ9PU5^픻cmm>}࣏>B۶mjr]{ž}pE W^-=i?RD(JH ̈́eI&!>>f;͟?1113fLu [naر=7o^jjW\իѩS'899h%NC˖-1g̟?uօN3&@s*n޼cȑ#AJo>رx$""RIKגabDD9~G+l۶835jԀ'DK.ELL  X3OcK,ALL  :u`Ȑ!:VX6m O(Bͱd:u UTd|8rlmmȑ#طoDkDDDDEZb(66pss+ǩ%*pxzzLQ:ʕt۶m3w t8UP"447n4CT\6l؀D (_`0axxxRJFIѸqc,^WFժU5}ʕ+8rEA xXYY!>>իW("""*y """0tPmݺus]"L6- /YXX`Ȑ!Y^bcc1h tx>x 1"" mP͛7cذa6*( j֬7HUBPPPt:֬Ygggy M6ƾpI{Arr2,,,8vppTZ ,ݻwQZ5.IDDDO211>йs|keef`0`ĉ1bDXIQXXXVl &`РAppp@ƍ 'l=FN3-~FŧYڕ]-Z0QQQضm "S[nŃ2g~5|}}qUܻw/Ϧ_Oi닠 *U}S#-ƙ5k.]Ç;QfffpqqS?ss,9::7]ZX?^jUt| nܸQ(m=IBHSNaݰ,hڴ)D֭Cxx8̠*о}"y(;^ױc@ AMF$.\3FINNK,nݺ,Xkkkf̘AݺuK_~W^yK<,%_y٠jqvvɉ>z[nfYlF?\ -띓ի֭W)LqrrB$ܹSn=:GGFWaFUf@ 0 tҥ-999Ѽysxzz͛7U!b@P7;wf ֭[$IY;;;cP&UVO?ѷo_-cF|>sBBBtJB՚ixU$C+'-[]}ӧO믿bccC׮]֭ ,[ϳk.ZlG5 o:u⥗^W^L2N/ofrJJ_%U/AAA||r IDATg0qD,,,1cfחΏ?ڵkիtM;V i?Zۣ OL<SZ`Ϟ=,\___5hV%[/''1_}KY| @ (z )76 ٳ[[[BBBn3 (YJ jϚ,t$I)))$''ӨQ#lmmK̦W4-c_LyK5<>$Ǐ@ <HDXXr6nHnn.A)1ड़ۼysQ٠Be3|I_~Ae~i޽{1YYbDFF Օׯ?~ȑ#tfLiXYYUjl$x"| cǎ-3 //rg@@|M2:LJ&,ˤ2c LR⹲,cggG˖-+W͌2@ ?L6dܸqOOOtBveׯsY/ThѢgϞ֭[jf}AAM6$%%q ?N8Att4x{{ӽ{bڳ?^aÆef ˟ɕ+WT}Xb&LAUt''?? uGiذI(5J IZ_|I&zj"""(Q;lܹzƏz^^K.ɓ4k֌aÆh 3pMզʈۗ'V:cښzM6csqq}fEFFJFFO>Ci4ܸz*z dYf^tNN`0Ռ(xϧI&\z ֬YCƍygFY9r?,RghzNyWѣGC@@w!55[b999|$%%1tP"""JlWeܹP)AR~Yz5zqw_m G^^^-Z">̜9~!CгgOnʘ1cxWxx"_}ׯ_֭[݅իW/j*$IBQ^=͛?̡Cښ`ccCDD̙3 2uTڵkW.EۻTƍL4G"2u֥UVXjj*, &Ȏ#Gh׮3g^Պdhׯ_|-Ed$^{Çj.)qƢv*/vvvdff]beo+++ ްaCZn]'N$I$''S~}ur  0$bE&>[[[Wo7oNݺu1HT^@^> и<<<\~ŋӯ_Ru[`8 ׮]c۷ܾ}^{ wwwZlɓO> ]4iD#2˖-?ƆT,--+m߃@e9p 2:uF^e>#v́tЁcǎ%I7n@$֭[MF'O:Qqqqj wŊ;;;;._\mKbb6EP^=222ػw/.]AD233Yz5:l1‚nݺa4ٹsԩSt̝;4G}f0)/\\\jXZZ /j]}!447obŊbEHT,páC`666h4,,,1b'Ν;WX64w^|MRSS2e o6Ŷ=zcrb2g֬YSޚNSUxl5 !&H~XbnbXZZ˜i͚֬5cҾ}{u`ҤIFN:GըxEѫW/-ZD@@%J 8<=z(ubQ Vzz:F3gޓ7ȑ#$&&ҴiS:Ĕ)S4vE~8u3f ڵ +++&O.]"++͛Wт!>>1c+uV^|EdY&>>^+==XZZqF>s|2}Q$&&w!2?9wzoO:{G6m;w.FcggGrr24ko@ x  INNիݳ~`Yh46l111ǓFXXXU{ǹsXXXDRR,hh޼9#Foah4.]n+KwjtЁSNqb~vvvvޑ#Gkח/ k׮寿BoңG2}]vE}6?\uʎچ = I_ƍtر󬬬0 l۶'|lUIӡjUf͚Z&8$ ٙRE---?~|}DDDgΜիfTѾOI+ /;$Em5],ܼy.]vZZhQ=طotl)z͛,bggGDDܼy{V rϙ3g0 Uҟ@ H!==b͛79s3F3gh42m4)h 00v=$RhҤ~ffjھE[W\W^aҥc I/"GM6EiӦQPP`&cƎs=gxl49{,>cƌaر؟^(Nj[&AA!ebΎ?W^yE]Y72rssfuRVC2j(\\\yb$ ),,d׮]NY(:Uk~~~3|<<}aa![lwFEE駟;v_$553gb0̂@ F,}ׯk.Z-,fh^"##U m޼٬G]j@ёz^zBtLE?---OS,w^,Y^ŦM5kY@,??O>W_}ÇsmS++F2n:u놏Ou"frJann.NDZc2M_|9ջIUжm[ٴi *tEPVѣXXXЬYZ̏EQPPPjw\\qqqMJq yyyPb N K^z$%%CsZ-O>`0FQ/_6Y믿8x >>>L27V(KLf3; /.fIwm믿Nzz:FFFo3..YUH رC||ɓNZm 55ٳg3zhu~$Ibܸq 4i.,Rfǫ!"[̔7o7 ˫Jů~Ŏ]՚i={dٲe 0OOOӧOM˖-KСCىâLyǎ" &(nMk$t<;R7̅| V' < 6Ah4btI?gΜ9\x{j۫>0?^{M*ycUߎm۶*j/Ց-ָqc݉%::oooBBBJ#j~GI!%%e}֔_[TL< Vbǎk,V`\NN7oD׫PFFW^e_dtz=DGGCPPY0øҠAۼ|2 t4fIE߾Ô)SXnHHʙ,;DZ.QxJu;[b"U{2)%t+{]H{UôssAѨN{}ʺ-%>VMll,YYYTy^ٳ\t 4i\*ٳgOLL ,F`` ?*o0裏8x 6l(??0k֬b߾lذ72k,֯_^Ufi&$IO>Ť1F#_~e1?`,YKTk/jUD999:tjLP/sbu}u;13x ᩡ,GA@q,j01uĽagt:4i7rƆ^xAf'>>0eY֭[lذ={jA}EbnbFFFe _dYĮd~|*û^QqhSZe Pʜ[]_ -f08<.\@Lݺug@3%E+9\pBV^ի*q5.2tܙɓ'{j||<*/"ׯ_'%%yaooofh䫯~cxzzj ZMڶm(v<))`eeU V j F359枷GV|p[KꅺUY?Cm`MFFΝ3(G\m<KKK|}}tNɩJ5 nnnxzzIrr2 1CQI~)熅޽{֭[h4p 4 Zl?#++,5jĘ1cW*G_3m4kVk2pI,,,8w899ammСCnՄ`ǩ9ݥo!PT٩Kzz:o6dmN_gϞ/g3$&&r%ڷo/UFY_kk@P6FƍG6m6l=%2z^${5`T IDAT@zz:Æ cȑٳŚH Oӽ{wƌSݦɇ8ޓt8P ?*Eg8;;gZχɝ;w5j|=SL<Be9s $Uy%J ~eYcP~'77B۔*h4RPP;wpttD>~Z-i, Xl<4aeӱV *6=0E8wV25 ƍcjՊ}"rv?$I" bbbʕ+U-V4&2iiix{{?sk,?Ç}6nnn 0Ν;Ӫ\Q}Ro&8U%TF \Ig׬6㿩vΏχ,\z^O^^ׯM{('jFhh4%""L.\-ZV`0PXX=-Z 11$s pppiӦeĘ._̪U:t( ,L=7o;vٹmܾ}Xnܸ {m۶Y!mVJsI.^HXXNNNUڟ4ȵ{nO7\?#GU/^LDD͚5{(oygݣ,mڴo߾lك'RPPP,3NղaN8~ `5htޝp>_|Q]矫TK.EѨDy b˖-lذ7|h4_p ʕ+ڵE\j{3U xt))[իUaaa$$$~hԨcV`رc1aaa|Wؠh3f b֮]˳> Ԕk2bbb?~| /<͏>6mڔ7P&Q@]%]f2ާMTS{n͛HH%K,Q*=zHM>zNXbiii?^L8{{{Խ{w6mJ||<ӧO_K.j)窼7n`۶mt77GU%7&00~W^y ]wy>rssi߾=3f`Ĉ\r|N8Appxc-vEΞ=޽{ l1I[.qE^JӦMqqqy M^h4L4 ;;;տy$%%giР,s;?o&~7 +3XլY3~ 7 C]W2IUeQo`یbYqFa$IAz͛WT)б,}gy{.--iӦ1p@6mZMJ8^$FMaa=#+]i/uDH"'%(ƙ\t ~ItGM!ܹ3vvvdgg#ICUB+͛7Wbڵǎ;hݺ*iUx'D@>hٲ%$w rss=6vXv튋 6R8@Pl1܈'))k׮f#b0?eb[_ܽ{7DFF.2cƌAє(Ŝ9s X5hfO )ֱw6T=sn3UփDѪy tԉ(dY&))SN`0~YbРAj ȑ#ܸqѣGV%ʶ<'''WR0F94W5lU(ʎ~X[[s Cٴ*5f 1@Pb7N:ő#G%88X pwڵkHDfʜò5k`r@@`~VUr k׮1??r4 ssBsIٟUfʩ" RTDTCLeN>ѣG*гgZ8ȲO?MTT:֨(qƒV`0 I ;w -[訮V}QKR,l޼|UC /ETTh4dYFѯ_L^oZ࠰___:uğIjj*HDaa!aaaVx,θHNN^{u~gN:}6lWWW  AcUNenܸyHMMCK$2VVVW.Lw+Vq9!M(s[ϟ端: :mҳgWy?aS$ѢE |||HKK?`888. Fٳ$ m  33SObM鿨f{yONNFe|}}K|}ӦM$$$m؊YFF~ o߾5. 0 zӧ;vd߾}lذAhx pZ:uꐖFjj*7._hs | jI5; p9!H{j2uTBCCTo!#5kF~*t@ڵ++V ==н{wrssi߾=c…dg߭m6z-ٽ{7[-]JOוJG!IxΜ9s0 ,[UV!2GPn]8{,'O&""D@`` ٌ5TdY^zFv/B:u3g666Tz,...HTf@L믿2w\~g$I?]vDW <^TG[czXB$4 #Gcٌ7'''Ξ=˦M8/մM6DEEGf%7߰o>>V˛oSO=ENNÇW1cF߼y@J+گoѣG9w !!z=z=f:wL+t_+[)ZV9}tN^^K,ߟ9s`mm̙3SN1.Hzz:6mbfE899S=VVV󶳳S`‚Ν;[ zLecǎh F,~>4k֌ d}rr.]͛7ժT%M,_b̘1fA3%pBVإK{KYci4lؐƍ$>̕+WŅm"2< +,c4ٲe ĘU,w}6m0j( »ᆱTʊ7xrPPP:k4O?7_=zJ/ ;w.NNN؏, `0_ܹsk,ˬZ .lg, 8RUɺu8qg~E@LP?)̭6  "IܹA|iܸ1hٲxI1bC-W+lZ-{VbZҷo_Ց N: I{Y޽{h4Hu@VVӦM#22W8`Rpbڴi5t:.w3ge`T@LGsq)6oތh$00֭[9AYj0̔+$agg_|A@@r)Eјeqʕ*S4E8 +w<#%_EdYf۶mtڕ^zldBBn';E)T/-·iqh4b n={NՓvtU]U?ڊ2|pƌɓ/t/bE JB6  ZnݺsN ;w.!!!4mڴعJSv튽8TUb4ՊEevʷ~K^LQB8/Z^zxb*ƍiڴiSEޞ@F#Ν㫯bĈꕗQ<OOO),,`00m4/^sب3ёyUvSFcǬl$8/^o]j;jcy EMD1a^y{^uB*,KЮ)jSTC {L2ooo/_N^^'Ofҥ8;; /^L\\?<;v4kWhgg =+xOU˄ *}-[d޽ӨQ#ձݛKΑ`W^jЬ~ HHH 00~wVX̙3Tϫ]ҮQ;w0~x '::$>3>W^ekѢE UY/+Z`AAA\vL;y'V۷FzJ#77Y3P)X'OҺuk^~rWkР^^^piOkkke'''QR`$ICxs`h4Ojj*/^$::8 wީa={RXXٳ 婧6GgPXXH~~>wMyK/tRz " &Uّz8(:"mQ$ 1e+ݰaÈȑ#$''3m4fΜ_n7бcG$Ih4r!6oɓyٲe /R:Q ,P3*X*2Y>ٳGgϞjHY [.͚5ڪALIIjҥKqttTW*;!/z%ٳ|< 6/{n/_Λoiv;w:u*Gf̝֭;EUiueD>|||M\\&Lβ,3c ^{5f͚VvG5LfL7̝;ȑ#֭"&*MudARRgΜa߾}888U l߾} 26hЀ?\=f4 Μ9SC$aaa)S:t(YYY^]\fpYy'EDϟζm>|x?Jlݺg_Vkse>Ν;~ж*fʕ,[ɓ'W}Uރ@ s)j7Uډ,q) ;^Vd{a…DGG޽{y8vX?#&L@ѐ^!IpppJӤhh߾=XXX`)zc ڵS)/w $IbΝ8:: /ΟyOU^^^j嚢F?αcٵk-ZG3j(5ɒ%lذ#G,lٲ>}DjjjJW9r#G0~xn߾͛o;:tCĉ9p111$&&boogΜ׷gƍÑ#GzϜ9Ñ#GԠZZZYYY<󤥥jŗF@p_([Ç;=h$I~tnݺviJhjUԩS >͛7wHnݪ2eՇMJJb֬YV{^YTj4hЀӧoɍF*FV\;۷nsh4,Z=k]\\dǎUUe1A/̭M(_Cڹmc4?ɤIc&11ǏմK2bΝ;ӹv'NGfN<'''0 %~$z{{{:w TI𩧞RXy*%-,,8p Ғ7VvjIJJ*6F>CFU$I xi4T 9r$ǎC$z>xª- vIFȑ#խ۶mSYJY9{,~!;]vY4\\\gi$''3bvޭ[`#GT3=<<ԭ,}d S%KKK OMM%%%>fO|2_}]Zħz}v6o\n0J>}زeCIHH 11:<4h4h45jdVȡ="//;vJRV\3[q\6sFq!<54h;v,: n@%;;z@pp0t:dY&'' booϚ5k]vftq'ND@@@_I~e˖ :;;ADDDRXXɓ'iРڵC$իGV8p/^aÆ듚JNNkӧccc6+,ӣGZlP~}dYVJKK#--gk<Ǐɖ)"H b%b(PӚh]UԮ=ZVk=KkĊ"B 1s~$$N~>y<:%'y*d])S ;666(BJJ QFaaaTق[ntAsMZ:EZ@f̘=fffj8!!A7ģ( O[}AN ȹ;}tt钺̳Ο?/ b RSS9}ŴA~7r~WF4)ֳgOMFJJ:*7\+cJ +~^4rU4t![$!&)-)85u(8 IDATy15+q'1ߠ*9>'^lvGGG>>&n:l???ٟћ7o˜1c8x ]cΝ;(ԨQ`ˮ_LL W^>?BwsBfum1^OܹQ-7:DeiP377Wkg+ cXkkk~zAll,cƌaΜ9F7!_և!7o^ƎgKh8vJpMgz:sgx}''' (K\\\LwdΩپDLi-4|Ǽ6t^4?E%iص,{E0k!Ӈ 6<.99gҭ[?f?A||:sTRXYYuVڴikw^h4|ǬXw➽;@HH\Fd\ ,,,Z*cǎesϲEܹs$''S\L+ݳ4 #_|#bbb;wrʽ\SLɓ ch&fY K[͛lڴkkk*UDӦMCQ9.+J.^… XZZjIMM͍dBBBx 4h OLL ԪU ///uv<իWQ7J/v۷sEEٙ PfMӰa Mҥٻw/{kԨ޻Dޕmذd4<|cǎQzuu&ZV9q@B[l1x ȓ'y<}y 111$$$d˒M!L!M˖-9~8cڵTTI;~^gL8rѾ}{n߾oƍ9s&̟?{qmlll1cZ.%55xwdzpB'N4h{{{֯_ܹsi׮|s.=xx .)))̜9+W1x`իGRRTP!˯ݳ_Κ5ks EQeF&@ddd43sYbZ:]XISb^ϵ9iG6?  S` p)%%kkkN:=^^^\B@P|yΝnooϼy5h4tA]لdf)BRRܹs7oRb ,ʕ+t<C݊۷Hǎg԰InS(_,#Bby̙3^[ rppAn9C%,,w2i$J.M|Leܸq.\)S`ee<,^~1|t:SLQye˖jժܼy1cPdIu'|B޽ ܜe˖I5jԈ/+Vg}f۷t888{M4 +Wd( sV˱cׯGz5RU7oL]?Cmr1a\\\їj3c!iȞ;we˾pgv쐳@^8q"'N@ә:՚hRgzQtu4 ⛥ann6m_~%O̥K:u*J2Zhff=f͚akk?~:vŋIHH0͐ȲK7QePFC۶m]ݵZ-UTyi4|}}3]KE޽ٴi͚5~333ɓ'޻w/Ç'))Pt™3g^뛐TI=ݻ޽AѶm[~w\b3fZڍ*g@>|޳ȊNa0R+4CzyÆ ё>}:#;YXX0tP]Ɩ-[h7;wp=o͛7n@}Yd зo,4 11k׮ΨQرBMb}?Ǐ?VqUT?GpV\b:u 3g3{lf͚ŪUܻwOMNaggѣaɌ3 ݜ5xY[Zze^zhuUree&SxN`` +VȻx"zn;wҥK;v ggg_a6n^O\\\>z=111;w.VcbbXf k֬hѢӤIu0-v4]7Y Zk &T{bGu\ʔ|QE=z4ϟٙRJKh(VÇgRh׺ 4 o޼wǎcݺuL<gg,9ܹsxb_N=5B; 8::fkL>ǏN%d\s -[F~kinISZ]ֲb fϞ[)U-?76䯙 ΅wh52>uܳ)rwIZ&Mˤ=}f 2D>~Ng_fEDDpfsqV nʶX߿m} t:cǎ`t%K-L0uVӧ ,Ihh(׮]S q(±c3J߿ 0l06l؀Q/;^ŋIMM OOO1EQ2isi-Zh4\thiFyi=\^ϠAy&6mzI>%Jɕ+eСǫÔ)SSfMׯOʕeJrVC),4JbQf'KТ??_Z+[ Z~_ ߏ jqw>oB4X{Wc!ɘ{ȵB|HZ-ŋU-v]ʖ-ӧIHH 44*U<7((oDGG3}tgoFQNlmm9~8yt Ϟ=M֭^ҥK? 63^ʤI2d^^^ݛ%KK+** NPիWhS...jLZV-JcEyn17o`Æ ԯ__g/VZ1Cɰ3alffFf͸pׯnj99s`gg(̜9!C0o<9B׮]վ-ZD*UU-Z $$;;;Zl܊^'::;IIIaȑDFF2zlK8::Rxq5jDZ˶YFkOUtaǹy(߶h1vN֜9snݺk.z]rJWfUB^kϲE.KQPW7'ufWVI[B!>T;pt'&&+WdsfΜ~ΊF_wpp0M6}iRN:nݚk2zhk˝;w7n=z 55={pe6nǏYܸq>}Pxqo߾ӽ{wprrCj)))S~}"""ݝի2.FRHW&r|BL3uTBCCꫯQFތ3ȝ;w>9z_/~ #)I3ñ GTvAÃ;vpBϣGصkL8:Sk-oy|b,~\ Ǽ6@!Bdl1N8aإKpvveQd5ojɉZΝ^/k.G13{>h QF|Wlڴ x˪U?~g5)~+?fILL@2'''VZeC!ըQ#6lho儱('={$&&x\\\Ԃ/*Uѣ?YsE`` $&&kg~Ǐ?ӹsgz*UzeW\bLwn!yfʗ/O|tZspɓ')VsP]p]*Umuvlǎ,YիӳgOS#1V|ғ44qLGTX&}(׶&uYh6lPQ|y&M鱰2I/1;ԱB!"kܺuK] 𬤤$Ξ= zt:zJ6mSN)Bjj*W^eڴi|jY Nﱳ#99NGLL mڴynaz4 fffXZZ/_>)%ZːjmiOoC777G}V/}0^\$!gx[nI&{DDRYXiqtxZiا~[מ4]Km4 mڴ! o˗OJC a͚5̟?g*~w/ڞD'bdV+}B!Mtt4IIIQQQ\v ŨINNFfi!s!^&!!ÇK ;y$Gɓ'?d ")) oooTF^'6zXC&2FooosjժK,A/1PA2DM>qRRl۶ 333={6[&_|?Lh4Ԯ][^s"""h׮tjU/_;wή䨄Xbb"Ç޽{L4ɤӎxOxºI'7a{nnBU |4'''ΝK۶m֭]<>#[ƍ 33-uۗ»FPNmNQ~/)B!2Z-x}M3g?~xdx!W\!<ݨabŘ7out,'hew]nRE!Bd,,,h׮+W̖hժk׮}}֭t҅;wЯ_?ƏdN˄_;v/K Zjo0 /_,Jjzw}ǚ5k(RH.{ I 9 IDAT[2uT.\FEa,]WWW틛[OKד͡Cpqqyxxx`aa҄Xbb"ݻwhѢ1B4!c7iBl駟bffƆ h߾['ěxx/5cw ~dhVPP]t… Fݻw租~1󳼢oDzf .m[U'ڌn1~!Bz='Odǎ޽Gr zb<૯bՔ)Se˖ehWIMM}ni׺uhݺ5%JzT^5EΒ6}س>edcڴi߿5k`aa^%i"kҥܹ bffZ KMM5gE(O?ݻw0aˎ*&Kרߟ#FlM9)cBHMTh<v^畘رc"Ӵi׊ؖBcu~2/s&B!^_RR3f`PP!իǸq0`Ibtlb zMll,FbРAUY1118pÇIHHhѢg;aYSڵk̙3N:ѪU+t.7|fUbooEs"^ǏVߟ^'):L[f Q[ѢE8qL۱8WvݣHڍN"00믿'OL^<ǟ[UuR!"('t҅f͚Qxq5 d%ϺqܹjժdJ(-1墡DGGӠA֭[G\"9( _5Qre郕I g19fijm]$z=ܸqϛ7/'O`o;,!ވ9?F??Reqڗ2O?1fvWWWf̘A֭3`s$::|B!\vz߾};^ 1x:65kD0ydz?x-[cggGFh޼9M4!woR%>z^Orr2w [[[uf؇($F.\ןKYZZ_P@S%j5Ȇ'~bƽ8 ƩSv>UVܾ};Sh4*5*"0!B7( WN:(Q.]E (`ꩩ~FwބM޽iܸ17ow!/_NӦMپ};:tՕm۶8DdfeeE…qpp[B(d=giiIǎYn7~kEEzAGpe_SnW,S _~!W\j (S -pv맜 ).B!xzmFJhӦ aaa 0[wȑ#:ˋC1fvލO`K:k޽{l۶]g<&L̐KK֗L&&&GQFtIꆉҭGr[RW]rnݺg SH׊gg5>>%)>spغ[rLE-(38B:w9OiӦL4 oo,;v0}l)ߧO5jĞ={ܹ3Wښ &зoLOFذaݻw'::3rH,-_=N5s1{l DBL# 1!L$)>-p*.rqFzŝ;w6 СC_kzU#s[ 8yB!+V3gΤCǏӿϟo… )VX{.ݺucTPKRl٬ 9K|XXX0uXBw٨QF:!>DfFpn-̓DW̋K-U]t!**J-( 7nwwLœ;.(bpju$~=dB!;ɜ;w[RNw>`eeŧ~{Qk]v'O*Usm۶.\̞=UWjhݺ5gɒ%̘1k׮CܹM"bB O4\ƎvP7cرݻs5MҿFM\2OD윬3}B!Ļ@NVz}aŊF 6diANؿ?kfѢE9rVXX&L`]-Z:$!D' 1!rnvX0דz`ab 23f{zz2|8+!(BIHH`޽4iԡuk֬W^F2999oѾ}LE_~aСXYY믿ҩS9. hѢFIBL&a&w*.휡cS8<%K4uHoݝ;wѣQ{-3gٳtؑ5k>Laԩ y:!D!5Ąa,)W eIpK=yК|… ӵkWRSS9|Qpp0˖-txzzf:&{k,q8"+g q;mB!ȑΞ=Kݺus-N:$ -ʞ={HJJ… ,ZOOOJ.\]] DQ,XETRu Ybݺu̜9sbkkKJ2#3ĄO``.l[U'ڍJO L.]8}Q_~ԩSqvج^}ĪQD:.|62ES7|_*EB! GqzoWnrׯȮ];t3]|\|N:1m436N5#GЧO~vڦKaB"rOoj/+f]~GƎKJJڞ/_>fΜIV2S82['Aᐲ^㛯_0 qBL+BnݺEɒ%eǎ+W!(0gH||iԨQc̞=… x=0mȐ!$%%B$ĄxG܉cը\EZyh7Ώ2tٳgҥ ǎ3joݺ53f _|'sQȃcIy~TiZ4} !BYfѠA*0:tQ{ݙ2eJn۶@n߾M~0a9wN"9'w`W w/Wrvv/@߿_#xe͛GܹXbf=Nd^)!N2wB!x~^zѦMlmmM;ɰsDQ9ڵkR y{{ӱcGN>ͼyرcɓ'N!Kl2̙~~~1I h5-wc7ȼ+\Cy}.ʕ+ӡC.\KNN毿f͚܉2a"\,j}.Jc֩rpٜ=B!Ļg ???:v숅r̨SM6ܻwx rPfM2ԟ:tՕK2grMʕslҒzѶm[9Ç !cu{!&{@{Q~:ensZHņ34,[}0vX ~EݫOĩ\wX؛oǸB!ơCL28p;;;SHLLdȑL<}}}YdIwx"_|GA,X S3L!55dz~*Wlꐄ@bBG_a尣= jHE2ϝ;wYnQ{ժU?>e˖t<1wORl*Z˱wBZjXZZrQ ,hK/$<<\m` 8ss eH05 ;;;f͚Ev#,BdI IIұ㏳/VW: n:zݻw6 FA̎B!^ѣGp8bcРA̜9Өjժ,^%Kf'OұcGΟ?O6m5k+Kt4.\a!$ĄxOExp> |`Erg,M~XdQ{rX`*URΝ;GXXX+W.6lZB!xzZj͛Yv-͚53uH]vh#5&MO>h Ȱa_ȟ?? ,QFv;z(joaܸqX[gl%"gTqC.P̚WT>~֭hcآ{ȑPDlۚzfKB!xLzjI...^+Wߺu+,_tfff (͛?>jufӦMKйsgN>M8p DDDdG"""QFel\-0I  (Ā?R{~ c绸ysB\'"=z4s~[ !BHș3g[E/yKz0={9z(ʕcZi*c4m4S,x$!&! Ԡ*DG';ٵ<:BXV?ƪ_8177 !}E\\THvӱQۗ,Y7;vp_Z']tyܽ{7;B2ƍcĉرիmH%Y$!&@ă5x]W9o_f$?6s!Bݺuח$S"2IҧOBBBV~M6lW_}Elllѣ 6-[Æ #,hơCd)1~)OB??ˋ2i$,--ٳgˁ2ܗ%ƍ8::ҲeK:uģGJ9tOrO!>4Bh4xVrEl%UHNL5I\NpDXsOS~cHs}rIbB!D֊op >7dffƠA|j+W]6 1Z/RzuBBBի/\rݻ7;B;ve˖hLBV8.8#'l5D̈́AO8i#U++in,]}BwO?ĵkט>}:Gd=ʈ#033@?SbE3ܗ-fb۶m@IHHȮHqvv6uBth2wSƣ{ [vs;o~1JEpuu޽{ˬYիQr 󢔗ȫ;F+h4<J(.\Ν;hZ?1`nnN5jzùvfffsE‚Zjep5;ӧՕڵkS~},,,5Bd\||< ח={Ȍ/pfffưa6lWwެ\2eʰd*UagEQ3g]t%B$Ą/0+(|>6pʛ'lg{p%eK<`QEM\ӄN‚)S_ү_?ϟQBslllHMMeL4 ,X+Wd[P3gŋiҤIB{!&M"  :1bSN5UBuGXr%_}O<~`Ȑ!9v'[ҤI>c6l 3!0!I !2It"r5>fU |o=!UoeԖ2WlCBCCgϞ\zGGG B5(^8ի\~ݻw3aRS֛Q_RT$,,sW_}E޼y裏<Ξ=Kƍy&*Ub >\ݥs,X K_S!/w:udt#Ғ1c0`uyeFܼy.]}vTҥKʎؔ)S8p 5j௿! A5EB wb"-VD5Re?cMQ{K<%y&f͚4lؐ-[?ɓ'TV h(Z(^k IDATiӆRJύݝU-)\0 6PB(i֬Y'6o۷tޝ2eP|yZ-xxxdk !jժEhh(={Tے/3fPFW^_MÆ qFvzҥK I&>G!D֑"Zf qq0+Wg>cZBX5|҃Cs>R#~,Y5jߟ;w瞗Hll}JJ ;vΝ;=ԩSкuk%n^OBBz777Cn(P@BWr ȑ3yӰaCΞ=KǎնwXW,EP@Q$&*klXwcK؂% jA * `ݙT@׵W̙g]y|I- r[+Wĉ|w9rl}zŎ;7o1! ;፥{:fa:oʒ{WRaw[`ظxVsƍ>}۷oo>[nqRSS%fccΈزe1ܸqsΩp\b(l޼>l ɪlsܼyq*P~7nLDDDA"0kkk6nȮ]S>|266fڴi={zꅧ'111y~tؑ-Zt^w BBBh޼9Cӧf<-)!fD-ygll, .k׮у~ѣG ƭ[bԩ̛7ǏYyȈO?`8ùs2y1>|hpq͑#G sԩ<.n:zٳge'Oӓ>}'A&&&7ҩS'ʔ)SoN:`P4..AѡC"##V:upcǎe;׼Om#b Yf 5k;KE>2"VM>%1cݫqY7nɓ9~8L:;reݙ8q"{޽{싯Ν;ܺu+ǯ̆0ʇ~Hz b8pJKKcݺu1yde-[IMEQ8rt]Ƙ*T'|ёO>'Af$''3f̘Exٱc6oެ^Ox{{Ç155M6 >ܠܹsT\Wt(Γ7m o;;;ӧo>.]dgϞ,Yŋg-a  aOc0vc/`ȲWV0.]W^\~ ʉOҾ}{sFYn] n߾nݺԩk׮֖SNѬY3ΝK޽6mJV8yA=`wAPAׯ_W>S0x+jKwSr {TPg.8޽{4i$X\ӰaC%<<<:7oVի9s&DZN3F$I=iӦ)ҳgO%%%EQE PVWbEٳʒ%K+++eYڧVU.\CϔYprieJŊ>}(=ʰ,ˊ\xQVcǎA!ˮ_tMxbA"En߾|g 3e+YbkhKMM|*U*m YSZ5ԩSСo'''>ŋ133S_?D޽uqmիǏ?aЦM8>>tN$I+V$ƍ7|a:]v1~xZhw}}ܹ3Owo2) qrr///ƍ\닳3=#GR},M2XqB))|:aBBGf۶m]Uf-_Pνѕzd,F1^6iDaƍDGGgHڄf,xyyq!RSS7oժU񛘘жm[~fʖ-mڷoIMMƆ֭[gXptЁ uYթU+uV1cUTy麦$''(I˖-[x)}cǎBD!̅1SNm۶ۗWz7lؐ &0{l/h;v,ZO> <={fJ, ݻ `yz\/ɸ ԔҥKwH$1rHFVEQLMMYd %KDѨ%eT|˲SN)ĉu:Nj9~naJYv)"IҪU+%222u9)SFF[(JR-ܸq&%%)-ZPlmm={u:tYl޼YIMMUƍJ˗/+M4Q:vDGGgD "˲R|ye̘1IIIQMh4kš5k*.doeFQ/ܹSR(^,!!Ai߾vlo"ZVIHHP7oiF|^ZVYraÆ)nKө/V$''+W:u|̙3Wn)GAz$ѣGT0ydׯϕ+W֖F)p&rD5Bu'T¬YpqqAQVZw}ݻYlC !888vqqq?~>}Cdd F(ǯ}'Icjjg}>=} ...HDrpss{e7oÇ8;;r ñ\r3g0c e[ne;v x6e{6m"""Ef) 3W^%442et(;ԔSN,?fʬ5gN< yzz2޽{e$#*Wӓ$|Ξ=˦Me9[u}%I2xвeK֭[Gƍ#&&&/CxH ΰc֭ٳ` .P^=fΜIJJJԨB y'O<D JD&MXnGƆ/ ĉ3offȑ#qvv<}ww;࣏>RV/H*V9 ?R zzpE޽KZ ^|_~J*ΝcժU=zΝ;ӭ[7断" B۷vp$»?ҥK|2VˬYO JtM4!<<\]~\U,,,طoUT{\x1WWΟ?<+'M|ڵ0ݜYZZwQV-s#l, ,σ:]IEhm4Jg,ĐLҥh4жm[ ,|Ԯ]ݻw#2&&&9<,IȲLHHYJ8Pvm<<( x{{@LL 7776֨Q}6wPBm/1xv|}}6mzY/j ,,s/Yϟ۟=CLw5k׮GGGuy`` ~)ƍrAz.JĨeH^[O4 vvv-[6K0O8Ç 88?#˽%K|e2 +fJ"(($t:Ǐgt:Xr%aaa?>G0xThٲ%ZKAqvv.0Dƍz*C Q0a>3n޼eY4  䮣#[n%44Ç___z쉗%K$%%K.1c u4155}az͛7W{IDhh(-zv=ʔ) yC$Ax5oFe~ WWWu+`ls7/t@ज़t 5t:ח *ʆ xb`x6;ȑ# al۶֭[sq|||8q";wdڴi1"|2=cǎ8pׯ z ܹs5j~y_BFVVV^x'OVZ//ccc5A&Yf* A򒩩)9͛7ɓjǏa߿KTi8`9y8&<:,ļ3F2^GGGe|}}|5RSSh4899qejDDD0bĈ,u{_ԬYIذaqttŅ`.\HB>}:z"**ʠCvܽ{Ǐckk˹s9r(,A9-~ʕ+3`"""ڷoNqID%dĉoߞ *z˲ 4[.ŋѶٳ䄵5>E4oFáCHKKGYqqqaӴiS-ZC{>/ d[,X`L2Z6md-SD5L Q߿ _2ƍՋ۷ocllFAӑJ^XbVKӦMTR 11'OrI͛m6T©S8x ~m/ߧ~ԪUKAxO͉'5j666yky;pYw… iܸqueYf888ШQ#z,lڴ%Kv,[ SSS@"!& .^r>}h"+00Igft:[Bj{GZ-ZEQ8r w888i&KGN »}I{_|:jժ,͊&QFN (Bhh(gϞe=>sss*T^_gB,HcIOʔ[hZƍɓ'sԿRqcccy)fOvcLKK#$$sssuT--)))ܽ{333$I"))+uH 䒴4~̙CjjD,_.]d-#oݑTbmҗF*Z$/x6RZؘN:dɒ,A?-OA'-6+ʕ+u=eZ-r#HLEA27H9nҥK| 4kfx?}dԨQ9s&:uaժUoǐD^}7߿?>Ȉe˖QnI$*,b "##qrrߟu2uTPhh4ي3&&VZرc$44THGy"!&8p` Ν;|,?Tb5100%Mj?~<nݺ̚5ʕ+tX +jժ~ߧ[˙p mG,[jʡ=H9D'R̨yjQN1RÆ cժU]A>ιs7o ESt76!͛3LX2dMǎ)UT|?Eĉ7N=_|Niٲ%&Ľ`ȧ~?@dd$'NdƍjBLex:۷os!g3N2[[[ؽ{7 .dɸ=??%EBL!t:.]ʴiHJJRذhѢlՒHs?Qړb(i!E^. o$,,,=z4/}O<"!szjƍGBBޞkҶm۬%h"nR<ؗxpQ&&\?%%FqUN|xF<;o<௿ӓ… hKIqww'>>KKKvڅ ,3~xؼy3j[,?]v5H*BPPW^k׮b{3S ̟ IDATȈ1cO&M?iӆ,esD?Ħ~b䞕#'*tAAx+ܸqEQYfA"$/WҨQ#ui׮&66{FۡUaq'_7LMI^+ bǎXZZǏש^www:vWh׮ڵM6m6ϒaY!Nj'_:EF~d>NGGGt:dHa2}a$Iʊۣ( qqq_ȲLHHOK.<̕$~;w2M6}<-/$b TXCj*f͚/Y-lݍhš{ JRr^/ o7oPJD^B 9r?*TH]믿ʑ#G2Ny>wjC s--poX_b4-}ʕc̈́g7^A { LeWIcB}fff̜9sssuybb/7nWW ---Q?jժhp~:tΝ;)Yd۷oϮ](Yyܹ3_n׮]sNiӦ 5i$֬Ye„ "!&FСC P3b4iBppp2H٨͋,\ɳy  =klll 8Ax5###Ǝ˥KW<44MW_./mytbS!œʑתU+޽{2痥Ȳ,3fW±c8rH^T{FQ<==!p Fe\\\Ĝ$I#I7odĈcccILL4'2,(VX$[rhذ!$ݻ+++u}EQt>SRX1fΜi06?H s,In[,ˤajj:u*cǎ%$$xvܹŋ3gǏ !WNjԨA@@,憃:1i$:uDdd$ȲÃcǎ/ 1Yf ,--ڵ+jرcp=֮],|駯EF*b7(\OꕘaڬS\ BnhԨӦMm'm)U B鳈j$ߓaz~V#j/]ʕ+S\9Eaܹlٲx$I舃,^???|||xժU#%%o[( ժUcŊ^k׮ѰaC|}}oΝ;9s?P( .\`?4h`\n~()S&_J/˗Ғ`6n͛ѣgΜܜx6nHf022̌5k0biР,_> }ܜÇ#IӧO7z /_Ν;ǘ1c055EQ>|Hf͈bݺu 4;; ǗSRާ9ArǏ;v,ׯ7XₗAQBˆЇ8SgJoC\="At#b#cK2!1P4%3&Z-UTA%vjEQHHHu8880`zIHNNfԪUA¨Q޽; ޽{|h4/_NLL ۷og߾},_3$={6#GgϞlڴ *_Nǒ%KfŊ3M4Oё 憙ܹ[[[F,rJ=ʖ-[$I)BXXݺucƍTR%IQRSSYl :4ӟ( >Mv(P/^޽{3@&L d ֭ӓC<3n8fϞ/$rb`'"ԥ%g7c^LAףO)Bll,O>[[Lo{#FHH궩-Z[[[RRR{.fffHDRR+V1龚S0x6a'Ϩ1˖-m۶L4+Vi[,w^"""hѢ:sb-駟ѣ+VZݦTR-[p$IXbbddDٲe_yy&={ERL\9I;d۶ml۶ŋ,9愱1xxx oշ׬Y3Zha?V/:tիtVVV/Mgbb˜1c^BQ"""t 777uEQ8<5k,GǞ]ⷭ BjՊk׮1rHu,ԪU'N I04r7nbGܟ3Ɵ {+/CA?#F_~,Zݻ3yd>},t:e9۽bccٳ':uݝ;ҽ{woP=zЩS':u4+PpӢC_m۴iC߾}YjGQg{hu)IIIDFF$ D̋7uO>a4k֌B uE,]'ZB[rһUsJ&ݻGxx8M4a.]ڠ .akf|WIx"+WSNQH}Yx"|A9DBL5XYYl2;Fʕ1j(uVT3~$sUD߫! »=zP0|pJ*֭[駟Xb'OdܹHDTTƍ3^fUhSEe~gl&L$IL<///fZ_EH!_!d~-Z= RkfNg'j-,,2|E1ǐNj !((6mpuϷ,\zٳgӧO_Myy1'ND|K,YurJZjڒeV˗_{?X[[h" #"" 1AW\a„ O/_4e$## ;ʽɣ4F A!L2ɓ'h K.1o<\]]2e h4*W|ǹ{.77:kfz111_DT{ Ip#ؘ3gP|yڵkGJ Ķm2Mr} y_L2YfsNtݦΓ]]J:_ V"44=uEEȲL``r uEQ066V?{iiiܺu N& .VEe޽KZZŐ!CԄNNN.][5=EݻL<ƍ3d'֬Y;/ b̙3Fz&''_;Q( ʕcڵ|w899!2aaat:.]D…seiV B.177g?chժ ǯlG\ 7xZ=zx+b:O2n82ek'[~=;wM6qFQ N!&&SSSuxFqmO M89s`jj}(d³n }RҒU>k7,eʔy̐FFFX[[F_O֭$ SSSuo:}"pܺu ///lll2$X釚FGGsU=j.Y666lܸEQ4IYV-vP PACΜ:uŋ3c ؾ};~~~,Y^zbV25js 7oߦw޴kNmO}9r/ĄΝ;g8tǎ,[ '''Ygҥ$''s).\$r[TTAK? /))I]ΎJ*q=u>9_0gwv\C^I;q. `mm!!g2z 8qʕ+=&,j*GrXxZ2K_O?N#..X=zkZj l׷_?Acƌ?wwwx}gʕ/|ʦ#J%KЩ<ܴ=oمi ATT)>|X1H VIueY&-- SSS֯_ԩS;v:k(ܹs/2gƏABLN:ԨQdY ucҤItԉHebooǎ#88O%$$!sdnnN͚5E},Y2CFdj׮xRyck>g0_>s᫯$IAu0}aY}రg^?c/%(HZsM{ B>?]8qEdY.=|fI!=^T&/CAV[Y;w.b̙8;;ͥKQF|G>}sqʗ/O^Xj\~kkk/Nǎӧ=z42}t$Iˋ5kݏݻd݋,siիǃܹ3իWϰFZr ۷o?;FCJ ^...TP=LׯOHH+VȴG^v#BBBuV(ɓ'UN_`{_>~RjժE$A @XX_|0X޸qc֮]KŊ_ن@lY %wôy  IKK?Ϗxiذ!nnnΝ;e888 2+WѣlٲE [nlܸ*Udq_l :4ӛ;EQ>|8jAx}GQ|||Xx1{P 섘(L_,YWWWW52=W9~8~a}ޣxN^2 gvQ6mY`A~ˏɛ.h"!&P72f?~.777gܹ=:Kc|֝T)j E-F*R ›e}u%Iiiitޝ:PV-qqqQ߳g˖-LbgL_(:xxxl2ݻ66z˲̜9sx1-Mڵkk׮caaA͚53nmS!{ :2[ݝºzL #$I_~jՊ#Gk.ٔcǎe۶mԨQ혴iN'<<^݁?1n$?CA:~&8mҽ{I&Z\]]Y*_uE)W;v0G&2/^#KǙZ-o^_:klX*UhڴAɩ@f)9z$aǴiؼysE%dF$A(`ܹ;v0rHPDFŵkרP7oؘhvŐ!CDRĐIAwXbظq#|p>#&L7|ҧRq[l.o]߷+a`<|dV(EA -ZxlyT{ѥKFEpp0u֥B uæV?dSQ>}:˗/g۶mtƳ#557ndH|QjUZnm*_;sӞJpp $ `ZGh[qfEto5}pX[[{N={{{~ `aaO?DJԚ¦bbbib]Q߿Ͼ}ٳ' C$A@mڴ! I&rJY{n~W4h6LtL( [c\^~  (]4Ihˋ3fd: -| fffH(lذCŊτFT^Q>?vyL>]Mf#zeYfĉ̘1CfX6ٻ.Ax9shoРA,\kkn( )KrPyϝ A!/ 6 䵷GeBCC3!YrpqqUM|URfmjuV3\f1+W4s( Q clR f;0`~7BBB([l.FfS,X9ofڵpΟ?رc166V2 ( K,Alٲvޭm={6={J*( w̌6mx5֯_ϭ[/iҤ{ 1A\JJ s7w/B^ن|1z),]hjT˰A-<111wLHHPױtc"EܿaO:.WPPP_FFF&r#I7 Zmæbkpuueܸq㏹OQfϞM~x)Çkkk&M̙3t CO?qIO΂ hܸ1ZaÆfP=z&%IRsW{M>ÇӲeKVsB2)+T~- 8/frwwӓKbgg64);˸?+BkV0RA!k ̐ իWcǎ=`N۷3M|d-###*Ui+&(4'Öϙf;֥fY Y is1fu4B_iii}삯dx{{-]𕖖vWN_REԬrku/]q3ưh"L8ݺujkk1n8L>wy'gŌ3FwYXm O>u$gZZ|AL4 ={5@qI̝;6?^{5Z~6m"##])Bt{o Zkl߾{U0}F?D8A=Q<1 >OlJ1Rc z3ghɰncPx'!˱l2^@]H~Ih4B  88;w֭[!CL|y{YxL@=`07VZHdcZ)B8L6 #F>m۶1o<|XnmCx]/Un(^P}9$SnqEǡ֭oxG0`|}x`[BHٳ'f3:깵Z-¯R>2 ɸ{mR9c_2 (JLLlWc8 u# .9@!5Szׯlܸ]vu'Nٳ?~+VÇb X:u ݻwkږb1ưyf̙3%%%^^^Xd x+be/(_\+28rfΜi]믿ZFYYϟ_Fpp0]<Beff"..?</^sL&;w.ȰcYtǸ6;ݑ1Vmv{1116Wrr28WcqQz5AP7zhĈؽ{u?wk3f e5OWc{aĉ8p`Wpl1TWW!!!F֭[172I!䚔b޼yOmׯG=n@i Q ~ ǺÍ =`ģ>j}ǎ c F.\$I\Br gDε௱ o ~y6D7`w S^^ٳs_ѯ_?>ɓuV~Ԇ'XpXpO^da۶m=OD!پ};fΜ|6H?xgP4l {3a0`N; DBp2)]C[ Pwp!̞=/_? B $ Ovsssqw{+%%AAAN=Oka!//aUUU삯D \U^DANW_@?1 3Xܹs߿u!`8pnukt Qykj=_@!p|I]f{bb"֯_7y<q#$6ToݍmAhE @[5/3wjlFUU:L{Za< III !B<c /6H+q!9uc(dC \fٳr ƺ-%%'|||.SL-[PTTd&!1B!N/`ƌ̴n8 /d QJ [s!yp򫛶b2{צ ?VZe!DVVU*ݻw-CAAW+""aEEG>rH|7Z޷~ѣG㣏>jxgY}01sL >)(#ػw/OӧOlK/mߣh0%}  #k;^~o!6k,wy:z~0xrY[@e2 55ذaHwbaUQQj]𕔔DWn:xйs&ysK7xæ*!2 !Xʕ+a6ۯ8;#zt. 3ZVlf}󑚚JNBϟq񁏏[qV;c(..v|Y7GXXjGCY=x3 ""z=pǎ9rd !Rq]waڴi8z(8d+Pa?d܎k˖-HKKCAARSSiBqyyy% %$k_eeen+44aHP`̘1ؾ};}lذs… h\.ǶmЯ_?XGݻ)))nxx.!F!!шWD3[VHMk813g //C@ h7s8{ld2t޽`-+))q|6Jjq? IKK r9cQp[FGGc iQ_| /V1B!bѡMaMQ]WwN:Y0!FL&fʹH#55RsRWIII v|INN 7܀?ZUkիWۅĩشiƎ 瑕owLvu+{7t^ìYZq; !8dm/ѭ}`{O@hqpI4yCJJ `1ZhZt:t:ϵP D׮]Z]o2+--횃$5BZã>nH$_ѣGW_ǃ>O?T*ݻ)#bB0ECؚqڹKk: ?8"##?c޼y뮻huIBHUZZJB)@t+___( i G(J$&&z((V.*..nv[G|\# ]pp06zܹsq;?]te˚P!C`yHG!;9KQ*%נּ|D{tCҗ146n܈Cbȑسg?]/BZ`0X.XlW(;; P* BRB\.oFV @PP5vΜ9>& .JNN</0((/7q֬Ytڵ ϣK.2eJq7oAZZuvm/A!LC ;'_@… 믿is_ HIIc߾}BjjGB#1TWWMqt0 $%%h%ȄT*W:J TF/z'm?˗c̘1qw`޼y ½q"7oqIttM_~1*BlԼ6󘻻aSRycĉ֑`U&c4hVZZH$̝;iiiغuʹA"xϣ:&8 f!ywܮ\eRVaL]۸q#&MPJ 2'OΝ;q뭷6ylff&gSNW}~aÆO?mk!(P<_OkLc(D9C_1r|7X,wFTWWCRaذaB77x#^xt!Aq5Ƙ4Gh/Ld}Ϝ9 K&+Wbo6fΜI׬gРA(**ݻѷo&߻w/n& o>]9٦ޓq1E!c@ӳBr Ф$k;B֑`5,l;_Frr29aÆYGB3/5BihQPP`R>00:KTB.{dm.BHԩS'|2d{BV7Ip9,]?cxw麈t(yZgBHa4ƳY0τqw& 8 JhJH Y@5;ӯh=AW ctHJJBvv6vڅ7xFBCmmh/kkkVk R)BZEϞ=m61ÇǞ={>,Ξ=kk׮X`Aku)yBіLQ0_h>#Najx ķB-.1vS H5D"tDž/Vٿf^1lٲ_}|||0k,HB,***P^^n|f}ru[GS& q-[`„ 4hve7|ػw/뢯cǎyN<-[`…4ՒYy( ̄& ¤I (#ce]sNi`3Ɯtp& Hxn8)1tS -+ `t8sk_-_(FB 逌F5b!22ӧOC@(LsW1B̙31f|WMD[\\~AȩG IDAT ѣs[={6N8nk!Lz(֯_'NXHL8Ѧe_gL7yK`>)MÎ phwuo^H2(2R~n$cY7\`C>3h`<fc8^A9<]hwqXѣʺ5k0>JL#o1 BLwd4ͱjjj Hϐ68^x7aoG9rSL-[Ovk (`A.o!iweG& )q~WC<4"K.Kت c .\pteff^W 1a q!DFFM$-w9 a+~C6B7nDyy9qaɍ?k,={:믿SO=+Wq(#mb0K02ͷcgIAںk% Ax'O].Q|ǻ_Me ?.^s(J"001فxp?wwB:T/} $$* 6ϛLqV%yyya֭:t(x`Ĉj*?;vKufOALz(8CEEnVjH_} ?@*B$d2'NX,FL `YP$B2(#IZ46򨬬l9 bcccbbl~}kfOk9]S]㖿+B(x^G~~>2227** H$6z-۷EEaa!c&j*mB:b 4_Я__0VUUK.(,,tI_-Z^x%m4Bo R)nZIyΪ% etWDXﲨl2pBϡP(BJuAG_ɐBa!8lY_NwAAAڵ+J{1jh9rà93N ꅧPx^98 BH1 hZ𫶶f_T B`몜Jn#PW^^^ҥ ju}&a㏱b DEEy iRSS}v 6 ÇJny(KU?2Lx~zpu!%%fuEkjjpC$Y|8 @A*`|tF1 gaDyn]@]C&9eygt[cs`<; wwB܊1j#  D ~ce_ذaV^W^yOH{4x`lڴ wy'ƌr"QQ}娨~C#FҥKqBMa|Kiii~V믿 ???qtcNWzyC-kJkSxZa aQ]DB(9.1n^YYYNitv]UMy%n؈>7}z*,6ԭP(B/B@~XIiԨQ?oo'=CaغuM-)SlquVMHs61#GbϞ=ѭ[7LxMF8p+JZ A+0&1PRs,i`:LgQ (r.a\ zSjM*++k9RiaWlll Dsd= &ϧ0.LF 'P(f~d2~y+Da!N6c ax'zjwwVErTCGXt);ݮ>3˗/jjQ/# ?eλ]\]nwA;;@)h#>JKK[|//Fw 88|}v:]P]VS g!5 ީ!4`^ï$ J%T*͈///Yd.n=`59Ȃ1հ#Fشi -Z>`!//6mc=C A~ԭ2l2;)))Xlnfu؋ڽ{7vᒶG@hDP90:UPmR @  !~~S Ngx՟#88)J-t|9L ^u/kV֞3!c vVhW&AT"44z)J)Kja!nAbԩS(..j8r1uTO(//GAAyƎ;0yd9s1jkD\\z|n)JdB;t0BDŽ<D]c/ u[.MNi,))i9$IWHH^{J/Z'[  D?g:t= [OԍoX| + ˭,B0 W$B1ʕ+q ΝìY"((ְl޽طo&L3gBTGEjj*¬{?~P(=r W",(L(dP**g #2Q A쉇[u%>^C" ::ixHC^>kr7Ʉ16۽{D+!7o>B:ha  B( Lsd-?SH!E(PaܹH$8fՈD"իT*Fr>>>ѣGpBpq c08y$ƎkSrM8}BD @tt DT ?M+8\N =`%e085vxY5:#ňUJ7n c#c(n |뷐iԡ(|µ适 'dr8ͱf?X Bp_RC_?p>|0uTwwBHF8{a`eY!Qk.HR :o&O JeC z]r`2#b6vwA!P)KH ca݊AڬvmS}^Mi qB!|Mp698a~SũO7]BH{`p8ͱ8^^^P*YёC=5k`Ř0arRT?bmDʲdSۏ9gȑ#H$);"PCy+/?o_7"" _k 'BtvkjjXXXsbDEE5xѴAr8DR AO֮ b_1jvm nny xN !mc኎&m/kSKTB$KdO%rJ7:~iwwf+j5:wL_+PfToտC56z7@::!ڵkѳgO 駟oh[\?BDP9~.~/;wBJaaa6#ry8VɓC!mqHHHӧ<222hV\FoLh",Y[oaĉzHƍD(nx&p7Z!"C߾}q)Žv ""Ν\7ydff"33:;1 tAN8 sqh4QCyx_ʬh x6 Q"~P0T^J4vrU}JfnpoPH2/Қ8}؅ bHKKc 3fE篣g CK^k !m(uL>D"DDD ++9F///ƺgkwfS H"P$] ~zO<~kye5Mk,Ӣ,G -*3//i>~="T)g,Zp|e]"!q?}uL iW k׮ŠAéS/ZG9y+6:7<_u{B!&::a B-2"t9r*s|>rHCx?&;BB%jDHW}?EPuR cP+V""߮[w%LFWiuٹNKc={~_L0III#e0&G'C3x͘w?zv|zkuBL&꫸ѷ﵍P&\.GHH+q-VRGݒ9iػ~c\@ʡpg'1BsC&(ס|SL'|ⒶǍ;wFZgjX%fQpJ{=|pS=P Ī QgPѡ2Sf2 F*@[l{ ш\kpFw^WBv܉no{cj\9mE: QTBgظq#~,^?CH@1#&&ā Я_?[쌴Om*[]~i1;:# u{!%F3 UQ F*.*T/gLkx98DDD8u ̛7Ib4i>3Rɓ'i3{{Х@dJ!}c;v,;8pzrwi6n܈{2ưpBxp - D7n(f6(ɮBynZc9:TdP~^k R p|f{JRZ#''Łj~y~Ț^1B\1#F`׮] :uC1Ծsg[ !cAAA ! qQo߾K6 Wb`j0B 0kӍM tZf3b=("|.σ,m9/^U ".Ԩ W[%uF|N^QQQN:(OAnUS Feff"%%z`Æ ϟ<䱨@C頾!Bv ,ݻw[f]42?//ǏcSo@h4cͤPpҭP%*/]uy)iQGy/?  Ψ  C,v4ٌF4:+ sv"22R 5c72.ۮ" }ELJ<#L>P Fe,\׮][ݟ >4J>zJ\`5I!amNH$BLLKFqeggɓPTݻ7yl ciwwfO* u eE:j+l L. D)CD@JG+!76((( ?ʄ:eځWC1Tנ"_-~,hvF <8~R^5`ZO!g}ٌ)S+;<&L~h^UU?h48}4гgOM_{an4:j(hQ2h Ҙ T IDATMi%rUu*r(֡"]wUs91-Ӓ/4Dtt& Xv-ɄL,YVj8[.:^߯} !& Ґ=zK+HNN HɄ h4DEE!66ienbbb`2py?HMMP4G>)ݰ!s p%#W 04=25SFzADכCR˭ C$:^B!BUUJ%e\H+zyy!55!!!Q\1d2K gϞ8|0N:P@Tsw8] nŚT*UWLL^1T/עR ^btR oT.%rH=%B!Bn f… Bڜ+W7DQQ|}}ܷ~}н{wBzq j]bjLW͌5uf^>>>n] (4«"_</MkH-\/¿!]}pK(Tje]-P%rȼ%nz5B!8WTTSO!88ӦMswiSv܉ݻ_1 s5FeD"c  ţ1fbW͘.HR*3ga6Meף,O˅2- b~rXv+V?ToB!A,/M7݄z;vEHӧOcҥNiN%ѷo_߿GE^sͣ1יI[&zң"GLtyV@p"P)௮+^$P(hB! v܉믿s~' 4"ܹ0zhwY1pi!Hзo_8p7z])Tbww~^Ni7,UCyy WW۟[-A@7W^ S@L !Bjصk+V`ǎ!oΝ A^^RRt:?@0k!J#Ŏ9}ϩ mGbI*GѤnWc ڊK:iQGY2tk>:a + S@&>0yǁ1\8sY;M$[f3Oj !¯JE fϟǨQ 8ZI,O>8x ;@z_2:RСC<8q "<bds־p}Y:Xe9V(?^8% R 0%JNF۷7$a4u]šC J!֍$߿?bG!# %!.&B;w:ΩJ$k|WR(֑bD~T*zҶxT "nDOw;_= -6qBS]*+R+ +\ @XXl.\`οꫯ  j1|1R;vP F! aݤ9H 9K! c ^^)BH{qK$///k`;*s]g5 Ă 9@b&'Qhhԫ؎`'؎z- !EEMs0 D3[%nw~/jޗS,..^^^P*u 2zU+dh\ݛ-ujkB on{GfJ6\QؘCQQ֬Y2J”)StgVVV1c֬YV)S |GD mNaq 7?JCDd Z-yddd`|5h^^^y_]vݓ-zRֺ[4;͟dn@_H[#R"Lvv6>s$''ٚA͛7_7FF^MOfU쌳[Ӑ}F> {%1vIp ??ѣi&QF@bV@kB͚5Á`nn΂k[Zykԩ_K(M""A%K`eeE!)) ]v:$֬YJL6N׵Bn8^h1j7d m1d̔?LSt* Z ::7onޏlj033{}KKKrQ&j7m@~p}""c >c\t02ZXz5w//:__T^j"^:zNLMM1bV7ob̙kDQF0uT,^5 VUUۈDXXmۦODTG]p3Tޭ©T @$M6Ν;xk{}z[.xd n)垄 dAAAKOO5맻5VELL 8qaǎ8tPf͂^}U,_#zϩc#4zc ?x'V\C,bcc)u :_e˖?Ō]ux/ VYYann#GYHD/Ծ}B?G\\;QSFD:N0 ,,1݃6nGGG]Y =z}Axx8uVB.p߾}ݻ7 """={^n"28ŦM+V<ߑy5_9bZ#G 886668|0ܤDdTkkkxzzDg~1er_yËF7c1bXu}Yf1b1d8::>G"//񁍍Ms BZdkJξ&""O;"""t NE_Xxx8ڵk|5F82R{5Aȏ¶?⼊gNDD'OĘ1c"ufDzz݌-zV* ڵC^^.\EJ?\m֧ uwEj,`JÐ!CիWW_T;vښE1"8mpAxkcQ.V\yyHHHu=cch_0c""""nnnjtAIKKõkN:q5=/F#Nz1n]Cqn9Dh-LM尴o.hlhZ+WԩSP*={6ڴi_СCѪU+\p~"")O=~y`S+l?zȢ"#33'NFRRNFNDzҾNGB]0UM}um<`f_~W\Ar"O8cj()Až0#Rh1L%u֡]v,Aj Aff&kp aÆPT5=F 4n݂JBTTlقkBVgΜTVaa! SSSbٲe;v,&O ,RI,6d519 uK2f&}^NNڵk|k׮G#[ bwgF]d<;Teha9:52wZ]rC +WO>C^+++øqбcG|R!z$-nٲn݂Vo ????EDBCCD,_]t-f6CX0lEt/פI,[ Ǐٳ۷o믿;GԈz6^ݿ=q&M!V02} +bbb[o!22ꫯEHضmwj$!iX`` ?e޽߿?n߾k"??իI@oA87ZYq%a'qG+&GL_^Z-qa._ {{{dbˇ}6e5ojA͈~EK.Epp0o$"#ĴZ-|||8s uZш@t /y4Dɯחpx% ^mܴ4+(//JOw"/Qb~Tlas:îogNDDDDTHr[II 1,BP72z:;;c ù8e~YŠapeoXgyak~Znn߾-u 2r8p rrI:,X߿~I&7x}XDDD(-S#+֔0fyIDDҹuڴiRիWcȑRG"#TTTOOOdggܹshD$bɈCiiXbb"bcc[8Q])Ϋi1 ?|O'ODDKR̙3hڴ)BBB0f2n(bv6lb%I:JJJPUU;wj*=-LLL:sU0m EQGw8Lm䘱/ [JNyy9矣EB=EF "" .Dxx8"##C$ I:,--accScHsssƆ0""z~.[R$-@e?vDD+V/8(,\Æ ç~*u"HR{VXr%v!u""20/Y<?@@);_(YMGrqbK1 %%sNo6l̠JDϕA}gdd`ҥ8zQ8O:~ǡj|ng*ן2cQjРAcΝCUUi>SToaaa!u"ITA ˗//u""2@݇PG^t /ܿ=2+%UAU"&u ""Pjj*ХK8qB8DDAĬ1x`899I 'ḻqYMWPgRVkADDqrr—_~,cԨQ̔:˗/:(qI8q`cL""'fEiQ .xCHUr. ""c%0i$\v SN֭[իWK ȶm0p@Z XX}A̙36mv؁S",, wޕ:6;͆ɒ|$=PuO#u""s~zjL4sFŋq(#!!رq̃b'NիWqq4o\HeffO>={6֬Y#u"c4dggٙcsa0( #'t֭)S`ʕ =o۷oGp),^1g`oLfHϙɡlf6SJ \II N:w}QVV&u,zj&L믿={3gdgkm6dff"??;w(:tш*O{c%A4w:8KKK$&&bpppO? gUܻwSL йsg#,,mܸΝ;qIL4 Z:#- f&R "z@&aذaǮ]ЪU+ݻr9o746lիaaa!u"sc̙Xs󂘪k}ʒ =(..AAAHIIq&((7oޔ:s=eYRGy}ZJ1[ۚ嗗#$$0`&O `իW1gDGG/^D۶mETeQ]0m@1P;ZŻ1Ȉѣ8q"9GIILL4 ؽ{7aÆI^ZVDDDR)-C>jq8digk#u ""2Rؼy3###03ݙ YTTƎ1c`޼yprr:QƂbpRQ67C/q>l4mڴƱ4h:vA$Jg222`gg1QqIxI""2zuJ~=+aDD7,?sNNN6m= Z-ABPQQ͛7# j|^&FTX#""5ȅ^Rx h.u ""qF\v K,J—_~޽{cRG; #Fx!00PhDFL׶sY&.JJvas݋T*L:ݺuCftV.\sA`` &N`c#,JDԇgqqKfmo돆-|m""%>>(//XYY>>>@ $Neddȑ#8z(Ο?$( 糳QYY-[JĂџh4"6uS}3:[j$&&̙3bƟ˗/upvvZaa!OǏ#==`ii=zo qX#""zci<άI}k5j _EDD/ ksżytѺuk`uKEdffڵk5>""";OVaÆҥ z^z&&{@d(?ȸ2 ͱ8?5L^\01Y*DDD'zT!55ͭuqO?RT*1|pL>ƹ7nĥKP\\ݟcƌѣkꊔcAK/ݻ5311AAAd2SGdX#""znM1sg~uߤCT?jU{ +[AZƍǞ\RYevZ8qT*aeeKKK:744(...pvv~3È o$""zٝ4?}ǘ\N#U>-Tv5fыSPPB H1""шI#|dZ),epCPJ[nNDDDDX#""""""""~Q""""""""2*,QaA bDDDDDDDDdTX#""""""""‚ĈȨ FDDDDDDDDF1""""""""2*,QaA bDDDDDDDDdTX#""""""""‚ĈȨ?co+Z$IENDB`e3nn-0.6.0/docs/guide/transformer.rst000066400000000000000000000230541514371756200175200ustar00rootroot00000000000000.. _transformer guide: Transformer =========== > The Transformer is a deep learning model introduced in 2017 that utilizes the mechanism of attention. It is used primarily in the field of natural language processing (NLP), but recent research has also developed its application in other tasks like video understanding. `Wikipedia`_ .. jupyter-execute:: :hide-code: import torch import torch import math from torch_cluster import radius_graph from torch_scatter import scatter from e3nn import o3, nn, io from e3nn.math import soft_unit_step, soft_one_hot_linspace import matplotlib.pyplot as plt In this document we will see how to implement an equivariant attention mechanism with ``e3nn``. We will implement the formula (1) of `SE(3)-Transformers`_. The output features :math:`f'` are computed by .. math:: f'_i = \sum_{j=1}^n \alpha_{ij} v_j \alpha_{ij} = \frac{\exp(q_i^T k_j)}{\sum_{j'=1}^n \exp(q_i^T k_{j'})} where :math:`q, k, v` are respectively called the queries, keys and values. They are functions of the input features :math:`f`. .. math:: q = h_Q(f) k = h_K(f) v = h_V(f) all these formula are well illustrated by the figure (2) of the same article. .. image:: transformer.png :width: 650 First we need to define the irreps of the inputs, the queries, the keys and the outputs. Note that outputs and values share the same irreps. .. jupyter-execute:: # Just define arbitrary irreps irreps_input = o3.Irreps("10x0e + 5x1o + 2x2e") irreps_query = o3.Irreps("11x0e + 4x1o") irreps_key = o3.Irreps("12x0e + 3x1o") irreps_output = o3.Irreps("14x0e + 6x1o") # also irreps of the values Lets create a random graph on which we can apply the attention mechanism: .. jupyter-execute:: num_nodes = 20 pos = torch.randn(num_nodes, 3) f = irreps_input.randn(num_nodes, -1) # create graph max_radius = 1.3 edge_src, edge_dst = radius_graph(pos, max_radius) edge_vec = pos[edge_src] - pos[edge_dst] edge_length = edge_vec.norm(dim=1) The queries :math:`q_i` are a linear combination of the input features :math:`f_i`. .. jupyter-execute:: h_q = o3.Linear(irreps_input, irreps_query) In order to generate weights that depends on the radii, we project the edges length on a basis: .. jupyter-execute:: number_of_basis = 10 edge_length_embedded = soft_one_hot_linspace( edge_length, start=0.0, end=max_radius, number=number_of_basis, basis='smooth_finite', cutoff=True # goes (smoothly) to zero at `start` and `end` ) edge_length_embedded = edge_length_embedded.mul(number_of_basis**0.5) We will also need a number between 0 and 1 that indicates smoothly if the length of the edge is smaller than ``max_radius``. .. jupyter-execute:: edge_weight_cutoff = soft_unit_step(10 * (1 - edge_length / max_radius)) Here is a figure of the function used: .. jupyter-execute:: :hide-code: x = torch.linspace(0.0, 1.5, 100) plt.plot(x, soft_unit_step(10 * (1 - x / max_radius))) plt.xlabel('edge length') plt.ylabel('weight cutoff') plt.tight_layout(); To create the values and the keys we have to use the relative position of the edges. We will use the spherical harmonics to have a richer describtor of the relative positions: .. jupyter-execute:: irreps_sh = o3.Irreps.spherical_harmonics(3) edge_sh = o3.spherical_harmonics(irreps_sh, edge_vec, True, normalization='component') We will make a tensor prodcut between the input and the spherical harmonics to create the values and keys. Because we want the weights of these tensor products to depend on the edge length we will generate the weights using multi layer perceptrons. .. jupyter-execute:: tp_k = o3.FullyConnectedTensorProduct(irreps_input, irreps_sh, irreps_key, shared_weights=False) fc_k = nn.FullyConnectedNet([number_of_basis, 16, tp_k.weight_numel], act=torch.nn.functional.silu) tp_v = o3.FullyConnectedTensorProduct(irreps_input, irreps_sh, irreps_output, shared_weights=False) fc_v = nn.FullyConnectedNet([number_of_basis, 16, tp_v.weight_numel], act=torch.nn.functional.silu) For the correpondance with the formula, ``tp_v, fc_v`` represent :math:`h_K` and ``tp_v, fc_v`` represent :math:`h_V`. Then we need a way to compute the dot product between the queries and the keys: .. jupyter-execute:: dot = o3.FullyConnectedTensorProduct(irreps_query, irreps_key, "0e") The operations ``tp_k``, ``tp_v`` and ``dot`` can be visualized as follow: .. jupyter-execute:: :hide-code: _, [ax1, ax2, ax3] = plt.subplots(1, 3, figsize=(9, 2.5)) plt.sca(ax1) tp_k.visualize() plt.sca(ax2) tp_v.visualize() plt.sca(ax3) dot.visualize() plt.tight_layout() Finally we can just use all the modules we created to compute the attention mechanism: .. jupyter-execute:: # compute the queries (per node), keys (per edge) and values (per edge) q = h_q(f) k = tp_k(f[edge_src], edge_sh, fc_k(edge_length_embedded)) v = tp_v(f[edge_src], edge_sh, fc_v(edge_length_embedded)) # compute the softmax (per edge) exp = edge_weight_cutoff[:, None] * dot(q[edge_dst], k).exp() # compute the numerator z = scatter(exp, edge_dst, dim=0, dim_size=len(f)) # compute the denominator (per nodes) z[z == 0] = 1 # to avoid 0/0 when all the neighbors are exactly at the cutoff alpha = exp / z[edge_dst] # compute the outputs (per node) f_out = scatter(alpha.relu().sqrt() * v, edge_dst, dim=0, dim_size=len(f)) Note that this implementation has small differences with the article. - Special care was taken to make the whole operation smooth when we move the points (deleting/creating new edges). It was done via ``edge_weight_cutoff``, ``edge_length_embedded`` and the property :math:`f(0)=0` for the radial neural network. - The output is weighted with :math:`\sqrt{\alpha_{ij}}` instead of :math:`\alpha_{ij}` to ensure a proper normalization. Both are checked below, starting by the normalization. .. jupyter-execute:: f_out.mean().item(), f_out.std().item() Let's put eveything into a function to check the smoothness and the equivariance. .. jupyter-execute:: def transformer(f, pos): edge_src, edge_dst = radius_graph(pos, max_radius) edge_vec = pos[edge_src] - pos[edge_dst] edge_length = edge_vec.norm(dim=1) edge_length_embedded = soft_one_hot_linspace( edge_length, start=0.0, end=max_radius, number=number_of_basis, basis='smooth_finite', cutoff=True ) edge_length_embedded = edge_length_embedded.mul(number_of_basis**0.5) edge_weight_cutoff = soft_unit_step(10 * (1 - edge_length / max_radius)) edge_sh = o3.spherical_harmonics(irreps_sh, edge_vec, True, normalization='component') q = h_q(f) k = tp_k(f[edge_src], edge_sh, fc_k(edge_length_embedded)) v = tp_v(f[edge_src], edge_sh, fc_v(edge_length_embedded)) exp = edge_weight_cutoff[:, None] * dot(q[edge_dst], k).exp() z = scatter(exp, edge_dst, dim=0, dim_size=len(f)) z[z == 0] = 1 alpha = exp / z[edge_dst] return scatter(alpha.relu().sqrt() * v, edge_dst, dim=0, dim_size=len(f)) Here is a smoothness check: tow nodes are placed at a distance 1 (``max_radius > 1``) so they see each other. A third node coming from far away moves slowly towards them. .. jupyter-execute:: :hide-output: f = irreps_input.randn(3, -1) xs = torch.linspace(-1.3, -1.0, 200) outputs = [] for x in xs: pos = torch.tensor([ [0.0, 0.5, 0.0], # this node always sees... [0.0, -0.5, 0.0], # ...this node [x.item(), 0.0, 0.0], # this node moves slowly ]) with torch.no_grad(): outputs.append(transformer(f, pos)) outputs = torch.stack(outputs) plt.plot(xs, outputs[:, 0, [0, 1, 14, 15, 16]], 'k') # plots 2 scalars and 1 vector plt.plot(xs, outputs[:, 1, [0, 1, 14, 15, 16]], 'g') plt.plot(xs, outputs[:, 2, [0, 1, 14, 15, 16]], 'r') .. jupyter-execute:: :hide-code: plt.plot(xs, outputs[:, 0, [0, 1, 14, 15, 16]], 'k') plt.plot(xs, outputs[:, 1, [0, 1, 14, 15, 16]], 'g') plt.plot(xs, outputs[:, 2, [0, 1, 14, 15, 16]], 'r') plt.xlabel('3rd node position') plt.ylabel('output features') plt.plot([], [], 'k', label='1st node') plt.plot([], [], 'g', label='2nd node') plt.plot([], [], 'r', label='3rd node') plt.legend() plt.tight_layout(); Finally we can check the equivariance: .. jupyter-execute:: f = irreps_input.randn(10, -1) pos = torch.randn(10, 3) rot = o3.rand_matrix() D_in = irreps_input.D_from_matrix(rot) D_out = irreps_output.D_from_matrix(rot) f_before = transformer(f @ D_in.T, pos @ rot.T) f_after = transformer(f, pos) @ D_out.T torch.allclose(f_before, f_after, atol=1e-3, rtol=1e-3) Extra sanity check of the backward pass: .. jupyter-execute:: for x in [0.0, 1e-6, max_radius / 2, max_radius - 1e-6, max_radius, max_radius + 1e-6, 2 * max_radius]: f = irreps_input.randn(2, -1, requires_grad=True) pos = torch.tensor([ [0.0, 0.0, 0.0], [x, 0.0, 0.0], ], requires_grad=True) transformer(f, pos).sum().backward() assert f.grad is None or torch.isfinite(f.grad).all() assert torch.isfinite(pos.grad).all() .. _SE(3)-Transformers: https://proceedings.neurips.cc/paper/2020/file/15231a7ce4ba789d13b722cc5c955834-Paper.pdf .. _Wikipedia: https://en.wikipedia.org/wiki/Transformer_(machine_learning_model) e3nn-0.6.0/docs/index.rst000066400000000000000000000054661514371756200151770ustar00rootroot00000000000000Euclidean neural networks ========================= What is ``e3nn``? ----------------- ``e3nn`` is a python library based on pytorch_ to create equivariant neural networks for the group :math:`O(3)`. Where to start? --------------- - Guide to the `e3nn.o3.Irreps`: :ref:`irreps guide` - Guide to implement a :ref:`conv guide` - The simplest example to start with is :ref:`tetris_poly`. - Guide to implement a :ref:`transformer guide` .. toctree:: :maxdepth: 2 api/e3nn guide/guide examples/examples Demonstration ------------- All the functions to manipulate rotations (rotation matrices, Euler angles, quaternions, convertions, ...) can be found here :ref:`Rotation functions`. The irreducible representations of :math:`O(3)` (more info at :ref:`Irreducible representations`) are represented by the class `e3nn.o3.Irrep`. The direct sum of multiple irrep is described by an object `e3nn.o3.Irreps`. If two tensors :math:`x` and :math:`y` transforms as :math:`D_x = 2 \times 1_o` (two vectors) and :math:`D_y = 0_e + 1_e` (a scalar and a pseudovector) respectively, where the indices :math:`e` and :math:`o` stand for even and odd -- the representation of parity, .. jupyter-execute:: import torch from e3nn import o3 irreps_x = o3.Irreps('2x1o') irreps_y = o3.Irreps('0e + 1e') x = irreps_x.randn(-1) y = irreps_y.randn(-1) irreps_x.dim, irreps_y.dim their outer product is a :math:`6 \times 4` matrix of two indices :math:`A_{ij} = x_i y_j`. .. jupyter-execute:: A = torch.einsum('i,j', x, y) A If a rotation is applied to the system, this matrix will transform with the representation :math:`D_x \otimes D_y` (the tensor product representation). .. math:: A = x y^t \longrightarrow A' = D_x A D_y^t Which can be represented by .. jupyter-execute:: :hide-code: import matplotlib.pyplot as plt .. jupyter-execute:: R = o3.rand_matrix() D_x = irreps_x.D_from_matrix(R) D_y = irreps_y.D_from_matrix(R) plt.imshow(torch.kron(D_x, D_y), cmap='bwr', vmin=-1, vmax=1); This representation is not irreducible (is reducible). It can be decomposed into irreps by a change of basis. The outerproduct followed by the change of basis is done by the class `e3nn.o3.FullTensorProduct`. .. jupyter-execute:: tp = o3.FullTensorProduct(irreps_x, irreps_y) print(tp) tp(x, y) As a sanity check, we can verify that the representation of the tensor prodcut is block diagonal and of the same dimension. .. jupyter-execute:: D = tp.irreps_out.D_from_matrix(R) plt.imshow(D, cmap='bwr', vmin=-1, vmax=1); `e3nn.o3.FullTensorProduct` is a special case of `e3nn.o3.TensorProduct`, other ones like `e3nn.o3.FullyConnectedTensorProduct` can contained weights what can be learned, very useful to create neural networks. .. _pytorch: https://pytorch.org/ e3nn-0.6.0/docs/make.bat000066400000000000000000000014331514371756200147310ustar00rootroot00000000000000@ECHO OFF pushd %~dp0 REM Command file for Sphinx documentation if "%SPHINXBUILD%" == "" ( set SPHINXBUILD=sphinx-build ) set SOURCEDIR=. set BUILDDIR=_build if "%1" == "" goto help %SPHINXBUILD% >NUL 2>NUL if errorlevel 9009 ( echo. echo.The 'sphinx-build' command was not found. Make sure you have Sphinx echo.installed, then set the SPHINXBUILD environment variable to point echo.to the full path of the 'sphinx-build' executable. Alternatively you echo.may add the Sphinx directory to PATH. echo. echo.If you don't have Sphinx installed, grab it from echo.http://sphinx-doc.org/ exit /b 1 ) %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% goto end :help %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% :end popd e3nn-0.6.0/docs/requirements.txt000066400000000000000000000004431514371756200166100ustar00rootroot00000000000000autodocsumm myst-parser sphinx sphinx-rtd-theme sympy ipykernel plotly jupyter-sphinx ase --find-links https://download.pytorch.org/whl/cpu torch==2.4.0 --find-links https://data.pyg.org/whl/torch-2.4.0+cpu.html torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric .e3nn-0.6.0/e3nn/000077500000000000000000000000001514371756200132365ustar00rootroot00000000000000e3nn-0.6.0/e3nn/__init__.py000066400000000000000000000052111514371756200153460ustar00rootroot00000000000000__version__ = "0.6.0" from typing import Dict import torch import packaging.version # torch.jit.script is deprecated in PT 2.10+ _TORCH_VERSION = packaging.version.parse(torch.__version__.split("+")[0]) _DEFAULT_JIT_MODE = "eager" if _TORCH_VERSION >= packaging.version.parse("2.10") else "script" _OPT_DEFAULTS: Dict[str, bool] = dict(specialized_code=True, optimize_einsums=True, jit_script_fx=True, jit_mode=_DEFAULT_JIT_MODE) def _handle_jit_script_fx_legacy(jit_script_fx: bool, current_jit_mode: str) -> str: """Handle the legacy jit_script_fx flag mapping to jit_mode. Parameters ---------- jit_script_fx : bool The legacy jit_script_fx flag value current_jit_mode : str The current jit_mode value Returns ------- str The new jit_mode value based on the legacy mapping rules """ if not jit_script_fx and current_jit_mode == "eager": # Keep it eager return "eager" elif not jit_script_fx: # Map False to eager if not already eager return "eager" elif jit_script_fx and current_jit_mode not in ["script", "inductor"]: # Map True to script only if not already script or inductor return "script" # In all other cases, keep current jit_mode return current_jit_mode def _validate_and_set_jit_mode(jit_mode: str) -> None: """Validate and set the jit_mode in _OPT_DEFAULTS.""" assert jit_mode in [ "script", "inductor", "eager", ], f"Invalid jit_mode: {jit_mode}. Expected 'script', 'inductor', or 'eager'." _OPT_DEFAULTS["jit_mode"] = jit_mode def set_optimization_defaults(**kwargs) -> None: r"""Globally set the default optimization settings. Parameters ---------- **kwargs Keyword arguments to set the default optimization settings. """ for k, v in kwargs.items(): if k not in _OPT_DEFAULTS: raise ValueError(f"Unknown optimization option: {k}") # Handles the legacy mapping for jit_script_fx # to jit_mode so that old code can still work # with the new defaults. if k == "jit_script_fx": # Update jit_mode based on the legacy mapping new_jit_mode = _handle_jit_script_fx_legacy(v, _OPT_DEFAULTS["jit_mode"]) _validate_and_set_jit_mode(new_jit_mode) _OPT_DEFAULTS[k] = v elif k == "jit_mode": # Validate and set the new jit_mode _validate_and_set_jit_mode(v) else: _OPT_DEFAULTS[k] = v def get_optimization_defaults() -> Dict[str, bool]: r"""Get the global default optimization settings.""" return dict(_OPT_DEFAULTS) e3nn-0.6.0/e3nn/io/000077500000000000000000000000001514371756200136455ustar00rootroot00000000000000e3nn-0.6.0/e3nn/io/__init__.py000066400000000000000000000002341514371756200157550ustar00rootroot00000000000000from ._cartesian_tensor import CartesianTensor from ._spherical_tensor import SphericalTensor __all__ = [ "CartesianTensor", "SphericalTensor", ] e3nn-0.6.0/e3nn/io/_cartesian_tensor.py000066400000000000000000000067411514371756200177310ustar00rootroot00000000000000from typing import Optional import torch from e3nn.o3._irreps import Irreps from e3nn.o3._reduce import ReducedTensorProducts class CartesianTensor(Irreps): r"""representation of a cartesian tensor into irreps Parameters ---------- formula : str Examples -------- >>> import torch >>> CartesianTensor("ij=-ji") 1x1e >>> x = CartesianTensor("ijk=-jik=-ikj") >>> x.from_cartesian(torch.ones(3, 3, 3)) tensor([0.]) >>> x.from_vectors(torch.ones(3), torch.ones(3), torch.ones(3)) tensor([0.]) >>> x = CartesianTensor("ij=ji") >>> t = torch.arange(9).to(torch.float).view(3,3) >>> y = x.from_cartesian(t) >>> z = x.to_cartesian(y) >>> torch.allclose(z, (t + t.T)/2, atol=1e-5) True """ # pylint: disable=abstract-method # These are set in __new__ formula: str indices: str def __new__( # pylint: disable=signature-differs cls, formula, ): indices = formula.split("=")[0].replace("-", "") rtp = ReducedTensorProducts(formula, **{i: "1o" for i in indices}) ret = super().__new__(cls, rtp.irreps_out) ret.formula = formula ret.indices = indices return ret def from_cartesian(self, data, rtp=None): r"""convert cartesian tensor into irreps Parameters ---------- data : `torch.Tensor` cartesian tensor of shape ``(..., 3, 3, 3, ...)`` Returns ------- `torch.Tensor` irreps tensor of shape ``(..., self.dim)`` """ if rtp is None: rtp = self.reduced_tensor_products(data) Q = rtp.change_of_basis.flatten(-len(self.indices)) return data.flatten(-len(self.indices)) @ Q.T def from_vectors(self, *xs, rtp=None): r"""convert :math:`x_1 \otimes x_2 \otimes x_3 \otimes \dots` Parameters ---------- xs : list of `torch.Tensor` list of vectors of shape ``(..., 3)`` Returns ------- `torch.Tensor` irreps tensor of shape ``(..., self.dim)`` """ if rtp is None: rtp = self.reduced_tensor_products(xs[0]) return rtp(*xs) # pylint: disable=not-callable def to_cartesian(self, data, rtp=None): r"""convert irreps tensor to cartesian tensor This is the symmetry-aware inverse operation of ``from_cartesian()``. Parameters ---------- data : `torch.Tensor` irreps tensor of shape ``(..., D)``, where D is the dimension of the irreps, i.e. ``D=self.dim``. Returns ------- `torch.Tensor` cartesian tensor of shape ``(..., 3, 3, 3, ...)`` """ if rtp is None: rtp = self.reduced_tensor_products(data) Q = rtp.change_of_basis cartesian_tensor = data @ Q.flatten(-len(self.indices)) shape = list(data.shape[:-1]) + list(Q.shape[1:]) cartesian_tensor = cartesian_tensor.view(shape) return cartesian_tensor def reduced_tensor_products(self, data: Optional[torch.Tensor] = None) -> ReducedTensorProducts: r"""reduced tensor products Returns ------- `e3nn.ReducedTensorProducts` reduced tensor products """ rtp = ReducedTensorProducts(self.formula, **{i: "1o" for i in self.indices}) if data is not None: rtp = rtp.to(device=data.device, dtype=data.dtype) return rtp e3nn-0.6.0/e3nn/io/_spherical_tensor.py000066400000000000000000000312441514371756200177260ustar00rootroot00000000000000from math import pi from collections import namedtuple from typing import Tuple import scipy.signal import torch from e3nn import o3 from e3nn.o3 import FromS2Grid, ToS2Grid def _find_peaks_2d(x): iii = [] for i in range(x.shape[0]): jj, _ = scipy.signal.find_peaks(x[i, :]) iii += [(i, j) for j in jj] jjj = [] for j in range(x.shape[1]): ii, _ = scipy.signal.find_peaks(x[:, j]) jjj += [(i, j) for i in ii] return list(set(iii).intersection(set(jjj))) class SphericalTensor(o3.Irreps): r"""representation of a signal on the sphere A `SphericalTensor` contains the coefficients :math:`A^l` of a function :math:`f` defined on the sphere .. math:: f(x) = \sum_{l=0}^{l_\mathrm{max}} A^l \cdot Y^l(x) The way this function is transformed by parity :math:`f \longrightarrow P f` is described by the two parameters :math:`p_v` and :math:`p_a` .. math:: (P f)(x) &= p_v f(p_a x) &= \sum_{l=0}^{l_\mathrm{max}} p_v p_a^l A^l \cdot Y^l(x) Parameters ---------- lmax : int :math:`l_\mathrm{max}` p_val : {+1, -1} :math:`p_v` p_arg : {+1, -1} :math:`p_a` Examples -------- >>> SphericalTensor(3, 1, 1) 1x0e+1x1e+1x2e+1x3e >>> SphericalTensor(3, 1, -1) 1x0e+1x1o+1x2e+1x3o """ # pylint: disable=abstract-method def __new__( # pylint: disable=signature-differs cls, lmax, p_val, p_arg, ): return super().__new__(cls, [(1, (l, p_val * p_arg**l)) for l in range(lmax + 1)]) def with_peaks_at(self, vectors, values=None): r"""Create a spherical tensor with peaks The peaks are located in :math:`\vec r_i` and have amplitude :math:`\|\vec r_i \|` Parameters ---------- vectors : `torch.Tensor` :math:`\vec r_i` tensor of shape ``(N, 3)`` values : `torch.Tensor`, optional value on the peak, tensor of shape ``(N)`` Returns ------- `torch.Tensor` tensor of shape ``(self.dim,)`` Examples -------- >>> s = SphericalTensor(4, 1, -1) >>> pos = torch.tensor([ ... [1.0, 0.0, 0.0], ... [3.0, 4.0, 0.0], ... ]) >>> x = s.with_peaks_at(pos) >>> s.signal_xyz(x, pos).long() tensor([1, 5]) >>> val = torch.tensor([ ... -1.5, ... 2.0, ... ]) >>> x = s.with_peaks_at(pos, val) >>> s.signal_xyz(x, pos) tensor([-1.5000, 2.0000]) """ if values is not None: vectors, values = torch.broadcast_tensors(vectors, values[..., None]) values = values[..., 0] # empty set of vectors returns a 0 spherical tensor if vectors.numel() == 0: return torch.zeros(vectors.shape[:-2] + (self.dim,)) assert ( self[0][1].p == 1 ), "since the value is set by the radii who is even, p_val has to be 1" # pylint: disable=no-member assert vectors.dim() == 2 and vectors.shape[1] == 3 if values is None: values = vectors.norm(dim=1) # [batch] vectors = vectors[values != 0] # [batch, 3] values = values[values != 0] coeff = o3.spherical_harmonics(self, vectors, normalize=True) # [batch, l * m] A = torch.einsum("ai,bi->ab", coeff, coeff) # Y(v_a) . Y(v_b) solution_b = radii_a solution = torch.linalg.lstsq(A, values).solution.reshape(-1) # [b] assert (values - A @ solution).abs().max() < 1e-5 * values.abs().max() return solution @ coeff def sum_of_diracs(self, positions: torch.Tensor, values: torch.Tensor) -> torch.Tensor: r"""Sum (almost-) dirac deltas .. math:: f(x) = \sum_i v_i \delta^L(\vec r_i) where :math:`\delta^L` is the apporximation of a dirac delta. Parameters ---------- positions : `torch.Tensor` :math:`\vec r_i` tensor of shape ``(..., N, 3)`` values : `torch.Tensor` :math:`v_i` tensor of shape ``(..., N)`` Returns ------- `torch.Tensor` tensor of shape ``(..., self.dim)`` Examples -------- >>> s = SphericalTensor(7, 1, -1) >>> pos = torch.tensor([ ... [1.0, 0.0, 0.0], ... [0.0, 1.0, 0.0], ... ]) >>> val = torch.tensor([ ... -1.0, ... 1.0, ... ]) >>> x = s.sum_of_diracs(pos, val) >>> s.signal_xyz(x, torch.eye(3)).mul(10.0).round() tensor([-10., 10., -0.]) >>> s.sum_of_diracs(torch.empty(1, 0, 2, 3), torch.empty(2, 0, 1)).shape torch.Size([2, 0, 64]) >>> s.sum_of_diracs(torch.randn(1, 3, 2, 3), torch.randn(2, 1, 1)).shape torch.Size([2, 3, 64]) """ positions, values = torch.broadcast_tensors(positions, values[..., None]) values = values[..., 0] if positions.numel() == 0: return torch.zeros(values.shape[:-1] + (self.dim,)) y = o3.spherical_harmonics(self, positions, True) # [..., N, dim] v = values[..., None] return 4 * pi / (self.lmax + 1) ** 2 * (y * v).sum(-2) def from_samples_on_s2(self, positions: torch.Tensor, values: torch.Tensor, res: int = 100) -> torch.Tensor: r"""Convert a set of position on the sphere and values into a spherical tensor Parameters ---------- positions : `torch.Tensor` tensor of shape ``(..., N, 3)`` values : `torch.Tensor` tensor of shape ``(..., N)`` Returns ------- `torch.Tensor` tensor of shape ``(..., self.dim)`` Examples -------- >>> s = SphericalTensor(2, 1, 1) >>> pos = torch.tensor([ ... [ ... [0.0, 0.0, 1.0], ... [0.0, 0.0, -1.0], ... ], ... [ ... [0.0, 1.0, 0.0], ... [0.0, -1.0, 0.0], ... ], ... ], dtype=torch.float64) >>> val = torch.tensor([ ... [ ... 1.0, ... -1.0, ... ], ... [ ... 1.0, ... -1.0, ... ], ... ], dtype=torch.float64) >>> s.from_samples_on_s2(pos, val, res=200).long() tensor([[0, 0, 0, 3, 0, 0, 0, 0, 0], [0, 0, 3, 0, 0, 0, 0, 0, 0]]) >>> pos = torch.empty(2, 0, 10, 3) >>> val = torch.empty(2, 0, 10) >>> s.from_samples_on_s2(pos, val) tensor([], size=(2, 0, 9)) """ positions, values = torch.broadcast_tensors(positions, values[..., None]) values = values[..., 0] if positions.numel() == 0: return torch.zeros(values.shape[:-1] + (self.dim,)) positions = torch.nn.functional.normalize(positions, dim=-1) # forward 0's instead of nan for zero-radius size = positions.shape[:-2] n = positions.shape[-2] positions = positions.reshape(-1, n, 3) values = values.reshape(-1, n) s2 = FromS2Grid(res=res, lmax=self.lmax, normalization="integral", dtype=values.dtype, device=values.device) pos = s2.grid.reshape(1, -1, 3) cd = torch.cdist(pos, positions) # [batch, b*a, N] i = torch.arange(len(values)).view(-1, 1) # [batch, 1] j = cd.argmin(2) # [batch, b*a] val = values[i, j] # [batch, b*a] val = val.reshape(*size, s2.res_beta, s2.res_alpha) return s2(val) def norms(self, signal) -> torch.Tensor: r"""The norms of each l component Parameters ---------- signal : `torch.Tensor` tensor of shape ``(..., dim)`` Returns ------- `torch.Tensor` tensor of shape ``(..., lmax+1)`` Examples -------- Examples -------- >>> s = SphericalTensor(1, 1, -1) >>> s.norms(torch.tensor([1.5, 0.0, 3.0, 4.0])) tensor([1.5000, 5.0000]) """ i = 0 norms = [] for _, ir in self: norms += [signal[..., i : i + ir.dim].norm(dim=-1)] i += ir.dim return torch.stack(norms, dim=-1) def signal_xyz(self, signal, r) -> torch.Tensor: r"""Evaluate the signal on given points on the sphere .. math:: f(\vec x / \|\vec x\|) Parameters ---------- signal : `torch.Tensor` tensor of shape ``(*A, self.dim)`` r : `torch.Tensor` tensor of shape ``(*B, 3)`` Returns ------- `torch.Tensor` tensor of shape ``(*A, *B)`` Examples -------- >>> s = SphericalTensor(3, 1, -1) >>> s.signal_xyz(s.randn(2, 1, 3, -1), torch.randn(2, 4, 3)).shape torch.Size([2, 1, 3, 2, 4]) """ sh = o3.spherical_harmonics(self, r, normalize=True) dim = (self.lmax + 1) ** 2 output = torch.einsum("bi,ai->ab", sh.reshape(-1, dim), signal.reshape(-1, dim)) return output.reshape(signal.shape[:-1] + r.shape[:-1]) def signal_on_grid(self, signal, res: int = 100, normalization: str = "integral"): r"""Evaluate the signal on a grid on the sphere""" Ret = namedtuple("Return", "grid, values") s2 = ToS2Grid(lmax=self.lmax, res=res, normalization=normalization) return Ret(s2.grid, s2(signal)) def plotly_surface( self, signals, centers=None, res: int = 100, radius: bool = True, relu: bool = False, normalization: str = "integral" ): r"""Create traces for plotly Examples -------- >>> import plotly.graph_objects as go >>> x = SphericalTensor(4, +1, +1) >>> traces = x.plotly_surface(x.randn(-1)) >>> traces = [go.Surface(**d) for d in traces] >>> fig = go.Figure(data=traces) """ signals = signals.reshape(-1, self.dim) if centers is None: centers = [None] * len(signals) else: centers = centers.reshape(-1, 3) traces = [] for signal, center in zip(signals, centers): r, f = self.plot(signal, center, res, radius, relu, normalization) traces += [ { "x": r[:, :, 0].numpy(), "y": r[:, :, 1].numpy(), "z": r[:, :, 2].numpy(), "surfacecolor": f.numpy(), } ] return traces def plot( self, signal, center=None, res: int = 100, radius: bool = True, relu: bool = False, normalization: str = "integral" ) -> Tuple[torch.Tensor, torch.Tensor]: r"""Create surface in order to make a plot""" assert signal.dim() == 1 r, f = self.signal_on_grid(signal, res, normalization) f = f.relu() if relu else f # beta: [0, pi] r[0] = r.new_tensor([0.0, 1.0, 0.0]) r[-1] = r.new_tensor([0.0, -1.0, 0.0]) f[0] = f[0].mean() f[-1] = f[-1].mean() # alpha: [0, 2pi] r = torch.cat([r, r[:, :1]], dim=1) # [beta, alpha, 3] f = torch.cat([f, f[:, :1]], dim=1) # [beta, alpha] if radius: r *= f.abs().unsqueeze(-1) if center is not None: r += center return r, f def find_peaks(self, signal, res: int = 100) -> Tuple[torch.Tensor, torch.Tensor]: r"""Locate peaks on the sphere Examples -------- >>> s = SphericalTensor(4, 1, -1) >>> pos = torch.tensor([ ... [4.0, 0.0, 4.0], ... [0.0, 5.0, 0.0], ... ]) >>> x = s.with_peaks_at(pos) >>> pos, val = s.find_peaks(x) >>> pos[val > 4.0].mul(10).round().abs() tensor([[ 7., 0., 7.], [ 0., 10., 0.]]) >>> val[val > 4.0].mul(10).round().abs() tensor([57., 50.]) """ x1, f1 = self.signal_on_grid(signal, res) abc = torch.tensor([pi / 2, pi / 2, pi / 2]) R = o3.angles_to_matrix(*abc) D = self.D_from_matrix(R) r_signal = D @ signal rx2, f2 = self.signal_on_grid(r_signal, res) x2 = torch.einsum("ij,baj->bai", R.T, rx2) ij = _find_peaks_2d(f1) x1p = torch.stack([x1[i, j] for i, j in ij]) f1p = torch.stack([f1[i, j] for i, j in ij]) ij = _find_peaks_2d(f2) x2p = torch.stack([x2[i, j] for i, j in ij]) f2p = torch.stack([f2[i, j] for i, j in ij]) # Union of the results mask = torch.cdist(x1p, x2p) < 2 * pi / res x = torch.cat([x1p[mask.sum(1) == 0], x2p]) f = torch.cat([f1p[mask.sum(1) == 0], f2p]) return x, f e3nn-0.6.0/e3nn/math/000077500000000000000000000000001514371756200141675ustar00rootroot00000000000000e3nn-0.6.0/e3nn/math/__init__.py000066400000000000000000000010341514371756200162760ustar00rootroot00000000000000from ._linalg import complete_basis, direct_sum, orthonormalize from ._normalize_activation import moment, normalize2mom from ._soft_unit_step import soft_unit_step from ._soft_one_hot_linspace import soft_one_hot_linspace from ._reduce import germinate_formulas, reduce_permutation from ._bessel import bessel __all__ = [ "complete_basis", "direct_sum", "orthonormalize", "moment", "normalize2mom", "soft_unit_step", "bessel", "soft_one_hot_linspace", "germinate_formulas", "reduce_permutation", ] e3nn-0.6.0/e3nn/math/_bessel.py000066400000000000000000000021641514371756200161600ustar00rootroot00000000000000# Porting from https://github.com/e3nn/e3nn-jax/blob/403f0ef159c537c2efa9d3d05799da85a86575d5/e3nn_jax/_src/radial.py#L213-L244 import torch import numpy as np def bessel(x: torch.Tensor, n: int, x_max: float=1.0) -> torch.Tensor: r"""Bessel basis functions. They obey the following normalization: .. math:: \int_0^c r^2 B_n(r, c) B_m(r, c) dr = \delta_{nm} Args: x (torch.Tensor): input of shape ``[...]`` n (int): number of basis functions x_max (float): maximum value of the input Returns: torch.Tensor: basis functions of shape ``[..., n]`` Klicpera, J.; Groß, J.; Günnemann, S. Directional Message Passing for Molecular Graphs; ICLR 2020. Equation (7) """ assert isinstance(n, int) x = x[..., None] n = torch.arange(1, n + 1, dtype=x.dtype, device=x.device) x_nonzero = torch.where(x == 0.0, 1.0, x) return np.sqrt(2.0 / x_max) * torch.where( x == 0, n * torch.pi / x_max, torch.sin(n * torch.pi / x_max * x_nonzero) / x_nonzero )e3nn-0.6.0/e3nn/math/_linalg.py000066400000000000000000000052611514371756200161520ustar00rootroot00000000000000from typing import Tuple import torch from e3nn import get_optimization_defaults def _conditional_script(fn): """apply torch.jit.script only if jit_mode is 'script'""" if get_optimization_defaults()["jit_mode"] == "script": return torch.jit.script(fn) return fn def direct_sum(*matrices): r"""Direct sum of matrices, put them in the diagonal""" front_indices = matrices[0].shape[:-2] m = sum(x.size(-2) for x in matrices) n = sum(x.size(-1) for x in matrices) total_shape = list(front_indices) + [m, n] out = matrices[0].new_zeros(total_shape) i, j = 0, 0 for x in matrices: m, n = x.shape[-2:] out[..., i : i + m, j : j + n] = x i += m j += n return out @_conditional_script def orthonormalize(original: torch.Tensor, eps: float = 1e-9) -> Tuple[torch.Tensor, torch.Tensor]: r"""orthonomalize vectors Parameters ---------- original : `torch.Tensor` list of the original vectors :math:`x` eps : float a small number Returns ------- final : `torch.Tensor` list of orthonomalized vectors :math:`y` matrix : `torch.Tensor` the matrix :math:`A` such that :math:`y = A x` """ assert original.dim() == 2 dim = original.shape[1] final = [] matrix = [] for i, x in enumerate(original): # x = sum_i cx_i original_i cx = x.new_zeros(len(original)) cx[i] = 1 for j, y in enumerate(final): c = torch.dot(x, y) x = x - c * y cx = cx - c * matrix[j] if x.norm() > 2 * eps: c = 1 / x.norm() x = c * x cx = c * cx x[x.abs() < eps] = 0 cx[cx.abs() < eps] = 0 c = x[x.nonzero()[0, 0]].sign() x = c * x cx = c * cx final += [x] matrix += [cx] final = torch.stack(final) if len(final) > 0 else original.new_zeros((0, dim)) matrix = torch.stack(matrix) if len(matrix) > 0 else original.new_zeros((0, len(original))) return final, matrix @_conditional_script def complete_basis(vecs: torch.Tensor, eps: float = 1e-9) -> torch.Tensor: assert vecs.dim() == 2 dim = vecs.shape[1] base = [x / x.norm() for x in vecs] expand = [] for x in torch.eye(dim, device=vecs.device, dtype=vecs.dtype): for y in base + expand: x -= torch.dot(x, y) * y if x.norm() > 2 * eps: x /= x.norm() x[x.abs() < eps] = x.new_zeros(()) x *= x[x.nonzero()[0, 0]].sign() expand += [x] expand = torch.stack(expand) if len(expand) > 0 else vecs.new_zeros(0, dim) return expand e3nn-0.6.0/e3nn/math/_normalize_activation.py000066400000000000000000000035711514371756200211270ustar00rootroot00000000000000from typing import Dict, List, Tuple import torch from e3nn.util.default_type import explicit_default_types from e3nn.util.jit import compile_mode def moment(f, n, dtype=None, device=None): r""" compute n th moment for z normal """ dtype, device = explicit_default_types(dtype, device) gen = torch.Generator(device=device).manual_seed(0) z = torch.randn(1_000_000, generator=gen, dtype=torch.float64, device=device).to(dtype=dtype, device=device) return f(z).pow(n).mean() @compile_mode("trace") class normalize2mom(torch.nn.Module): _is_id: bool cst: float def __init__( # pylint: disable=unused-argument self, f, dtype=None, device=None, ) -> None: super().__init__() # Try to infer a device: if device is None and isinstance(f, torch.nn.Module): # Avoid circular import from e3nn.util._argtools import _get_device device = _get_device(f) with torch.no_grad(): cst = moment(f, 2, dtype=torch.float64, device=device).pow(-0.5).item() if abs(cst - 1) < 1e-4: self._is_id = True else: self._is_id = False self.f = f self.cst = cst def forward(self, x): if self._is_id: return self.f(x) else: return self.f(x).mul(self.cst) @staticmethod def _make_tracing_inputs( # pylint: disable=unused-argument n: int, ) -> List[Dict[str, Tuple[torch.Tensor]]]: # No reason to trace this with more than one tiny input, # since f is assumed by `moment` to be an elementwise scalar # function return [ { "forward": ( torch.zeros( size=(1,), ), ) } ] e3nn-0.6.0/e3nn/math/_reduce.py000066400000000000000000000075161514371756200161600ustar00rootroot00000000000000import itertools import torch from e3nn.math import perm def germinate_formulas(formula): formulas = [(-1 if f.startswith("-") else 1, f.replace("-", "")) for f in formula.split("=")] s0, f0 = formulas[0] assert s0 == 1 for _s, f in formulas: if len(set(f)) != len(f) or set(f) != set(f0): raise RuntimeError(f"{f} is not a permutation of {f0}") if len(f0) != len(f): raise RuntimeError(f"{f0} and {f} don't have the same number of indices") # `formulas` is a list of (sign, permutation of indices) # each formula can be viewed as a permutation of the original formula formulas = {(s, tuple(f.index(i) for i in f0)) for s, f in formulas} # set of generators (permutations) # they can be composed, for instance if you have ijk=jik=ikj # you also have ijk=jki # applying all possible compositions creates an entire group while True: n = len(formulas) formulas = formulas.union([(s, perm.inverse(p)) for s, p in formulas]) formulas = formulas.union([(s1 * s2, perm.compose(p1, p2)) for s1, p1 in formulas for s2, p2 in formulas]) if len(formulas) == n: break # we break when the set is stable => it is now a group \o/ return f0, formulas def reduce_permutation(f0, formulas, dtype=None, device=None, **dims): r""" Parameters ---------- f0 : str formulas : list of tuple (int, str) dims : dict of str -> int Examples -------- >>> Q, ret = reduce_permutation(*germinate_formulas("ij=-ji"), i=2) >>> Q.shape, len(ret) (torch.Size([1, 2, 2]), 1) """ # here we check that each index has one and only one dimension for _s, p in formulas: f = "".join(f0[i] for i in p) for i, j in zip(f0, f): if i in dims and j in dims and dims[i] != dims[j]: raise RuntimeError(f"dimension of {i} and {j} should be the same") if i in dims: dims[j] = dims[i] if j in dims: dims[i] = dims[j] for i in f0: if i not in dims: raise RuntimeError(f"index {i} has no dimension associated to it") dims = [dims[i] for i in f0] full_base = list(itertools.product(*(range(d) for d in dims))) # (0, 0, 0), (0, 0, 1), (0, 0, 2), ... (3, 3, 3) # len(full_base) degrees of freedom in an unconstrained tensor # but there is constraints given by the group `formulas` # For instance if `ij=-ji`, then 00=-00, 01=-01 and so on base = set() for x in full_base: # T[x] is a coefficient of the tensor T and is related to other coefficient T[y] # if x and y are related by a formula xs = {(s, tuple(x[i] for i in p)) for s, p in formulas} # s * T[x] are all equal for all (s, x) in xs # if T[x] = -T[x] it is then equal to 0 and we lose this degree of freedom if (-1, x) not in xs: # the sign is arbitrary, put both possibilities base.add(frozenset({frozenset(xs), frozenset({(-s, x) for s, x in xs})})) # len(base) is the number of degrees of freedom in the tensor. base = sorted( [sorted([sorted(xs) for xs in x]) for x in base] ) # requested for python 3.7 but not for 3.8 (probably a bug in 3.7) # First we compute the change of basis (projection) between full_base and base d_sym = len(base) Q = torch.zeros(d_sym, len(full_base), dtype=dtype, device=device) ret = [] for i, x in enumerate(base): x = max(x, key=lambda xs: sum(s for s, x in xs)) ret.append(x) for s, e in x: # j = full_base.index(e) j = 0 for k, d in zip(e, dims): j *= d j += k Q[i, j] = s / len(x) ** 0.5 # assert torch.allclose(Q @ Q.T, torch.eye(d_sym)) Q = Q.reshape(d_sym, *dims) return Q, ret e3nn-0.6.0/e3nn/math/_soft_one_hot_linspace.py000066400000000000000000000106041514371756200212450ustar00rootroot00000000000000import math import torch from e3nn.math import soft_unit_step def soft_one_hot_linspace(x: torch.Tensor, start, end, number, basis=None, cutoff=None) -> torch.Tensor: r"""Projection on a basis of functions Returns a set of :math:`\{y_i(x)\}_{i=1}^N`, .. math:: y_i(x) = \frac{1}{Z} f_i(x) where :math:`x` is the input and :math:`f_i` is the ith basis function. :math:`Z` is a constant defined (if possible) such that, .. math:: \langle \sum_{i=1}^N y_i(x)^2 \rangle_x \approx 1 See the last plot below. Note that ``bessel`` basis cannot be normalized. Parameters ---------- x : `torch.Tensor` tensor of shape :math:`(...)` start : float minimum value span by the basis end : float maximum value span by the basis number : int number of basis functions :math:`N` basis : {'gaussian', 'cosine', 'smooth_finite', 'fourier', 'bessel'} choice of basis family; note that due to the :math:`1/x` term, ``bessel`` basis does not satisfy the normalization of other basis choices cutoff : bool if ``cutoff=True`` then for all :math:`x` outside of the interval defined by ``(start, end)``, :math:`\forall i, \; f_i(x) \approx 0` Returns ------- `torch.Tensor` tensor of shape :math:`(..., N)` Examples -------- .. jupyter-execute:: :hide-code: import torch from e3nn.math import soft_one_hot_linspace import matplotlib.pyplot as plt .. jupyter-execute:: bases = ['gaussian', 'cosine', 'smooth_finite', 'fourier', 'bessel'] x = torch.linspace(-1.0, 2.0, 100) .. jupyter-execute:: fig, axss = plt.subplots(len(bases), 2, figsize=(9, 6), sharex=True, sharey=True) for axs, b in zip(axss, bases): for ax, c in zip(axs, [True, False]): plt.sca(ax) plt.plot(x, soft_one_hot_linspace(x, -0.5, 1.5, number=4, basis=b, cutoff=c)) plt.plot([-0.5]*2, [-2, 2], 'k-.') plt.plot([1.5]*2, [-2, 2], 'k-.') plt.title(f"{b}" + (" with cutoff" if c else "")) plt.ylim(-1, 1.5) plt.tight_layout() .. jupyter-execute:: fig, axss = plt.subplots(len(bases), 2, figsize=(9, 6), sharex=True, sharey=True) for axs, b in zip(axss, bases): for ax, c in zip(axs, [True, False]): plt.sca(ax) plt.plot(x, soft_one_hot_linspace(x, -0.5, 1.5, number=4, basis=b, cutoff=c).pow(2).sum(1)) plt.plot([-0.5]*2, [-2, 2], 'k-.') plt.plot([1.5]*2, [-2, 2], 'k-.') plt.title(f"{b}" + (" with cutoff" if c else "")) plt.ylim(0, 2) plt.tight_layout() """ # pylint: disable=misplaced-comparison-constant if cutoff not in [True, False]: raise ValueError("cutoff must be specified") if not cutoff: values = torch.linspace(start, end, number, dtype=x.dtype, device=x.device) step = values[1] - values[0] else: values = torch.linspace(start, end, number + 2, dtype=x.dtype, device=x.device) step = values[1] - values[0] values = values[1:-1] diff = (x[..., None] - values) / step if basis == "gaussian": return diff.pow(2).neg().exp().div(1.12) if basis == "cosine": return torch.cos(math.pi / 2 * diff) * (diff < 1) * (-1 < diff) if basis == "smooth_finite": return 1.14136 * torch.exp(torch.tensor(2.0)) * soft_unit_step(diff + 1) * soft_unit_step(1 - diff) if basis == "fourier": x = (x[..., None] - start) / (end - start) if not cutoff: i = torch.arange(0, number, dtype=x.dtype, device=x.device) return torch.cos(math.pi * i * x) / math.sqrt(0.25 + number / 2) else: i = torch.arange(1, number + 1, dtype=x.dtype, device=x.device) return torch.sin(math.pi * i * x) / math.sqrt(0.25 + number / 2) * (0 < x) * (x < 1) if basis == "bessel": x = x[..., None] - start c = end - start bessel_roots = torch.arange(1, number + 1, dtype=x.dtype, device=x.device) * math.pi out = math.sqrt(2 / c) * torch.sin(bessel_roots * x / c) / x if not cutoff: return out else: return out * ((x / c) < 1) * (0 < x) raise ValueError(f'basis="{basis}" is not a valid entry') e3nn-0.6.0/e3nn/math/_soft_unit_step.py000066400000000000000000000025721514371756200177530ustar00rootroot00000000000000import torch class _SoftUnitStep(torch.autograd.Function): # pylint: disable=arguments-differ @staticmethod def forward(ctx, x) -> torch.Tensor: ctx.save_for_backward(x) y = torch.zeros_like(x) mask = x > 0.0 safe_x = torch.where(mask, x, torch.ones_like(x)) # Avoid division by zero y = torch.where(mask, torch.exp(-1.0 / safe_x), torch.zeros_like(x)) return y @staticmethod def backward(ctx, dy) -> torch.Tensor: (x,) = ctx.saved_tensors mask = x > 0.0 safe_x = torch.where(mask, x, torch.ones_like(x)) # Avoid division by zero dx = torch.where(mask, torch.exp(-1.0 / safe_x) / (safe_x * safe_x), torch.zeros_like(x)) return dx * dy def soft_unit_step(x): r"""smooth :math:`C^\infty` version of the unit step function .. math:: x \mapsto \theta(x) e^{-1/x} Parameters ---------- x : `torch.Tensor` tensor of shape :math:`(...)` Returns ------- `torch.Tensor` tensor of shape :math:`(...)` Examples -------- .. jupyter-execute:: :hide-code: import torch from e3nn.math import soft_unit_step import matplotlib.pyplot as plt .. jupyter-execute:: x = torch.linspace(-1.0, 10.0, 1000) plt.plot(x, soft_unit_step(x)); """ return _SoftUnitStep.apply(x) e3nn-0.6.0/e3nn/math/perm.py000066400000000000000000000060771514371756200155160ustar00rootroot00000000000000from typing import Tuple, Set, Optional import random import math import torch from e3nn.math import complete_basis TY_PERM = Tuple[int] def is_perm(p: TY_PERM): return sorted(set(p)) == list(range(len(p))) def identity(n: int) -> TY_PERM: return tuple(i for i in range(n)) def compose(p1: TY_PERM, p2: TY_PERM) -> TY_PERM: r""" compute p1 . p2 """ assert is_perm(p1) and is_perm(p2) assert len(p1) == len(p2) # p: i |-> p[i] # [p1.p2](i) = p1(p2(i)) = p1[p2[i]] return tuple(p1[p2[i]] for i in range(len(p1))) def inverse(p: TY_PERM) -> TY_PERM: r""" compute the inverse permutation """ return tuple(p.index(i) for i in range(len(p))) def rand(n: int) -> TY_PERM: i = random.randint(0, math.factorial(n) - 1) return from_int(i, n) def from_int(i: int, n: int) -> TY_PERM: pool = list(range(n)) p = [] for _ in range(n): j = i % n i = i // n p.append(pool.pop(j)) n -= 1 return tuple(p) def to_int(p: TY_PERM) -> int: n = len(p) pool = list(range(n)) i = 0 m = 1 for j in p: k = pool.index(j) i += k * m m *= len(pool) pool.pop(k) return i def group(n: int) -> Set[TY_PERM]: return {from_int(i, n) for i in range(math.factorial(n))} def germinate(subset: Set[TY_PERM]) -> Set[TY_PERM]: while True: n = len(subset) subset = subset.union([inverse(p) for p in subset]) subset = subset.union([compose(p1, p2) for p1 in subset for p2 in subset]) if len(subset) == n: return subset def is_group(g: Set[TY_PERM]) -> bool: if len(g) == 0: return False n = len(next(iter(g))) for p in g: assert len(p) == n, p if identity(n) not in g: return False for p in g: if inverse(p) not in g: return False for p1 in g: for p2 in g: if compose(p1, p2) not in g: return False return True def to_cycles(p: TY_PERM) -> Set[Tuple[int]]: n = len(p) cycles = set() for i in range(n): c = [i] while p[i] != c[0]: i = p[i] c += [i] if len(c) >= 2: i = c.index(min(c)) c = c[i:] + c[:i] cycles.add(tuple(c)) return cycles def sign(p: TY_PERM) -> int: s = 1 for c in to_cycles(p): if len(c) % 2 == 0: s = -s return s def standard_representation( p: TY_PERM, dtype: Optional[torch.dtype] = None, device: Optional[torch.dtype] = None ) -> torch.Tensor: r"""irrep of Sn of dimension n - 1""" A = complete_basis(torch.ones(1, len(p), dtype=dtype, device=device), eps=0.1 / len(p)) return A @ natural_representation(p) @ A.T def natural_representation( p: TY_PERM, dtype: Optional[torch.dtype] = None, device: Optional[torch.dtype] = None ) -> torch.Tensor: r"""natural representation of Sn""" n = len(p) ip = inverse(p) d = torch.zeros(n, n, dtype=dtype, device=device) for a in range(n): d[a, ip[a]] = 1 return d e3nn-0.6.0/e3nn/nn/000077500000000000000000000000001514371756200136515ustar00rootroot00000000000000e3nn-0.6.0/e3nn/nn/__init__.py000066400000000000000000000010471514371756200157640ustar00rootroot00000000000000from ._extract import Extract, ExtractIr from ._activation import Activation from ._batchnorm import BatchNorm from ._fc import FullyConnectedNet from ._gate import Gate from ._identity import Identity from ._s2act import S2Activation from ._so3act import SO3Activation from ._normact import NormActivation from ._dropout import Dropout __all__ = [ "Extract", "ExtractIr", "BatchNorm", "FullyConnectedNet", "Activation", "Gate", "Identity", "S2Activation", "SO3Activation", "NormActivation", "Dropout", ] e3nn-0.6.0/e3nn/nn/_activation.py000066400000000000000000000070101514371756200165210ustar00rootroot00000000000000import torch from e3nn.o3._irreps import Irreps from e3nn.math import normalize2mom from e3nn.util.jit import compile_mode @compile_mode("trace") class Activation(torch.nn.Module): r"""Scalar activation function. Odd scalar inputs require activation functions with a defined parity (odd or even). Parameters ---------- irreps_in : `e3nn.o3.Irreps` representation of the input acts : list of function or None list of activation functions, `None` if non-scalar or identity Examples -------- >>> a = Activation("256x0o", [torch.abs]) >>> a.irreps_out 256x0e >>> a = Activation("256x0o+16x1e", [None, None]) >>> a.irreps_out 256x0o+16x1e """ def __init__(self, irreps_in, acts) -> None: super().__init__() irreps_in = Irreps(irreps_in) if len(irreps_in) != len(acts): raise ValueError(f"Irreps in and number of activation functions does not match: {len(acts), (irreps_in, acts)}") # normalize the second moment acts = [normalize2mom(act) if act is not None else None for act in acts] from e3nn.util._argtools import _get_device irreps_out = [] for (mul, (l_in, p_in)), act in zip(irreps_in, acts): if act is not None: if l_in != 0: raise ValueError("Activation: cannot apply an activation function to a non-scalar input.") x = torch.linspace(0, 10, 256, device=_get_device(act)) a1, a2 = act(x), act(-x) if (a1 - a2).abs().max() < 1e-5: p_act = 1 elif (a1 + a2).abs().max() < 1e-5: p_act = -1 else: p_act = 0 p_out = p_act if p_in == -1 else p_in irreps_out.append((mul, (0, p_out))) if p_out == 0: raise ValueError( "Activation: the parity is violated! The input scalar is odd but the activation is neither " "even nor odd." ) else: irreps_out.append((mul, (l_in, p_in))) self.irreps_in = irreps_in self.irreps_out = Irreps(irreps_out) self.acts = torch.nn.ModuleList(acts) self.paths = [(mul, (l, p), act) for (mul, (l, p)), act in zip(self.irreps_in, self.acts)] assert len(self.irreps_in) == len(self.acts) def __repr__(self) -> str: acts = "".join(["x" if a is not None else " " for a in self.acts]) return f"{self.__class__.__name__} [{acts}] ({self.irreps_in} -> {self.irreps_out})" def forward(self, features, dim: int = -1): """evaluate Parameters ---------- features : `torch.Tensor` tensor of shape ``(...)`` Returns ------- `torch.Tensor` tensor of shape the same shape as the input """ # - PROFILER - with torch.autograd.profiler.record_function(repr(self)): output = [] index = 0 for mul, (l, _), act in self.paths: ir_dim = 2 * l + 1 if act is not None: output.append(act(features.narrow(dim, index, mul))) else: output.append(features.narrow(dim, index, mul * ir_dim)) index += mul * ir_dim if len(output) > 1: return torch.cat(output, dim=dim) elif len(output) == 1: return output[0] else: return torch.zeros_like(features) e3nn-0.6.0/e3nn/nn/_batchnorm.py000066400000000000000000000162711514371756200163460ustar00rootroot00000000000000import torch from torch import nn from e3nn import o3 from e3nn.util.jit import compile_mode @compile_mode("unsupported") class BatchNorm(nn.Module): """Batch normalization for orthonormal representations It normalizes by the norm of the representations. Note that the norm is invariant only for orthonormal representations. Irreducible representations `wigner_D` are orthonormal. Parameters ---------- irreps : `o3.Irreps` representation eps : float avoid division by zero when we normalize by the variance momentum : float momentum of the running average affine : bool do we have weight and bias parameters reduce : {'mean', 'max'} method used to reduce instance : bool apply instance norm instead of batch norm include_bias : bool include a bias term for batch norm of scalars normalization : str which normalization method to apply (i.e., `norm` or `component`) """ __constants__ = ["instance", "normalization", "irs", "affine"] def __init__( self, irreps: o3.Irreps, eps: float = 1e-5, momentum: float = 0.1, affine: bool = True, reduce: str = "mean", instance: bool = False, include_bias: bool = True, normalization: str = "component", ) -> None: super().__init__() self.irreps = o3.Irreps(irreps) self.eps = eps self.momentum = momentum self.affine = affine self.instance = instance self.include_bias = include_bias num_scalar = sum(mul for mul, ir in self.irreps if ir.is_scalar()) num_features = self.irreps.num_irreps self.features = [] if self.instance: self.register_buffer("running_mean", None) self.register_buffer("running_var", None) else: self.register_buffer("running_mean", torch.zeros(num_scalar)) self.register_buffer("running_var", torch.ones(num_features)) if affine: self.weight = nn.Parameter(torch.ones(num_features)) if self.include_bias: self.bias = nn.Parameter(torch.zeros(num_scalar)) else: self.register_parameter("weight", None) if self.include_bias: self.register_parameter("bias", None) assert isinstance(reduce, str), "reduce should be passed as a string value" assert reduce in ["mean", "max"], "reduce needs to be 'mean' or 'max'" self.reduce = reduce irs = [] for mul, ir in self.irreps: irs.append((mul, ir.dim, ir.is_scalar())) self.irs = irs assert normalization in ["norm", "component"], "normalization needs to be 'norm' or 'component'" self.normalization = normalization def __repr__(self) -> str: return f"{self.__class__.__name__} ({self.irreps}, eps={self.eps}, momentum={self.momentum})" def _roll_avg(self, curr, update) -> float: return (1 - self.momentum) * curr + self.momentum * update.detach() def forward(self, input) -> torch.Tensor: """evaluate Parameters ---------- input : `torch.Tensor` tensor of shape ``(batch, ..., irreps.dim)`` Returns ------- `torch.Tensor` tensor of shape ``(batch, ..., irreps.dim)`` """ orig_shape = input.shape batch = input.shape[0] dim = input.shape[-1] input = input.reshape(batch, -1, dim) # [batch, sample, stacked features] if self.training and not self.instance: new_means = [] new_vars = [] fields = [] ix = 0 irm = 0 irv = 0 iw = 0 ib = 0 for mul, d, is_scalar in self.irs: field = input[:, :, ix : ix + mul * d] # [batch, sample, mul * repr] ix += mul * d # [batch, sample, mul, repr] field = field.reshape(batch, -1, mul, d) if is_scalar: if self.training or self.instance: if self.instance: field_mean = field.mean(1).reshape(batch, mul) # [batch, mul] else: field_mean = field.mean([0, 1]).reshape(mul) # [mul] new_means.append(self._roll_avg(self.running_mean[irm : irm + mul], field_mean)) else: field_mean = self.running_mean[irm : irm + mul] irm += mul # [batch, sample, mul, repr] field = field - field_mean.reshape(-1, 1, mul, 1) if self.training or self.instance: if self.normalization == "norm": field_norm = field.pow(2).sum(3) # [batch, sample, mul] elif self.normalization == "component": field_norm = field.pow(2).mean(3) # [batch, sample, mul] else: raise ValueError(f"Invalid normalization option {self.normalization}") if self.reduce == "mean": field_norm = field_norm.mean(1) # [batch, mul] elif self.reduce == "max": field_norm = field_norm.max(1).values # [batch, mul] else: raise ValueError(f"Invalid reduce option {self.reduce}") if not self.instance: field_norm = field_norm.mean(0) # [mul] new_vars.append(self._roll_avg(self.running_var[irv : irv + mul], field_norm)) else: field_norm = self.running_var[irv : irv + mul] irv += mul field_norm = (field_norm + self.eps).pow(-0.5) # [(batch,) mul] if self.affine: weight = self.weight[iw : iw + mul] # [mul] iw += mul field_norm = field_norm * weight # [(batch,) mul] field = field * field_norm.reshape(-1, 1, mul, 1) # [batch, sample, mul, repr] if self.affine and self.include_bias and is_scalar: bias = self.bias[ib : ib + mul] # [mul] ib += mul field += bias.reshape(mul, 1) # [batch, sample, mul, repr] fields.append(field.reshape(batch, -1, mul * d)) # [batch, sample, mul * repr] torch._assert(ix == dim, f"`ix` should have reached input.size(-1) ({dim}), but it ended at {ix}") if self.training and not self.instance: torch._assert(irm == self.running_mean.numel(), "irm == self.running_mean.numel()") torch._assert(irv == self.running_var.size(0), "irv == self.running_var.size(0)") if self.affine: torch._assert(iw == self.weight.size(0), "iw == self.weight.size(0)") if self.include_bias: torch._assert(ib == self.bias.numel(), "ib == self.bias.numel()") if self.training and not self.instance: if len(new_means) > 0: torch.cat(new_means, out=self.running_mean) if len(new_vars) > 0: torch.cat(new_vars, out=self.running_var) output = torch.cat(fields, dim=2) # [batch, sample, stacked features] return output.reshape(orig_shape) e3nn-0.6.0/e3nn/nn/_dropout.py000066400000000000000000000036211514371756200160600ustar00rootroot00000000000000import torch from e3nn import o3 from e3nn.util.jit import compile_mode @compile_mode("script") class Dropout(torch.nn.Module): """Equivariant Dropout :math:`A_{zai}` is the input and :math:`B_{zai}` is the output where - ``z`` is the batch index - ``a`` any non-batch and non-irrep index - ``i`` is the irrep index, for instance if ``irreps="0e + 2x1e"`` then ``i=2`` select the *second vector* .. math:: B_{zai} = \frac{x_{zi}}{1-p} A_{zai} where :math:`p` is the dropout probability and :math:`x` is a Bernoulli random variable with parameter :math:`1-p`. Parameters ---------- irreps : `o3.Irreps` representation p : float probability to drop """ def __init__(self, irreps, p) -> None: super().__init__() self.irreps = o3.Irreps(irreps) self.p = p def __repr__(self) -> str: return f"{self.__class__.__name__} ({self.irreps}, p={self.p})" def forward(self, x): """evaluate Parameters ---------- input : `torch.Tensor` tensor of shape ``(batch, ..., irreps.dim)`` Returns ------- `torch.Tensor` tensor of shape ``(batch, ..., irreps.dim)`` """ if not self.training: return x batch = x.shape[0] noises = [] for mul, (l, _p) in self.irreps: dim = 2 * l + 1 noise = x.new_empty(batch, mul) if self.p >= 1: noise.fill_(0) elif self.p <= 0: noise.fill_(1) else: noise.bernoulli_(1 - self.p).div_(1 - self.p) noise = noise[:, :, None].expand(-1, -1, dim).reshape(batch, mul * dim) noises.append(noise) noise = torch.cat(noises, dim=-1) while noise.dim() < x.dim(): noise = noise[:, None] return x * noise e3nn-0.6.0/e3nn/nn/_extract.py000066400000000000000000000065541514371756200160460ustar00rootroot00000000000000from typing import Tuple import torch from torch import fx from e3nn.util.codegen import CodeGenMixin from e3nn.util.jit import compile_mode from e3nn.o3._irreps import Irrep, Irreps @compile_mode("script") class Extract(CodeGenMixin, torch.nn.Module): # pylint: disable=abstract-method def __init__(self, irreps_in, irreps_outs, instructions, squeeze_out: bool = False) -> None: r"""Extract sub sets of irreps Parameters ---------- irreps_in : `e3nn.o3.Irreps` representation of the input irreps_outs : list of `e3nn.o3.Irreps` list of representation of the outputs instructions : list of tuple of int list of tuples, one per output continaing each ``len(irreps_outs[i])`` int squeeze_out : bool, default False if ``squeeze_out`` and only one output exists, a ``torch.Tensor`` will be returned instead of a ``Tuple[torch.Tensor]`` Examples -------- >>> c = Extract('1e + 0e + 0e', ['0e', '0e'], [(1,), (2,)]) >>> c(torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0])) (tensor([1.]), tensor([2.])) """ super().__init__() self.irreps_in = Irreps(irreps_in) self.irreps_outs = tuple(Irreps(irreps) for irreps in irreps_outs) self.instructions = instructions assert len(self.irreps_outs) == len(self.instructions) for irreps_out, ins in zip(self.irreps_outs, self.instructions): assert len(irreps_out) == len(ins) # == generate code == graph = fx.Graph() x = fx.Proxy(graph.placeholder("x", torch.Tensor)) torch._assert(x.shape[-1] == self.irreps_in.dim, "invalid input shape") out = [] for irreps in self.irreps_outs: out.append(x.new_zeros(x.shape[:-1] + (irreps.dim,))) for i, (irreps_out, ins) in enumerate(zip(self.irreps_outs, self.instructions)): if ins == tuple(range(len(self.irreps_in))): out[i].copy_(x) else: for s_out, i_in in zip(irreps_out.slices(), ins): i_start = self.irreps_in[:i_in].dim i_len = self.irreps_in[i_in].dim out[i].narrow(-1, s_out.start, s_out.stop - s_out.start).copy_(x.narrow(-1, i_start, i_len)) out = tuple(e.node for e in out) if squeeze_out and len(out) == 1: graph.output(out[0], torch.Tensor) else: graph.output(out, Tuple[(torch.Tensor,) * len(self.irreps_outs)]) self._codegen_register({"_compiled_forward": fx.GraphModule({}, graph)}) def forward(self, x: torch.Tensor): return self._compiled_forward(x) @compile_mode("script") class ExtractIr(Extract): # pylint: disable=abstract-method def __init__(self, irreps_in, ir) -> None: r"""Extract ``ir`` from irreps Parameters ---------- irreps_in : `e3nn.o3.Irreps` representation of the input ir : `e3nn.o3.Irrep` representation to extract """ ir = Irrep(ir) irreps_in = Irreps(irreps_in) self.irreps_out = Irreps([mul_ir for mul_ir in irreps_in if mul_ir.ir == ir]) instructions = [tuple(i for i, mul_ir in enumerate(irreps_in) if mul_ir.ir == ir)] super().__init__(irreps_in, [self.irreps_out], instructions, squeeze_out=True) e3nn-0.6.0/e3nn/nn/_fc.py000066400000000000000000000047431514371756200147620ustar00rootroot00000000000000from typing import List import torch from e3nn.math import normalize2mom from e3nn.util.jit import compile_mode @compile_mode("script") class _Layer(torch.nn.Module): h_in: float h_out: float var_in: float var_out: float _profiling_str: str def __init__(self, h_in, h_out, act, var_in, var_out) -> None: super().__init__() self.weight = torch.nn.Parameter(torch.randn(h_in, h_out)) self.act = act self.h_in = h_in self.h_out = h_out self.var_in = var_in self.var_out = var_out self._profiling_str = repr(self) def __repr__(self) -> str: act = self.act if hasattr(act, "__name__"): act = act.__name__ elif isinstance(act, torch.nn.Module): act = act.__class__.__name__ return f"Layer({self.h_in}->{self.h_out}, act={act})" def forward(self, x: torch.Tensor): # - PROFILER - with torch.autograd.profiler.record_function(self._profiling_str): if self.act is not None: w = self.weight / (self.h_in * self.var_in) ** 0.5 x = x @ w x = self.act(x) x = x * self.var_out**0.5 else: w = self.weight / (self.h_in * self.var_in / self.var_out) ** 0.5 x = x @ w return x @compile_mode("script") class FullyConnectedNet(torch.nn.Sequential): r"""Fully-connected Neural Network Parameters ---------- hs : list of int input, internal and output dimensions act : function activation function :math:`\phi`, it will be automatically normalized by a scaling factor such that .. math:: \int_{-\infty}^{\infty} \phi(z)^2 \frac{e^{-z^2/2}}{\sqrt{2\pi}} dz = 1 """ hs: List[int] def __init__(self, hs, act=None, variance_in: int = 1, variance_out: int = 1, out_act: bool = False) -> None: super().__init__() self.hs = list(hs) if act is not None: act = normalize2mom(act) var_in = variance_in for i, (h1, h2) in enumerate(zip(self.hs, self.hs[1:])): if i == len(self.hs) - 2: var_out = variance_out a = act if out_act else None else: var_out = 1 a = act layer = _Layer(h1, h2, a, var_in, var_out) setattr(self, f"layer{i}", layer) var_in = var_out def __repr__(self) -> str: return f"{self.__class__.__name__}{self.hs}" e3nn-0.6.0/e3nn/nn/_gate.py000066400000000000000000000127141514371756200153070ustar00rootroot00000000000000import torch from e3nn.o3._irreps import Irreps from e3nn.o3._tensor_product._sub import ElementwiseTensorProduct from ._extract import Extract from ._activation import Activation from e3nn.util.jit import compile_mode @compile_mode("script") class _Sortcut(torch.nn.Module): def __init__(self, *irreps_outs) -> None: super().__init__() self.irreps_outs = tuple(Irreps(irreps).simplify() for irreps in irreps_outs) irreps_in = sum(self.irreps_outs, Irreps([])) i = 0 instructions = [] for irreps_out in self.irreps_outs: instructions += [tuple(range(i, i + len(irreps_out)))] i += len(irreps_out) assert len(irreps_in) == i, (len(irreps_in), i) irreps_in, p, _ = irreps_in.sort() instructions = [tuple(p[i] for i in x) for x in instructions] self.cut = Extract(irreps_in, self.irreps_outs, instructions) self.irreps_in = irreps_in.simplify() def forward(self, x): return self.cut(x) @compile_mode("script") class Gate(torch.nn.Module): r"""Gate activation function. The gate activation is a direct sum of two sets of irreps. The first set of irreps is ``irreps_scalars`` passed through activation functions ``act_scalars``. The second set of irreps is ``irreps_gated`` multiplied by the scalars ``irreps_gates`` passed through activation functions ``act_gates``. Mathematically, this can be written as: .. math:: \left(\bigoplus_i \phi_i(x_i) \right) \oplus \left(\bigoplus_j \phi_j(g_j) y_j \right) where :math:`x_i` and :math:`\phi_i` are from ``irreps_scalars`` and ``act_scalars``, and :math:`g_j`, :math:`\phi_j`, and :math:`y_j` are from ``irreps_gates``, ``act_gates``, and ``irreps_gated``. The parameters passed in should adhere to the following conditions: 1. ``len(irreps_scalars) == len(act_scalars)``. 2. ``len(irreps_gates) == len(act_gates)``. 3. ``irreps_gates.num_irreps == irreps_gated.num_irreps``. Parameters ---------- irreps_scalars : `e3nn.o3.Irreps` Representation of the scalars that will be passed through the activation functions ``act_scalars``. act_scalars : list of function or None Activation functions acting on the scalars. irreps_gates : `e3nn.o3.Irreps` Representation of the scalars that will be passed through the activation functions ``act_gates`` and multiplied by the ``irreps_gated``. act_gates : list of function or None Activation functions acting on the gates. The number of functions in the list should match the number of irrep groups in ``irreps_gates``. irreps_gated : `e3nn.o3.Irreps` Representation of the gated tensors. ``irreps_gates.num_irreps == irreps_gated.num_irreps`` Examples -------- >>> g = Gate("16x0o", [torch.tanh], "32x0o", [torch.tanh], "16x1e+16x1o") >>> g.irreps_out 16x0o+16x1o+16x1e """ def __init__(self, irreps_scalars, act_scalars, irreps_gates, act_gates, irreps_gated) -> None: super().__init__() irreps_scalars = Irreps(irreps_scalars) irreps_gates = Irreps(irreps_gates) irreps_gated = Irreps(irreps_gated) if len(irreps_gates) > 0 and irreps_gates.lmax > 0: raise ValueError(f"Gate scalars must be scalars, instead got irreps_gates = {irreps_gates}") if len(irreps_scalars) > 0 and irreps_scalars.lmax > 0: raise ValueError(f"Scalars must be scalars, instead got irreps_scalars = {irreps_scalars}") if irreps_gates.num_irreps != irreps_gated.num_irreps: raise ValueError( f"There are {irreps_gated.num_irreps} irreps in irreps_gated, but a different number " f"({irreps_gates.num_irreps}) of gate scalars in irreps_gates" ) self.sc = _Sortcut(irreps_scalars, irreps_gates, irreps_gated) self.irreps_scalars, self.irreps_gates, self.irreps_gated = self.sc.irreps_outs self._irreps_in = self.sc.irreps_in self.act_scalars = Activation(irreps_scalars, act_scalars) irreps_scalars = self.act_scalars.irreps_out self.act_gates = Activation(irreps_gates, act_gates) irreps_gates = self.act_gates.irreps_out self.mul = ElementwiseTensorProduct(irreps_gated, irreps_gates) irreps_gated = self.mul.irreps_out self._irreps_out = irreps_scalars + irreps_gated def __repr__(self) -> str: return f"{self.__class__.__name__} ({self.irreps_in} -> {self.irreps_out})" def forward(self, features): """Evaluate the gated activation function. Parameters ---------- features : `torch.Tensor` tensor of shape ``(..., irreps_in.dim)`` Returns ------- `torch.Tensor` tensor of shape ``(..., irreps_out.dim)`` """ # - PROFILER - with torch.autograd.profiler.record_function('Gate'): scalars, gates, gated = self.sc(features) scalars = self.act_scalars(scalars) if gates.shape[-1]: gates = self.act_gates(gates) gated = self.mul(gated, gates) features = torch.cat([scalars, gated], dim=-1) else: features = scalars return features @property def irreps_in(self): """Input representations.""" return self._irreps_in @property def irreps_out(self): """Output representations.""" return self._irreps_out e3nn-0.6.0/e3nn/nn/_identity.py000066400000000000000000000016431514371756200162170ustar00rootroot00000000000000import torch from e3nn import o3 from e3nn.util.jit import compile_mode @compile_mode("trace") class Identity(torch.nn.Module): r"""Identity operation Parameters ---------- irreps_in : `e3nn.o3.Irreps` irreps_out : `e3nn.o3.Irreps` """ def __init__(self, irreps_in, irreps_out) -> None: super().__init__() self.irreps_in = o3.Irreps(irreps_in).simplify() self.irreps_out = o3.Irreps(irreps_out).simplify() assert self.irreps_in == self.irreps_out output_mask = torch.cat([torch.ones(mul * (2 * l + 1)) for mul, (l, _p) in self.irreps_out]) self.register_buffer("output_mask", output_mask) def __repr__(self) -> str: return f"{self.__class__.__name__}({self.irreps_in} -> {self.irreps_out})" def forward( # pylint: disable=no-self-use self, features, ): """evaluate""" return features e3nn-0.6.0/e3nn/nn/_normact.py000066400000000000000000000074351514371756200160360ustar00rootroot00000000000000from typing import Callable, Optional import torch from e3nn.o3._irreps import Irreps from e3nn.o3._norm import Norm from e3nn.o3._tensor_product._sub import ElementwiseTensorProduct from e3nn.util.jit import compile_mode @compile_mode("trace") class NormActivation(torch.nn.Module): r"""Norm-based activation function Applies a scalar nonlinearity to the norm of each irrep and ouputs a (normalized) version of that irrep multiplied by the scalar output of the scalar nonlinearity. Parameters ---------- irreps_in : `e3nn.o3.Irreps` representation of the input scalar_nonlinearity : callable scalar nonlinearity such as ``torch.sigmoid`` normalize : bool whether to normalize the input features before multiplying them by the scalars from the nonlinearity epsilon : float, optional when ``normalize``ing, norms smaller than ``epsilon`` will be clamped up to ``epsilon`` to avoid division by zero and NaN gradients. Not allowed when ``normalize`` is False. bias : bool whether to apply a learnable additive bias to the inputs of the ``scalar_nonlinearity`` Examples -------- >>> n = NormActivation("2x1e", torch.sigmoid) >>> feats = torch.ones(1, 2*3) >>> print(feats.reshape(1, 2, 3).norm(dim=-1)) tensor([[1.7321, 1.7321]]) >>> print(torch.sigmoid(feats.reshape(1, 2, 3).norm(dim=-1))) tensor([[0.8497, 0.8497]]) >>> print(n(feats).reshape(1, 2, 3).norm(dim=-1)) tensor([[0.8497, 0.8497]]) """ epsilon: Optional[float] _eps_squared: float def __init__( self, irreps_in: Irreps, scalar_nonlinearity: Callable, normalize: bool = True, epsilon: Optional[float] = None, bias: bool = False, ) -> None: super().__init__() self.irreps_in = Irreps(irreps_in) self.irreps_out = Irreps(irreps_in) if epsilon is None and normalize: epsilon = 1e-8 elif epsilon is not None and not normalize: raise ValueError("epsilon and normalize = False don't make sense together") elif not epsilon > 0: raise ValueError(f"epsilon {epsilon} is invalid, must be strictly positive.") self.epsilon = epsilon if self.epsilon is not None: self._eps_squared = epsilon * epsilon else: self._eps_squared = 0.0 # doesn't matter # if we have an epsilon, use squared and do the sqrt ourselves self.norm = Norm(irreps_in, squared=(epsilon is not None)) self.scalar_nonlinearity = scalar_nonlinearity self.normalize = normalize self.bias = bias if self.bias: self.biases = torch.nn.Parameter(torch.zeros(irreps_in.num_irreps)) self.scalar_multiplier = ElementwiseTensorProduct( irreps_in1=self.norm.irreps_out, irreps_in2=irreps_in, ) def forward(self, features): """evaluate Parameters ---------- features : `torch.Tensor` tensor of shape ``(..., irreps_in.dim)`` Returns ------- `torch.Tensor` tensor of shape ``(..., irreps_in.dim)`` """ norms = self.norm(features) if self._eps_squared > 0: # See TFN for the original version of this approach: # https://github.com/tensorfieldnetworks/tensorfieldnetworks/blob/master/tensorfieldnetworks/utils.py#L22 norms[norms < self._eps_squared] = self._eps_squared norms = norms.sqrt() nonlin_arg = norms if self.bias: nonlin_arg = nonlin_arg + self.biases scalings = self.scalar_nonlinearity(nonlin_arg) if self.normalize: scalings = scalings / norms return self.scalar_multiplier(scalings, features) e3nn-0.6.0/e3nn/nn/_s2act.py000066400000000000000000000076601514371756200154070ustar00rootroot00000000000000import torch from e3nn import o3 from e3nn.math import normalize2mom from e3nn.util.jit import compile_mode @compile_mode("script") class S2Activation(torch.nn.Module): r"""Apply non linearity on the signal on the sphere | Maps to the sphere, apply the non linearity point wise and project back. | The signal on the sphere is a quasiregular representation of :math:`O(3)` and we can apply a pointwise operation on | these representations. .. math:: \{A^l\}_l \mapsto \{\int \phi(\sum_l A^l \cdot Y^l(x)) Y^j(x) dx\}_j Parameters ---------- irreps : `o3.Irreps` input representation of the form ``[(1, (l, p_val * (p_arg)^l)) for l in [0, ..., lmax]]`` act : function activation function :math:`\phi` res : int resolution of the grid on the sphere (the higher the more accurate) normalization : {'norm', 'component'} lmax_out : int, optional maximum ``l`` of the output random_rot : bool rotate randomly the grid Examples -------- >>> from e3nn import io >>> m = S2Activation(io.SphericalTensor(5, p_val=+1, p_arg=-1), torch.tanh, 100) """ def __init__( self, irreps: o3.Irreps, act, res, normalization: str = "component", lmax_out=None, random_rot: bool = False ) -> None: super().__init__() irreps = o3.Irreps(irreps).simplify() _, (_, p_val) = irreps[0] _, (lmax, _) = irreps[-1] assert all(mul == 1 for mul, _ in irreps) assert irreps.ls == list(range(lmax + 1)) if all(p == p_val for _, (l, p) in irreps): p_arg = 1 elif all(p == p_val * (-1) ** l for _, (l, p) in irreps): p_arg = -1 else: assert False, "the parity of the input is not well defined" self.irreps_in = irreps # the input transforms as : A_l ---> p_val * (p_arg)^l * A_l # the sphere signal transforms as : f(r) ---> p_val * f(p_arg * r) if lmax_out is None: lmax_out = lmax if p_val in (0, +1): self.irreps_out = o3.Irreps([(1, (l, p_val * p_arg**l)) for l in range(lmax_out + 1)]) if p_val == -1: x = torch.linspace(0, 10, 256) a1, a2 = act(x), act(-x) if (a1 - a2).abs().max() < a1.abs().max() * 1e-10: # p_act = 1 self.irreps_out = o3.Irreps([(1, (l, p_arg**l)) for l in range(lmax_out + 1)]) elif (a1 + a2).abs().max() < a1.abs().max() * 1e-10: # p_act = -1 self.irreps_out = o3.Irreps([(1, (l, -(p_arg**l))) for l in range(lmax_out + 1)]) else: # p_act = 0 raise ValueError("warning! the parity is violated") self.to_s2 = o3.ToS2Grid(lmax, res, normalization=normalization) self.from_s2 = o3.FromS2Grid(res, lmax_out, normalization=normalization, lmax_in=lmax) self.act = normalize2mom(act) self.random_rot = random_rot def __repr__(self) -> str: return f"{self.__class__.__name__} ({self.irreps_in} -> {self.irreps_out})" def forward(self, features): r"""evaluate Parameters ---------- features : `torch.Tensor` tensor :math:`\{A^l\}_l` of shape ``(..., self.irreps_in.dim)`` Returns ------- `torch.Tensor` tensor of shape ``(..., self.irreps_out.dim)`` """ assert features.shape[-1] == self.irreps_in.dim if self.random_rot: abc = o3.rand_angles(dtype=features.dtype, device=features.device) features = torch.einsum("ij,...j->...i", self.irreps_in.D_from_angles(*abc), features) features = self.to_s2(features) # [..., beta, alpha] features = self.act(features) features = self.from_s2(features) if self.random_rot: features = torch.einsum("ij,...j->...i", self.irreps_out.D_from_angles(*abc).T, features) return features e3nn-0.6.0/e3nn/nn/_so3act.py000066400000000000000000000030651514371756200155620ustar00rootroot00000000000000import torch from e3nn.math import normalize2mom from e3nn.util.jit import compile_mode from e3nn.o3 import SO3Grid @compile_mode("script") class SO3Activation(torch.nn.Module): r"""Apply non linearity on the signal on SO(3) Parameters ---------- lmax_in : int input lmax lmax_out : int output lmax act : function activation function :math:`\phi` resolution : int SO(3) grid resolution normalization : {'norm', 'component'} """ def __init__(self, lmax_in, lmax_out, act, resolution, *, normalization: str = "component", aspect_ratio: int = 2) -> None: super().__init__() self.grid_in = SO3Grid(lmax_in, resolution, normalization=normalization, aspect_ratio=aspect_ratio) self.grid_out = SO3Grid(lmax_out, resolution, normalization=normalization, aspect_ratio=aspect_ratio) self.act = normalize2mom(act) self.lmax_in = lmax_in self.lmax_out = lmax_out def __repr__(self) -> str: return f"{self.__class__.__name__} ({self.lmax_in} -> {self.lmax_out})" def forward(self, features) -> torch.Tensor: r"""evaluate Parameters ---------- features : `torch.Tensor` tensor of shape ``(..., self.irreps_in.dim)`` Returns ------- `torch.Tensor` tensor of shape ``(..., self.irreps_out.dim)`` """ features = self.grid_in.to_grid(features) features = self.act(features) features = self.grid_out.from_grid(features) return features e3nn-0.6.0/e3nn/nn/models/000077500000000000000000000000001514371756200151345ustar00rootroot00000000000000e3nn-0.6.0/e3nn/nn/models/__init__.py000066400000000000000000000000001514371756200172330ustar00rootroot00000000000000e3nn-0.6.0/e3nn/nn/models/gate_points_2101.py000066400000000000000000000267541514371756200205030ustar00rootroot00000000000000"""model with self-interactions and gates Exact equivariance to :math:`E(3)` version of january 2021 """ import math from typing import Dict, Optional import torch from e3nn import o3 from e3nn.math import soft_one_hot_linspace from e3nn.nn import FullyConnectedNet, Gate from e3nn.o3 import FullyConnectedTensorProduct, TensorProduct from e3nn.util.jit import compile_mode def scatter(src: torch.Tensor, index: torch.Tensor, dim_size: int) -> torch.Tensor: # special case of torch_scatter.scatter with dim=0 out = src.new_zeros(dim_size, src.shape[1]) index = index.reshape(-1, 1).expand_as(src) return out.scatter_add_(0, index, src) def radius_graph(pos, r_max, batch) -> torch.Tensor: # naive and inefficient version of torch_cluster.radius_graph r = torch.cdist(pos, pos) index = ((r < r_max) & (r > 0)).nonzero().T index = index[:, batch[index[0]] == batch[index[1]]] return index @compile_mode("script") class Convolution(torch.nn.Module): r"""equivariant convolution Parameters ---------- irreps_in : `e3nn.o3.Irreps` representation of the input node features irreps_node_attr : `e3nn.o3.Irreps` representation of the node attributes irreps_edge_attr : `e3nn.o3.Irreps` representation of the edge attributes irreps_out : `e3nn.o3.Irreps` or None representation of the output node features number_of_basis : int number of basis on which the edge length are projected radial_layers : int number of hidden layers in the radial fully connected network radial_neurons : int number of neurons in the hidden layers of the radial fully connected network num_neighbors : float typical number of nodes convolved over """ def __init__( self, irreps_in, irreps_node_attr, irreps_edge_attr, irreps_out, number_of_basis, radial_layers, radial_neurons, num_neighbors, ) -> None: super().__init__() self.irreps_in = o3.Irreps(irreps_in) self.irreps_node_attr = o3.Irreps(irreps_node_attr) self.irreps_edge_attr = o3.Irreps(irreps_edge_attr) self.irreps_out = o3.Irreps(irreps_out) self.num_neighbors = num_neighbors self.sc = FullyConnectedTensorProduct(self.irreps_in, self.irreps_node_attr, self.irreps_out) self.lin1 = FullyConnectedTensorProduct(self.irreps_in, self.irreps_node_attr, self.irreps_in) irreps_mid = [] instructions = [] for i, (mul, ir_in) in enumerate(self.irreps_in): for j, (_, ir_edge) in enumerate(self.irreps_edge_attr): for ir_out in ir_in * ir_edge: if ir_out in self.irreps_out: k = len(irreps_mid) irreps_mid.append((mul, ir_out)) instructions.append((i, j, k, "uvu", True)) irreps_mid = o3.Irreps(irreps_mid) irreps_mid, p, _ = irreps_mid.sort() instructions = [(i_1, i_2, p[i_out], mode, train) for i_1, i_2, i_out, mode, train in instructions] tp = TensorProduct( self.irreps_in, self.irreps_edge_attr, irreps_mid, instructions, internal_weights=False, shared_weights=False, ) self.fc = FullyConnectedNet( [number_of_basis] + radial_layers * [radial_neurons] + [tp.weight_numel], torch.nn.functional.silu ) self.tp = tp self.lin2 = FullyConnectedTensorProduct(irreps_mid, self.irreps_node_attr, self.irreps_out) def forward(self, node_input, node_attr, edge_src, edge_dst, edge_attr, edge_length_embedded) -> torch.Tensor: weight = self.fc(edge_length_embedded) x = node_input s = self.sc(x, node_attr) x = self.lin1(x, node_attr) edge_features = self.tp(x[edge_src], edge_attr, weight) x = scatter(edge_features, edge_dst, dim_size=x.shape[0]).div(self.num_neighbors**0.5) x = self.lin2(x, node_attr) c_s, c_x = math.sin(math.pi / 8), math.cos(math.pi / 8) m = self.sc.output_mask c_x = (1 - m) + c_x * m return c_s * s + c_x * x def smooth_cutoff(x): u = 2 * (x - 1) y = (math.pi * u).cos().neg().add(1).div(2) y[u > 0] = 0 y[u < -1] = 1 return y def tp_path_exists(irreps_in1, irreps_in2, ir_out): irreps_in1 = o3.Irreps(irreps_in1).simplify() irreps_in2 = o3.Irreps(irreps_in2).simplify() ir_out = o3.Irrep(ir_out) for _, ir1 in irreps_in1: for _, ir2 in irreps_in2: if ir_out in ir1 * ir2: return True return False class Compose(torch.nn.Module): def __init__(self, first, second) -> None: super().__init__() self.first = first self.second = second self.irreps_in = self.first.irreps_in self.irreps_out = self.second.irreps_out def forward(self, *input): x = self.first(*input) return self.second(x) class Network(torch.nn.Module): r"""equivariant neural network Parameters ---------- irreps_in : `e3nn.o3.Irreps` or None representation of the input features can be set to ``None`` if nodes don't have input features irreps_hidden : `e3nn.o3.Irreps` representation of the hidden features irreps_out : `e3nn.o3.Irreps` representation of the output features irreps_node_attr : `e3nn.o3.Irreps` or None representation of the nodes attributes can be set to ``None`` if nodes don't have attributes irreps_edge_attr : `e3nn.o3.Irreps` representation of the edge attributes the edge attributes are :math:`h(r) Y(\vec r / r)` where :math:`h` is a smooth function that goes to zero at ``max_radius`` and :math:`Y` are the spherical harmonics polynomials layers : int number of gates (non linearities) max_radius : float maximum radius for the convolution number_of_basis : int number of basis on which the edge length are projected radial_layers : int number of hidden layers in the radial fully connected network radial_neurons : int number of neurons in the hidden layers of the radial fully connected network num_neighbors : float typical number of nodes at a distance ``max_radius`` num_nodes : float typical number of nodes in a graph """ def __init__( self, irreps_in: Optional[o3.Irreps], irreps_hidden: o3.Irreps, irreps_out: o3.Irreps, irreps_node_attr: o3.Irreps, irreps_edge_attr: Optional[o3.Irreps], layers: int, max_radius: float, number_of_basis: int, radial_layers: int, radial_neurons: int, num_neighbors: float, num_nodes: float, reduce_output: bool = True, ) -> None: super().__init__() self.max_radius = max_radius self.number_of_basis = number_of_basis self.num_neighbors = num_neighbors self.num_nodes = num_nodes self.reduce_output = reduce_output self.irreps_in = o3.Irreps(irreps_in) if irreps_in is not None else None self.irreps_hidden = o3.Irreps(irreps_hidden) self.irreps_out = o3.Irreps(irreps_out) self.irreps_node_attr = o3.Irreps(irreps_node_attr) if irreps_node_attr is not None else o3.Irreps("0e") self.irreps_edge_attr = o3.Irreps(irreps_edge_attr) self.input_has_node_in = irreps_in is not None self.input_has_node_attr = irreps_node_attr is not None irreps = self.irreps_in if self.irreps_in is not None else o3.Irreps("0e") act = { 1: torch.nn.functional.silu, -1: torch.tanh, } act_gates = { 1: torch.sigmoid, -1: torch.tanh, } self.layers = torch.nn.ModuleList() for _ in range(layers): irreps_scalars = o3.Irreps( [ (mul, ir) for mul, ir in self.irreps_hidden if ir.l == 0 and tp_path_exists(irreps, self.irreps_edge_attr, ir) ] ) irreps_gated = o3.Irreps( [(mul, ir) for mul, ir in self.irreps_hidden if ir.l > 0 and tp_path_exists(irreps, self.irreps_edge_attr, ir)] ) ir = "0e" if tp_path_exists(irreps, self.irreps_edge_attr, "0e") else "0o" irreps_gates = o3.Irreps([(mul, ir) for mul, _ in irreps_gated]) gate = Gate( irreps_scalars, [act[ir.p] for _, ir in irreps_scalars], # scalar irreps_gates, [act_gates[ir.p] for _, ir in irreps_gates], # gates (scalars) irreps_gated, # gated tensors ) conv = Convolution( irreps, self.irreps_node_attr, self.irreps_edge_attr, gate.irreps_in, number_of_basis, radial_layers, radial_neurons, num_neighbors, ) irreps = gate.irreps_out self.layers.append(Compose(conv, gate)) self.layers.append( Convolution( irreps, self.irreps_node_attr, self.irreps_edge_attr, self.irreps_out, number_of_basis, radial_layers, radial_neurons, num_neighbors, ) ) def forward(self, data: Dict[str, torch.Tensor]) -> torch.Tensor: """evaluate the network Parameters ---------- data : `torch_geometric.data.Data` or dict data object containing - ``pos`` the position of the nodes (atoms) - ``x`` the input features of the nodes, optional - ``z`` the attributes of the nodes, for instance the atom type, optional - ``batch`` the graph to which the node belong, optional """ if "batch" in data: batch = data["batch"] else: batch = data["pos"].new_zeros(data["pos"].shape[0], dtype=torch.long) edge_index = radius_graph(data["pos"], self.max_radius, batch) edge_src = edge_index[0] edge_dst = edge_index[1] edge_vec = data["pos"][edge_src] - data["pos"][edge_dst] edge_sh = o3.spherical_harmonics(self.irreps_edge_attr, edge_vec, True, normalization="component") edge_length = edge_vec.norm(dim=1) edge_length_embedded = soft_one_hot_linspace( x=edge_length, start=0.0, end=self.max_radius, number=self.number_of_basis, basis="gaussian", cutoff=False ).mul(self.number_of_basis**0.5) edge_attr = smooth_cutoff(edge_length / self.max_radius)[:, None] * edge_sh if self.input_has_node_in and "x" in data: assert self.irreps_in is not None x = data["x"] else: assert self.irreps_in is None x = data["pos"].new_ones((data["pos"].shape[0], 1)) if self.input_has_node_attr and "z" in data: z = data["z"] else: assert self.irreps_node_attr == o3.Irreps("0e") z = data["pos"].new_ones((data["pos"].shape[0], 1)) for lay in self.layers: x = lay(x, z, edge_src, edge_dst, edge_attr, edge_length_embedded) if self.reduce_output: return scatter(x, batch, dim_size=int(batch.max()) + 1).div(self.num_nodes**0.5) else: return x e3nn-0.6.0/e3nn/nn/models/gate_points_2102.py000066400000000000000000000275751514371756200205060ustar00rootroot00000000000000"""model with self-interactions and gates Exact equivariance to :math:`E(3)` version of february 2021 """ import math from typing import Dict, Optional import torch from e3nn import o3 from e3nn.math import soft_one_hot_linspace from e3nn.nn import ExtractIr, FullyConnectedNet, Gate from e3nn.o3 import FullyConnectedTensorProduct, TensorProduct from e3nn.util.jit import compile_mode def scatter(src: torch.Tensor, index: torch.Tensor, dim_size: int) -> torch.Tensor: # special case of torch_scatter.scatter with dim=0 out = src.new_zeros(dim_size, src.shape[1]) index = index.reshape(-1, 1).expand_as(src) return out.scatter_add_(0, index, src) def radius_graph(pos, r_max, batch) -> torch.Tensor: # naive and inefficient version of torch_cluster.radius_graph r = torch.cdist(pos, pos) index = ((r < r_max) & (r > 0)).nonzero().T index = index[:, batch[index[0]] == batch[index[1]]] return index @compile_mode("script") class Convolution(torch.nn.Module): r"""equivariant convolution Parameters ---------- irreps_in : `e3nn.o3.Irreps` representation of the input node features irreps_node_attr : `e3nn.o3.Irreps` representation of the node attributes irreps_edge_attr : `e3nn.o3.Irreps` representation of the edge attributes irreps_out : `e3nn.o3.Irreps` or None representation of the output node features number_of_edge_features : int number of scalar (0e) features of the edge used to feed the FC network radial_layers : int number of hidden layers in the radial fully connected network radial_neurons : int number of neurons in the hidden layers of the radial fully connected network num_neighbors : float typical number of nodes convolved over """ def __init__( self, irreps_in: o3.Irreps, irreps_node_attr: o3.Irreps, irreps_edge_attr: o3.Irreps, irreps_out: Optional[o3.Irreps], number_of_edge_features: int, radial_layers: int, radial_neurons: int, num_neighbors: float, ) -> None: super().__init__() self.irreps_in = o3.Irreps(irreps_in) self.irreps_node_attr = o3.Irreps(irreps_node_attr) self.irreps_edge_attr = o3.Irreps(irreps_edge_attr) self.irreps_out = o3.Irreps(irreps_out) self.num_neighbors = num_neighbors self.sc = FullyConnectedTensorProduct(self.irreps_in, self.irreps_node_attr, self.irreps_out) self.lin1 = FullyConnectedTensorProduct(self.irreps_in, self.irreps_node_attr, self.irreps_in) irreps_mid = [] instructions = [] for i, (mul, ir_in) in enumerate(self.irreps_in): for j, (_, ir_edge) in enumerate(self.irreps_edge_attr): for ir_out in ir_in * ir_edge: if ir_out in self.irreps_out: k = len(irreps_mid) irreps_mid.append((mul, ir_out)) instructions.append((i, j, k, "uvu", True)) irreps_mid = o3.Irreps(irreps_mid) irreps_mid, p, _ = irreps_mid.sort() instructions = [(i_1, i_2, p[i_out], mode, train) for i_1, i_2, i_out, mode, train in instructions] tp = TensorProduct( self.irreps_in, self.irreps_edge_attr, irreps_mid, instructions, internal_weights=False, shared_weights=False, ) self.fc = FullyConnectedNet( [number_of_edge_features] + radial_layers * [radial_neurons] + [tp.weight_numel], torch.nn.functional.silu ) self.tp = tp self.lin2 = FullyConnectedTensorProduct(irreps_mid, self.irreps_node_attr, self.irreps_out) def forward(self, node_input, node_attr, edge_src, edge_dst, edge_attr, edge_features) -> torch.Tensor: weight = self.fc(edge_features) x = node_input s = self.sc(x, node_attr) x = self.lin1(x, node_attr) edge_features = self.tp(x[edge_src], edge_attr, weight) x = scatter(edge_features, edge_dst, dim_size=x.shape[0]).div(self.num_neighbors**0.5) x = self.lin2(x, node_attr) c_s, c_x = math.sin(math.pi / 8), math.cos(math.pi / 8) m = self.sc.output_mask c_x = (1 - m) + c_x * m return c_s * s + c_x * x def smooth_cutoff(x): u = 2 * (x - 1) y = (math.pi * u).cos().neg().add(1).div(2) y[u > 0] = 0 y[u < -1] = 1 return y def tp_path_exists(irreps_in1, irreps_in2, ir_out) -> bool: irreps_in1 = o3.Irreps(irreps_in1).simplify() irreps_in2 = o3.Irreps(irreps_in2).simplify() ir_out = o3.Irrep(ir_out) for _, ir1 in irreps_in1: for _, ir2 in irreps_in2: if ir_out in ir1 * ir2: return True return False class Compose(torch.nn.Module): def __init__(self, first, second) -> None: super().__init__() self.first = first self.second = second self.irreps_in = self.first.irreps_in self.irreps_out = self.second.irreps_out def forward(self, *input): x = self.first(*input) return self.second(x) class Network(torch.nn.Module): r"""equivariant neural network Parameters ---------- irreps_in : `e3nn.o3.Irreps` or None representation of the input features can be set to ``None`` if nodes don't have input features irreps_hidden : `e3nn.o3.Irreps` representation of the hidden features irreps_out : `e3nn.o3.Irreps` representation of the output features irreps_node_attr : `e3nn.o3.Irreps` or None representation of the nodes attributes can be set to ``None`` if nodes don't have attributes irreps_edge_attr : `e3nn.o3.Irreps` representation of the edge attributes the edge attributes are :math:`h(r) Y(\vec r / r)` where :math:`h` is a smooth function that goes to zero at ``max_radius`` and :math:`Y` are the spherical harmonics polynomials layers : int number of gates (non linearities) max_radius : float maximum radius for the convolution number_of_basis : int number of basis on which the edge length are projected radial_layers : int number of hidden layers in the radial fully connected network radial_neurons : int number of neurons in the hidden layers of the radial fully connected network num_neighbors : float typical number of nodes at a distance ``max_radius`` num_nodes : float typical number of nodes in a graph """ def __init__( self, irreps_in: o3.Irreps, irreps_hidden: o3.Irreps, irreps_out: o3.Irreps, irreps_node_attr: o3.Irreps, irreps_edge_attr: o3.Irreps, layers: int, max_radius: float, number_of_basis: int, radial_layers: int, radial_neurons: int, num_neighbors: float, num_nodes: float, reduce_output: bool = True, ) -> None: super().__init__() self.max_radius = max_radius self.number_of_basis = number_of_basis self.num_neighbors = num_neighbors self.num_nodes = num_nodes self.reduce_output = reduce_output self.irreps_in = o3.Irreps(irreps_in) if irreps_in is not None else None self.irreps_hidden = o3.Irreps(irreps_hidden) self.irreps_out = o3.Irreps(irreps_out) self.irreps_node_attr = o3.Irreps(irreps_node_attr) if irreps_node_attr is not None else o3.Irreps("0e") self.irreps_edge_attr = o3.Irreps(irreps_edge_attr) self.input_has_node_in = irreps_in is not None self.input_has_node_attr = irreps_node_attr is not None self.ext_z = ExtractIr(self.irreps_node_attr, "0e") number_of_edge_features = number_of_basis + 2 * self.irreps_node_attr.count("0e") irreps = self.irreps_in if self.irreps_in is not None else o3.Irreps("0e") act = { 1: torch.nn.functional.silu, -1: torch.tanh, } act_gates = { 1: torch.sigmoid, -1: torch.tanh, } self.layers = torch.nn.ModuleList() for _ in range(layers): irreps_scalars = o3.Irreps( [ (mul, ir) for mul, ir in self.irreps_hidden if ir.l == 0 and tp_path_exists(irreps, self.irreps_edge_attr, ir) ] ) irreps_gated = o3.Irreps( [(mul, ir) for mul, ir in self.irreps_hidden if ir.l > 0 and tp_path_exists(irreps, self.irreps_edge_attr, ir)] ) ir = "0e" if tp_path_exists(irreps, self.irreps_edge_attr, "0e") else "0o" irreps_gates = o3.Irreps([(mul, ir) for mul, _ in irreps_gated]) gate = Gate( irreps_scalars, [act[ir.p] for _, ir in irreps_scalars], # scalar irreps_gates, [act_gates[ir.p] for _, ir in irreps_gates], # gates (scalars) irreps_gated, # gated tensors ) conv = Convolution( irreps, self.irreps_node_attr, self.irreps_edge_attr, gate.irreps_in, number_of_edge_features, radial_layers, radial_neurons, num_neighbors, ) irreps = gate.irreps_out self.layers.append(Compose(conv, gate)) self.layers.append( Convolution( irreps, self.irreps_node_attr, self.irreps_edge_attr, self.irreps_out, number_of_edge_features, radial_layers, radial_neurons, num_neighbors, ) ) def forward(self, data: Dict[str, torch.Tensor]) -> torch.Tensor: """evaluate the network Parameters ---------- data : `torch_geometric.data.Data` or dict data object containing - ``pos`` the position of the nodes (atoms) - ``x`` the input features of the nodes, optional - ``z`` the attributes of the nodes, for instance the atom type, optional - ``batch`` the graph to which the node belong, optional """ if "batch" in data: batch = data["batch"] else: batch = data["pos"].new_zeros(data["pos"].shape[0], dtype=torch.long) edge_index = radius_graph(data["pos"], self.max_radius, batch) edge_src = edge_index[0] edge_dst = edge_index[1] edge_vec = data["pos"][edge_src] - data["pos"][edge_dst] edge_sh = o3.spherical_harmonics(self.irreps_edge_attr, edge_vec, True, normalization="component") edge_length = edge_vec.norm(dim=1) edge_length_embedded = soft_one_hot_linspace( x=edge_length, start=0.0, end=self.max_radius, number=self.number_of_basis, basis="gaussian", cutoff=False ).mul(self.number_of_basis**0.5) edge_attr = smooth_cutoff(edge_length / self.max_radius)[:, None] * edge_sh if self.input_has_node_in and "x" in data: assert self.irreps_in is not None x = data["x"] else: assert self.irreps_in is None x = data["pos"].new_ones((data["pos"].shape[0], 1)) if self.input_has_node_attr and "z" in data: z = data["z"] else: assert self.irreps_node_attr == o3.Irreps("0e") z = data["pos"].new_ones((data["pos"].shape[0], 1)) scalar_z = self.ext_z(z) edge_features = torch.cat([edge_length_embedded, scalar_z[edge_src], scalar_z[edge_dst]], dim=1) for lay in self.layers: x = lay(x, z, edge_src, edge_dst, edge_attr, edge_features) if self.reduce_output: return scatter(x, batch, dim_size=int(batch.max()) + 1).div(self.num_nodes**0.5) else: return x e3nn-0.6.0/e3nn/nn/models/v2103/000077500000000000000000000000001514371756200157075ustar00rootroot00000000000000e3nn-0.6.0/e3nn/nn/models/v2103/__init__.py000066400000000000000000000000001514371756200200060ustar00rootroot00000000000000e3nn-0.6.0/e3nn/nn/models/v2103/conv_points_in_out.py000066400000000000000000000116311514371756200222010ustar00rootroot00000000000000r"""example of a graph convolution when the input and output nodes are different >>> test() """ from typing import Optional import torch from torch_scatter import scatter from e3nn import o3 from e3nn.nn import FullyConnectedNet from e3nn.o3 import TensorProduct, FullyConnectedTensorProduct from e3nn.util.jit import compile_mode @compile_mode("script") class Convolution(torch.nn.Module): r"""equivariant convolution Parameters ---------- irreps_node_input : `e3nn.o3.Irreps` representation of the input node features irreps_node_output : `e3nn.o3.Irreps` or None representation of the output node features irreps_node_attr_input : `e3nn.o3.Irreps` representation of the input node attributes irreps_node_attr_output : `e3nn.o3.Irreps` representation of the output node attributes irreps_edge_attr : `e3nn.o3.Irreps` representation of the edge attributes num_edge_scalar_attr : int number of scalar (0e) attributes of the edge used to feed the FC network radial_layers : int number of hidden layers in the radial fully connected network radial_neurons : int number of neurons in the hidden layers of the radial fully connected network num_neighbors : float typical number of nodes convolved over """ def __init__( self, irreps_node_input: o3.Irreps, irreps_node_output: Optional[o3.Irreps], irreps_node_attr_input: o3.Irreps, irreps_node_attr_output: o3.Irreps, irreps_edge_attr: o3.Irreps, num_edge_scalar_attr: int, radial_layers: int, radial_neurons: int, num_neighbors: float, ) -> None: super().__init__() self.irreps_node_input = o3.Irreps(irreps_node_input) self.irreps_node_attr_input = o3.Irreps(irreps_node_attr_input) self.irreps_node_attr_output = o3.Irreps(irreps_node_attr_output) self.irreps_edge_attr = o3.Irreps(irreps_edge_attr) self.irreps_node_output = o3.Irreps(irreps_node_output) self.num_neighbors = num_neighbors self.lin1 = FullyConnectedTensorProduct(self.irreps_node_input, self.irreps_node_attr_input, self.irreps_node_input) irreps_mid = [] instructions = [] for i, (mul, ir_in) in enumerate(self.irreps_node_input): for j, (_, ir_edge) in enumerate(self.irreps_edge_attr): for ir_out in ir_in * ir_edge: if ir_out in self.irreps_node_output: k = len(irreps_mid) irreps_mid.append((mul, ir_out)) instructions.append((i, j, k, "uvu", True)) irreps_mid = o3.Irreps(irreps_mid) irreps_mid, p, _ = irreps_mid.sort() instructions = [(i_1, i_2, p[i_out], mode, train) for i_1, i_2, i_out, mode, train in instructions] tp = TensorProduct( self.irreps_node_input, self.irreps_edge_attr, irreps_mid, instructions, internal_weights=False, shared_weights=False, ) self.fc = FullyConnectedNet( [num_edge_scalar_attr] + radial_layers * [radial_neurons] + [tp.weight_numel], torch.nn.functional.silu ) self.tp = tp self.lin2 = FullyConnectedTensorProduct(irreps_mid, self.irreps_node_attr_output, self.irreps_node_output) def forward( self, node_input, node_attr_input, node_attr_output, edge_src, edge_dst, edge_attr, edge_scalar_attr ) -> torch.Tensor: weight = self.fc(edge_scalar_attr) node_input = self.lin1(node_input, node_attr_input) edge_features = self.tp(node_input[edge_src], edge_attr, weight) node_output = scatter(edge_features, edge_dst, dim=0, dim_size=node_attr_output.shape[0]) node_output.div_(self.num_neighbors**0.5) return self.lin2(node_output, node_attr_output) def test() -> None: from torch_cluster import radius from e3nn.math import soft_one_hot_linspace conv = Convolution( irreps_node_input="0e + 1e", irreps_node_output="0e + 1e", irreps_node_attr_input="2x0e", irreps_node_attr_output="3x0e", irreps_edge_attr="0e + 1e", num_edge_scalar_attr=4, radial_layers=1, radial_neurons=50, num_neighbors=3.0, ) pos_in = torch.randn(5, 3) pos_out = torch.randn(2, 3) node_input = torch.randn(5, 4) node_attr_input = torch.randn(5, 2) node_attr_output = torch.randn(2, 3) edge_src, edge_dst = radius(pos_out, pos_in, r=2.0) edge_vec = pos_in[edge_src] - pos_out[edge_dst] edge_attr = o3.spherical_harmonics([0, 1], edge_vec, True) edge_scalar_attr = soft_one_hot_linspace( x=edge_vec.norm(dim=1), start=0.0, end=2.0, number=4, basis="smooth_finite", cutoff=True ) conv(node_input, node_attr_input, node_attr_output, edge_src, edge_dst, edge_attr, edge_scalar_attr) e3nn-0.6.0/e3nn/nn/models/v2103/gate_points_message_passing.py000066400000000000000000000124111514371756200240240ustar00rootroot00000000000000""" >>> test() """ import torch from e3nn import o3 from e3nn.nn import Gate from .points_convolution import Convolution def tp_path_exists(irreps_in1, irreps_in2, ir_out) -> bool: irreps_in1 = o3.Irreps(irreps_in1).simplify() irreps_in2 = o3.Irreps(irreps_in2).simplify() ir_out = o3.Irrep(ir_out) for _, ir1 in irreps_in1: for _, ir2 in irreps_in2: if ir_out in ir1 * ir2: return True return False class Compose(torch.nn.Module): def __init__(self, first, second) -> None: super().__init__() self.first = first self.second = second def forward(self, *input): x = self.first(*input) return self.second(x) class MessagePassing(torch.nn.Module): r""" Parameters ---------- irreps_node_input : `e3nn.o3.Irreps` representation of the input features irreps_node_hidden : `e3nn.o3.Irreps` representation of the hidden features irreps_node_output : `e3nn.o3.Irreps` representation of the output features irreps_node_attr : `e3nn.o3.Irreps` representation of the nodes attributes irreps_edge_attr : `e3nn.o3.Irreps` representation of the edge attributes layers : int number of gates (non linearities) fc_neurons : list of int number of neurons per layers in the fully connected network first layer and hidden layers but not the output layer """ def __init__( self, irreps_node_input, irreps_node_hidden, irreps_node_output, irreps_node_attr, irreps_edge_attr, layers, fc_neurons, num_neighbors, ) -> None: super().__init__() self.num_neighbors = num_neighbors self.irreps_node_input = o3.Irreps(irreps_node_input) self.irreps_node_hidden = o3.Irreps(irreps_node_hidden) self.irreps_node_output = o3.Irreps(irreps_node_output) self.irreps_node_attr = o3.Irreps(irreps_node_attr) self.irreps_edge_attr = o3.Irreps(irreps_edge_attr) irreps_node = self.irreps_node_input act = { 1: torch.nn.functional.silu, -1: torch.tanh, } act_gates = { 1: torch.sigmoid, -1: torch.tanh, } self.layers = torch.nn.ModuleList() for _ in range(layers): irreps_scalars = o3.Irreps( [ (mul, ir) for mul, ir in self.irreps_node_hidden if ir.l == 0 and tp_path_exists(irreps_node, self.irreps_edge_attr, ir) ] ).simplify() irreps_gated = o3.Irreps( [ (mul, ir) for mul, ir in self.irreps_node_hidden if ir.l > 0 and tp_path_exists(irreps_node, self.irreps_edge_attr, ir) ] ) ir = "0e" if tp_path_exists(irreps_node, self.irreps_edge_attr, "0e") else "0o" irreps_gates = o3.Irreps([(mul, ir) for mul, _ in irreps_gated]).simplify() gate = Gate( irreps_scalars, [act[ir.p] for _, ir in irreps_scalars], # scalar irreps_gates, [act_gates[ir.p] for _, ir in irreps_gates], # gates (scalars) irreps_gated, # gated tensors ) conv = Convolution( irreps_node, self.irreps_node_attr, self.irreps_edge_attr, gate.irreps_in, fc_neurons, num_neighbors ) irreps_node = gate.irreps_out self.layers.append(Compose(conv, gate)) self.layers.append( Convolution( irreps_node, self.irreps_node_attr, self.irreps_edge_attr, self.irreps_node_output, fc_neurons, num_neighbors ) ) def forward(self, node_features, node_attr, edge_src, edge_dst, edge_attr, edge_scalars) -> torch.Tensor: for lay in self.layers: node_features = lay(node_features, node_attr, edge_src, edge_dst, edge_attr, edge_scalars) return node_features def test() -> None: from torch_cluster import radius_graph from e3nn.util.test import assert_equivariant, assert_auto_jitable mp = MessagePassing( irreps_node_input="0e", irreps_node_hidden="0e + 1e", irreps_node_output="1e", irreps_node_attr="0e + 1e", irreps_edge_attr="1e", layers=3, fc_neurons=[2, 100], num_neighbors=3.0, ) num_nodes = 4 node_pos = torch.randn(num_nodes, 3) edge_index = radius_graph(node_pos, 3.0) edge_src, edge_dst = edge_index num_edges = edge_index.shape[1] edge_attr = node_pos[edge_index[0]] - node_pos[edge_index[1]] node_features = torch.randn(num_nodes, 1) node_attr = torch.randn(num_nodes, 4) edge_scalars = torch.randn(num_edges, 2) assert mp(node_features, node_attr, edge_src, edge_dst, edge_attr, edge_scalars).shape == (num_nodes, 3) assert_equivariant( mp, irreps_in=[mp.irreps_node_input, mp.irreps_node_attr, None, None, mp.irreps_edge_attr, None], args_in=[node_features, node_attr, edge_src, edge_dst, edge_attr, edge_scalars], irreps_out=[mp.irreps_node_output], ) assert_auto_jitable(mp.layers[0].first) e3nn-0.6.0/e3nn/nn/models/v2103/gate_points_networks.py000066400000000000000000000150021514371756200225270ustar00rootroot00000000000000""" >>> test_simple_network() >>> test_network_for_a_graph_with_attributes() """ from typing import Dict, Union import torch from torch_cluster import radius_graph from torch_geometric.data import Data from torch_scatter import scatter from e3nn import o3 from e3nn.math import soft_one_hot_linspace from .gate_points_message_passing import MessagePassing class SimpleNetwork(torch.nn.Module): def __init__( self, irreps_in, irreps_out, max_radius, num_neighbors: int, num_nodes: int, mul: int = 50, layers: int = 3, lmax: int = 2, pool_nodes: bool = True, ) -> None: super().__init__() self.lmax = lmax self.max_radius = max_radius self.number_of_basis = 10 self.num_nodes = num_nodes self.pool_nodes = pool_nodes irreps_node_hidden = o3.Irreps([(mul, (l, p)) for l in range(lmax + 1) for p in [-1, 1]]) self.mp = MessagePassing( irreps_node_input=irreps_in, irreps_node_hidden=irreps_node_hidden, irreps_node_output=irreps_out, irreps_node_attr="0e", irreps_edge_attr=o3.Irreps.spherical_harmonics(lmax), layers=layers, fc_neurons=[self.number_of_basis, 100], num_neighbors=num_neighbors, ) self.irreps_in = self.mp.irreps_node_input self.irreps_out = self.mp.irreps_node_output def preprocess(self, data: Union[Data, Dict[str, torch.Tensor]]) -> torch.Tensor: if "batch" in data: batch = data["batch"] else: batch = data["pos"].new_zeros(data["pos"].shape[0], dtype=torch.long) # Create graph edge_index = radius_graph(data["pos"], self.max_radius, batch, max_num_neighbors=len(data["pos"]) - 1) edge_src = edge_index[0] edge_dst = edge_index[1] # Edge attributes edge_vec = data["pos"][edge_src] - data["pos"][edge_dst] return batch, data["x"], edge_src, edge_dst, edge_vec def forward(self, data: Union[Data, Dict[str, torch.Tensor]]) -> torch.Tensor: batch, node_inputs, edge_src, edge_dst, edge_vec = self.preprocess(data) del data edge_attr = o3.spherical_harmonics(range(self.lmax + 1), edge_vec, True, normalization="component") # Edge length embedding edge_length = edge_vec.norm(dim=1) edge_length_embedding = soft_one_hot_linspace( edge_length, 0.0, self.max_radius, self.number_of_basis, basis="cosine", # the cosine basis with cutoff = True goes to zero at max_radius cutoff=True, # no need for an additional smooth cutoff ).mul(self.number_of_basis**0.5) # Node attributes are not used here node_attr = node_inputs.new_ones(node_inputs.shape[0], 1) node_outputs = self.mp(node_inputs, node_attr, edge_src, edge_dst, edge_attr, edge_length_embedding) if self.pool_nodes: return scatter(node_outputs, batch, dim=0).div(self.num_nodes**0.5) else: return node_outputs class NetworkForAGraphWithAttributes(torch.nn.Module): def __init__( self, irreps_node_input: o3.Irreps, irreps_node_attr: o3.Irreps, irreps_edge_attr: o3.Irreps, irreps_node_output: o3.Irreps, max_radius, num_neighbors: int, num_nodes: int, mul: int = 50, layers: int = 3, lmax: int = 2, pool_nodes: bool = True, ) -> None: super().__init__() self.lmax = lmax self.max_radius = max_radius self.number_of_basis = 10 self.num_nodes = num_nodes self.irreps_edge_attr = o3.Irreps(irreps_edge_attr) self.pool_nodes = pool_nodes irreps_node_hidden = o3.Irreps([(mul, (l, p)) for l in range(lmax + 1) for p in [-1, 1]]) self.mp = MessagePassing( irreps_node_input=irreps_node_input, irreps_node_hidden=irreps_node_hidden, irreps_node_output=irreps_node_output, irreps_node_attr=irreps_node_attr, irreps_edge_attr=self.irreps_edge_attr + o3.Irreps.spherical_harmonics(lmax), layers=layers, fc_neurons=[self.number_of_basis, 100], num_neighbors=num_neighbors, ) self.irreps_node_input = self.mp.irreps_node_input self.irreps_node_attr = self.mp.irreps_node_attr self.irreps_node_output = self.mp.irreps_node_output def forward(self, data: Union[Data, Dict[str, torch.Tensor]]) -> torch.Tensor: if "batch" in data: batch = data["batch"] else: batch = data["pos"].new_zeros(data["pos"].shape[0], dtype=torch.long) # The graph edge_src = data["edge_index"][0] edge_dst = data["edge_index"][1] # Edge attributes edge_vec = data["pos"][edge_src] - data["pos"][edge_dst] edge_sh = o3.spherical_harmonics(range(self.lmax + 1), edge_vec, True, normalization="component") edge_attr = torch.cat([data["edge_attr"], edge_sh], dim=1) # Edge length embedding edge_length = edge_vec.norm(dim=1) edge_length_embedding = soft_one_hot_linspace( edge_length, 0.0, self.max_radius, self.number_of_basis, basis="cosine", # the cosine basis with cutoff = True goes to zero at max_radius cutoff=True, # no need for an additional smooth cutoff ).mul(self.number_of_basis**0.5) node_outputs = self.mp(data["node_input"], data["node_attr"], edge_src, edge_dst, edge_attr, edge_length_embedding) if self.pool_nodes: return scatter(node_outputs, batch, dim=0).div(self.num_nodes**0.5) else: return node_outputs def test_simple_network() -> None: net = SimpleNetwork("3x0e + 2x1o", "4x0e + 1x1o", max_radius=2.0, num_neighbors=3.0, num_nodes=5.0) net({"pos": torch.randn(5, 3), "x": net.irreps_in.randn(5, -1)}) def test_network_for_a_graph_with_attributes() -> None: net = NetworkForAGraphWithAttributes( "3x0e + 2x1o", "4x0e + 1x1o", "1e", "3x0o + 1e", max_radius=2.0, num_neighbors=3.0, num_nodes=5.0 ) net( { "pos": torch.randn(3, 3), "edge_index": torch.tensor([[0, 1, 2], [1, 2, 0]]), "node_input": net.irreps_node_input.randn(3, -1), "node_attr": net.irreps_node_attr.randn(3, -1), "edge_attr": net.irreps_edge_attr.randn(3, -1), } ) e3nn-0.6.0/e3nn/nn/models/v2103/points_convolution.py000066400000000000000000000072031514371756200222360ustar00rootroot00000000000000import torch from torch_scatter import scatter from e3nn import o3 from e3nn.nn import FullyConnectedNet from e3nn.o3 import FullyConnectedTensorProduct, TensorProduct from e3nn.util.jit import compile_mode @compile_mode("script") class Convolution(torch.nn.Module): r"""equivariant convolution Parameters ---------- irreps_node_input : `e3nn.o3.Irreps` representation of the input node features irreps_node_attr : `e3nn.o3.Irreps` representation of the node attributes irreps_edge_attr : `e3nn.o3.Irreps` representation of the edge attributes irreps_node_output : `e3nn.o3.Irreps` or None representation of the output node features fc_neurons : list of int number of neurons per layers in the fully connected network first layer and hidden layers but not the output layer num_neighbors : float typical number of nodes convolved over """ def __init__( self, irreps_node_input, irreps_node_attr, irreps_edge_attr, irreps_node_output, fc_neurons, num_neighbors ) -> None: super().__init__() self.irreps_node_input = o3.Irreps(irreps_node_input) self.irreps_node_attr = o3.Irreps(irreps_node_attr) self.irreps_edge_attr = o3.Irreps(irreps_edge_attr) self.irreps_node_output = o3.Irreps(irreps_node_output) self.num_neighbors = num_neighbors self.sc = FullyConnectedTensorProduct(self.irreps_node_input, self.irreps_node_attr, self.irreps_node_output) self.lin1 = FullyConnectedTensorProduct(self.irreps_node_input, self.irreps_node_attr, self.irreps_node_input) irreps_mid = [] instructions = [] for i, (mul, ir_in) in enumerate(self.irreps_node_input): for j, (_, ir_edge) in enumerate(self.irreps_edge_attr): for ir_out in ir_in * ir_edge: if ir_out in self.irreps_node_output or ir_out == o3.Irrep(0, 1): k = len(irreps_mid) irreps_mid.append((mul, ir_out)) instructions.append((i, j, k, "uvu", True)) irreps_mid = o3.Irreps(irreps_mid) irreps_mid, p, _ = irreps_mid.sort() instructions = [(i_1, i_2, p[i_out], mode, train) for i_1, i_2, i_out, mode, train in instructions] tp = TensorProduct( self.irreps_node_input, self.irreps_edge_attr, irreps_mid, instructions, internal_weights=False, shared_weights=False, ) self.fc = FullyConnectedNet(fc_neurons + [tp.weight_numel], torch.nn.functional.silu) self.tp = tp self.lin2 = FullyConnectedTensorProduct(irreps_mid, self.irreps_node_attr, self.irreps_node_output) self.lin3 = FullyConnectedTensorProduct(irreps_mid, self.irreps_node_attr, "0e") def forward(self, node_input, node_attr, edge_src, edge_dst, edge_attr, edge_scalars) -> torch.Tensor: weight = self.fc(edge_scalars) node_self_connection = self.sc(node_input, node_attr) node_features = self.lin1(node_input, node_attr) edge_features = self.tp(node_features[edge_src], edge_attr, weight) node_features = scatter(edge_features, edge_dst, dim=0, dim_size=node_input.shape[0]).div(self.num_neighbors**0.5) node_conv_out = self.lin2(node_features, node_attr) node_angle = 0.1 * self.lin3(node_features, node_attr) # ^^^------ start small, favor self-connection cos, sin = node_angle.cos(), node_angle.sin() m = self.sc.output_mask sin = (1 - m) + sin * m return cos * node_self_connection + sin * node_conv_out e3nn-0.6.0/e3nn/nn/models/v2103/voxel_convolution.py000066400000000000000000000114421514371756200220570ustar00rootroot00000000000000r""" This is a tentative implementation of voxel convolution >>> test() """ import torch from e3nn import o3 from e3nn.o3 import FullyConnectedTensorProduct, Linear from e3nn.math import soft_one_hot_linspace class Convolution(torch.nn.Module): r"""convolution on voxels Parameters ---------- irreps_in : `e3nn.o3.Irreps` irreps_out : `e3nn.o3.Irreps` irreps_sh : `e3nn.o3.Irreps` set typically to ``o3.Irreps.spherical_harmonics(lmax)`` size : int steps : tuple of int """ def __init__(self, irreps_in, irreps_out, irreps_sh, size, steps=(1, 1, 1), **kwargs) -> None: super().__init__() self.irreps_in = o3.Irreps(irreps_in) self.irreps_out = o3.Irreps(irreps_out) self.irreps_sh = o3.Irreps(irreps_sh) self.size = size self.num_rbfs = self.size if "padding" not in kwargs: kwargs["padding"] = self.size // 2 self.kwargs = kwargs # self-connection self.sc = Linear(self.irreps_in, self.irreps_out) # connection with neighbors r = torch.linspace(-1, 1, self.size) x = r * steps[0] / min(steps) x = x[x.abs() <= 1] y = r * steps[1] / min(steps) y = y[y.abs() <= 1] z = r * steps[2] / min(steps) z = z[z.abs() <= 1] lattice = torch.stack(torch.meshgrid(x, y, z, indexing="ij"), dim=-1) # [x, y, z, R^3] emb = soft_one_hot_linspace( x=lattice.norm(dim=-1), start=0.0, end=1.0, number=self.num_rbfs, basis="smooth_finite", cutoff=True, ) self.register_buffer("emb", emb) sh = o3.spherical_harmonics(self.irreps_sh, lattice, True, "component") # [x, y, z, irreps_sh.dim] self.register_buffer("sh", sh) self.tp = FullyConnectedTensorProduct( self.irreps_in, self.irreps_sh, self.irreps_out, shared_weights=False, compile_left_right=False, compile_right=True, ) self.weight = torch.nn.Parameter(torch.randn(self.num_rbfs, self.tp.weight_numel)) def forward(self, x): r""" Parameters ---------- x : `torch.Tensor` tensor of shape ``(batch, irreps_in.dim, x, y, z)`` Returns ------- `torch.Tensor` tensor of shape ``(batch, irreps_out.dim, x, y, z)`` """ sc = self.sc(x.transpose(1, 4)).transpose(1, 4) weight = self.emb @ self.weight weight = weight / (self.size ** (3 / 2)) kernel = self.tp.right(self.sh, weight) # [x, y, z, irreps_in.dim, irreps_out.dim] kernel = torch.einsum("xyzio->oixyz", kernel) return sc + 0.1 * torch.nn.functional.conv3d(x, kernel, **self.kwargs) class LowPassFilter(torch.nn.Module): def __init__(self, scale, stride: int = 1, transposed: bool = False, steps=(1, 1, 1)) -> None: super().__init__() sigma = 0.5 * (scale**2 - 1) ** 0.5 size = int(1 + 2 * 2.5 * sigma) if size % 2 == 0: size += 1 r = torch.linspace(-1, 1, size) x = r * steps[0] / min(steps) x = x[x.abs() <= 1] y = r * steps[1] / min(steps) y = y[y.abs() <= 1] z = r * steps[2] / min(steps) z = z[z.abs() <= 1] lattice = torch.stack(torch.meshgrid(x, y, z, indexing="ij"), dim=-1) # [x, y, z, R^3] lattice = (size // 2) * lattice kernel = torch.exp(-lattice.norm(dim=-1).pow(2) / (2 * sigma**2)) kernel = kernel / kernel.sum() if transposed: kernel = kernel * stride**3 kernel = kernel[None, None] self.register_buffer("kernel", kernel) self.scale = scale self.stride = stride self.size = size self.transposed = transposed def forward(self, image): """ Parameters ---------- image : `torch.Tensor` tensor of shape ``(..., x, y, z)`` Returns ------- `torch.Tensor` tensor of shape ``(..., x, y, z)`` """ if self.scale <= 1: assert self.stride == 1 return image out = image out = out.reshape(-1, 1, *out.shape[-3:]) if self.transposed: out = torch.nn.functional.conv_transpose3d(out, self.kernel, padding=self.size // 2, stride=self.stride) else: out = torch.nn.functional.conv3d(out, self.kernel, padding=self.size // 2, stride=self.stride) out = out.reshape(*image.shape[:-3], *out.shape[-3:]) return out def test() -> None: conv = Convolution("0e + 1e", "0e + 1e + 1o + 2e + 2o", o3.Irreps.spherical_harmonics(lmax=3), 5) x = torch.randn(10, 4, 32, 32, 32) conv(x) fi = LowPassFilter(2.0) fi(x) e3nn-0.6.0/e3nn/nn/models/v2104/000077500000000000000000000000001514371756200157105ustar00rootroot00000000000000e3nn-0.6.0/e3nn/nn/models/v2104/__init__.py000066400000000000000000000000001514371756200200070ustar00rootroot00000000000000e3nn-0.6.0/e3nn/nn/models/v2104/voxel_convolution.py000066400000000000000000000126161514371756200220640ustar00rootroot00000000000000r""" This is an implementation of voxel convolution >>> test() """ import math import torch from e3nn import o3 from e3nn.o3 import FullyConnectedTensorProduct, Linear from e3nn.math import soft_one_hot_linspace class Convolution(torch.nn.Module): r"""convolution on voxels Parameters ---------- irreps_in : `e3nn.o3.Irreps` input irreps irreps_out : `e3nn.o3.Irreps` output irreps irreps_sh : `e3nn.o3.Irreps` set typically to ``o3.Irreps.spherical_harmonics(lmax)`` diameter : float diameter of the filter in physical units num_radial_basis : int number of radial basis functions steps : tuple of float size of the pixel in physical units """ def __init__(self, irreps_in, irreps_out, irreps_sh, diameter, num_radial_basis, steps=(1.0, 1.0, 1.0), **kwargs) -> None: super().__init__() self.irreps_in = o3.Irreps(irreps_in) self.irreps_out = o3.Irreps(irreps_out) self.irreps_sh = o3.Irreps(irreps_sh) self.num_radial_basis = num_radial_basis # self-connection self.sc = Linear(self.irreps_in, self.irreps_out) # connection with neighbors r = diameter / 2 s = math.floor(r / steps[0]) x = torch.arange(-s, s + 1.0) * steps[0] s = math.floor(r / steps[1]) y = torch.arange(-s, s + 1.0) * steps[1] s = math.floor(r / steps[2]) z = torch.arange(-s, s + 1.0) * steps[2] lattice = torch.stack(torch.meshgrid(x, y, z, indexing="ij"), dim=-1) # [x, y, z, R^3] self.register_buffer("lattice", lattice) if "padding" not in kwargs: kwargs["padding"] = tuple(s // 2 for s in lattice.shape[:3]) self.kwargs = kwargs emb = soft_one_hot_linspace( x=lattice.norm(dim=-1), start=0.0, end=r, number=self.num_radial_basis, basis="smooth_finite", cutoff=True, ) self.register_buffer("emb", emb) sh = o3.spherical_harmonics( l=self.irreps_sh, x=lattice, normalize=True, normalization="component" ) # [x, y, z, irreps_sh.dim] self.register_buffer("sh", sh) self.tp = FullyConnectedTensorProduct( self.irreps_in, self.irreps_sh, self.irreps_out, shared_weights=False, compile_left_right=False, compile_right=True, ) self.weight = torch.nn.Parameter(torch.randn(self.num_radial_basis, self.tp.weight_numel)) def kernel(self) -> torch.Tensor: weight = self.emb @ self.weight weight = weight / (self.sh.shape[0] * self.sh.shape[1] * self.sh.shape[2]) kernel = self.tp.right(self.sh, weight) # [x, y, z, irreps_in.dim, irreps_out.dim] kernel = torch.einsum("xyzio->oixyz", kernel) return kernel def forward(self, x): r""" Parameters ---------- x : `torch.Tensor` tensor of shape ``(batch, irreps_in.dim, x, y, z)`` Returns ------- `torch.Tensor` tensor of shape ``(batch, irreps_out.dim, x, y, z)`` """ sc = self.sc(x.transpose(1, 4)).transpose(1, 4) return sc + torch.nn.functional.conv3d(x, self.kernel(), **self.kwargs) class LowPassFilter(torch.nn.Module): def __init__(self, scale, stride: int = 1, transposed: bool = False, steps=(1, 1, 1)) -> None: super().__init__() sigma = 0.5 * (scale**2 - 1) ** 0.5 size = int(1 + 2 * 2.5 * sigma) if size % 2 == 0: size += 1 r = torch.linspace(-1, 1, size) x = r * steps[0] / min(steps) x = x[x.abs() <= 1] y = r * steps[1] / min(steps) y = y[y.abs() <= 1] z = r * steps[2] / min(steps) z = z[z.abs() <= 1] lattice = torch.stack(torch.meshgrid(x, y, z, indexing="ij"), dim=-1) # [x, y, z, R^3] lattice = (size // 2) * lattice kernel = torch.exp(-lattice.norm(dim=-1).pow(2) / (2 * sigma**2)) kernel = kernel / kernel.sum() if transposed: kernel = kernel * stride**3 kernel = kernel[None, None] self.register_buffer("kernel", kernel) self.scale = scale self.stride = stride self.size = size self.transposed = transposed def forward(self, image): """ Parameters ---------- image : `torch.Tensor` tensor of shape ``(..., x, y, z)`` Returns ------- `torch.Tensor` tensor of shape ``(..., x, y, z)`` """ if self.scale <= 1: assert self.stride == 1 return image out = image out = out.reshape(-1, 1, *out.shape[-3:]) if self.transposed: out = torch.nn.functional.conv_transpose3d(out, self.kernel, padding=self.size // 2, stride=self.stride) else: out = torch.nn.functional.conv3d(out, self.kernel, padding=self.size // 2, stride=self.stride) out = out.reshape(*image.shape[:-3], *out.shape[-3:]) return out def test() -> None: conv = Convolution( "0e + 1e", "0e + 1e + 1o + 2e + 2o", o3.Irreps.spherical_harmonics(lmax=3), diameter=5, num_radial_basis=5, steps=(1, 1, 1), ) x = torch.randn(10, 4, 32, 32, 32) conv(x) fi = LowPassFilter(2.0) fi(x) e3nn-0.6.0/e3nn/nn/models/v2106/000077500000000000000000000000001514371756200157125ustar00rootroot00000000000000e3nn-0.6.0/e3nn/nn/models/v2106/__init__.py000066400000000000000000000000001514371756200200110ustar00rootroot00000000000000e3nn-0.6.0/e3nn/nn/models/v2106/gate_points_message_passing.py000066400000000000000000000137111514371756200240330ustar00rootroot00000000000000""" >>> test() """ import torch from e3nn import o3 from e3nn.nn import Gate from .points_convolution import Convolution def tp_path_exists(irreps_in1, irreps_in2, ir_out) -> bool: irreps_in1 = o3.Irreps(irreps_in1).simplify() irreps_in2 = o3.Irreps(irreps_in2).simplify() ir_out = o3.Irrep(ir_out) for _, ir1 in irreps_in1: for _, ir2 in irreps_in2: if ir_out in ir1 * ir2: return True return False class Compose(torch.nn.Module): def __init__(self, first, second) -> None: super().__init__() self.first = first self.second = second def forward(self, *input): x = self.first(*input) return self.second(x) class MessagePassing(torch.nn.Module): r""" Parameters ---------- irreps_node_sequence : list of `e3nn.o3.Irreps` representation of the input/hidden/output features irreps_node_attr : `e3nn.o3.Irreps` representation of the nodes attributes irreps_edge_attr : `e3nn.o3.Irreps` representation of the edge attributes layers : int number of gates (non linearities) fc_neurons : list of int number of neurons per layers in the fully connected network first layer and hidden layers but not the output layer """ def __init__( self, irreps_node_sequence, irreps_node_attr, irreps_edge_attr, fc_neurons, num_neighbors, ) -> None: super().__init__() self.num_neighbors = num_neighbors irreps_node_sequence = [o3.Irreps(irreps) for irreps in irreps_node_sequence] self.irreps_node_attr = o3.Irreps(irreps_node_attr) self.irreps_edge_attr = o3.Irreps(irreps_edge_attr) act = { 1: torch.nn.functional.silu, -1: torch.tanh, } act_gates = { 1: torch.sigmoid, -1: torch.tanh, } self.layers = torch.nn.ModuleList() self.irreps_node_sequence = [irreps_node_sequence[0]] irreps_node = irreps_node_sequence[0] for irreps_node_hidden in irreps_node_sequence[1:-1]: irreps_scalars = o3.Irreps( [ (mul, ir) for mul, ir in irreps_node_hidden if ir.l == 0 and tp_path_exists(irreps_node, self.irreps_edge_attr, ir) ] ).simplify() irreps_gated = o3.Irreps( [ (mul, ir) for mul, ir in irreps_node_hidden if ir.l > 0 and tp_path_exists(irreps_node, self.irreps_edge_attr, ir) ] ) if irreps_gated.dim > 0: if tp_path_exists(irreps_node, self.irreps_edge_attr, "0e"): ir = "0e" elif tp_path_exists(irreps_node, self.irreps_edge_attr, "0o"): ir = "0o" else: raise ValueError( f"irreps_node={irreps_node} times irreps_edge_attr={self.irreps_edge_attr} is unable to produce gates " f"needed for irreps_gated={irreps_gated}" ) else: ir = None irreps_gates = o3.Irreps([(mul, ir) for mul, _ in irreps_gated]).simplify() gate = Gate( irreps_scalars, [act[ir.p] for _, ir in irreps_scalars], # scalar irreps_gates, [act_gates[ir.p] for _, ir in irreps_gates], # gates (scalars) irreps_gated, # gated tensors ) conv = Convolution( irreps_node, self.irreps_node_attr, self.irreps_edge_attr, gate.irreps_in, fc_neurons, num_neighbors ) self.layers.append(Compose(conv, gate)) irreps_node = gate.irreps_out self.irreps_node_sequence.append(irreps_node) irreps_node_output = irreps_node_sequence[-1] self.layers.append( Convolution( irreps_node, self.irreps_node_attr, self.irreps_edge_attr, irreps_node_output, fc_neurons, num_neighbors ) ) self.irreps_node_sequence.append(irreps_node_output) self.irreps_node_input = self.irreps_node_sequence[0] self.irreps_node_output = self.irreps_node_sequence[-1] def forward(self, node_features, node_attr, edge_src, edge_dst, edge_attr, edge_scalars) -> torch.Tensor: for lay in self.layers: node_features = lay(node_features, node_attr, edge_src, edge_dst, edge_attr, edge_scalars) return node_features def radius_graph(pos, r_max) -> torch.Tensor: # naive version of torch_cluster.radius_graph r = torch.cdist(pos, pos) return ((r < r_max) & (r > 0)).nonzero().T def test() -> None: from e3nn.util.test import assert_equivariant, assert_auto_jitable mp = MessagePassing( irreps_node_sequence=["0e", "0e + 1e", "0e + 1e", "0e + 1e", "1e"], irreps_node_attr="0e + 1e", irreps_edge_attr="0e + 1e", fc_neurons=[2, 100], num_neighbors=3.0, ) num_nodes = 4 node_pos = torch.randn(num_nodes, 3) edge_index = radius_graph(node_pos, 3.0) edge_src, edge_dst = edge_index num_edges = edge_index.shape[1] edge_attr = o3.spherical_harmonics( [0, 1], node_pos[edge_src] - node_pos[edge_dst], normalize=True, normalization="component" ) node_features = torch.randn(num_nodes, 1) node_attr = torch.randn(num_nodes, 4) edge_scalars = torch.randn(num_edges, 2) assert mp(node_features, node_attr, edge_src, edge_dst, edge_attr, edge_scalars).shape == (num_nodes, 3) assert_equivariant( mp, irreps_in=[mp.irreps_node_input, mp.irreps_node_attr, None, None, mp.irreps_edge_attr, None], args_in=[node_features, node_attr, edge_src, edge_dst, edge_attr, edge_scalars], irreps_out=[mp.irreps_node_output], ) assert_auto_jitable(mp.layers[0].first) e3nn-0.6.0/e3nn/nn/models/v2106/gate_points_networks.py000066400000000000000000000165101514371756200225370ustar00rootroot00000000000000""" >>> test_simple_network() >>> test_network_for_a_graph_with_attributes() """ from typing import Dict import torch from e3nn import o3 from e3nn.math import soft_one_hot_linspace from .gate_points_message_passing import MessagePassing def scatter(src: torch.Tensor, index: torch.Tensor, dim_size: int) -> torch.Tensor: # special case of torch_scatter.scatter with dim=0 out = src.new_zeros(dim_size, src.shape[1]) index = index.reshape(-1, 1).expand_as(src) return out.scatter_add_(0, index, src) def radius_graph(pos, r_max, batch) -> torch.Tensor: # naive and inefficient version of torch_cluster.radius_graph r = torch.cdist(pos, pos) index = ((r < r_max) & (r > 0)).nonzero().T index = index[:, batch[index[0]] == batch[index[1]]] return index class SimpleNetwork(torch.nn.Module): def __init__( self, irreps_in, irreps_out, max_radius, num_neighbors, num_nodes, mul=50, layers=3, lmax=2, pool_nodes=True, ) -> None: super().__init__() self.lmax = lmax self.max_radius = max_radius self.number_of_basis = 10 self.num_nodes = num_nodes self.pool_nodes = pool_nodes irreps_node_hidden = o3.Irreps([(mul, (l, p)) for l in range(lmax + 1) for p in [-1, 1]]) self.mp = MessagePassing( irreps_node_sequence=[irreps_in] + layers * [irreps_node_hidden] + [irreps_out], irreps_node_attr="0e", irreps_edge_attr=o3.Irreps.spherical_harmonics(lmax), fc_neurons=[self.number_of_basis, 100], num_neighbors=num_neighbors, ) self.irreps_in = self.mp.irreps_node_input self.irreps_out = self.mp.irreps_node_output def preprocess(self, data: Dict[str, torch.Tensor]) -> torch.Tensor: if "batch" in data: batch = data["batch"] else: batch = data["pos"].new_zeros(data["pos"].shape[0], dtype=torch.long) # Create graph edge_index = radius_graph(data["pos"], self.max_radius, batch) edge_src = edge_index[0] edge_dst = edge_index[1] # Edge attributes edge_vec = data["pos"][edge_src] - data["pos"][edge_dst] return batch, data["x"], edge_src, edge_dst, edge_vec def forward(self, data: Dict[str, torch.Tensor]) -> torch.Tensor: batch, node_inputs, edge_src, edge_dst, edge_vec = self.preprocess(data) del data edge_attr = o3.spherical_harmonics(range(self.lmax + 1), edge_vec, True, normalization="component") # Edge length embedding edge_length = edge_vec.norm(dim=1) edge_length_embedding = soft_one_hot_linspace( edge_length, 0.0, self.max_radius, self.number_of_basis, basis="smooth_finite", # the smooth_finite basis with cutoff = True goes to zero at max_radius cutoff=True, # no need for an additional smooth cutoff ).mul(self.number_of_basis**0.5) # Node attributes are not used here node_attr = node_inputs.new_ones(node_inputs.shape[0], 1) node_outputs = self.mp(node_inputs, node_attr, edge_src, edge_dst, edge_attr, edge_length_embedding) if self.pool_nodes: return scatter(node_outputs, batch, int(batch.max()) + 1).div(self.num_nodes**0.5) else: return node_outputs class NetworkForAGraphWithAttributes(torch.nn.Module): def __init__( self, irreps_node_input, irreps_node_attr, irreps_edge_attr, irreps_node_output, max_radius, num_neighbors, num_nodes, mul=50, layers=3, lmax=2, pool_nodes=True, ) -> None: super().__init__() self.lmax = lmax self.max_radius = max_radius self.number_of_basis = 10 self.num_nodes = num_nodes self.irreps_edge_attr = o3.Irreps(irreps_edge_attr) self.pool_nodes = pool_nodes irreps_node_hidden = o3.Irreps([(mul, (l, p)) for l in range(lmax + 1) for p in [-1, 1]]) self.mp = MessagePassing( irreps_node_sequence=[irreps_node_input] + layers * [irreps_node_hidden] + [irreps_node_output], irreps_node_attr=irreps_node_attr, irreps_edge_attr=self.irreps_edge_attr + o3.Irreps.spherical_harmonics(lmax), fc_neurons=[self.number_of_basis, 100], num_neighbors=num_neighbors, ) self.irreps_node_input = self.mp.irreps_node_input self.irreps_node_attr = self.mp.irreps_node_attr self.irreps_node_output = self.mp.irreps_node_output def preprocess(self, data: Dict[str, torch.Tensor]) -> torch.Tensor: if "batch" in data: batch = data["batch"] else: batch = data["pos"].new_zeros(data["pos"].shape[0], dtype=torch.long) # Create graph if "edge_index" in data: edge_src = data["edge_index"][0] edge_dst = data["edge_index"][1] else: edge_index = radius_graph(data["pos"], self.max_radius, batch) edge_src = edge_index[0] edge_dst = edge_index[1] # Edge attributes edge_vec = data["pos"][edge_src] - data["pos"][edge_dst] if "x" in data: node_input = data["x"] else: node_input = data["node_input"] node_attr = data["node_attr"] edge_attr = data["edge_attr"] return batch, node_input, node_attr, edge_attr, edge_src, edge_dst, edge_vec def forward(self, data: Dict[str, torch.Tensor]) -> torch.Tensor: batch, node_input, node_attr, edge_attr, edge_src, edge_dst, edge_vec = self.preprocess(data) del data # Edge attributes edge_sh = o3.spherical_harmonics(range(self.lmax + 1), edge_vec, True, normalization="component") edge_attr = torch.cat([edge_attr, edge_sh], dim=1) # Edge length embedding edge_length = edge_vec.norm(dim=1) edge_length_embedding = soft_one_hot_linspace( edge_length, 0.0, self.max_radius, self.number_of_basis, basis="smooth_finite", # the smooth_finite basis with cutoff = True goes to zero at max_radius cutoff=True, # no need for an additional smooth cutoff ).mul(self.number_of_basis**0.5) node_outputs = self.mp(node_input, node_attr, edge_src, edge_dst, edge_attr, edge_length_embedding) if self.pool_nodes: return scatter(node_outputs, batch, int(batch.max()) + 1).div(self.num_nodes**0.5) else: return node_outputs def test_simple_network() -> None: net = SimpleNetwork("3x0e + 2x1o", "4x0e + 1x1o", max_radius=2.0, num_neighbors=3.0, num_nodes=5.0) net({"pos": torch.randn(5, 3), "x": net.irreps_in.randn(5, -1)}) def test_network_for_a_graph_with_attributes() -> None: net = NetworkForAGraphWithAttributes( "3x0e + 2x1o", "4x0e + 1x1o", "1e", "3x0o + 1e", max_radius=2.0, num_neighbors=3.0, num_nodes=5.0 ) net( { "pos": torch.randn(3, 3), "edge_index": torch.tensor([[0, 1, 2], [1, 2, 0]]), "node_input": net.irreps_node_input.randn(3, -1), "node_attr": net.irreps_node_attr.randn(3, -1), "edge_attr": net.irreps_edge_attr.randn(3, -1), } ) e3nn-0.6.0/e3nn/nn/models/v2106/points_convolution.py000066400000000000000000000106021514371756200222360ustar00rootroot00000000000000""" Compare to v2103 - replaced the angle trick by a factor alpha inspired by https://arxiv.org/pdf/2002.10444.pdf """ import torch from e3nn import o3 from e3nn.nn import FullyConnectedNet from e3nn.o3 import FullyConnectedTensorProduct, TensorProduct from e3nn.util.jit import compile_mode def scatter(src: torch.Tensor, index: torch.Tensor, dim_size: int) -> torch.Tensor: # special case of torch_scatter.scatter with dim=0 out = src.new_zeros(dim_size, src.shape[1]) index = index.reshape(-1, 1).expand_as(src) return out.scatter_add_(0, index, src) @compile_mode("script") class Convolution(torch.nn.Module): r"""equivariant convolution Parameters ---------- irreps_node_input : `e3nn.o3.Irreps` representation of the input node features irreps_node_attr : `e3nn.o3.Irreps` representation of the node attributes irreps_edge_attr : `e3nn.o3.Irreps` representation of the edge attributes irreps_node_output : `e3nn.o3.Irreps` or None representation of the output node features fc_neurons : list of int number of neurons per layers in the fully connected network first layer and hidden layers but not the output layer num_neighbors : float typical number of nodes convolved over """ def __init__( self, irreps_node_input, irreps_node_attr, irreps_edge_attr, irreps_node_output, fc_neurons, num_neighbors ) -> None: super().__init__() self.irreps_node_input = o3.Irreps(irreps_node_input) self.irreps_node_attr = o3.Irreps(irreps_node_attr) self.irreps_edge_attr = o3.Irreps(irreps_edge_attr) self.irreps_node_output = o3.Irreps(irreps_node_output) self.num_neighbors = num_neighbors self.sc = FullyConnectedTensorProduct(self.irreps_node_input, self.irreps_node_attr, self.irreps_node_output) self.lin1 = FullyConnectedTensorProduct(self.irreps_node_input, self.irreps_node_attr, self.irreps_node_input) irreps_mid = [] instructions = [] for i, (mul, ir_in) in enumerate(self.irreps_node_input): for j, (_, ir_edge) in enumerate(self.irreps_edge_attr): for ir_out in ir_in * ir_edge: if ir_out in self.irreps_node_output or ir_out == o3.Irrep(0, 1): k = len(irreps_mid) irreps_mid.append((mul, ir_out)) instructions.append((i, j, k, "uvu", True)) irreps_mid = o3.Irreps(irreps_mid) irreps_mid, p, _ = irreps_mid.sort() assert irreps_mid.dim > 0, ( f"irreps_node_input={self.irreps_node_input} time irreps_edge_attr={self.irreps_edge_attr} produces nothing " f"in irreps_node_output={self.irreps_node_output}" ) instructions = [(i_1, i_2, p[i_out], mode, train) for i_1, i_2, i_out, mode, train in instructions] tp = TensorProduct( self.irreps_node_input, self.irreps_edge_attr, irreps_mid, instructions, internal_weights=False, shared_weights=False, ) self.fc = FullyConnectedNet(fc_neurons + [tp.weight_numel], torch.nn.functional.silu) self.tp = tp self.lin2 = FullyConnectedTensorProduct(irreps_mid, self.irreps_node_attr, self.irreps_node_output) # inspired by https://arxiv.org/pdf/2002.10444.pdf self.alpha = FullyConnectedTensorProduct(irreps_mid, self.irreps_node_attr, "0e") with torch.no_grad(): self.alpha.weight.zero_() assert ( self.alpha.output_mask[0] == 1.0 ), f"irreps_mid={irreps_mid} and irreps_node_attr={self.irreps_node_attr} are not able to generate scalars" def forward(self, node_input, node_attr, edge_src, edge_dst, edge_attr, edge_scalars) -> torch.Tensor: weight = self.fc(edge_scalars) node_self_connection = self.sc(node_input, node_attr) node_features = self.lin1(node_input, node_attr) edge_features = self.tp(node_features[edge_src], edge_attr, weight) node_features = scatter(edge_features, edge_dst, dim_size=node_input.shape[0]).div(self.num_neighbors**0.5) node_conv_out = self.lin2(node_features, node_attr) alpha = self.alpha(node_features, node_attr) m = self.sc.output_mask alpha = (1 - m) + alpha * m return node_self_connection + alpha * node_conv_out e3nn-0.6.0/e3nn/nn/models/v2203/000077500000000000000000000000001514371756200157105ustar00rootroot00000000000000e3nn-0.6.0/e3nn/nn/models/v2203/sparse_voxel_convolution.py000066400000000000000000000073711514371756200234430ustar00rootroot00000000000000import math import torch from e3nn import o3 from e3nn.math import soft_one_hot_linspace from e3nn.o3 import FullyConnectedTensorProduct, Linear try: from MinkowskiEngine import KernelGenerator, MinkowskiConvolutionFunction, SparseTensor from MinkowskiEngineBackend._C import ConvolutionMode except ImportError: pass class Convolution(torch.nn.Module): r"""convolution on voxels Parameters ---------- irreps_in : `e3nn.o3.Irreps` input irreps irreps_out : `e3nn.o3.Irreps` output irreps irreps_sh : `e3nn.o3.Irreps` set typically to ``o3.Irreps.spherical_harmonics(lmax)`` diameter : float diameter of the filter in physical units num_radial_basis : int number of radial basis functions steps : tuple of float size of the pixel in physical units """ def __init__(self, irreps_in, irreps_out, irreps_sh, diameter, num_radial_basis, steps=(1.0, 1.0, 1.0)) -> None: super().__init__() self.irreps_in = o3.Irreps(irreps_in) self.irreps_out = o3.Irreps(irreps_out) self.irreps_sh = o3.Irreps(irreps_sh) self.num_radial_basis = num_radial_basis # self-connection self.sc = Linear(self.irreps_in, self.irreps_out) # connection with neighbors r = diameter / 2 s = math.floor(r / steps[0]) x = torch.arange(-s, s + 1.0) * steps[0] s = math.floor(r / steps[1]) y = torch.arange(-s, s + 1.0) * steps[1] s = math.floor(r / steps[2]) z = torch.arange(-s, s + 1.0) * steps[2] lattice = torch.stack(torch.meshgrid(x, y, z, indexing="ij"), dim=-1) # [x, y, z, R^3] self.register_buffer("lattice", lattice) emb = soft_one_hot_linspace( x=lattice.norm(dim=-1), start=0.0, end=r, number=self.num_radial_basis, basis="smooth_finite", cutoff=True, ) self.register_buffer("emb", emb) sh = o3.spherical_harmonics( l=self.irreps_sh, x=lattice, normalize=True, normalization="component" ) # [x, y, z, irreps_sh.dim] self.register_buffer("sh", sh) self.tp = FullyConnectedTensorProduct( self.irreps_in, self.irreps_sh, self.irreps_out, shared_weights=False, compile_left_right=False, compile_right=True, ) self.weight = torch.nn.Parameter(torch.randn(self.num_radial_basis, self.tp.weight_numel)) self.kernel_generator = KernelGenerator(lattice.shape[:3], dimension=3) self.conv_fn = MinkowskiConvolutionFunction() def kernel(self) -> torch.Tensor: weight = self.emb @ self.weight weight = weight / (self.sh.shape[0] * self.sh.shape[1] * self.sh.shape[2]) kernel = self.tp.right(self.sh, weight) # [x, y, z, irreps_in.dim, irreps_out.dim] # TODO: understand why this is necessary kernel = torch.einsum("xyzij->zyxij", kernel) # [z, y, x, irreps_in.dim, irreps_out.dim] kernel = kernel.reshape(-1, *kernel.shape[-2:]) # [z * y * x, irreps_in.dim, irreps_out.dim] return kernel def forward(self, x): r""" Parameters ---------- x : SparseTensor Returns ------- SparseTensor """ sc = self.sc(x.F) out = self.conv_fn.apply( x.F, self.kernel(), self.kernel_generator, ConvolutionMode.DEFAULT, x.coordinate_map_key, x.coordinate_map_key, x._manager, ) return SparseTensor( sc + out, coordinate_map_key=x.coordinate_map_key, coordinate_manager=x._manager, ) e3nn-0.6.0/e3nn/o3/000077500000000000000000000000001514371756200135575ustar00rootroot00000000000000e3nn-0.6.0/e3nn/o3/__init__.py000066400000000000000000000054741514371756200157020ustar00rootroot00000000000000from ._rotation import ( rand_matrix, identity_angles, rand_angles, compose_angles, inverse_angles, identity_quaternion, rand_quaternion, compose_quaternion, inverse_quaternion, rand_axis_angle, compose_axis_angle, matrix_x, matrix_y, matrix_z, angles_to_matrix, matrix_to_angles, angles_to_quaternion, matrix_to_quaternion, axis_angle_to_quaternion, quaternion_to_axis_angle, matrix_to_axis_angle, angles_to_axis_angle, axis_angle_to_matrix, quaternion_to_matrix, quaternion_to_angles, axis_angle_to_angles, angles_to_xyz, xyz_to_angles, ) from ._wigner import wigner_D, wigner_3j, change_basis_real_to_complex, su2_generators, so3_generators from ._irreps import Irrep, Irreps from ._tensor_product import ( Instruction, TensorProduct, FullyConnectedTensorProduct, ElementwiseTensorProduct, FullTensorProduct, TensorSquare, ) from .experimental import FullTensorProductv2 from ._spherical_harmonics import SphericalHarmonics, spherical_harmonics from ._angular_spherical_harmonics import ( SphericalHarmonicsAlphaBeta, spherical_harmonics_alpha_beta, spherical_harmonics_alpha, Legendre, ) from ._reduce import ReducedTensorProducts from ._s2grid import ( s2_grid, spherical_harmonics_s2_grid, rfft, irfft, ToS2Grid, FromS2Grid, ) from ._so3grid import SO3Grid from ._linear import Linear from ._norm import Norm __all__ = [ "rand_matrix", "identity_angles", "rand_angles", "compose_angles", "inverse_angles", "identity_quaternion", "rand_quaternion", "compose_quaternion", "inverse_quaternion", "rand_axis_angle", "compose_axis_angle", "matrix_x", "matrix_y", "matrix_z", "angles_to_matrix", "matrix_to_angles", "angles_to_quaternion", "matrix_to_quaternion", "axis_angle_to_quaternion", "quaternion_to_axis_angle", "matrix_to_axis_angle", "angles_to_axis_angle", "axis_angle_to_matrix", "quaternion_to_matrix", "quaternion_to_angles", "axis_angle_to_angles", "angles_to_xyz", "xyz_to_angles", "wigner_D", "wigner_3j", "change_basis_real_to_complex", "su2_generators", "so3_generators", "Irrep", "Irreps", "irrep", "Instruction", "TensorProduct", "FullyConnectedTensorProduct", "ElementwiseTensorProduct", "FullTensorProduct", "FullTensorProductv2", "TensorSquare", "SphericalHarmonics", "spherical_harmonics", "SphericalHarmonicsAlphaBeta", "spherical_harmonics_alpha_beta", "spherical_harmonics_alpha", "Legendre", "ReducedTensorProducts", "s2_grid", "spherical_harmonics_s2_grid", "rfft", "irfft", "ToS2Grid", "FromS2Grid", "SO3Grid", "Linear", "Norm", ] e3nn-0.6.0/e3nn/o3/_angular_spherical_harmonics.py000066400000000000000000000137461514371756200220310ustar00rootroot00000000000000r"""Spherical Harmonics as functions of Euler angles""" import math from typing import List, Tuple import torch from torch import fx from sympy import Integer, Poly, diff, factorial, pi, sqrt, symbols from e3nn.util.jit import compile_mode from e3nn import o3, get_optimization_defaults def _conditional_script(fn): """apply torch.jit.script only if jit_mode is 'script'""" if get_optimization_defaults()["jit_mode"] == "script": return torch.jit.script(fn) return fn @compile_mode("script") class SphericalHarmonicsAlphaBeta(torch.nn.Module): """JITable module version of :meth:`e3nn.o3.spherical_harmonics_alpha_beta`. Parameters are identical to :meth:`e3nn.o3.spherical_harmonics_alpha_beta`. """ normalization: str _ls_list: List[int] _lmax: int def __init__(self, l, normalization: str = "integral") -> None: super().__init__() if isinstance(l, o3.Irreps): ls = [l for mul, (l, p) in l for _ in range(mul)] elif isinstance(l, int): ls = [l] else: ls = list(l) self._ls_list = ls self._lmax = max(ls) self.legendre = Legendre(ls) self.normalization = normalization def forward(self, alpha: torch.Tensor, beta: torch.Tensor) -> torch.Tensor: y, z = beta.cos(), beta.sin() sha = spherical_harmonics_alpha(self._lmax, alpha.flatten()) # [z, m] shy = self.legendre(y.flatten(), z.flatten()) # [z, l * m] out = _mul_m_lm([(1, l) for l in self._ls_list], sha, shy) if self.normalization == "norm": out.div_( torch.cat( [ (math.sqrt(2 * l + 1) / math.sqrt(4 * math.pi)) * torch.ones(2 * l + 1, dtype=out.dtype, device=out.device) for l in self._ls_list ] ) ) elif self.normalization == "component": out.mul_(math.sqrt(4 * math.pi)) return out.reshape(alpha.shape + (shy.shape[1],)) def spherical_harmonics_alpha_beta(l, alpha, beta, *, normalization: str = "integral"): r"""Spherical harmonics of :math:`\vec r = R_y(\alpha) R_x(\beta) e_y` .. math:: Y^l(\alpha, \beta) = S^l(\alpha) P^l(\cos(\beta)) where :math:`P^l` are the `Legendre` polynomials Parameters ---------- l : int or list of int degree of the spherical harmonics. alpha : `torch.Tensor` tensor of shape ``(...)``. beta : `torch.Tensor` tensor of shape ``(...)``. Returns ------- `torch.Tensor` a tensor of shape ``(..., 2l+1)`` """ sh = SphericalHarmonicsAlphaBeta(l, normalization=normalization) return sh(alpha, beta) @_conditional_script def spherical_harmonics_alpha(l: int, alpha: torch.Tensor) -> torch.Tensor: r""":math:`S^l(\alpha)` of `spherical_harmonics_alpha_beta` Parameters ---------- l : int degree of the spherical harmonics. alpha : `torch.Tensor` tensor of shape ``(...)``. Returns ------- `torch.Tensor` a tensor of shape ``(..., 2l+1)`` """ alpha = alpha.unsqueeze(-1) # [..., 1] m = torch.arange(1, l + 1, dtype=alpha.dtype, device=alpha.device) # [1, 2, 3, ..., l] cos = torch.cos(m * alpha) # [..., m] m = torch.arange(l, 0, -1, dtype=alpha.dtype, device=alpha.device) # [l, l-1, l-2, ..., 1] sin = torch.sin(m * alpha) # [..., m] out = torch.cat( [ math.sqrt(2) * sin, torch.ones_like(alpha), math.sqrt(2) * cos, ], dim=alpha.ndim - 1, ) return out # [..., m] @compile_mode("script") class Legendre(fx.GraphModule): # pylint: disable=abstract-method def __init__(self, ls) -> None: super().__init__(self, fx.Graph()) # == generate code == graph = self.graph z = fx.Proxy(graph.placeholder("z", torch.Tensor)) y = fx.Proxy(graph.placeholder("y", torch.Tensor)) out = z.new_zeros(z.shape + (sum(2 * l + 1 for l in ls),)) i = 0 for l in ls: leg = [] for m in range(l + 1): p = _poly_legendre(l, m) p = list(p.items()) (zn, yn), c = p[0] x = float(c) * z**zn * y**yn for (zn, yn), c in p[1:]: x += float(c) * z**zn * y**yn leg.append(x.unsqueeze(-1)) for m in range(-l, l + 1): out.narrow(-1, i, 1).copy_(leg[abs(m)]) i += 1 graph.output(out.node, torch.Tensor) self.recompile() def _poly_legendre(l, m): r""" polynomial coefficients of legendre y = sqrt(1 - z^2) """ z, y = symbols("z y", real=True) return Poly(_sympy_legendre(l, m), domain="R", gens=(z, y)).as_dict() def _sympy_legendre(l, m) -> float: r""" en.wikipedia.org/wiki/Associated_Legendre_polynomials - remove two times (-1)^m - use another normalization such that P(l, -m) = P(l, m) - remove (-1)^l y = sqrt(1 - z^2) """ l = Integer(l) m = Integer(abs(m)) z, y = symbols("z y", real=True) ex = 1 / (2**l * factorial(l)) * y**m * diff((z**2 - 1) ** l, z, l + m) ex *= sqrt((2 * l + 1) / (4 * pi) * factorial(l - m) / factorial(l + m)) return ex @_conditional_script def _mul_m_lm(mul_l: List[Tuple[int, int]], x_m: torch.Tensor, x_lm: torch.Tensor) -> torch.Tensor: """ multiply tensor [..., l * m] by [..., m] """ l_max = x_m.shape[-1] // 2 out = [] i = 0 for mul, l in mul_l: d = mul * (2 * l + 1) x1 = x_lm[..., i : i + d] # [..., mul * m] x1 = x1.reshape(x1.shape[:-1] + (mul, 2 * l + 1)) # [..., mul, m] x2 = x_m[..., l_max - l : l_max + l + 1] # [..., m] x2 = x2.reshape(x2.shape[:-1] + (1, 2 * l + 1)) # [..., mul=1, m] x = x1 * x2 x = x.reshape(x.shape[:-2] + (d,)) out.append(x) i += d return torch.cat(out, dim=-1) e3nn-0.6.0/e3nn/o3/_irreps.py000066400000000000000000000561601514371756200156040ustar00rootroot00000000000000import itertools import collections from typing import List, Union, Callable import torch from e3nn.math import direct_sum, perm # These imports avoid cyclic reference from o3 itself from . import _rotation from . import _wigner class Irrep(tuple): r"""Irreducible representation of :math:`O(3)` This class does not contain any data, it is a structure that describe the representation. It is typically used as argument of other classes of the library to define the input and output representations of functions. Parameters ---------- l : int non-negative integer, the degree of the representation, :math:`l = 0, 1, \dots` p : {1, -1} the parity of the representation Examples -------- Create a scalar representation (:math:`l=0`) of even parity. >>> Irrep(0, 1) 0e Create a pseudotensor representation (:math:`l=2`) of odd parity. >>> Irrep(2, -1) 2o Create a vector representation (:math:`l=1`) of the parity of the spherical harmonics (:math:`-1^l` gives odd parity). >>> Irrep("1y") 1o >>> Irrep("2o").dim 5 >>> Irrep("2e") in Irrep("1o") * Irrep("1o") True >>> Irrep("1o") + Irrep("2o") 1x1o+1x2o """ def __new__(cls, l: Union[int, "Irrep", str, tuple], p=None): if p is None: if isinstance(l, Irrep): return l if isinstance(l, _MulIr): return l.ir.l if isinstance(l, str): try: name = l.strip() l = int(name[:-1]) assert l >= 0 p = { "e": 1, "o": -1, "y": (-1) ** l, }[name[-1]] except Exception: raise ValueError(f'unable to convert string "{name}" into an Irrep') elif isinstance(l, tuple): l, p = l if not isinstance(l, int) or l < 0: raise ValueError(f"l must be positive integer, got {l}") if p not in (-1, 1): raise ValueError(f"parity must be on of (-1, 1), got {p}") return super().__new__(cls, (l, p)) @property def l(self) -> int: # noqa: E743 r"""The degree of the representation, :math:`l = 0, 1, \dots`.""" return self[0] @property def p(self) -> int: r"""The parity of the representation, :math:`p = \pm 1`.""" return self[1] def __repr__(self) -> str: p = {+1: "e", -1: "o"}[self.p] return f"{self.l}{p}" @classmethod def iterator(cls, lmax=None): r"""Iterator through all the irreps of :math:`O(3)` Examples -------- >>> it = Irrep.iterator() >>> next(it), next(it), next(it), next(it) (0e, 0o, 1o, 1e) """ for l in itertools.count(): yield Irrep(l, (-1) ** l) yield Irrep(l, -((-1) ** l)) if l == lmax: break def D_from_angles(self, alpha, beta, gamma, k=None) -> torch.Tensor: r"""Matrix :math:`p^k D^l(\alpha, \beta, \gamma)` (matrix) Representation of :math:`O(3)`. :math:`D` is the representation of :math:`SO(3)`, see `wigner_D`. Parameters ---------- alpha : `torch.Tensor` tensor of shape :math:`(...)` Rotation :math:`\alpha` around Y axis, applied third. beta : `torch.Tensor` tensor of shape :math:`(...)` Rotation :math:`\beta` around X axis, applied second. gamma : `torch.Tensor` tensor of shape :math:`(...)` Rotation :math:`\gamma` around Y axis, applied first. k : `torch.Tensor`, optional tensor of shape :math:`(...)` How many times the parity is applied. Returns ------- `torch.Tensor` tensor of shape :math:`(..., 2l+1, 2l+1)` See Also -------- o3.wigner_D Irreps.D_from_angles """ if k is None: k = torch.zeros_like(alpha) alpha, beta, gamma, k = torch.broadcast_tensors(alpha, beta, gamma, k) return _wigner.wigner_D(self.l, alpha, beta, gamma) * self.p ** k[..., None, None] def D_from_quaternion(self, q, k=None) -> torch.Tensor: r"""Matrix of the representation, see `Irrep.D_from_angles` Parameters ---------- q : `torch.Tensor` tensor of shape :math:`(..., 4)` k : `torch.Tensor`, optional tensor of shape :math:`(...)` Returns ------- `torch.Tensor` tensor of shape :math:`(..., 2l+1, 2l+1)` """ return self.D_from_angles(*_rotation.quaternion_to_angles(q), k) def D_from_matrix(self, R) -> torch.Tensor: r"""Matrix of the representation, see `Irrep.D_from_angles` Parameters ---------- R : `torch.Tensor` tensor of shape :math:`(..., 3, 3)` k : `torch.Tensor`, optional tensor of shape :math:`(...)` Returns ------- `torch.Tensor` tensor of shape :math:`(..., 2l+1, 2l+1)` Examples -------- >>> m = Irrep(1, -1).D_from_matrix(-torch.eye(3)) >>> m.long() tensor([[-1, 0, 0], [ 0, -1, 0], [ 0, 0, -1]]) """ d = torch.det(R).sign() R = d[..., None, None] * R k = (1 - d) / 2 return self.D_from_angles(*_rotation.matrix_to_angles(R), k) def D_from_axis_angle(self, axis, angle) -> torch.Tensor: r"""Matrix of the representation, see `Irrep.D_from_angles` Parameters ---------- axis : `torch.Tensor` tensor of shape :math:`(..., 3)` angle : `torch.Tensor` tensor of shape :math:`(...)` Returns ------- `torch.Tensor` tensor of shape :math:`(..., 2l+1, 2l+1)` """ return self.D_from_angles(*_rotation.axis_angle_to_angles(axis, angle)) @property def dim(self) -> int: """The dimension of the representation, :math:`2 l + 1`.""" return 2 * self.l + 1 def is_scalar(self) -> bool: """Equivalent to ``l == 0 and p == 1``""" return self.l == 0 and self.p == 1 def __mul__(self, other): r"""Generate the irreps from the product of two irreps. Returns ------- generator of `e3nn.o3.Irrep` """ other = Irrep(other) p = self.p * other.p lmin = abs(self.l - other.l) lmax = self.l + other.l for l in range(lmin, lmax + 1): yield Irrep(l, p) def count(self, _value): raise NotImplementedError def index(self, _value): raise NotImplementedError def __rmul__(self, other): r""" >>> 3 * Irrep('1e') 3x1e """ assert isinstance(other, int) return Irreps([(other, self)]) def __add__(self, other): return Irreps(self) + Irreps(other) def __contains__(self, _object): raise NotImplementedError def __len__(self): raise NotImplementedError class _MulIr(tuple): def __new__(cls, mul, ir=None): if ir is None: mul, ir = mul assert isinstance(mul, int) assert isinstance(ir, Irrep) return super().__new__(cls, (mul, ir)) @property def mul(self) -> int: return self[0] @property def ir(self) -> Irrep: return self[1] @property def dim(self) -> int: return self.mul * self.ir.dim def __repr__(self) -> str: return f"{self.mul}x{self.ir}" def __getitem__(self, item) -> Union[int, Irrep]: # pylint: disable=useless-super-delegation return super().__getitem__(item) def count(self, _value): raise NotImplementedError def index(self, _value): raise NotImplementedError class Irreps(tuple): r"""Direct sum of irreducible representations of :math:`O(3)` This class does not contain any data, it is a structure that describe the representation. It is typically used as argument of other classes of the library to define the input and output representations of functions. Attributes ---------- dim : int the total dimension of the representation num_irreps : int number of irreps. the sum of the multiplicities ls : list of int list of :math:`l` values lmax : int maximum :math:`l` value Examples -------- Create a representation of 100 :math:`l=0` of even parity and 50 pseudo-vectors. >>> x = Irreps([(100, (0, 1)), (50, (1, 1))]) >>> x 100x0e+50x1e >>> x.dim 250 Create a representation of 100 :math:`l=0` of even parity and 50 pseudo-vectors. >>> Irreps("100x0e + 50x1e") 100x0e+50x1e >>> Irreps("100x0e + 50x1e + 0x2e") 100x0e+50x1e+0x2e >>> Irreps("100x0e + 50x1e + 0x2e").lmax 1 >>> Irrep("2e") in Irreps("0e + 2e") True Empty Irreps >>> Irreps(), Irreps("") (, ) """ # Marker attribute to identify Irreps instances across different class definitions # (e.g., when using torch.package). This enables isinstance-like checks via hasattr # when the class identity differs between packaged and environment code. _e3nn_irreps_marker = True def __new__(cls, irreps=None) -> Union[_MulIr, "Irreps"]: if isinstance(irreps, Irreps): return super().__new__(cls, irreps) out = [] if isinstance(irreps, Irrep): out.append(_MulIr(1, Irrep(irreps))) elif isinstance(irreps, str): try: if irreps.strip() != "": for mul_ir in irreps.split("+"): if "x" in mul_ir: mul, ir = mul_ir.split("x") mul = int(mul) ir = Irrep(ir) else: mul = 1 ir = Irrep(mul_ir) assert isinstance(mul, int) and mul >= 0 out.append(_MulIr(mul, ir)) except Exception: raise ValueError(f'Unable to convert string "{irreps}" into an Irreps') elif irreps is None: pass else: for mul_ir in irreps: mul = None ir = None if isinstance(mul_ir, str): mul = 1 ir = Irrep(mul_ir) elif isinstance(mul_ir, Irrep): mul = 1 ir = mul_ir elif isinstance(mul_ir, _MulIr): mul, ir = mul_ir elif len(mul_ir) == 2: mul, ir = mul_ir ir = Irrep(ir) if not (isinstance(mul, int) and mul >= 0 and ir is not None): raise ValueError(f'Unable to interpret "{mul_ir}" as an irrep.') out.append(_MulIr(mul, ir)) return super().__new__(cls, out) @staticmethod def spherical_harmonics(lmax: int, p: int = -1) -> "Irreps": r"""representation of the spherical harmonics Parameters ---------- lmax : int maximum :math:`l` p : {1, -1} the parity of the representation Returns ------- `e3nn.o3.Irreps` representation of :math:`(Y^0, Y^1, \dots, Y^{\mathrm{lmax}})` Examples -------- >>> Irreps.spherical_harmonics(3) 1x0e+1x1o+1x2e+1x3o >>> Irreps.spherical_harmonics(4, p=1) 1x0e+1x1e+1x2e+1x3e+1x4e """ return Irreps([(1, (l, p**l)) for l in range(lmax + 1)]) def slices(self): r"""List of slices corresponding to indices for each irrep. Examples -------- >>> Irreps('2x0e + 1e').slices() [slice(0, 2, None), slice(2, 5, None)] """ s = [] i = 0 for mul_ir in self: s.append(slice(i, i + mul_ir.dim)) i += mul_ir.dim return s def randn( self, *size: int, normalization: str = "component", requires_grad: bool = False, dtype=None, device=None ) -> torch.Tensor: r"""Random tensor. Parameters ---------- *size : list of int size of the output tensor, needs to contains a ``-1`` normalization : {'component', 'norm'} Returns ------- `torch.Tensor` tensor of shape ``size`` where ``-1`` is replaced by ``self.dim`` Examples -------- >>> Irreps("5x0e + 10x1o").randn(5, -1, 5, normalization='norm').shape torch.Size([5, 35, 5]) >>> random_tensor = Irreps("2o").randn(2, -1, 3, normalization='norm') >>> random_tensor.norm(dim=1).sub(1).abs().max().item() < 1e-5 True """ di = size.index(-1) lsize = size[:di] rsize = size[di + 1 :] if normalization == "component": return torch.randn(*lsize, self.dim, *rsize, requires_grad=requires_grad, dtype=dtype, device=device) elif normalization == "norm": x = torch.zeros(*lsize, self.dim, *rsize, requires_grad=requires_grad, dtype=dtype, device=device) with torch.no_grad(): for s, (mul, ir) in zip(self.slices(), self): r = torch.randn(*lsize, mul, ir.dim, *rsize, dtype=dtype, device=device) r.div_(r.norm(2, dim=di + 1, keepdim=True)) x.narrow(di, s.start, mul * ir.dim).copy_(r.reshape(*lsize, -1, *rsize)) return x else: raise ValueError("Normalization needs to be 'norm' or 'component'") def __getitem__(self, i) -> Union[_MulIr, "Irreps"]: x = super().__getitem__(i) if isinstance(i, slice): return Irreps(x) return x def __contains__(self, ir) -> bool: ir = Irrep(ir) return ir in (irrep for _, irrep in self) def count(self, ir) -> int: r"""Multiplicity of ``ir``. Parameters ---------- ir : `e3nn.o3.Irrep` Returns ------- `int` total multiplicity of ``ir`` """ ir = Irrep(ir) return sum(mul for mul, irrep in self if ir == irrep) def index(self, _object): raise NotImplementedError def __add__(self, irreps) -> "Irreps": irreps = Irreps(irreps) return Irreps(super().__add__(irreps)) def __mul__(self, other) -> "Irreps": r""" >>> (Irreps('2x1e') * 3).simplify() 6x1e """ if isinstance(other, Irreps): raise NotImplementedError("Use o3.TensorProduct for this, see the documentation") return Irreps(super().__mul__(other)) def __rmul__(self, other) -> "Irreps": r""" >>> 2 * Irreps('0e + 1e') 1x0e+1x1e+1x0e+1x1e """ return Irreps(super().__rmul__(other)) def simplify(self) -> "Irreps": """Simplify the representations. Returns ------- `e3nn.o3.Irreps` Examples -------- Note that simplify does not sort the representations. >>> Irreps("1e + 1e + 0e").simplify() 2x1e+1x0e Equivalent representations which are separated from each other are not combined. >>> Irreps("1e + 1e + 0e + 1e").simplify() 2x1e+1x0e+1x1e """ out = [] for mul, ir in self: if out and out[-1][1] == ir: out[-1] = (out[-1][0] + mul, ir) elif mul > 0: out.append((mul, ir)) return Irreps(out) def remove_zero_multiplicities(self) -> "Irreps": """Remove any irreps with multiplicities of zero. Returns ------- `e3nn.o3.Irreps` Examples -------- >>> Irreps("4x0e + 0x1o + 2x3e").remove_zero_multiplicities() 4x0e+2x3e """ out = [(mul, ir) for mul, ir in self if mul > 0] return Irreps(out) def sort(self): r"""Sort the representations. Returns ------- irreps : `e3nn.o3.Irreps` p : tuple of int inv : tuple of int Examples -------- >>> Irreps("1e + 0e + 1e").sort().irreps 1x0e+1x1e+1x1e >>> Irreps("2o + 1e + 0e + 1e").sort().p (3, 1, 0, 2) >>> Irreps("2o + 1e + 0e + 1e").sort().inv (2, 1, 3, 0) """ Ret = collections.namedtuple("sort", ["irreps", "p", "inv"]) out = [(ir, i, mul) for i, (mul, ir) in enumerate(self)] out = sorted(out) inv = tuple(i for _, i, _ in out) p = perm.inverse(inv) irreps = Irreps([(mul, ir) for ir, _, mul in out]) return Ret(irreps, p, inv) def regroup(self) -> "Irreps": r"""Regroup the same irreps together. Equivalent to :meth:`sort` followed by :meth:`simplify`. Returns ------- irreps: `e3nn.o3.Irreps` Examples -------- >>> Irreps("1e + 0e + 1e + 0x2e").regroup() 1x0e+2x1e """ return self.sort().irreps.simplify() def filter( self, keep: Union["Irreps", List[Irrep], Callable[[_MulIr], bool]] = None, *, drop: Union["Irreps", List[Irrep], Callable[[_MulIr], bool]] = None, lmax: int = None, ) -> "Irreps": r"""Filter the irreps. Args: keep (`Irreps` or list of `Irrep` or function): list of irrep to keep drop (`Irreps` or list of `Irrep` or function): list of irrep to drop lmax (int): maximum :math:`l` value Returns: `Irreps`: filtered irreps Examples: >>> Irreps("1e + 2e + 0e").filter(keep=["0e", "1e"]) 1x1e+1x0e >>> Irreps("1e + 2e + 0e").filter(keep="2e + 2x1e") 1x1e+1x2e >>> Irreps("1e + 2e + 0e").filter(drop="2e + 2x1e") 1x0e >>> Irreps("1e + 2e + 0e").filter(lmax=1) 1x1e+1x0e """ if keep is None and drop is None and lmax is None: return self if keep is not None and drop is not None: raise ValueError("Cannot specify both keep and drop") if keep is not None and lmax is not None: raise ValueError("Cannot specify both keep and lmax") if drop is not None and lmax is not None: raise ValueError("Cannot specify both drop and lmax") if keep is not None: if isinstance(keep, str): keep = Irreps(keep) if isinstance(keep, Irrep): keep = [keep] if isinstance(keep, _MulIr): keep = [keep.ir] if callable(keep): return Irreps([mul_ir for mul_ir in self if keep(mul_ir)]) keep = {Irrep(ir) for ir in keep} return Irreps([(mul, ir) for mul, ir in self if ir in keep]) if drop is not None: if isinstance(drop, str): drop = Irreps(drop) if isinstance(drop, Irrep): drop = [drop] if isinstance(drop, _MulIr): drop = [drop.ir] if callable(drop): return Irreps([mul_ir for mul_ir in self if not drop(mul_ir)]) drop = {Irrep(ir) for ir in drop} return Irreps([(mul, ir) for mul, ir in self if ir not in drop]) if lmax is not None: return Irreps([(mul, ir) for mul, ir in self if ir.l <= lmax]) @property def slice_by_mul(self): r"""Return the slice with respect to the multiplicities. """ return _MulIndexSliceHelper(self) @property def dim(self) -> int: return sum(mul * ir.dim for mul, ir in self) @property def num_irreps(self) -> int: return sum(mul for mul, _ in self) @property def ls(self) -> List[int]: return [l for mul, (l, p) in self for _ in range(mul)] @property def lmax(self) -> int: if len(self) == 0: raise ValueError("Cannot get lmax of empty Irreps") return max(self.ls) def __repr__(self) -> str: return "+".join(f"{mul_ir}" for mul_ir in self) def D_from_angles(self, alpha, beta, gamma, k=None): r"""Matrix of the representation Parameters ---------- alpha : `torch.Tensor` tensor of shape :math:`(...)` beta : `torch.Tensor` tensor of shape :math:`(...)` gamma : `torch.Tensor` tensor of shape :math:`(...)` k : `torch.Tensor`, optional tensor of shape :math:`(...)` Returns ------- `torch.Tensor` tensor of shape :math:`(..., \mathrm{dim}, \mathrm{dim})` """ blocks = [] for mul, ir in self: D = ir.D_from_angles(alpha, beta, gamma, k) blocks.extend([D] * mul) return direct_sum(*blocks) def D_from_quaternion(self, q, k=None): r"""Matrix of the representation Parameters ---------- q : `torch.Tensor` tensor of shape :math:`(..., 4)` k : `torch.Tensor`, optional tensor of shape :math:`(...)` Returns ------- `torch.Tensor` tensor of shape :math:`(..., \mathrm{dim}, \mathrm{dim})` """ return self.D_from_angles(*_rotation.quaternion_to_angles(q), k) def D_from_matrix(self, R): r"""Matrix of the representation Parameters ---------- R : `torch.Tensor` tensor of shape :math:`(..., 3, 3)` Returns ------- `torch.Tensor` tensor of shape :math:`(..., \mathrm{dim}, \mathrm{dim})` """ d = torch.det(R).sign() R = d[..., None, None] * R k = (1 - d) / 2 return self.D_from_angles(*_rotation.matrix_to_angles(R), k) def D_from_axis_angle(self, axis, angle): r"""Matrix of the representation Parameters ---------- axis : `torch.Tensor` tensor of shape :math:`(..., 3)` angle : `torch.Tensor` tensor of shape :math:`(...)` Returns ------- `torch.Tensor` tensor of shape :math:`(..., \mathrm{dim}, \mathrm{dim})` """ return self.D_from_angles(*_rotation.axis_angle_to_angles(axis, angle)) class _MulIndexSliceHelper: irreps: Irreps def __init__(self, irreps) -> None: self.irreps = irreps def __getitem__(self, index: slice) -> Irreps: if not isinstance(index, slice): raise IndexError("Irreps.slice_by_mul only supports slices.") start, stop, stride = index.indices(self.irreps.num_irreps) if stride != 1: raise NotImplementedError("Irreps.slice_by_mul does not support strides.") out = [] i = 0 for mul, ir in self.irreps: if start <= i and i + mul <= stop: out.append((mul, ir)) elif start < i + mul and i < stop: out.append((min(stop, i + mul) - max(start, i), ir)) i += mul return Irreps(out) e3nn-0.6.0/e3nn/o3/_linear.py000066400000000000000000000432221514371756200155450ustar00rootroot00000000000000from typing import List, NamedTuple, Optional, Tuple, Union from opt_einsum_fx import optimize_einsums_full import torch from torch import fx import e3nn from e3nn.o3._irreps import Irreps from e3nn.util import prod from e3nn.util.codegen import CodeGenMixin from e3nn.util.jit import compile_mode from ._tensor_product._codegen import _sum_tensors class Instruction(NamedTuple): i_in: int i_out: int path_shape: tuple path_weight: float # TODO: Need a better that also accounts for the shape class LinearSlices(NamedTuple): slice_1D: slice shape_2D: tuple @compile_mode("script") class Linear(CodeGenMixin, torch.nn.Module): r"""Linear operation equivariant to :math:`O(3)` Notes ----- `e3nn.o3.Linear` objects created with different partitionings of the same irreps, such as ``Linear("10x0e", "0e")`` and ``Linear("3x0e + 7x0e", "0e")``, are *not* equivalent: the second module has more instructions, which affects normalization. In a rough sense: Linear("10x0e", "0e") = normalization_coeff_0 * W_0 @ input Linear("3x0e + 7x0e", "0e") = normalization_coeff_1 * W_1 @ input[:3] + normalization_coeff_2 * W_2 @ input[3:] To make them equivalent, simplify ``irreps_in`` before constructing network modules: o3.Irreps("3x0e + 7x0e").simplify() # => 10x0e Parameters ---------- irreps_in : `e3nn.o3.Irreps` representation of the input irreps_out : `e3nn.o3.Irreps` representation of the output internal_weights : bool whether the `e3nn.o3.Linear` should store its own weights. Defaults to ``True`` unless ``shared_weights`` is explicitly set to ``False``, for consistancy with `e3nn.o3.TensorProduct`. shared_weights : bool whether the `e3nn.o3.Linear` should be weighted individually for each input in a batch. Defaults to ``True``. Cannot be ``False`` if ``internal_weights`` is ``True``. instructions : list of 2-tuples, optional list of tuples ``(i_in, i_out)`` indicating which irreps in ``irreps_in`` should contribute to which irreps in ``irreps_out``. If ``None`` (the default), all allowable instructions will be created: every ``(i_in, i_out)`` such that ``irreps_in[i_in].ir == irreps_out[i_out].ir``. biases : list of bool, optional indicates for each element of ``irreps_out`` if it has a bias. By default there is no bias. If ``biases=True`` it gives bias to all scalars (l=0 and p=1). Attributes ---------- weight_numel : int the size of the weights for this `e3nn.o3.Linear` Examples -------- Linearly combines 4 scalars into 8 scalars and 16 vectors into 8 vectors. >>> lin = Linear("4x0e+16x1o", "8x0e+8x1o") >>> lin.weight_numel 160 Create a "block sparse" linear that does not combine two different groups of scalars; note that the number of weights is 4*4 + 3*3 = 25: >>> lin = Linear("4x0e + 3x0e", "4x0e + 3x0e", instructions=[(0, 0), (1, 1)]) >>> lin.weight_numel 25 Be careful: because they have different instructions, the following two operations are not normalized in the same way, even though they contain all the same "connections": >>> lin1 = Linear("10x0e", "0e") >>> lin2 = Linear("3x0e + 7x0e", "0e") >>> lin1.weight_numel == lin2.weight_numel True >>> with torch.no_grad(): ... lin1.weight.fill_(1.0) ... lin2.weight.fill_(1.0) Parameter containing: ... >>> x = torch.arange(10.0) >>> (lin1(x) - lin2(x)).abs().item() < 1e-5 True """ weight_numel: int internal_weights: bool shared_weights: bool def __init__( self, irreps_in: Irreps, irreps_out: Irreps, *, f_in: Optional[int] = None, f_out: Optional[int] = None, internal_weights: Optional[bool] = None, shared_weights: Optional[bool] = None, instructions: Optional[List[Tuple[int, int]]] = None, biases: Union[bool, List[bool]] = False, path_normalization: str = "element", _optimize_einsums: Optional[bool] = None, ) -> None: super().__init__() assert path_normalization in ["element", "path"] irreps_in = Irreps(irreps_in) irreps_out = Irreps(irreps_out) if instructions is None: # By default, make all possible connections instructions = [ (i_in, i_out) for i_in, (_, ir_in) in enumerate(irreps_in) for i_out, (_, ir_out) in enumerate(irreps_out) if ir_in == ir_out ] instructions = [ Instruction( i_in=i_in, i_out=i_out, path_shape=(irreps_in[i_in].mul, irreps_out[i_out].mul), path_weight=1, ) for i_in, i_out in instructions ] def alpha(ins) -> float: x = sum( irreps_in[i.i_in if path_normalization == "element" else ins.i_in].mul for i in instructions if i.i_out == ins.i_out ) if f_in is not None: x *= f_in return 1.0 if x == 0 else x instructions = [ Instruction(i_in=ins.i_in, i_out=ins.i_out, path_shape=ins.path_shape, path_weight=alpha(ins) ** (-0.5)) for ins in instructions ] for ins in instructions: if not ins.i_in < len(irreps_in): raise IndexError(f"{ins.i_in} is not a valid index for irreps_in") if not ins.i_out < len(irreps_out): raise IndexError(f"{ins.i_out} is not a valid index for irreps_out") if not (ins.i_in == -1 or irreps_in[ins.i_in].ir == irreps_out[ins.i_out].ir): raise ValueError(f"{ins.i_in} and {ins.i_out} do not have the same irrep") if biases is None: biases = len(irreps_out) * (False,) if isinstance(biases, bool): biases = [biases and ir.is_scalar() for _, ir in irreps_out] assert len(biases) == len(irreps_out) assert all(ir.is_scalar() or (not b) for b, (_, ir) in zip(biases, irreps_out)) instructions += [ Instruction(i_in=-1, i_out=i_out, path_shape=(mul_ir.dim,), path_weight=1.0) for i_out, (bias, mul_ir) in enumerate(zip(biases, irreps_out)) if bias ] # == Process arguments == if shared_weights is False and internal_weights is None: internal_weights = False if shared_weights is None: shared_weights = True if internal_weights is None: internal_weights = True assert shared_weights or not internal_weights self.internal_weights = internal_weights self.shared_weights = shared_weights self.irreps_in = irreps_in self.irreps_out = irreps_out self.instructions = instructions opt_defaults = e3nn.get_optimization_defaults() self._optimize_einsums = _optimize_einsums if _optimize_einsums is not None else opt_defaults["optimize_einsums"] del opt_defaults # == Generate code == graphmod, self.weight_numel, self.bias_numel = _codegen_linear( self.irreps_in, self.irreps_out, self.instructions, f_in, f_out, shared_weights=shared_weights, optimize_einsums=self._optimize_einsums, ) self._codegen_register({"_compiled_main": graphmod}) # == Generate weights == if internal_weights and self.weight_numel > 0: assert self.shared_weights, "Having internal weights impose shared weights" self.weight = torch.nn.Parameter(torch.randn(*((f_in, f_out) if f_in is not None else ()), self.weight_numel)) else: # For TorchScript, there always has to be some kind of defined .weight self.register_buffer("weight", torch.Tensor()) # == Generate biases == if internal_weights and self.bias_numel > 0: assert self.shared_weights, "Having internal weights impose shared weights" self.bias = torch.nn.Parameter( torch.zeros(*((f_out,) if f_out is not None else ()), self.bias_numel) ) # see appendix C.1 and Eq.5 of https://arxiv.org/pdf/2011.14522.pdf else: self.register_buffer("bias", torch.Tensor()) # == Compute output mask == if self.irreps_out.dim > 0: output_mask = torch.cat( [ ( torch.ones(mul_ir.dim) if any((ins.i_out == i_out) and (0 not in ins.path_shape) for ins in self.instructions) else torch.zeros(mul_ir.dim) ) for i_out, mul_ir in enumerate(self.irreps_out) ] ) else: output_mask = torch.ones(0) self.register_buffer("output_mask", output_mask) # Register 2D weight slices self.weight_index_slices = [] for i,ins in enumerate(self.instructions): offset = sum(prod(ins_pre.path_shape) for ins_pre in self.instructions[:i]) # TODO: Slop self.weight_index_slices.append(LinearSlices(slice(offset, offset + prod(ins.path_shape), None), ins.path_shape)) def __repr__(self) -> str: return f"{self.__class__.__name__}({self.irreps_in} -> {self.irreps_out} | {self.weight_numel} weights)" def forward(self, features, weight: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None): """evaluate Parameters ---------- features : `torch.Tensor` tensor of shape ``(..., irreps_in.dim)`` weight : `torch.Tensor`, optional required if ``internal_weights`` is `False` Returns ------- `torch.Tensor` tensor of shape ``(..., irreps_out.dim)`` """ if weight is None: if self.weight_numel > 0 and not self.internal_weights: raise RuntimeError("Weights must be provided when internal_weights = False") weight = self.weight if bias is None: if self.bias_numel > 0 and not self.internal_weights: raise RuntimeError("Biases must be provided when internal_weights = False") bias = self.bias return self._compiled_main(features, weight, bias) def weight_view_for_instruction(self, instruction: int, weight: Optional[torch.Tensor] = None) -> torch.Tensor: r"""View of weights corresponding to ``instruction``. Parameters ---------- instruction : int The index of the instruction to get a view on the weights for. weight : `torch.Tensor`, optional like ``weight`` argument to ``forward()`` Returns ------- `torch.Tensor` A view on ``weight`` or this object's internal weights for the weights corresponding to the ``instruction`` th instruction. """ if weight is None: assert self.internal_weights, "Weights must be provided when internal_weights = False" weight = self.weight batchshape = weight.shape[:-1] offset = sum(prod(ins.path_shape) for ins in self.instructions[:instruction]) ins = self.instructions[instruction] return weight.narrow(-1, offset, prod(ins.path_shape)).view(batchshape + ins.path_shape) def weight_views(self, weight: Optional[torch.Tensor] = None, yield_instruction: bool = False): r"""Iterator over weight views for all instructions. Parameters ---------- weight : `torch.Tensor`, optional like ``weight`` argument to ``forward()`` yield_instruction : `bool`, default False Whether to also yield the corresponding instruction. Yields ------ If ``yield_instruction`` is ``True``, yields ``(instruction_index, instruction, weight_view)``. Otherwise, yields ``weight_view``. """ if weight is None: assert self.internal_weights, "Weights must be provided when internal_weights = False" weight = self.weight batchshape = weight.shape[:-1] offset = 0 for ins_i, ins in enumerate(self.instructions): flatsize = prod(ins.path_shape) this_weight = weight.narrow(-1, offset, flatsize).view(batchshape + ins.path_shape) offset += flatsize if yield_instruction: yield ins_i, ins, this_weight else: yield this_weight def _codegen_linear( irreps_in: Irreps, irreps_out: Irreps, instructions: List[Instruction], f_in: Optional[int] = None, f_out: Optional[int] = None, shared_weights: bool = False, optimize_einsums: bool = True, ) -> Tuple[fx.GraphModule, int, int]: graph_out = fx.Graph() tracer_out = fx.proxy.GraphAppendingTracer(graph_out) # = Function definitions = x = fx.Proxy(graph_out.placeholder("x", torch.Tensor), tracer_out) ws = fx.Proxy(graph_out.placeholder("w", torch.Tensor), tracer_out) bs = fx.Proxy(graph_out.placeholder("b", torch.Tensor), tracer_out) if f_in is None: size = x.shape[:-1] outsize = size + (irreps_out.dim,) else: size = x.shape[:-2] outsize = size + ( f_out, irreps_out.dim, ) bias_numel = sum(irreps_out[i.i_out].dim for i in instructions if i.i_in == -1) if bias_numel > 0: if f_out is None: bs = bs.reshape(-1, bias_numel) else: bs = bs.reshape(-1, f_out, bias_numel) # = Short-circut for nothing to do = # We produce no code for empty instructions instructions = [ins for ins in instructions if 0 not in ins.path_shape] if len(instructions) == 0 and bias_numel == 0: out = x.new_zeros(outsize) graph_out.output(out.node, torch.Tensor) # Short circut # 0 is weight_numel return fx.GraphModule({}, graph_out, "linear_forward"), 0, 0 if f_in is None: x = x.reshape(-1, irreps_in.dim) else: x = x.reshape(-1, f_in, irreps_in.dim) batch_out = x.shape[0] weight_numel = sum(prod(ins.path_shape) for ins in instructions if ins.i_in != -1) if weight_numel > 0: ws = ws.reshape(-1, weight_numel) if f_in is None else ws.reshape(-1, f_in, f_out, weight_numel) # = extract individual input irreps = if len(irreps_in) == 1: x_list = [x.reshape(batch_out, *(() if f_in is None else (f_in,)), irreps_in[0].mul, irreps_in[0].ir.dim)] else: x_list = [ x.narrow(-1, i.start, mul_ir.dim).reshape(batch_out, *(() if f_in is None else (f_in,)), mul_ir.mul, mul_ir.ir.dim) for i, mul_ir in zip(irreps_in.slices(), irreps_in) ] z = "" if shared_weights else "z" flat_weight_index = 0 flat_bias_index = 0 out_list = [] for ins in instructions: mul_ir_out = irreps_out[ins.i_out] if ins.i_in == -1: # = bias = b = bs.narrow(-1, flat_bias_index, prod(ins.path_shape)) flat_bias_index += prod(ins.path_shape) out_list += [(ins.path_weight * b).reshape(1, *(() if f_out is None else (f_out,)), mul_ir_out.dim)] else: mul_ir_in = irreps_in[ins.i_in] # Short-circut for empty irreps if mul_ir_in.dim == 0 or mul_ir_out.dim == 0: continue # Extract the weight from the flattened weight tensor path_nweight = prod(ins.path_shape) if len(instructions) == 1: # Avoid unnecessary view when there is only one weight w = ws else: w = ws.narrow(-1, flat_weight_index, path_nweight) w = w.reshape((() if shared_weights else (-1,)) + (() if f_in is None else (f_in, f_out)) + ins.path_shape) flat_weight_index += path_nweight if f_in is None: ein_out = torch.einsum(f"{z}uw,zui->zwi", w, x_list[ins.i_in]) else: ein_out = torch.einsum(f"{z}xyuw,zxui->zywi", w, x_list[ins.i_in]) ein_out = ins.path_weight * ein_out out_list += [ein_out.reshape(batch_out, *(() if f_out is None else (f_out,)), mul_ir_out.dim)] # = Return the result = out = [ _sum_tensors( [out for ins, out in zip(instructions, out_list) if ins.i_out == i_out], shape=(batch_out, *(() if f_out is None else (f_out,)), mul_ir_out.dim), like=x, ) for i_out, mul_ir_out in enumerate(irreps_out) if mul_ir_out.mul > 0 ] if len(out) > 1: out = torch.cat(out, dim=-1) else: out = out[0] out = out.reshape(outsize) graph_out.output(out.node, torch.Tensor) # check graphs graph_out.lint() graphmod_out = fx.GraphModule({}, graph_out, "linear_forward") # TODO: when eliminate_dead_code() is in PyTorch stable, use that if optimize_einsums: # See _tensor_product/_codegen.py for notes batchdim = 4 example_inputs = ( torch.zeros((batchdim, *(() if f_in is None else (f_in,)), irreps_in.dim)), torch.zeros( 1 if shared_weights else batchdim, f_in or 1, f_out or 1, weight_numel, ), torch.zeros( 1 if shared_weights else batchdim, f_out or 1, bias_numel, ), ) graphmod_out = optimize_einsums_full(graphmod_out, example_inputs) return graphmod_out, weight_numel, bias_numel e3nn-0.6.0/e3nn/o3/_norm.py000066400000000000000000000034521514371756200152470ustar00rootroot00000000000000import torch from e3nn.o3._irreps import Irreps from e3nn.o3._tensor_product._tensor_product import TensorProduct from e3nn.util.jit import compile_mode @compile_mode("trace") class Norm(torch.nn.Module): r"""Norm of each irrep in a direct sum of irreps. Parameters ---------- irreps_in : `e3nn.o3.Irreps` representation of the input squared : bool, optional Whether to return the squared norm. ``False`` by default, i.e. the norm itself (sqrt of squared norm) is returned. Examples -------- Compute the norms of 17 vectors. >>> norm = Norm("17x1o") >>> norm(torch.randn(17 * 3)).shape torch.Size([17]) """ squared: bool def __init__(self, irreps_in, squared: bool = False) -> None: super().__init__() irreps_in = Irreps(irreps_in).simplify() irreps_out = Irreps([(mul, "0e") for mul, _ in irreps_in]) instr = [(i, i, i, "uuu", False, ir.dim) for i, (mul, ir) in enumerate(irreps_in)] self.tp = TensorProduct(irreps_in, irreps_in, irreps_out, instr, irrep_normalization="component") self.irreps_in = irreps_in self.irreps_out = irreps_out.simplify() self.squared = squared def __repr__(self) -> str: return f"{self.__class__.__name__}({self.irreps_in})" def forward(self, features): """Compute norms of irreps in ``features``. Parameters ---------- features : `torch.Tensor` tensor of shape ``(..., irreps_in.dim)`` Returns ------- `torch.Tensor` tensor of shape ``(..., irreps_out.dim)`` """ out = self.tp(features, features) if self.squared: return out else: # ReLU fixes gradients at zero return out.relu().sqrt() e3nn-0.6.0/e3nn/o3/_reduce.py000066400000000000000000000253471514371756200155520ustar00rootroot00000000000000import collections import torch from torch import fx from e3nn.o3._irreps import Irrep, Irreps from e3nn.o3._wigner import wigner_3j from e3nn.o3._tensor_product._tensor_product import TensorProduct from e3nn.math import germinate_formulas, orthonormalize, reduce_permutation from e3nn.util import explicit_default_types from e3nn.util.codegen import CodeGenMixin from e3nn.util.jit import compile_mode _TP = collections.namedtuple("tp", "op, args") _INPUT = collections.namedtuple("input", "tensor, start, stop") def _wigner_nj(*irrepss, normalization: str = "component", filter_ir_mid=None, dtype=None, device=None): irrepss = [Irreps(irreps) for irreps in irrepss] if filter_ir_mid is not None: filter_ir_mid = [Irrep(ir) for ir in filter_ir_mid] if len(irrepss) == 1: (irreps,) = irrepss ret = [] e = torch.eye(irreps.dim, dtype=dtype, device=device) i = 0 for mul, ir in irreps: for _ in range(mul): sl = slice(i, i + ir.dim) ret += [(ir, _INPUT(0, sl.start, sl.stop), e[sl])] i += ir.dim return ret *irrepss_left, irreps_right = irrepss ret = [] for ir_left, path_left, C_left in _wigner_nj( *irrepss_left, normalization=normalization, filter_ir_mid=filter_ir_mid, dtype=dtype, device=device ): i = 0 for mul, ir in irreps_right: for ir_out in ir_left * ir: if filter_ir_mid is not None and ir_out not in filter_ir_mid: continue C = wigner_3j(ir_out.l, ir_left.l, ir.l, dtype=dtype, device=device) if normalization == "component": C *= ir_out.dim**0.5 if normalization == "norm": C *= ir_left.dim**0.5 * ir.dim**0.5 C = torch.einsum("jk,ijl->ikl", C_left.flatten(1), C) C = C.reshape(ir_out.dim, *(irreps.dim for irreps in irrepss_left), ir.dim) for u in range(mul): E = torch.zeros( ir_out.dim, *(irreps.dim for irreps in irrepss_left), irreps_right.dim, dtype=dtype, device=device ) sl = slice(i + u * ir.dim, i + (u + 1) * ir.dim) E[..., sl] = C ret += [ ( ir_out, _TP(op=(ir_left, ir, ir_out), args=(path_left, _INPUT(len(irrepss_left), sl.start, sl.stop))), E, ) ] i += mul * ir.dim return sorted(ret, key=lambda x: x[0]) def _get_ops(path): if isinstance(path, _INPUT): return assert isinstance(path, _TP) yield path.op for op in _get_ops(path.args[0]): yield op @compile_mode("trace") class ReducedTensorProducts(CodeGenMixin, torch.nn.Module): r"""reduce a tensor with symmetries into irreducible representations Parameters ---------- formula : str String made of letters ``-`` and ``=`` that represent the indices symmetries of the tensor. For instance ``ij=ji`` means that the tensor has two indices and if they are exchanged, its value is the same. ``ij=-ji`` means that the tensor change its sign if the two indices are exchanged. filter_ir_out : list of `e3nn.o3.Irrep`, optional Optional, list of allowed irrep in the output filter_ir_mid : list of `e3nn.o3.Irrep`, optional Optional, list of allowed irrep in the intermediary operations **kwargs : dict of `e3nn.o3.Irreps` each letter present in the formula has to be present in the ``irreps`` dictionary, unless it can be inferred by the formula. For instance if the formula is ``ij=ji`` you can provide the representation of ``i`` only: ``ReducedTensorProducts('ij=ji', i='1o')``. Attributes ---------- irreps_in : list of `e3nn.o3.Irreps` input representations irreps_out : `e3nn.o3.Irreps` output representation change_of_basis : `torch.Tensor` tensor of shape ``(irreps_out.dim, irreps_in[0].dim, ..., irreps_in[-1].dim)`` Examples -------- >>> tp = ReducedTensorProducts('ij=-ji', i='1o') >>> x = torch.tensor([1.0, 0.0, 0.0]) >>> y = torch.tensor([0.0, 1.0, 0.0]) >>> tp(x, y) + tp(y, x) tensor([0., 0., 0.]) >>> tp = ReducedTensorProducts('ijkl=jikl=ikjl=ijlk', i="1e") >>> tp.irreps_out 1x0e+1x2e+1x4e >>> tp = ReducedTensorProducts('ij=ji', i='1o') >>> x, y = torch.randn(2, 3) >>> a = torch.einsum('zij,i,j->z', tp.change_of_basis, x, y) >>> b = tp(x, y) >>> assert torch.allclose(a, b, atol=1e-3, rtol=1e-3) """ # pylint: disable=abstract-method def __init__(self, formula, filter_ir_out=None, filter_ir_mid=None, eps: float = 1e-9, **irreps) -> None: super().__init__() if filter_ir_out is not None: try: filter_ir_out = [Irrep(ir) for ir in filter_ir_out] except ValueError: raise ValueError(f"filter_ir_out (={filter_ir_out}) must be an iterable of e3nn.o3.Irrep") if filter_ir_mid is not None: try: filter_ir_mid = [Irrep(ir) for ir in filter_ir_mid] except ValueError: raise ValueError(f"filter_ir_mid (={filter_ir_mid}) must be an iterable of e3nn.o3.Irrep") f0, formulas = germinate_formulas(formula) irreps = {i: Irreps(irs) for i, irs in irreps.items()} for i in irreps: if len(i) != 1: raise TypeError(f"got an unexpected keyword argument '{i}'") for _sign, p in formulas: f = "".join(f0[i] for i in p) for i, j in zip(f0, f): if i in irreps and j in irreps and irreps[i] != irreps[j]: raise RuntimeError(f"irreps of {i} and {j} should be the same") if i in irreps: irreps[j] = irreps[i] if j in irreps: irreps[i] = irreps[j] for i in f0: if i not in irreps: raise RuntimeError(f"index {i} has no irreps associated to it") for i in irreps: if i not in f0: raise RuntimeError(f"index {i} has an irreps but does not appear in the fomula") base_perm, _ = reduce_permutation(f0, formulas, dtype=torch.float64, **{i: irs.dim for i, irs in irreps.items()}) Ps = collections.defaultdict(list) for ir, path, base_o3 in _wigner_nj(*[irreps[i] for i in f0], filter_ir_mid=filter_ir_mid, dtype=torch.float64): if filter_ir_out is None or ir in filter_ir_out: # P = base_o3.flatten(1) @ base_perm.flatten(1).T # if P.norm() > eps: # if this Irrep is present in the premutation basis we keep it Ps[ir].append((path, base_o3)) outputs = [] change_of_basis = [] irreps_out = [] P = base_perm.flatten(1) # [permutation basis, input basis] (a,omega) PP = P @ P.T # (a,a) for ir in Ps: mul = len(Ps[ir]) paths = [path for path, _ in Ps[ir]] base_o3 = torch.stack([R for _, R in Ps[ir]]) R = base_o3.flatten(2) # [multiplicity, ir, input basis] (u,j,omega) proj_s = [] # list of projectors into vector space for j in range(ir.dim): # Solve X @ R[:, j] = Y @ P, but keep only X RR = R[:, j] @ R[:, j].T # (u,u) RP = R[:, j] @ P.T # (u,a) prob = torch.cat([torch.cat([RR, -RP], dim=1), torch.cat([-RP.T, PP], dim=1)], dim=0) eigenvalues, eigenvectors = torch.linalg.eigh(prob) X = eigenvectors[:, eigenvalues < eps][:mul].T # [solutions, multiplicity] proj_s.append(X.T @ X) break # do not check all components because too time expensive for p in proj_s: assert (p - proj_s[0]).abs().max() < eps, f"found different solutions for irrep {ir}" # look for an X such that X.T @ X = Projector X, _ = orthonormalize(proj_s[0], eps) for x in X: C = torch.einsum("u,ui...->i...", x, base_o3) correction = (ir.dim / C.pow(2).sum()) ** 0.5 C = correction * C outputs.append([((correction * v).item(), p) for v, p in zip(x, paths) if v.abs() > eps]) change_of_basis.append(C) irreps_out.append((1, ir)) dtype, _ = explicit_default_types(None, None) self.register_buffer("change_of_basis", torch.cat(change_of_basis).to(dtype=dtype)) tps = set() for vp_list in outputs: for v, p in vp_list: for op in _get_ops(p): tps.add(op) root = torch.nn.Module() tps = list(tps) for i, op in enumerate(tps): tp = TensorProduct(op[0], op[1], op[2], [(0, 0, 0, "uuu", False)]) setattr(root, f"tp{i}", tp) graph = fx.Graph() tracer = torch.fx.proxy.GraphAppendingTracer(graph) inputs = [fx.Proxy(graph.placeholder(f"x{i}", torch.Tensor), tracer) for i in f0] self.irreps_in = [irreps[i] for i in f0] self.irreps_out = Irreps(irreps_out).simplify() values = {} def evaluate(path): if path in values: return values[path] if isinstance(path, _INPUT): out = inputs[path.tensor] if (path.start, path.stop) != (0, self.irreps_in[path.tensor].dim): out = out.narrow(-1, path.start, path.stop - path.start) if isinstance(path, _TP): x1 = evaluate(path.args[0]).node x2 = evaluate(path.args[1]).node out = fx.Proxy(graph.call_module(f"tp{tps.index(path.op)}", (x1, x2)), tracer) values[path] = out return out outs = [] for vp_list in outputs: v, p = vp_list[0] out = evaluate(p) if abs(v - 1.0) > eps: out = v * out for v, p in vp_list[1:]: t = evaluate(p) if abs(v - 1.0) > eps: t = v * t out = out + t outs.append(out) out = torch.cat(outs, dim=-1) graph.output(out.node) graphmod = fx.GraphModule(root, graph, "main") self._codegen_register({"main": graphmod}) def __repr__(self) -> str: return ( f"ReducedTensorProducts(\n" f" in: {' times '.join(map(repr, self.irreps_in))}\n" f" out: {self.irreps_out}\n" ")" ) def forward(self, *xs): return self.main(*xs) e3nn-0.6.0/e3nn/o3/_rotation.py000066400000000000000000000410671514371756200161370ustar00rootroot00000000000000import math import torch # matrix def rand_matrix(*shape, requires_grad: bool = False, dtype=None, device=None): r"""random rotation matrix Parameters ---------- *shape : int Returns ------- `torch.Tensor` tensor of shape :math:`(\mathrm{shape}, 3, 3)` """ R = angles_to_matrix(*rand_angles(*shape, dtype=dtype, device=device)) return R.detach().requires_grad_(requires_grad) # angles def identity_angles(*shape, requires_grad: bool = False, dtype=None, device=None): r"""angles of the identity rotation Parameters ---------- *shape : int Returns ------- alpha : `torch.Tensor` tensor of shape :math:`(\mathrm{shape})` beta : `torch.Tensor` tensor of shape :math:`(\mathrm{shape})` gamma : `torch.Tensor` tensor of shape :math:`(\mathrm{shape})` """ abc = torch.zeros(3, *shape, dtype=dtype, device=device) return abc[0].requires_grad_(requires_grad), abc[1].requires_grad_(requires_grad), abc[2].requires_grad_(requires_grad) def rand_angles(*shape, requires_grad: bool = False, dtype=None, device=None): r"""random rotation angles Parameters ---------- *shape : int Returns ------- alpha : `torch.Tensor` tensor of shape :math:`(\mathrm{shape})` beta : `torch.Tensor` tensor of shape :math:`(\mathrm{shape})` gamma : `torch.Tensor` tensor of shape :math:`(\mathrm{shape})` """ alpha, gamma = 2 * math.pi * torch.rand(2, *shape, dtype=dtype, device=device) beta = torch.rand(shape, dtype=dtype, device=device).mul(2).sub(1).acos() alpha = alpha.detach().requires_grad_(requires_grad) beta = beta.detach().requires_grad_(requires_grad) gamma = gamma.detach().requires_grad_(requires_grad) return alpha, beta, gamma def compose_angles(a1, b1, c1, a2, b2, c2): r"""compose angles Computes :math:`(a, b, c)` such that :math:`R(a, b, c) = R(a_1, b_1, c_1) \circ R(a_2, b_2, c_2)` Parameters ---------- a1 : `torch.Tensor` tensor of shape :math:`(...)`, (applied second) b1 : `torch.Tensor` tensor of shape :math:`(...)`, (applied second) c1 : `torch.Tensor` tensor of shape :math:`(...)`, (applied second) a2 : `torch.Tensor` tensor of shape :math:`(...)`, (applied first) b2 : `torch.Tensor` tensor of shape :math:`(...)`, (applied first) c2 : `torch.Tensor` tensor of shape :math:`(...)`, (applied first) Returns ------- alpha : `torch.Tensor` tensor of shape :math:`(...)` beta : `torch.Tensor` tensor of shape :math:`(...)` gamma : `torch.Tensor` tensor of shape :math:`(...)` """ a1, b1, c1, a2, b2, c2 = torch.broadcast_tensors(a1, b1, c1, a2, b2, c2) return matrix_to_angles(angles_to_matrix(a1, b1, c1) @ angles_to_matrix(a2, b2, c2)) def inverse_angles(a, b, c): r"""angles of the inverse rotation Parameters ---------- a : `torch.Tensor` tensor of shape :math:`(...)` b : `torch.Tensor` tensor of shape :math:`(...)` c : `torch.Tensor` tensor of shape :math:`(...)` Returns ------- alpha : `torch.Tensor` tensor of shape :math:`(...)` beta : `torch.Tensor` tensor of shape :math:`(...)` gamma : `torch.Tensor` tensor of shape :math:`(...)` """ return -c, -b, -a # quaternions def identity_quaternion(*shape, requires_grad: bool = False, dtype=None, device=None): r"""quaternion of identity rotation Parameters ---------- *shape : int Returns ------- `torch.Tensor` tensor of shape :math:`(\mathrm{shape}, 4)` """ q = torch.zeros(*shape, 4, dtype=dtype, device=device) q[..., 0] = 1 # or -1... q = q.detach().requires_grad_(requires_grad) return q def rand_quaternion(*shape, requires_grad: bool = False, dtype=None, device=None): r"""generate random quaternion Parameters ---------- *shape : int Returns ------- `torch.Tensor` tensor of shape :math:`(\mathrm{shape}, 4)` """ q = angles_to_quaternion(*rand_angles(*shape, dtype=dtype, device=device)) q = q.detach().requires_grad_(requires_grad) return q def compose_quaternion(q1, q2) -> torch.Tensor: r"""compose two quaternions: :math:`q_1 \circ q_2` Parameters ---------- q1 : `torch.Tensor` tensor of shape :math:`(..., 4)`, (applied second) q2 : `torch.Tensor` tensor of shape :math:`(..., 4)`, (applied first) Returns ------- `torch.Tensor` tensor of shape :math:`(..., 4)` """ q1, q2 = torch.broadcast_tensors(q1, q2) return torch.stack( [ q1[..., 0] * q2[..., 0] - q1[..., 1] * q2[..., 1] - q1[..., 2] * q2[..., 2] - q1[..., 3] * q2[..., 3], q1[..., 1] * q2[..., 0] + q1[..., 0] * q2[..., 1] + q1[..., 2] * q2[..., 3] - q1[..., 3] * q2[..., 2], q1[..., 0] * q2[..., 2] - q1[..., 1] * q2[..., 3] + q1[..., 2] * q2[..., 0] + q1[..., 3] * q2[..., 1], q1[..., 0] * q2[..., 3] + q1[..., 1] * q2[..., 2] - q1[..., 2] * q2[..., 1] + q1[..., 3] * q2[..., 0], ], dim=-1, ) def inverse_quaternion(q): r"""inverse of a quaternion Works only for unit quaternions. Parameters ---------- q : `torch.Tensor` tensor of shape :math:`(..., 4)` Returns ------- `torch.Tensor` tensor of shape :math:`(..., 4)` """ q = q.clone() q[..., 1:].neg_() return q # axis-angle def rand_axis_angle(*shape, requires_grad: bool = False, dtype=None, device=None): r"""generate random rotation as axis-angle Parameters ---------- *shape : int Returns ------- axis : `torch.Tensor` tensor of shape :math:`(\mathrm{shape}, 3)` angle : `torch.Tensor` tensor of shape :math:`(\mathrm{shape})` """ axis, angle = angles_to_axis_angle(*rand_angles(*shape, dtype=dtype, device=device)) axis = axis.detach().requires_grad_(requires_grad) angle = angle.detach().requires_grad_(requires_grad) return axis, angle def compose_axis_angle(axis1, angle1, axis2, angle2): r"""compose :math:`(\vec x_1, \alpha_1)` with :math:`(\vec x_2, \alpha_2)` Parameters ---------- axis1 : `torch.Tensor` tensor of shape :math:`(..., 3)`, (applied second) angle1 : `torch.Tensor` tensor of shape :math:`(...)`, (applied second) axis2 : `torch.Tensor` tensor of shape :math:`(..., 3)`, (applied first) angle2 : `torch.Tensor` tensor of shape :math:`(...)`, (applied first) Returns ------- axis : `torch.Tensor` tensor of shape :math:`(..., 3)` angle : `torch.Tensor` tensor of shape :math:`(...)` """ return quaternion_to_axis_angle( compose_quaternion(axis_angle_to_quaternion(axis1, angle1), axis_angle_to_quaternion(axis2, angle2)) ) # conversions def matrix_x(angle: torch.Tensor) -> torch.Tensor: r"""matrix of rotation around X axis Parameters ---------- angle : `torch.Tensor` tensor of any shape :math:`(...)` Returns ------- `torch.Tensor` matrices of shape :math:`(..., 3, 3)` """ c = angle.cos() s = angle.sin() o = torch.ones_like(angle) z = torch.zeros_like(angle) return torch.stack( [ torch.stack([o, z, z], dim=-1), torch.stack([z, c, -s], dim=-1), torch.stack([z, s, c], dim=-1), ], dim=-2, ) def matrix_y(angle: torch.Tensor) -> torch.Tensor: r"""matrix of rotation around Y axis Parameters ---------- angle : `torch.Tensor` tensor of any shape :math:`(...)` Returns ------- `torch.Tensor` matrices of shape :math:`(..., 3, 3)` """ c = angle.cos() s = angle.sin() o = torch.ones_like(angle) z = torch.zeros_like(angle) return torch.stack( [ torch.stack([c, z, s], dim=-1), torch.stack([z, o, z], dim=-1), torch.stack([-s, z, c], dim=-1), ], dim=-2, ) def matrix_z(angle: torch.Tensor) -> torch.Tensor: r"""matrix of rotation around Z axis Parameters ---------- angle : `torch.Tensor` tensor of any shape :math:`(...)` Returns ------- `torch.Tensor` matrices of shape :math:`(..., 3, 3)` """ c = angle.cos() s = angle.sin() o = torch.ones_like(angle) z = torch.zeros_like(angle) return torch.stack( [torch.stack([c, -s, z], dim=-1), torch.stack([s, c, z], dim=-1), torch.stack([z, z, o], dim=-1)], dim=-2 ) def angles_to_matrix(alpha, beta, gamma) -> torch.Tensor: r"""conversion from angles to matrix Parameters ---------- alpha : `torch.Tensor` tensor of shape :math:`(...)` beta : `torch.Tensor` tensor of shape :math:`(...)` gamma : `torch.Tensor` tensor of shape :math:`(...)` Returns ------- `torch.Tensor` matrices of shape :math:`(..., 3, 3)` """ alpha, beta, gamma = torch.broadcast_tensors(alpha, beta, gamma) return matrix_y(alpha) @ matrix_x(beta) @ matrix_y(gamma) def matrix_to_angles(R): r"""conversion from matrix to angles Parameters ---------- R : `torch.Tensor` matrices of shape :math:`(..., 3, 3)` Returns ------- alpha : `torch.Tensor` tensor of shape :math:`(...)` beta : `torch.Tensor` tensor of shape :math:`(...)` gamma : `torch.Tensor` tensor of shape :math:`(...)` """ assert torch.allclose(torch.det(R), R.new_tensor(1)) x = R @ R.new_tensor([0.0, 1.0, 0.0]) a, b = xyz_to_angles(x) R = angles_to_matrix(a, b, torch.zeros_like(a)).transpose(-1, -2) @ R c = torch.atan2(R[..., 0, 2], R[..., 0, 0]) return a, b, c def angles_to_quaternion(alpha, beta, gamma) -> torch.Tensor: r"""conversion from angles to quaternion Parameters ---------- alpha : `torch.Tensor` tensor of shape :math:`(...)` beta : `torch.Tensor` tensor of shape :math:`(...)` gamma : `torch.Tensor` tensor of shape :math:`(...)` Returns ------- `torch.Tensor` matrices of shape :math:`(..., 4)` """ alpha, beta, gamma = torch.broadcast_tensors(alpha, beta, gamma) qa = axis_angle_to_quaternion(alpha.new_tensor([0.0, 1.0, 0.0]), alpha) qb = axis_angle_to_quaternion(beta.new_tensor([1.0, 0.0, 0.0]), beta) qc = axis_angle_to_quaternion(gamma.new_tensor([0.0, 1.0, 0.0]), gamma) return compose_quaternion(qa, compose_quaternion(qb, qc)) def matrix_to_quaternion(R) -> torch.Tensor: r"""conversion from matrix :math:`R` to quaternion :math:`q` Parameters ---------- R : `torch.Tensor` tensor of shape :math:`(..., 3, 3)` Returns ------- `torch.Tensor` tensor of shape :math:`(..., 4)` """ return axis_angle_to_quaternion(*matrix_to_axis_angle(R)) def axis_angle_to_quaternion(xyz, angle) -> torch.Tensor: r"""convertion from axis-angle to quaternion Parameters ---------- xyz : `torch.Tensor` tensor of shape :math:`(..., 3)` angle : `torch.Tensor` tensor of shape :math:`(...)` Returns ------- `torch.Tensor` tensor of shape :math:`(..., 4)` """ xyz, angle = torch.broadcast_tensors(xyz, angle[..., None]) xyz = torch.nn.functional.normalize(xyz, dim=-1) c = torch.cos(angle[..., :1] / 2) s = torch.sin(angle / 2) return torch.cat([c, xyz * s], dim=-1) def quaternion_to_axis_angle(q): r"""convertion from quaternion to axis-angle Parameters ---------- q : `torch.Tensor` tensor of shape :math:`(..., 4)` Returns ------- axis : `torch.Tensor` tensor of shape :math:`(..., 3)` angle : `torch.Tensor` tensor of shape :math:`(...)` """ angle = 2 * torch.acos(q[..., 0].clamp(-1, 1)) axis = torch.nn.functional.normalize(q[..., 1:], dim=-1) return axis, angle def matrix_to_axis_angle(R): r"""conversion from matrix to axis-angle Parameters ---------- R : `torch.Tensor` tensor of shape :math:`(..., 3, 3)` Returns ------- axis : `torch.Tensor` tensor of shape :math:`(..., 3)` angle : `torch.Tensor` tensor of shape :math:`(...)` """ assert torch.allclose(torch.det(R), R.new_tensor(1)) tr = R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2] angle = torch.acos(tr.sub(1).div(2).clamp(-1, 1)) axis = torch.stack( [ R[..., 2, 1] - R[..., 1, 2], R[..., 0, 2] - R[..., 2, 0], R[..., 1, 0] - R[..., 0, 1], ], dim=-1, ) axis = torch.nn.functional.normalize(axis, dim=-1) return axis, angle def angles_to_axis_angle(alpha, beta, gamma): r"""conversion from angles to axis-angle Parameters ---------- alpha : `torch.Tensor` tensor of shape :math:`(...)` beta : `torch.Tensor` tensor of shape :math:`(...)` gamma : `torch.Tensor` tensor of shape :math:`(...)` Returns ------- axis : `torch.Tensor` tensor of shape :math:`(..., 3)` angle : `torch.Tensor` tensor of shape :math:`(...)` """ return matrix_to_axis_angle(angles_to_matrix(alpha, beta, gamma)) def axis_angle_to_matrix(axis, angle) -> torch.Tensor: r"""conversion from axis-angle to matrix Parameters ---------- axis : `torch.Tensor` tensor of shape :math:`(..., 3)` angle : `torch.Tensor` tensor of shape :math:`(...)` Returns ------- `torch.Tensor` tensor of shape :math:`(..., 3, 3)` """ axis, angle = torch.broadcast_tensors(axis, angle[..., None]) alpha, beta = xyz_to_angles(axis) R = angles_to_matrix(alpha, beta, torch.zeros_like(beta)) Ry = matrix_y(angle[..., 0]) return R @ Ry @ R.transpose(-2, -1) def quaternion_to_matrix(q) -> torch.Tensor: r"""convertion from quaternion to matrix Parameters ---------- q : `torch.Tensor` tensor of shape :math:`(..., 4)` Returns ------- `torch.Tensor` tensor of shape :math:`(..., 3, 3)` """ return axis_angle_to_matrix(*quaternion_to_axis_angle(q)) def quaternion_to_angles(q): r"""convertion from quaternion to angles Parameters ---------- q : `torch.Tensor` tensor of shape :math:`(..., 4)` Returns ------- alpha : `torch.Tensor` tensor of shape :math:`(...)` beta : `torch.Tensor` tensor of shape :math:`(...)` gamma : `torch.Tensor` tensor of shape :math:`(...)` """ return matrix_to_angles(quaternion_to_matrix(q)) def axis_angle_to_angles(axis, angle): r"""convertion from axis-angle to angles Parameters ---------- axis : `torch.Tensor` tensor of shape :math:`(..., 3)` angle : `torch.Tensor` tensor of shape :math:`(...)` Returns ------- alpha : `torch.Tensor` tensor of shape :math:`(...)` beta : `torch.Tensor` tensor of shape :math:`(...)` gamma : `torch.Tensor` tensor of shape :math:`(...)` """ return matrix_to_angles(axis_angle_to_matrix(axis, angle)) # point on the sphere def angles_to_xyz(alpha, beta) -> torch.Tensor: r"""convert :math:`(\alpha, \beta)` into a point :math:`(x, y, z)` on the sphere Parameters ---------- alpha : `torch.Tensor` tensor of shape :math:`(...)` beta : `torch.Tensor` tensor of shape :math:`(...)` Returns ------- `torch.Tensor` tensor of shape :math:`(..., 3)` Examples -------- >>> angles_to_xyz(torch.tensor(1.7), torch.tensor(0.0)).abs() tensor([0., 1., 0.]) """ alpha, beta = torch.broadcast_tensors(alpha, beta) x = torch.sin(beta) * torch.sin(alpha) y = torch.cos(beta) z = torch.sin(beta) * torch.cos(alpha) return torch.stack([x, y, z], dim=-1) def xyz_to_angles(xyz): r"""convert a point :math:`\vec r = (x, y, z)` on the sphere into angles :math:`(\alpha, \beta)` .. math:: \vec r = R(\alpha, \beta, 0) \vec e_z Parameters ---------- xyz : `torch.Tensor` tensor of shape :math:`(..., 3)` Returns ------- alpha : `torch.Tensor` tensor of shape :math:`(...)` beta : `torch.Tensor` tensor of shape :math:`(...)` """ xyz = torch.nn.functional.normalize(xyz, p=2.0, dim=-1) # forward 0's instead of nan for zero-radius xyz = xyz.clamp(-1, 1) beta = torch.acos(xyz[..., 1]) alpha = torch.atan2(xyz[..., 0], xyz[..., 2]) return alpha, beta e3nn-0.6.0/e3nn/o3/_s2grid.py000066400000000000000000000365741514371756200155010ustar00rootroot00000000000000r"""Transformation between two representations of a signal on the sphere. .. math:: f: S^2 \longrightarrow \mathbb{R} is a signal on the sphere. One representation that we like to call "spherical tensor" is .. math:: f(x) = \sum_{l=0}^{l_{\mathit{max}}} F^l \cdot Y^l(x) it is made of :math:`(l_{\mathit{max}} + 1)^2` real numbers represented in the above formula by the familly of vectors :math:`F^l \in \mathbb{R}^{2l+1}`. Another representation is the discretization around the sphere. For this representation we chose a particular grid of size :math:`(N, M)` .. math:: x_{ij} &= (\sin(\beta_i) \sin(\alpha_j), \cos(\beta_i), \sin(\beta_i) \cos(\alpha_j)) \beta_i &= \pi (i + 0.5) / N \alpha_j &= 2 \pi j / M In the code, :math:`N` is called ``res_beta`` and :math:`M` is ``res_alpha``. The discrete representation is therefore .. math:: \{ h_{ij} = f(x_{ij}) \}_{ij} """ import math import torch import torch.fft from e3nn import o3 from e3nn.util import explicit_default_types from e3nn.util.jit import compile_mode def _quadrature_weights(b, dtype=None, device=None): """ function copied from ``lie_learn.spaces.S3`` Compute quadrature weights for the grid used by Kostelec & Rockmore [1, 2]. """ k = torch.arange(b, device=device) w = torch.tensor( [ ( (2.0 / b) * torch.sin(math.pi * (2.0 * j + 1.0) / (4.0 * b)) * ((1.0 / (2 * k + 1)) * torch.sin((2 * j + 1) * (2 * k + 1) * math.pi / (4.0 * b))).sum() ) for j in torch.arange(2 * b, device=device) ], dtype=dtype, device=device, ) w /= 2.0 * ((2 * b) ** 2) return w def s2_grid(res_beta, res_alpha, dtype=None, device=None): r"""grid on the sphere Parameters ---------- res_beta : int :math:`N` res_alpha : int :math:`M` dtype : torch.dtype or None ``dtype`` of the returned tensors. If ``None`` then set to ``torch.get_default_dtype()``. device : torch.device or None ``device`` of the returned tensors. If ``None`` then set to the default device of the current context. Returns ------- betas : `torch.Tensor` tensor of shape ``(res_beta)`` alphas : `torch.Tensor` tensor of shape ``(res_alpha)`` """ dtype, device = explicit_default_types(dtype, device) i = torch.arange(res_beta, dtype=dtype, device=device) betas = (i + 0.5) / res_beta * math.pi i = torch.arange(res_alpha, dtype=dtype, device=device) alphas = i / res_alpha * 2 * math.pi return betas, alphas def spherical_harmonics_s2_grid(lmax, res_beta, res_alpha, dtype=None, device=None): r"""spherical harmonics evaluated on the grid on the sphere .. math:: f(x) = \sum_{l=0}^{l_{\mathit{max}}} F^l \cdot Y^l(x) f(\beta, \alpha) = \sum_{l=0}^{l_{\mathit{max}}} F^l \cdot S^l(\alpha) P^l(\cos(\beta)) Parameters ---------- lmax : int :math:`l_{\mathit{max}}` res_beta : int :math:`N` res_alpha : int :math:`M` Returns ------- betas : `torch.Tensor` tensor of shape ``(res_beta)`` alphas : `torch.Tensor` tensor of shape ``(res_alpha)`` shb : `torch.Tensor` tensor of shape ``(res_beta, (lmax + 1)**2)`` sha : `torch.Tensor` tensor of shape ``(res_alpha, 2 lmax + 1)`` """ betas, alphas = s2_grid(res_beta, res_alpha, dtype=dtype, device=device) shb = o3.Legendre(list(range(lmax + 1)))(betas.cos(), betas.sin().abs()) # [b, l * m] sha = o3.spherical_harmonics_alpha(lmax, alphas) # [a, m] return betas, alphas, shb, sha def _complete_lmax_res(lmax, res_beta, res_alpha): """ try to use FFT i.e. 2 * lmax + 1 == res_alpha """ if res_beta is None: if lmax is not None: res_beta = 2 * (lmax + 1) # minimum req. to go on sphere and back elif res_alpha is not None: res_beta = 2 * ((res_alpha + 1) // 2) else: raise ValueError("All the entries are None") if res_alpha is None: if lmax is not None: if res_beta is not None: res_alpha = max(2 * lmax + 1, res_beta - 1) else: res_alpha = 2 * lmax + 1 # minimum req. to go on sphere and back elif res_beta is not None: res_alpha = res_beta - 1 if lmax is None: lmax = min(res_beta // 2 - 1, (res_alpha - 1) // 2) # maximum possible to go on sphere and back # see tests -------------------------------^ assert res_beta % 2 == 0 assert lmax + 1 <= res_beta // 2 return lmax, res_beta, res_alpha def _expand_matrix(ls, like=None, dtype=None, device=None): """ convertion matrix between a flatten vector (L, m) like that (0, 0) (1, -1) (1, 0) (1, 1) (2, -2) (2, -1) (2, 0) (2, 1) (2, 2) and a bidimensional matrix representation like that (0, 0) (1, -1) (1, 0) (1, 1) (2, -2) (2, -1) (2, 0) (2, 1) (2, 2) :return: tensor [l, m, l * m] """ lmax = max(ls) if like is None: m = torch.zeros(len(ls), 2 * lmax + 1, sum(2 * l + 1 for l in ls), dtype=dtype, device=device) else: m = like.new_zeros((len(ls), 2 * lmax + 1, sum(2 * l + 1 for l in ls)), dtype=dtype, device=device) i = 0 for j, l in enumerate(ls): m[j, lmax - l : lmax + l + 1, i : i + 2 * l + 1] = torch.eye(2 * l + 1, dtype=dtype, device=device) i += 2 * l + 1 return m def rfft(x, l) -> torch.Tensor: r"""Real fourier transform Parameters ---------- x : `torch.Tensor` tensor of shape ``(..., 2 l + 1)`` res : int output resolution, has to be an odd number Returns ------- `torch.Tensor` tensor of shape ``(..., res)`` Examples -------- >>> lmax = 8 >>> res = 101 >>> _betas, _alphas, _shb, sha = spherical_harmonics_s2_grid(lmax, res, res) >>> x = torch.randn(res) >>> (rfft(x, lmax) - x @ sha).abs().max().item() < 1e-4 True """ *size, res = x.shape x = x.reshape(-1, res) x = torch.fft.rfft(x, dim=1) x = torch.cat( [ x[:, 1 : l + 1].imag.flip(1).mul(-math.sqrt(2)), x[:, :1].real, x[:, 1 : l + 1].real.mul(math.sqrt(2)), ], dim=1, ) return x.reshape(*size, 2 * l + 1) def irfft(x, res): r"""Inverse of the real fourier transform Parameters ---------- x : `torch.Tensor` tensor of shape ``(..., 2 l + 1)`` res : int output resolution, has to be an odd number Returns ------- `torch.Tensor` positions on the sphere, tensor of shape ``(..., res, 3)`` Examples -------- >>> lmax = 8 >>> res = 101 >>> _betas, _alphas, _shb, sha = spherical_harmonics_s2_grid(lmax, res, res) >>> x = torch.randn(2 * lmax + 1) >>> (irfft(x, res) - sha @ x).abs().max().item() < 1e-4 True """ assert res % 2 == 1 *size, sm = x.shape x = x.reshape(-1, sm) x = torch.cat( [ x.new_zeros((x.shape[0], (res - sm) // 2)), x, x.new_zeros((x.shape[0], (res - sm) // 2)), ], dim=-1, ) assert x.shape[1] == res l = res // 2 x = torch.complex( torch.cat([x[:, l : l + 1], x[:, l + 1 :].div(math.sqrt(2))], dim=1), torch.cat( [ torch.zeros_like(x[:, :1]), x[:, :l].flip(-1).div(-math.sqrt(2)), ], dim=1, ), ) x = torch.fft.irfft(x, n=res, dim=1) * res return x.reshape(*size, res) @compile_mode("trace") class ToS2Grid(torch.nn.Module): r"""Transform spherical tensor into signal on the sphere The inverse transformation of `FromS2Grid` Parameters ---------- lmax : int res : int, tuple of int resolution in ``beta`` and in ``alpha`` normalization : {'norm', 'component', 'integral'} dtype : torch.dtype or None, optional device : torch.device or None, optional Examples -------- >>> m = ToS2Grid(6, (100, 101)) >>> x = torch.randn(3, 49) >>> m(x).shape torch.Size([3, 100, 101]) `ToS2Grid` and `FromS2Grid` are inverse of each other >>> m = ToS2Grid(6, (100, 101)) >>> k = FromS2Grid((100, 101), 6) >>> x = torch.randn(3, 49) >>> y = k(m(x)) >>> (x - y).abs().max().item() < 1e-4 True Attributes ---------- grid : `torch.Tensor` positions on the sphere, tensor of shape ``(res_beta, res_alpha, 3)`` """ def __init__(self, lmax=None, res=None, normalization: str = "component", dtype=None, device=None) -> None: super().__init__() assert normalization in ["norm", "component", "integral"] or torch.is_tensor( normalization ), "normalization needs to be 'norm', 'component' or 'integral'" if isinstance(res, int) or res is None: lmax, res_beta, res_alpha = _complete_lmax_res(lmax, res, None) else: lmax, res_beta, res_alpha = _complete_lmax_res(lmax, *res) betas, alphas, shb, sha = spherical_harmonics_s2_grid(lmax, res_beta, res_alpha, dtype=dtype, device=device) n = None if normalization == "component": # normalize such that all l has the same variance on the sphere # given that all componant has mean 0 and variance 1 n = ( math.sqrt(4 * math.pi) * betas.new_tensor([1 / math.sqrt(2 * l + 1) for l in range(lmax + 1)]) / math.sqrt(lmax + 1) ) if normalization == "norm": # normalize such that all l has the same variance on the sphere # given that all componant has mean 0 and variance 1/(2L+1) n = math.sqrt(4 * math.pi) * betas.new_ones(lmax + 1) / math.sqrt(lmax + 1) if normalization == "integral": n = betas.new_ones(lmax + 1) if torch.is_tensor(normalization): n = normalization m = _expand_matrix(range(lmax + 1), dtype=dtype, device=device) # [l, m, i] shb = torch.einsum("lmj,bj,lmi,l->mbi", m, shb, m, n) # [m, b, i] self.lmax, self.res_beta, self.res_alpha = lmax, res_beta, res_alpha self.register_buffer("alphas", alphas) self.register_buffer("betas", betas) self.register_buffer("sha", sha) self.register_buffer("shb", shb) def __repr__(self) -> str: return f"{self.__class__.__name__}(lmax={self.lmax} res={self.res_beta}x{self.res_alpha} (beta x alpha))" @property def grid(self) -> torch.Tensor: beta, alpha = torch.meshgrid(self.betas, self.alphas, indexing="ij") return o3.angles_to_xyz(alpha, beta) def forward(self, x): r"""Evaluate Parameters ---------- x : `torch.Tensor` tensor of shape ``(..., (l+1)^2)`` Returns ------- `torch.Tensor` tensor of shape ``[..., beta, alpha]`` """ size = x.shape[:-1] x = x.reshape(-1, x.shape[-1]) x = torch.einsum("mbi,zi->zbm", self.shb, x) # [batch, beta, m] sa, sm = self.sha.shape if sa >= sm and sa % 2 == 1: x = irfft(x, sa) else: x = torch.einsum("am,zbm->zba", self.sha, x) return x.reshape(*size, *x.shape[1:]) def _make_tracing_inputs(self, n: int): return [{"forward": (torch.randn(self.lmax**2),)} for _ in range(n)] @compile_mode("trace") class FromS2Grid(torch.nn.Module): r"""Transform signal on the sphere into spherical tensor The inverse transformation of `ToS2Grid` Parameters ---------- res : int, tuple of int resolution in ``beta`` and in ``alpha`` lmax : int normalization : {'norm', 'component', 'integral'} lmax_in : int, optional dtype : torch.dtype or None, optional device : torch.device or None, optional Examples -------- >>> m = FromS2Grid((100, 101), 6) >>> x = torch.randn(3, 100, 101) >>> m(x).shape torch.Size([3, 49]) `ToS2Grid` and `FromS2Grid` are inverse of each other >>> m = FromS2Grid((100, 101), 6) >>> k = ToS2Grid(6, (100, 101)) >>> x = torch.randn(3, 100, 101) >>> x = k(m(x)) # remove high frequencies >>> y = k(m(x)) >>> (x - y).abs().max().item() < 1e-4 True Attributes ---------- grid : `torch.Tensor` positions on the sphere, tensor of shape ``(res_beta, res_alpha, 3)`` """ def __init__(self, res=None, lmax=None, normalization: str = "component", lmax_in=None, dtype=None, device=None) -> None: super().__init__() assert normalization in ["norm", "component", "integral"] or torch.is_tensor( normalization ), "normalization needs to be 'norm', 'component' or 'integral'" if isinstance(res, int) or res is None: lmax, res_beta, res_alpha = _complete_lmax_res(lmax, res, None) else: lmax, res_beta, res_alpha = _complete_lmax_res(lmax, *res) if lmax_in is None: lmax_in = lmax betas, alphas, shb, sha = spherical_harmonics_s2_grid(lmax, res_beta, res_alpha, dtype=dtype, device=device) # normalize such that it is the inverse of ToS2Grid n = None if normalization == "component": n = ( math.sqrt(4 * math.pi) * betas.new_tensor([math.sqrt(2 * l + 1) for l in range(lmax + 1)]) * math.sqrt(lmax_in + 1) ) if normalization == "norm": n = math.sqrt(4 * math.pi) * betas.new_ones(lmax + 1) * math.sqrt(lmax_in + 1) if normalization == "integral": n = 4 * math.pi * betas.new_ones(lmax + 1) if torch.is_tensor(normalization): n = normalization m = _expand_matrix(range(lmax + 1), dtype=dtype, device=device) # [l, m, i] assert res_beta % 2 == 0 qw = _quadrature_weights(res_beta // 2, dtype=dtype, device=device) * res_beta**2 / res_alpha # [b] shb = torch.einsum("lmj,bj,lmi,l,b->mbi", m, shb, m, n, qw) # [m, b, i] self.lmax, self.res_beta, self.res_alpha = lmax, res_beta, res_alpha self.register_buffer("alphas", alphas) self.register_buffer("betas", betas) self.register_buffer("sha", sha) self.register_buffer("shb", shb) def __repr__(self) -> str: return f"{self.__class__.__name__}(lmax={self.lmax} res={self.res_beta}x{self.res_alpha} (beta x alpha))" @property def grid(self) -> torch.Tensor: beta, alpha = torch.meshgrid(self.betas, self.alphas, indexing="ij") return o3.angles_to_xyz(alpha, beta) def forward(self, x) -> torch.Tensor: r"""Evaluate Parameters ---------- x : `torch.Tensor` tensor of shape ``[..., beta, alpha]`` Returns ------- `torch.Tensor` tensor of shape ``(..., (l+1)^2)`` """ size = x.shape[:-2] res_beta, res_alpha = x.shape[-2:] x = x.reshape(-1, res_beta, res_alpha) sa, sm = self.sha.shape if sm <= sa and sa % 2 == 1: x = rfft(x, sm // 2) else: x = torch.einsum("am,zba->zbm", self.sha, x) x = torch.einsum("mbi,zbm->zi", self.shb, x) return x.reshape(*size, x.shape[1]) def _make_tracing_inputs(self, n: int): return [{"forward": (torch.randn(self.res_beta, self.res_alpha),)} for _ in range(n)] e3nn-0.6.0/e3nn/o3/_so3grid.py000066400000000000000000000047161514371756200156520ustar00rootroot00000000000000import torch from e3nn.util.jit import compile_mode from ._wigner import wigner_D from ._s2grid import _quadrature_weights, s2_grid def flat_wigner(lmax: int, alpha: torch.Tensor, beta: torch.Tensor, gamma: torch.Tensor) -> torch.Tensor: return torch.cat([(2 * l + 1) ** 0.5 * wigner_D(l, alpha, beta, gamma).flatten(-2) for l in range(lmax + 1)], dim=-1) @compile_mode("script") class SO3Grid(torch.nn.Module): # pylint: disable=abstract-method r"""Apply non linearity on the signal on SO(3) Parameters ---------- lmax : int irreps representation ``[(2 * l + 1, (l, p_val)) for l in [0, ..., lmax]]`` resolution : int SO(3) grid resolution normalization : {'norm', 'component'} aspect_ratio : float default value (2) should be optimal """ def __init__(self, lmax, resolution, *, normalization: str = "component", aspect_ratio: int = 2) -> None: super().__init__() assert normalization == "component" nb = 2 * resolution na = round(2 * aspect_ratio * resolution) b, a = s2_grid(nb, na) self.register_buffer("D", flat_wigner(lmax, a[:, None, None], b[None, :, None], a[None, None, :])) qw = _quadrature_weights(nb // 2) * nb**2 / na**2 self.register_buffer("qw", qw) self.register_buffer("alpha", a) self.register_buffer("beta", b) self.register_buffer("gamma", a) self.res_alpha = na self.res_beta = nb self.res_gamma = na def __repr__(self) -> str: return f"{self.__class__.__name__} ({self.lmax})" def to_grid(self, features) -> torch.Tensor: r"""evaluate Parameters ---------- features : `torch.Tensor` tensor of shape ``(..., self.irreps.dim)`` Returns ------- `torch.Tensor` tensor of shape ``(..., self.res_alpha, self.res_beta, self.res_gamma)`` """ return torch.einsum("...i,abci->...abc", features, self.D) / self.D.shape[-1] ** 0.5 def from_grid(self, features) -> torch.Tensor: r"""evaluate Parameters ---------- features : `torch.Tensor` tensor of shape ``(..., self.res_alpha, self.res_beta, self.res_gamma)`` Returns ------- `torch.Tensor` tensor of shape ``(..., self.irreps.dim)`` """ return torch.einsum("...abc,abci,b->...i", features, self.D, self.qw) * self.D.shape[-1] ** 0.5 e3nn-0.6.0/e3nn/o3/_spherical_harmonics.py000066400000000000000000001675351514371756200203260ustar00rootroot00000000000000r"""Spherical Harmonics as polynomials of x, y, z""" from typing import Union, List, Any import math import torch from e3nn.o3._irreps import Irreps from e3nn import get_optimization_defaults from e3nn.util.jit import compile_mode @compile_mode("script") class SphericalHarmonics(torch.nn.Module): """JITable module version of :meth:`e3nn.o3.spherical_harmonics`. Parameters are identical to :meth:`e3nn.o3.spherical_harmonics`. """ normalize: bool normalization: str _ls_list: List[int] _lmax: int _is_range_lmax: bool _prof_str: str def __init__( self, irreps_out: Union[int, List[int], str, Irreps], normalize: bool, normalization: str = "integral", irreps_in: Any = None, ) -> None: super().__init__() self.normalize = normalize self.normalization = normalization assert normalization in ["integral", "component", "norm"] if isinstance(irreps_out, str): irreps_out = Irreps(irreps_out) if isinstance(irreps_out, Irreps) and irreps_in is None: for mul, (l, p) in irreps_out: if l % 2 == 1 and p == 1: irreps_in = Irreps("1e") if irreps_in is None: irreps_in = Irreps("1o") irreps_in = Irreps(irreps_in) if irreps_in not in (Irreps("1x1o"), Irreps("1x1e")): raise ValueError( f"irreps_in for SphericalHarmonics must be either a vector (`1x1o`) or a pseudovector (`1x1e`), " f"not `{irreps_in}`" ) self.irreps_in = irreps_in input_p = irreps_in[0].ir.p # pylint: disable=no-member if isinstance(irreps_out, Irreps): ls = [] for mul, (l, p) in irreps_out: if p != input_p**l: raise ValueError( f"irreps_out `{irreps_out}` passed to SphericalHarmonics asked for an output of l = {l} with parity " f"p = {p}, which is inconsistent with the input parity {input_p} — the output parity should have been " f"p = {input_p**l}" ) ls.extend([l] * mul) elif isinstance(irreps_out, int): ls = [irreps_out] else: ls = list(irreps_out) irreps_out = Irreps([(1, (l, input_p**l)) for l in ls]).simplify() self.irreps_out = irreps_out self._ls_list = ls self._lmax = max(ls) self._is_range_lmax = ls == list(range(max(ls) + 1)) self._prof_str = f"spherical_harmonics({ls})" _lmax = 12 if self._lmax > _lmax: raise NotImplementedError( f"spherical_harmonics maximum l implemented is {_lmax}, send us an email to ask for more" ) if get_optimization_defaults()["jit_mode"] == "script": self.sph_func = torch.jit.script(_spherical_harmonics) elif get_optimization_defaults()["jit_mode"] == "inductor": self.sph_func = torch.compile(_spherical_harmonics, fullgraph=True) else: self.sph_func = _spherical_harmonics def forward(self, x: torch.Tensor) -> torch.Tensor: # - PROFILER - with torch.autograd.profiler.record_function(self._prof_str): if self.normalize: x = torch.nn.functional.normalize(x, dim=-1) # forward 0's instead of nan for zero-radius sh = self.sph_func(self._lmax, x[..., 0], x[..., 1], x[..., 2]) if not self._is_range_lmax: sh = torch.cat([sh[..., l * l : (l + 1) * (l + 1)] for l in self._ls_list], dim=-1) if self.normalization == "integral": sh.div_(math.sqrt(4 * math.pi)) elif self.normalization == "norm": sh.div_( torch.cat( [math.sqrt(2 * l + 1) * torch.ones(2 * l + 1, dtype=sh.dtype, device=sh.device) for l in self._ls_list] ) ) return sh def spherical_harmonics( l: Union[int, List[int], str, Irreps], x: torch.Tensor, normalize: bool, normalization: str = "integral" ): r"""Spherical harmonics .. image:: https://user-images.githubusercontent.com/333780/79220728-dbe82c00-7e54-11ea-82c7-b3acbd9b2246.gif | Polynomials defined on the 3d space :math:`Y^l: \mathbb{R}^3 \longrightarrow \mathbb{R}^{2l+1}` | Usually restricted on the sphere (with ``normalize=True``) :math:`Y^l: S^2 \longrightarrow \mathbb{R}^{2l+1}` | who satisfies the following properties: * are polynomials of the cartesian coordinates ``x, y, z`` * is equivariant :math:`Y^l(R x) = D^l(R) Y^l(x)` * are orthogonal :math:`\int_{S^2} Y^l_m(x) Y^j_n(x) dx = \text{cste} \; \delta_{lj} \delta_{mn}` The value of the constant depends on the choice of normalization. It obeys the following property: .. math:: Y^{l+1}_i(x) &= \text{cste}(l) \; & C_{ijk} Y^l_j(x) x_k \partial_k Y^{l+1}_i(x) &= \text{cste}(l) \; (l+1) & C_{ijk} Y^l_j(x) Where :math:`C` are the `wigner_3j`. .. note:: This function match with this table of standard real spherical harmonics from Wikipedia_ when ``normalize=True``, ``normalization='integral'`` and is called with the argument in the order ``y,z,x`` (instead of ``x,y,z``). .. _Wikipedia: https://en.wikipedia.org/wiki/Table_of_spherical_harmonics#Real_spherical_harmonics Parameters ---------- l : int or list of int degree of the spherical harmonics. x : `torch.Tensor` tensor :math:`x` of shape ``(..., 3)``. normalize : bool whether to normalize the ``x`` to unit vectors that lie on the sphere before projecting onto the spherical harmonics normalization : {'integral', 'component', 'norm'} normalization of the output tensors --- note that this option is independent of ``normalize``, which controls the processing of the *input*, rather than the output. Valid options: * *component*: :math:`\|Y^l(x)\|^2 = 2l+1, x \in S^2` * *norm*: :math:`\|Y^l(x)\| = 1, x \in S^2`, ``component / sqrt(2l+1)`` * *integral*: :math:`\int_{S^2} Y^l_m(x)^2 dx = 1`, ``component / sqrt(4pi)`` Returns ------- `torch.Tensor` a tensor of shape ``(..., 2l+1)`` .. math:: Y^l(x) Examples -------- >>> spherical_harmonics(0, torch.randn(2, 3), False, normalization='component') tensor([[1.], [1.]]) See Also -------- wigner_D wigner_3j """ sh = SphericalHarmonics(l, normalize, normalization) return sh(x) def _spherical_harmonics(lmax: int, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor) -> torch.Tensor: sh_0_0 = torch.ones_like(x) if lmax == 0: return torch.stack( [ sh_0_0, ], dim=-1, ) sh_1_0 = math.sqrt(3) * x sh_1_1 = math.sqrt(3) * y sh_1_2 = math.sqrt(3) * z if lmax == 1: return torch.stack([sh_0_0, sh_1_0, sh_1_1, sh_1_2], dim=-1) sh_2_0 = math.sqrt(15) * x * z sh_2_1 = math.sqrt(15) * x * y y2 = y.pow(2) x2z2 = x.pow(2) + z.pow(2) sh_2_2 = math.sqrt(5) * (y2 - (1 / 2) * x2z2) sh_2_3 = math.sqrt(15) * y * z sh_2_4 = (1 / 2) * math.sqrt(15) * (z.pow(2) - x.pow(2)) if lmax == 2: return torch.stack([sh_0_0, sh_1_0, sh_1_1, sh_1_2, sh_2_0, sh_2_1, sh_2_2, sh_2_3, sh_2_4], dim=-1) sh_3_0 = (1 / 6) * math.sqrt(42) * (sh_2_0 * z + sh_2_4 * x) sh_3_1 = math.sqrt(7) * sh_2_0 * y sh_3_2 = (1 / 8) * math.sqrt(168) * (4.0 * y2 - x2z2) * x sh_3_3 = (1 / 2) * math.sqrt(7) * y * (2.0 * y2 - 3.0 * x2z2) sh_3_4 = (1 / 8) * math.sqrt(168) * z * (4.0 * y2 - x2z2) sh_3_5 = math.sqrt(7) * sh_2_4 * y sh_3_6 = (1 / 6) * math.sqrt(42) * (sh_2_4 * z - sh_2_0 * x) if lmax == 3: return torch.stack( [ sh_0_0, sh_1_0, sh_1_1, sh_1_2, sh_2_0, sh_2_1, sh_2_2, sh_2_3, sh_2_4, sh_3_0, sh_3_1, sh_3_2, sh_3_3, sh_3_4, sh_3_5, sh_3_6, ], dim=-1, ) sh_4_0 = (3 / 4) * math.sqrt(2) * (sh_3_0 * z + sh_3_6 * x) sh_4_1 = (3 / 4) * sh_3_0 * y + (3 / 8) * math.sqrt(6) * sh_3_1 * z + (3 / 8) * math.sqrt(6) * sh_3_5 * x sh_4_2 = ( -3 / 56 * math.sqrt(14) * sh_3_0 * z + (3 / 14) * math.sqrt(21) * sh_3_1 * y + (3 / 56) * math.sqrt(210) * sh_3_2 * z + (3 / 56) * math.sqrt(210) * sh_3_4 * x + (3 / 56) * math.sqrt(14) * sh_3_6 * x ) sh_4_3 = ( -3 / 56 * math.sqrt(42) * sh_3_1 * z + (3 / 28) * math.sqrt(105) * sh_3_2 * y + (3 / 28) * math.sqrt(70) * sh_3_3 * x + (3 / 56) * math.sqrt(42) * sh_3_5 * x ) sh_4_4 = -3 / 28 * math.sqrt(42) * sh_3_2 * x + (3 / 7) * math.sqrt(7) * sh_3_3 * y - 3 / 28 * math.sqrt(42) * sh_3_4 * z sh_4_5 = ( -3 / 56 * math.sqrt(42) * sh_3_1 * x + (3 / 28) * math.sqrt(70) * sh_3_3 * z + (3 / 28) * math.sqrt(105) * sh_3_4 * y - 3 / 56 * math.sqrt(42) * sh_3_5 * z ) sh_4_6 = ( -3 / 56 * math.sqrt(14) * sh_3_0 * x - 3 / 56 * math.sqrt(210) * sh_3_2 * x + (3 / 56) * math.sqrt(210) * sh_3_4 * z + (3 / 14) * math.sqrt(21) * sh_3_5 * y - 3 / 56 * math.sqrt(14) * sh_3_6 * z ) sh_4_7 = -3 / 8 * math.sqrt(6) * sh_3_1 * x + (3 / 8) * math.sqrt(6) * sh_3_5 * z + (3 / 4) * sh_3_6 * y sh_4_8 = (3 / 4) * math.sqrt(2) * (-sh_3_0 * x + sh_3_6 * z) if lmax == 4: return torch.stack( [ sh_0_0, sh_1_0, sh_1_1, sh_1_2, sh_2_0, sh_2_1, sh_2_2, sh_2_3, sh_2_4, sh_3_0, sh_3_1, sh_3_2, sh_3_3, sh_3_4, sh_3_5, sh_3_6, sh_4_0, sh_4_1, sh_4_2, sh_4_3, sh_4_4, sh_4_5, sh_4_6, sh_4_7, sh_4_8, ], dim=-1, ) sh_5_0 = (1 / 10) * math.sqrt(110) * (sh_4_0 * z + sh_4_8 * x) sh_5_1 = (1 / 5) * math.sqrt(11) * sh_4_0 * y + (1 / 5) * math.sqrt(22) * sh_4_1 * z + (1 / 5) * math.sqrt(22) * sh_4_7 * x sh_5_2 = ( -1 / 30 * math.sqrt(22) * sh_4_0 * z + (4 / 15) * math.sqrt(11) * sh_4_1 * y + (1 / 15) * math.sqrt(154) * sh_4_2 * z + (1 / 15) * math.sqrt(154) * sh_4_6 * x + (1 / 30) * math.sqrt(22) * sh_4_8 * x ) sh_5_3 = ( -1 / 30 * math.sqrt(66) * sh_4_1 * z + (1 / 15) * math.sqrt(231) * sh_4_2 * y + (1 / 30) * math.sqrt(462) * sh_4_3 * z + (1 / 30) * math.sqrt(462) * sh_4_5 * x + (1 / 30) * math.sqrt(66) * sh_4_7 * x ) sh_5_4 = ( -1 / 15 * math.sqrt(33) * sh_4_2 * z + (2 / 15) * math.sqrt(66) * sh_4_3 * y + (1 / 15) * math.sqrt(165) * sh_4_4 * x + (1 / 15) * math.sqrt(33) * sh_4_6 * x ) sh_5_5 = ( -1 / 15 * math.sqrt(110) * sh_4_3 * x + (1 / 3) * math.sqrt(11) * sh_4_4 * y - 1 / 15 * math.sqrt(110) * sh_4_5 * z ) sh_5_6 = ( -1 / 15 * math.sqrt(33) * sh_4_2 * x + (1 / 15) * math.sqrt(165) * sh_4_4 * z + (2 / 15) * math.sqrt(66) * sh_4_5 * y - 1 / 15 * math.sqrt(33) * sh_4_6 * z ) sh_5_7 = ( -1 / 30 * math.sqrt(66) * sh_4_1 * x - 1 / 30 * math.sqrt(462) * sh_4_3 * x + (1 / 30) * math.sqrt(462) * sh_4_5 * z + (1 / 15) * math.sqrt(231) * sh_4_6 * y - 1 / 30 * math.sqrt(66) * sh_4_7 * z ) sh_5_8 = ( -1 / 30 * math.sqrt(22) * sh_4_0 * x - 1 / 15 * math.sqrt(154) * sh_4_2 * x + (1 / 15) * math.sqrt(154) * sh_4_6 * z + (4 / 15) * math.sqrt(11) * sh_4_7 * y - 1 / 30 * math.sqrt(22) * sh_4_8 * z ) sh_5_9 = -1 / 5 * math.sqrt(22) * sh_4_1 * x + (1 / 5) * math.sqrt(22) * sh_4_7 * z + (1 / 5) * math.sqrt(11) * sh_4_8 * y sh_5_10 = (1 / 10) * math.sqrt(110) * (-sh_4_0 * x + sh_4_8 * z) if lmax == 5: return torch.stack( [ sh_0_0, sh_1_0, sh_1_1, sh_1_2, sh_2_0, sh_2_1, sh_2_2, sh_2_3, sh_2_4, sh_3_0, sh_3_1, sh_3_2, sh_3_3, sh_3_4, sh_3_5, sh_3_6, sh_4_0, sh_4_1, sh_4_2, sh_4_3, sh_4_4, sh_4_5, sh_4_6, sh_4_7, sh_4_8, sh_5_0, sh_5_1, sh_5_2, sh_5_3, sh_5_4, sh_5_5, sh_5_6, sh_5_7, sh_5_8, sh_5_9, sh_5_10, ], dim=-1, ) sh_6_0 = (1 / 6) * math.sqrt(39) * (sh_5_0 * z + sh_5_10 * x) sh_6_1 = ( (1 / 6) * math.sqrt(13) * sh_5_0 * y + (1 / 12) * math.sqrt(130) * sh_5_1 * z + (1 / 12) * math.sqrt(130) * sh_5_9 * x ) sh_6_2 = ( -1 / 132 * math.sqrt(286) * sh_5_0 * z + (1 / 33) * math.sqrt(715) * sh_5_1 * y + (1 / 132) * math.sqrt(286) * sh_5_10 * x + (1 / 44) * math.sqrt(1430) * sh_5_2 * z + (1 / 44) * math.sqrt(1430) * sh_5_8 * x ) sh_6_3 = ( -1 / 132 * math.sqrt(858) * sh_5_1 * z + (1 / 22) * math.sqrt(429) * sh_5_2 * y + (1 / 22) * math.sqrt(286) * sh_5_3 * z + (1 / 22) * math.sqrt(286) * sh_5_7 * x + (1 / 132) * math.sqrt(858) * sh_5_9 * x ) sh_6_4 = ( -1 / 66 * math.sqrt(429) * sh_5_2 * z + (2 / 33) * math.sqrt(286) * sh_5_3 * y + (1 / 66) * math.sqrt(2002) * sh_5_4 * z + (1 / 66) * math.sqrt(2002) * sh_5_6 * x + (1 / 66) * math.sqrt(429) * sh_5_8 * x ) sh_6_5 = ( -1 / 66 * math.sqrt(715) * sh_5_3 * z + (1 / 66) * math.sqrt(5005) * sh_5_4 * y + (1 / 66) * math.sqrt(3003) * sh_5_5 * x + (1 / 66) * math.sqrt(715) * sh_5_7 * x ) sh_6_6 = ( -1 / 66 * math.sqrt(2145) * sh_5_4 * x + (1 / 11) * math.sqrt(143) * sh_5_5 * y - 1 / 66 * math.sqrt(2145) * sh_5_6 * z ) sh_6_7 = ( -1 / 66 * math.sqrt(715) * sh_5_3 * x + (1 / 66) * math.sqrt(3003) * sh_5_5 * z + (1 / 66) * math.sqrt(5005) * sh_5_6 * y - 1 / 66 * math.sqrt(715) * sh_5_7 * z ) sh_6_8 = ( -1 / 66 * math.sqrt(429) * sh_5_2 * x - 1 / 66 * math.sqrt(2002) * sh_5_4 * x + (1 / 66) * math.sqrt(2002) * sh_5_6 * z + (2 / 33) * math.sqrt(286) * sh_5_7 * y - 1 / 66 * math.sqrt(429) * sh_5_8 * z ) sh_6_9 = ( -1 / 132 * math.sqrt(858) * sh_5_1 * x - 1 / 22 * math.sqrt(286) * sh_5_3 * x + (1 / 22) * math.sqrt(286) * sh_5_7 * z + (1 / 22) * math.sqrt(429) * sh_5_8 * y - 1 / 132 * math.sqrt(858) * sh_5_9 * z ) sh_6_10 = ( -1 / 132 * math.sqrt(286) * sh_5_0 * x - 1 / 132 * math.sqrt(286) * sh_5_10 * z - 1 / 44 * math.sqrt(1430) * sh_5_2 * x + (1 / 44) * math.sqrt(1430) * sh_5_8 * z + (1 / 33) * math.sqrt(715) * sh_5_9 * y ) sh_6_11 = ( -1 / 12 * math.sqrt(130) * sh_5_1 * x + (1 / 6) * math.sqrt(13) * sh_5_10 * y + (1 / 12) * math.sqrt(130) * sh_5_9 * z ) sh_6_12 = (1 / 6) * math.sqrt(39) * (-sh_5_0 * x + sh_5_10 * z) if lmax == 6: return torch.stack( [ sh_0_0, sh_1_0, sh_1_1, sh_1_2, sh_2_0, sh_2_1, sh_2_2, sh_2_3, sh_2_4, sh_3_0, sh_3_1, sh_3_2, sh_3_3, sh_3_4, sh_3_5, sh_3_6, sh_4_0, sh_4_1, sh_4_2, sh_4_3, sh_4_4, sh_4_5, sh_4_6, sh_4_7, sh_4_8, sh_5_0, sh_5_1, sh_5_2, sh_5_3, sh_5_4, sh_5_5, sh_5_6, sh_5_7, sh_5_8, sh_5_9, sh_5_10, sh_6_0, sh_6_1, sh_6_2, sh_6_3, sh_6_4, sh_6_5, sh_6_6, sh_6_7, sh_6_8, sh_6_9, sh_6_10, sh_6_11, sh_6_12, ], dim=-1, ) sh_7_0 = (1 / 14) * math.sqrt(210) * (sh_6_0 * z + sh_6_12 * x) sh_7_1 = (1 / 7) * math.sqrt(15) * sh_6_0 * y + (3 / 7) * math.sqrt(5) * sh_6_1 * z + (3 / 7) * math.sqrt(5) * sh_6_11 * x sh_7_2 = ( -1 / 182 * math.sqrt(390) * sh_6_0 * z + (6 / 91) * math.sqrt(130) * sh_6_1 * y + (3 / 91) * math.sqrt(715) * sh_6_10 * x + (1 / 182) * math.sqrt(390) * sh_6_12 * x + (3 / 91) * math.sqrt(715) * sh_6_2 * z ) sh_7_3 = ( -3 / 182 * math.sqrt(130) * sh_6_1 * z + (3 / 182) * math.sqrt(130) * sh_6_11 * x + (3 / 91) * math.sqrt(715) * sh_6_2 * y + (5 / 182) * math.sqrt(858) * sh_6_3 * z + (5 / 182) * math.sqrt(858) * sh_6_9 * x ) sh_7_4 = ( (3 / 91) * math.sqrt(65) * sh_6_10 * x - 3 / 91 * math.sqrt(65) * sh_6_2 * z + (10 / 91) * math.sqrt(78) * sh_6_3 * y + (15 / 182) * math.sqrt(78) * sh_6_4 * z + (15 / 182) * math.sqrt(78) * sh_6_8 * x ) sh_7_5 = ( -5 / 91 * math.sqrt(39) * sh_6_3 * z + (15 / 91) * math.sqrt(39) * sh_6_4 * y + (3 / 91) * math.sqrt(390) * sh_6_5 * z + (3 / 91) * math.sqrt(390) * sh_6_7 * x + (5 / 91) * math.sqrt(39) * sh_6_9 * x ) sh_7_6 = ( -15 / 182 * math.sqrt(26) * sh_6_4 * z + (12 / 91) * math.sqrt(65) * sh_6_5 * y + (2 / 91) * math.sqrt(1365) * sh_6_6 * x + (15 / 182) * math.sqrt(26) * sh_6_8 * x ) sh_7_7 = ( -3 / 91 * math.sqrt(455) * sh_6_5 * x + (1 / 13) * math.sqrt(195) * sh_6_6 * y - 3 / 91 * math.sqrt(455) * sh_6_7 * z ) sh_7_8 = ( -15 / 182 * math.sqrt(26) * sh_6_4 * x + (2 / 91) * math.sqrt(1365) * sh_6_6 * z + (12 / 91) * math.sqrt(65) * sh_6_7 * y - 15 / 182 * math.sqrt(26) * sh_6_8 * z ) sh_7_9 = ( -5 / 91 * math.sqrt(39) * sh_6_3 * x - 3 / 91 * math.sqrt(390) * sh_6_5 * x + (3 / 91) * math.sqrt(390) * sh_6_7 * z + (15 / 91) * math.sqrt(39) * sh_6_8 * y - 5 / 91 * math.sqrt(39) * sh_6_9 * z ) sh_7_10 = ( -3 / 91 * math.sqrt(65) * sh_6_10 * z - 3 / 91 * math.sqrt(65) * sh_6_2 * x - 15 / 182 * math.sqrt(78) * sh_6_4 * x + (15 / 182) * math.sqrt(78) * sh_6_8 * z + (10 / 91) * math.sqrt(78) * sh_6_9 * y ) sh_7_11 = ( -3 / 182 * math.sqrt(130) * sh_6_1 * x + (3 / 91) * math.sqrt(715) * sh_6_10 * y - 3 / 182 * math.sqrt(130) * sh_6_11 * z - 5 / 182 * math.sqrt(858) * sh_6_3 * x + (5 / 182) * math.sqrt(858) * sh_6_9 * z ) sh_7_12 = ( -1 / 182 * math.sqrt(390) * sh_6_0 * x + (3 / 91) * math.sqrt(715) * sh_6_10 * z + (6 / 91) * math.sqrt(130) * sh_6_11 * y - 1 / 182 * math.sqrt(390) * sh_6_12 * z - 3 / 91 * math.sqrt(715) * sh_6_2 * x ) sh_7_13 = -3 / 7 * math.sqrt(5) * sh_6_1 * x + (3 / 7) * math.sqrt(5) * sh_6_11 * z + (1 / 7) * math.sqrt(15) * sh_6_12 * y sh_7_14 = (1 / 14) * math.sqrt(210) * (-sh_6_0 * x + sh_6_12 * z) if lmax == 7: return torch.stack( [ sh_0_0, sh_1_0, sh_1_1, sh_1_2, sh_2_0, sh_2_1, sh_2_2, sh_2_3, sh_2_4, sh_3_0, sh_3_1, sh_3_2, sh_3_3, sh_3_4, sh_3_5, sh_3_6, sh_4_0, sh_4_1, sh_4_2, sh_4_3, sh_4_4, sh_4_5, sh_4_6, sh_4_7, sh_4_8, sh_5_0, sh_5_1, sh_5_2, sh_5_3, sh_5_4, sh_5_5, sh_5_6, sh_5_7, sh_5_8, sh_5_9, sh_5_10, sh_6_0, sh_6_1, sh_6_2, sh_6_3, sh_6_4, sh_6_5, sh_6_6, sh_6_7, sh_6_8, sh_6_9, sh_6_10, sh_6_11, sh_6_12, sh_7_0, sh_7_1, sh_7_2, sh_7_3, sh_7_4, sh_7_5, sh_7_6, sh_7_7, sh_7_8, sh_7_9, sh_7_10, sh_7_11, sh_7_12, sh_7_13, sh_7_14, ], dim=-1, ) sh_8_0 = (1 / 4) * math.sqrt(17) * (sh_7_0 * z + sh_7_14 * x) sh_8_1 = ( (1 / 8) * math.sqrt(17) * sh_7_0 * y + (1 / 16) * math.sqrt(238) * sh_7_1 * z + (1 / 16) * math.sqrt(238) * sh_7_13 * x ) sh_8_2 = ( -1 / 240 * math.sqrt(510) * sh_7_0 * z + (1 / 60) * math.sqrt(1785) * sh_7_1 * y + (1 / 240) * math.sqrt(46410) * sh_7_12 * x + (1 / 240) * math.sqrt(510) * sh_7_14 * x + (1 / 240) * math.sqrt(46410) * sh_7_2 * z ) sh_8_3 = ( (1 / 80) * math.sqrt(2) * ( -math.sqrt(85) * sh_7_1 * z + math.sqrt(2210) * sh_7_11 * x + math.sqrt(85) * sh_7_13 * x + math.sqrt(2210) * sh_7_2 * y + math.sqrt(2210) * sh_7_3 * z ) ) sh_8_4 = ( (1 / 40) * math.sqrt(935) * sh_7_10 * x + (1 / 40) * math.sqrt(85) * sh_7_12 * x - 1 / 40 * math.sqrt(85) * sh_7_2 * z + (1 / 10) * math.sqrt(85) * sh_7_3 * y + (1 / 40) * math.sqrt(935) * sh_7_4 * z ) sh_8_5 = ( (1 / 48) * math.sqrt(2) * ( math.sqrt(102) * sh_7_11 * x - math.sqrt(102) * sh_7_3 * z + math.sqrt(1122) * sh_7_4 * y + math.sqrt(561) * sh_7_5 * z + math.sqrt(561) * sh_7_9 * x ) ) sh_8_6 = ( (1 / 16) * math.sqrt(34) * sh_7_10 * x - 1 / 16 * math.sqrt(34) * sh_7_4 * z + (1 / 4) * math.sqrt(17) * sh_7_5 * y + (1 / 16) * math.sqrt(102) * sh_7_6 * z + (1 / 16) * math.sqrt(102) * sh_7_8 * x ) sh_8_7 = ( -1 / 80 * math.sqrt(1190) * sh_7_5 * z + (1 / 40) * math.sqrt(1785) * sh_7_6 * y + (1 / 20) * math.sqrt(255) * sh_7_7 * x + (1 / 80) * math.sqrt(1190) * sh_7_9 * x ) sh_8_8 = ( -1 / 60 * math.sqrt(1785) * sh_7_6 * x + (1 / 15) * math.sqrt(255) * sh_7_7 * y - 1 / 60 * math.sqrt(1785) * sh_7_8 * z ) sh_8_9 = ( -1 / 80 * math.sqrt(1190) * sh_7_5 * x + (1 / 20) * math.sqrt(255) * sh_7_7 * z + (1 / 40) * math.sqrt(1785) * sh_7_8 * y - 1 / 80 * math.sqrt(1190) * sh_7_9 * z ) sh_8_10 = ( -1 / 16 * math.sqrt(34) * sh_7_10 * z - 1 / 16 * math.sqrt(34) * sh_7_4 * x - 1 / 16 * math.sqrt(102) * sh_7_6 * x + (1 / 16) * math.sqrt(102) * sh_7_8 * z + (1 / 4) * math.sqrt(17) * sh_7_9 * y ) sh_8_11 = ( (1 / 48) * math.sqrt(2) * ( math.sqrt(1122) * sh_7_10 * y - math.sqrt(102) * sh_7_11 * z - math.sqrt(102) * sh_7_3 * x - math.sqrt(561) * sh_7_5 * x + math.sqrt(561) * sh_7_9 * z ) ) sh_8_12 = ( (1 / 40) * math.sqrt(935) * sh_7_10 * z + (1 / 10) * math.sqrt(85) * sh_7_11 * y - 1 / 40 * math.sqrt(85) * sh_7_12 * z - 1 / 40 * math.sqrt(85) * sh_7_2 * x - 1 / 40 * math.sqrt(935) * sh_7_4 * x ) sh_8_13 = ( (1 / 80) * math.sqrt(2) * ( -math.sqrt(85) * sh_7_1 * x + math.sqrt(2210) * sh_7_11 * z + math.sqrt(2210) * sh_7_12 * y - math.sqrt(85) * sh_7_13 * z - math.sqrt(2210) * sh_7_3 * x ) ) sh_8_14 = ( -1 / 240 * math.sqrt(510) * sh_7_0 * x + (1 / 240) * math.sqrt(46410) * sh_7_12 * z + (1 / 60) * math.sqrt(1785) * sh_7_13 * y - 1 / 240 * math.sqrt(510) * sh_7_14 * z - 1 / 240 * math.sqrt(46410) * sh_7_2 * x ) sh_8_15 = ( -1 / 16 * math.sqrt(238) * sh_7_1 * x + (1 / 16) * math.sqrt(238) * sh_7_13 * z + (1 / 8) * math.sqrt(17) * sh_7_14 * y ) sh_8_16 = (1 / 4) * math.sqrt(17) * (-sh_7_0 * x + sh_7_14 * z) if lmax == 8: return torch.stack( [ sh_0_0, sh_1_0, sh_1_1, sh_1_2, sh_2_0, sh_2_1, sh_2_2, sh_2_3, sh_2_4, sh_3_0, sh_3_1, sh_3_2, sh_3_3, sh_3_4, sh_3_5, sh_3_6, sh_4_0, sh_4_1, sh_4_2, sh_4_3, sh_4_4, sh_4_5, sh_4_6, sh_4_7, sh_4_8, sh_5_0, sh_5_1, sh_5_2, sh_5_3, sh_5_4, sh_5_5, sh_5_6, sh_5_7, sh_5_8, sh_5_9, sh_5_10, sh_6_0, sh_6_1, sh_6_2, sh_6_3, sh_6_4, sh_6_5, sh_6_6, sh_6_7, sh_6_8, sh_6_9, sh_6_10, sh_6_11, sh_6_12, sh_7_0, sh_7_1, sh_7_2, sh_7_3, sh_7_4, sh_7_5, sh_7_6, sh_7_7, sh_7_8, sh_7_9, sh_7_10, sh_7_11, sh_7_12, sh_7_13, sh_7_14, sh_8_0, sh_8_1, sh_8_2, sh_8_3, sh_8_4, sh_8_5, sh_8_6, sh_8_7, sh_8_8, sh_8_9, sh_8_10, sh_8_11, sh_8_12, sh_8_13, sh_8_14, sh_8_15, sh_8_16, ], dim=-1, ) sh_9_0 = (1 / 6) * math.sqrt(38) * (sh_8_0 * z + sh_8_16 * x) sh_9_1 = (1 / 9) * math.sqrt(19) * (sh_8_0 * y + 2 * sh_8_1 * z + 2 * sh_8_15 * x) sh_9_2 = ( -1 / 306 * math.sqrt(646) * sh_8_0 * z + (4 / 153) * math.sqrt(646) * sh_8_1 * y + (2 / 153) * math.sqrt(4845) * sh_8_14 * x + (1 / 306) * math.sqrt(646) * sh_8_16 * x + (2 / 153) * math.sqrt(4845) * sh_8_2 * z ) sh_9_3 = ( -1 / 306 * math.sqrt(1938) * sh_8_1 * z + (1 / 306) * math.sqrt(67830) * sh_8_13 * x + (1 / 306) * math.sqrt(1938) * sh_8_15 * x + (1 / 51) * math.sqrt(1615) * sh_8_2 * y + (1 / 306) * math.sqrt(67830) * sh_8_3 * z ) sh_9_4 = ( (1 / 306) * math.sqrt(58786) * sh_8_12 * x + (1 / 153) * math.sqrt(969) * sh_8_14 * x - 1 / 153 * math.sqrt(969) * sh_8_2 * z + (2 / 153) * math.sqrt(4522) * sh_8_3 * y + (1 / 306) * math.sqrt(58786) * sh_8_4 * z ) sh_9_5 = ( (1 / 153) * math.sqrt(12597) * sh_8_11 * x + (1 / 153) * math.sqrt(1615) * sh_8_13 * x - 1 / 153 * math.sqrt(1615) * sh_8_3 * z + (1 / 153) * math.sqrt(20995) * sh_8_4 * y + (1 / 153) * math.sqrt(12597) * sh_8_5 * z ) sh_9_6 = ( (1 / 153) * math.sqrt(10659) * sh_8_10 * x + (1 / 306) * math.sqrt(9690) * sh_8_12 * x - 1 / 306 * math.sqrt(9690) * sh_8_4 * z + (2 / 51) * math.sqrt(646) * sh_8_5 * y + (1 / 153) * math.sqrt(10659) * sh_8_6 * z ) sh_9_7 = ( (1 / 306) * math.sqrt(13566) * sh_8_11 * x - 1 / 306 * math.sqrt(13566) * sh_8_5 * z + (1 / 153) * math.sqrt(24871) * sh_8_6 * y + (1 / 306) * math.sqrt(35530) * sh_8_7 * z + (1 / 306) * math.sqrt(35530) * sh_8_9 * x ) sh_9_8 = ( (1 / 153) * math.sqrt(4522) * sh_8_10 * x - 1 / 153 * math.sqrt(4522) * sh_8_6 * z + (4 / 153) * math.sqrt(1615) * sh_8_7 * y + (1 / 51) * math.sqrt(1615) * sh_8_8 * x ) sh_9_9 = (1 / 51) * math.sqrt(323) * (-2 * sh_8_7 * x + 3 * sh_8_8 * y - 2 * sh_8_9 * z) sh_9_10 = ( -1 / 153 * math.sqrt(4522) * sh_8_10 * z - 1 / 153 * math.sqrt(4522) * sh_8_6 * x + (1 / 51) * math.sqrt(1615) * sh_8_8 * z + (4 / 153) * math.sqrt(1615) * sh_8_9 * y ) sh_9_11 = ( (1 / 153) * math.sqrt(24871) * sh_8_10 * y - 1 / 306 * math.sqrt(13566) * sh_8_11 * z - 1 / 306 * math.sqrt(13566) * sh_8_5 * x - 1 / 306 * math.sqrt(35530) * sh_8_7 * x + (1 / 306) * math.sqrt(35530) * sh_8_9 * z ) sh_9_12 = ( (1 / 153) * math.sqrt(10659) * sh_8_10 * z + (2 / 51) * math.sqrt(646) * sh_8_11 * y - 1 / 306 * math.sqrt(9690) * sh_8_12 * z - 1 / 306 * math.sqrt(9690) * sh_8_4 * x - 1 / 153 * math.sqrt(10659) * sh_8_6 * x ) sh_9_13 = ( (1 / 153) * math.sqrt(12597) * sh_8_11 * z + (1 / 153) * math.sqrt(20995) * sh_8_12 * y - 1 / 153 * math.sqrt(1615) * sh_8_13 * z - 1 / 153 * math.sqrt(1615) * sh_8_3 * x - 1 / 153 * math.sqrt(12597) * sh_8_5 * x ) sh_9_14 = ( (1 / 306) * math.sqrt(58786) * sh_8_12 * z + (2 / 153) * math.sqrt(4522) * sh_8_13 * y - 1 / 153 * math.sqrt(969) * sh_8_14 * z - 1 / 153 * math.sqrt(969) * sh_8_2 * x - 1 / 306 * math.sqrt(58786) * sh_8_4 * x ) sh_9_15 = ( -1 / 306 * math.sqrt(1938) * sh_8_1 * x + (1 / 306) * math.sqrt(67830) * sh_8_13 * z + (1 / 51) * math.sqrt(1615) * sh_8_14 * y - 1 / 306 * math.sqrt(1938) * sh_8_15 * z - 1 / 306 * math.sqrt(67830) * sh_8_3 * x ) sh_9_16 = ( -1 / 306 * math.sqrt(646) * sh_8_0 * x + (2 / 153) * math.sqrt(4845) * sh_8_14 * z + (4 / 153) * math.sqrt(646) * sh_8_15 * y - 1 / 306 * math.sqrt(646) * sh_8_16 * z - 2 / 153 * math.sqrt(4845) * sh_8_2 * x ) sh_9_17 = (1 / 9) * math.sqrt(19) * (-2 * sh_8_1 * x + 2 * sh_8_15 * z + sh_8_16 * y) sh_9_18 = (1 / 6) * math.sqrt(38) * (-sh_8_0 * x + sh_8_16 * z) if lmax == 9: return torch.stack( [ sh_0_0, sh_1_0, sh_1_1, sh_1_2, sh_2_0, sh_2_1, sh_2_2, sh_2_3, sh_2_4, sh_3_0, sh_3_1, sh_3_2, sh_3_3, sh_3_4, sh_3_5, sh_3_6, sh_4_0, sh_4_1, sh_4_2, sh_4_3, sh_4_4, sh_4_5, sh_4_6, sh_4_7, sh_4_8, sh_5_0, sh_5_1, sh_5_2, sh_5_3, sh_5_4, sh_5_5, sh_5_6, sh_5_7, sh_5_8, sh_5_9, sh_5_10, sh_6_0, sh_6_1, sh_6_2, sh_6_3, sh_6_4, sh_6_5, sh_6_6, sh_6_7, sh_6_8, sh_6_9, sh_6_10, sh_6_11, sh_6_12, sh_7_0, sh_7_1, sh_7_2, sh_7_3, sh_7_4, sh_7_5, sh_7_6, sh_7_7, sh_7_8, sh_7_9, sh_7_10, sh_7_11, sh_7_12, sh_7_13, sh_7_14, sh_8_0, sh_8_1, sh_8_2, sh_8_3, sh_8_4, sh_8_5, sh_8_6, sh_8_7, sh_8_8, sh_8_9, sh_8_10, sh_8_11, sh_8_12, sh_8_13, sh_8_14, sh_8_15, sh_8_16, sh_9_0, sh_9_1, sh_9_2, sh_9_3, sh_9_4, sh_9_5, sh_9_6, sh_9_7, sh_9_8, sh_9_9, sh_9_10, sh_9_11, sh_9_12, sh_9_13, sh_9_14, sh_9_15, sh_9_16, sh_9_17, sh_9_18, ], dim=-1, ) sh_10_0 = (1 / 10) * math.sqrt(105) * (sh_9_0 * z + sh_9_18 * x) sh_10_1 = ( (1 / 10) * math.sqrt(21) * sh_9_0 * y + (3 / 20) * math.sqrt(42) * sh_9_1 * z + (3 / 20) * math.sqrt(42) * sh_9_17 * x ) sh_10_2 = ( -1 / 380 * math.sqrt(798) * sh_9_0 * z + (3 / 95) * math.sqrt(399) * sh_9_1 * y + (3 / 380) * math.sqrt(13566) * sh_9_16 * x + (1 / 380) * math.sqrt(798) * sh_9_18 * x + (3 / 380) * math.sqrt(13566) * sh_9_2 * z ) sh_10_3 = ( -3 / 380 * math.sqrt(266) * sh_9_1 * z + (1 / 95) * math.sqrt(6783) * sh_9_15 * x + (3 / 380) * math.sqrt(266) * sh_9_17 * x + (3 / 190) * math.sqrt(2261) * sh_9_2 * y + (1 / 95) * math.sqrt(6783) * sh_9_3 * z ) sh_10_4 = ( (3 / 95) * math.sqrt(665) * sh_9_14 * x + (3 / 190) * math.sqrt(133) * sh_9_16 * x - 3 / 190 * math.sqrt(133) * sh_9_2 * z + (4 / 95) * math.sqrt(399) * sh_9_3 * y + (3 / 95) * math.sqrt(665) * sh_9_4 * z ) sh_10_5 = ( (21 / 380) * math.sqrt(190) * sh_9_13 * x + (1 / 190) * math.sqrt(1995) * sh_9_15 * x - 1 / 190 * math.sqrt(1995) * sh_9_3 * z + (3 / 38) * math.sqrt(133) * sh_9_4 * y + (21 / 380) * math.sqrt(190) * sh_9_5 * z ) sh_10_6 = ( (7 / 380) * math.sqrt(1482) * sh_9_12 * x + (3 / 380) * math.sqrt(1330) * sh_9_14 * x - 3 / 380 * math.sqrt(1330) * sh_9_4 * z + (21 / 95) * math.sqrt(19) * sh_9_5 * y + (7 / 380) * math.sqrt(1482) * sh_9_6 * z ) sh_10_7 = ( (3 / 190) * math.sqrt(1729) * sh_9_11 * x + (21 / 380) * math.sqrt(38) * sh_9_13 * x - 21 / 380 * math.sqrt(38) * sh_9_5 * z + (7 / 190) * math.sqrt(741) * sh_9_6 * y + (3 / 190) * math.sqrt(1729) * sh_9_7 * z ) sh_10_8 = ( (3 / 190) * math.sqrt(1463) * sh_9_10 * x + (7 / 190) * math.sqrt(114) * sh_9_12 * x - 7 / 190 * math.sqrt(114) * sh_9_6 * z + (6 / 95) * math.sqrt(266) * sh_9_7 * y + (3 / 190) * math.sqrt(1463) * sh_9_8 * z ) sh_10_9 = ( (3 / 190) * math.sqrt(798) * sh_9_11 * x - 3 / 190 * math.sqrt(798) * sh_9_7 * z + (3 / 190) * math.sqrt(4389) * sh_9_8 * y + (1 / 190) * math.sqrt(21945) * sh_9_9 * x ) sh_10_10 = ( -3 / 190 * math.sqrt(1995) * sh_9_10 * z - 3 / 190 * math.sqrt(1995) * sh_9_8 * x + (1 / 19) * math.sqrt(399) * sh_9_9 * y ) sh_10_11 = ( (3 / 190) * math.sqrt(4389) * sh_9_10 * y - 3 / 190 * math.sqrt(798) * sh_9_11 * z - 3 / 190 * math.sqrt(798) * sh_9_7 * x + (1 / 190) * math.sqrt(21945) * sh_9_9 * z ) sh_10_12 = ( (3 / 190) * math.sqrt(1463) * sh_9_10 * z + (6 / 95) * math.sqrt(266) * sh_9_11 * y - 7 / 190 * math.sqrt(114) * sh_9_12 * z - 7 / 190 * math.sqrt(114) * sh_9_6 * x - 3 / 190 * math.sqrt(1463) * sh_9_8 * x ) sh_10_13 = ( (3 / 190) * math.sqrt(1729) * sh_9_11 * z + (7 / 190) * math.sqrt(741) * sh_9_12 * y - 21 / 380 * math.sqrt(38) * sh_9_13 * z - 21 / 380 * math.sqrt(38) * sh_9_5 * x - 3 / 190 * math.sqrt(1729) * sh_9_7 * x ) sh_10_14 = ( (7 / 380) * math.sqrt(1482) * sh_9_12 * z + (21 / 95) * math.sqrt(19) * sh_9_13 * y - 3 / 380 * math.sqrt(1330) * sh_9_14 * z - 3 / 380 * math.sqrt(1330) * sh_9_4 * x - 7 / 380 * math.sqrt(1482) * sh_9_6 * x ) sh_10_15 = ( (21 / 380) * math.sqrt(190) * sh_9_13 * z + (3 / 38) * math.sqrt(133) * sh_9_14 * y - 1 / 190 * math.sqrt(1995) * sh_9_15 * z - 1 / 190 * math.sqrt(1995) * sh_9_3 * x - 21 / 380 * math.sqrt(190) * sh_9_5 * x ) sh_10_16 = ( (3 / 95) * math.sqrt(665) * sh_9_14 * z + (4 / 95) * math.sqrt(399) * sh_9_15 * y - 3 / 190 * math.sqrt(133) * sh_9_16 * z - 3 / 190 * math.sqrt(133) * sh_9_2 * x - 3 / 95 * math.sqrt(665) * sh_9_4 * x ) sh_10_17 = ( -3 / 380 * math.sqrt(266) * sh_9_1 * x + (1 / 95) * math.sqrt(6783) * sh_9_15 * z + (3 / 190) * math.sqrt(2261) * sh_9_16 * y - 3 / 380 * math.sqrt(266) * sh_9_17 * z - 1 / 95 * math.sqrt(6783) * sh_9_3 * x ) sh_10_18 = ( -1 / 380 * math.sqrt(798) * sh_9_0 * x + (3 / 380) * math.sqrt(13566) * sh_9_16 * z + (3 / 95) * math.sqrt(399) * sh_9_17 * y - 1 / 380 * math.sqrt(798) * sh_9_18 * z - 3 / 380 * math.sqrt(13566) * sh_9_2 * x ) sh_10_19 = ( -3 / 20 * math.sqrt(42) * sh_9_1 * x + (3 / 20) * math.sqrt(42) * sh_9_17 * z + (1 / 10) * math.sqrt(21) * sh_9_18 * y ) sh_10_20 = (1 / 10) * math.sqrt(105) * (-sh_9_0 * x + sh_9_18 * z) if lmax == 10: return torch.stack( [ sh_0_0, sh_1_0, sh_1_1, sh_1_2, sh_2_0, sh_2_1, sh_2_2, sh_2_3, sh_2_4, sh_3_0, sh_3_1, sh_3_2, sh_3_3, sh_3_4, sh_3_5, sh_3_6, sh_4_0, sh_4_1, sh_4_2, sh_4_3, sh_4_4, sh_4_5, sh_4_6, sh_4_7, sh_4_8, sh_5_0, sh_5_1, sh_5_2, sh_5_3, sh_5_4, sh_5_5, sh_5_6, sh_5_7, sh_5_8, sh_5_9, sh_5_10, sh_6_0, sh_6_1, sh_6_2, sh_6_3, sh_6_4, sh_6_5, sh_6_6, sh_6_7, sh_6_8, sh_6_9, sh_6_10, sh_6_11, sh_6_12, sh_7_0, sh_7_1, sh_7_2, sh_7_3, sh_7_4, sh_7_5, sh_7_6, sh_7_7, sh_7_8, sh_7_9, sh_7_10, sh_7_11, sh_7_12, sh_7_13, sh_7_14, sh_8_0, sh_8_1, sh_8_2, sh_8_3, sh_8_4, sh_8_5, sh_8_6, sh_8_7, sh_8_8, sh_8_9, sh_8_10, sh_8_11, sh_8_12, sh_8_13, sh_8_14, sh_8_15, sh_8_16, sh_9_0, sh_9_1, sh_9_2, sh_9_3, sh_9_4, sh_9_5, sh_9_6, sh_9_7, sh_9_8, sh_9_9, sh_9_10, sh_9_11, sh_9_12, sh_9_13, sh_9_14, sh_9_15, sh_9_16, sh_9_17, sh_9_18, sh_10_0, sh_10_1, sh_10_2, sh_10_3, sh_10_4, sh_10_5, sh_10_6, sh_10_7, sh_10_8, sh_10_9, sh_10_10, sh_10_11, sh_10_12, sh_10_13, sh_10_14, sh_10_15, sh_10_16, sh_10_17, sh_10_18, sh_10_19, sh_10_20, ], dim=-1, ) sh_11_0 = (1 / 22) * math.sqrt(506) * (sh_10_0 * z + sh_10_20 * x) sh_11_1 = ( (1 / 11) * math.sqrt(23) * sh_10_0 * y + (1 / 11) * math.sqrt(115) * sh_10_1 * z + (1 / 11) * math.sqrt(115) * sh_10_19 * x ) sh_11_2 = ( -1 / 462 * math.sqrt(966) * sh_10_0 * z + (2 / 231) * math.sqrt(4830) * sh_10_1 * y + (1 / 231) * math.sqrt(45885) * sh_10_18 * x + (1 / 231) * math.sqrt(45885) * sh_10_2 * z + (1 / 462) * math.sqrt(966) * sh_10_20 * x ) sh_11_3 = ( -1 / 154 * math.sqrt(322) * sh_10_1 * z + (1 / 154) * math.sqrt(18354) * sh_10_17 * x + (1 / 154) * math.sqrt(322) * sh_10_19 * x + (1 / 77) * math.sqrt(3059) * sh_10_2 * y + (1 / 154) * math.sqrt(18354) * sh_10_3 * z ) sh_11_4 = ( (1 / 154) * math.sqrt(16422) * sh_10_16 * x + (1 / 77) * math.sqrt(161) * sh_10_18 * x - 1 / 77 * math.sqrt(161) * sh_10_2 * z + (2 / 77) * math.sqrt(966) * sh_10_3 * y + (1 / 154) * math.sqrt(16422) * sh_10_4 * z ) sh_11_5 = ( (2 / 231) * math.sqrt(8211) * sh_10_15 * x + (1 / 231) * math.sqrt(2415) * sh_10_17 * x - 1 / 231 * math.sqrt(2415) * sh_10_3 * z + (1 / 231) * math.sqrt(41055) * sh_10_4 * y + (2 / 231) * math.sqrt(8211) * sh_10_5 * z ) sh_11_6 = ( (2 / 77) * math.sqrt(805) * sh_10_14 * x + (1 / 154) * math.sqrt(1610) * sh_10_16 * x - 1 / 154 * math.sqrt(1610) * sh_10_4 * z + (4 / 77) * math.sqrt(322) * sh_10_5 * y + (2 / 77) * math.sqrt(805) * sh_10_6 * z ) sh_11_7 = ( (1 / 22) * math.sqrt(230) * sh_10_13 * x + (1 / 22) * math.sqrt(46) * sh_10_15 * x - 1 / 22 * math.sqrt(46) * sh_10_5 * z + (1 / 11) * math.sqrt(115) * sh_10_6 * y + (1 / 22) * math.sqrt(230) * sh_10_7 * z ) sh_11_8 = ( (1 / 66) * math.sqrt(1794) * sh_10_12 * x + (1 / 33) * math.sqrt(138) * sh_10_14 * x - 1 / 33 * math.sqrt(138) * sh_10_6 * z + (4 / 33) * math.sqrt(69) * sh_10_7 * y + (1 / 66) * math.sqrt(1794) * sh_10_8 * z ) sh_11_9 = ( (1 / 77) * math.sqrt(2093) * sh_10_11 * x + (1 / 77) * math.sqrt(966) * sh_10_13 * x - 1 / 77 * math.sqrt(966) * sh_10_7 * z + (1 / 77) * math.sqrt(6279) * sh_10_8 * y + (1 / 77) * math.sqrt(2093) * sh_10_9 * z ) sh_11_10 = ( (1 / 77) * math.sqrt(3542) * sh_10_10 * x + (1 / 154) * math.sqrt(4830) * sh_10_12 * x - 1 / 154 * math.sqrt(4830) * sh_10_8 * z + (2 / 77) * math.sqrt(1610) * sh_10_9 * y ) sh_11_11 = ( (1 / 21) * math.sqrt(483) * sh_10_10 * y - 1 / 231 * math.sqrt(26565) * sh_10_11 * z - 1 / 231 * math.sqrt(26565) * sh_10_9 * x ) sh_11_12 = ( (1 / 77) * math.sqrt(3542) * sh_10_10 * z + (2 / 77) * math.sqrt(1610) * sh_10_11 * y - 1 / 154 * math.sqrt(4830) * sh_10_12 * z - 1 / 154 * math.sqrt(4830) * sh_10_8 * x ) sh_11_13 = ( (1 / 77) * math.sqrt(2093) * sh_10_11 * z + (1 / 77) * math.sqrt(6279) * sh_10_12 * y - 1 / 77 * math.sqrt(966) * sh_10_13 * z - 1 / 77 * math.sqrt(966) * sh_10_7 * x - 1 / 77 * math.sqrt(2093) * sh_10_9 * x ) sh_11_14 = ( (1 / 66) * math.sqrt(1794) * sh_10_12 * z + (4 / 33) * math.sqrt(69) * sh_10_13 * y - 1 / 33 * math.sqrt(138) * sh_10_14 * z - 1 / 33 * math.sqrt(138) * sh_10_6 * x - 1 / 66 * math.sqrt(1794) * sh_10_8 * x ) sh_11_15 = ( (1 / 22) * math.sqrt(230) * sh_10_13 * z + (1 / 11) * math.sqrt(115) * sh_10_14 * y - 1 / 22 * math.sqrt(46) * sh_10_15 * z - 1 / 22 * math.sqrt(46) * sh_10_5 * x - 1 / 22 * math.sqrt(230) * sh_10_7 * x ) sh_11_16 = ( (2 / 77) * math.sqrt(805) * sh_10_14 * z + (4 / 77) * math.sqrt(322) * sh_10_15 * y - 1 / 154 * math.sqrt(1610) * sh_10_16 * z - 1 / 154 * math.sqrt(1610) * sh_10_4 * x - 2 / 77 * math.sqrt(805) * sh_10_6 * x ) sh_11_17 = ( (2 / 231) * math.sqrt(8211) * sh_10_15 * z + (1 / 231) * math.sqrt(41055) * sh_10_16 * y - 1 / 231 * math.sqrt(2415) * sh_10_17 * z - 1 / 231 * math.sqrt(2415) * sh_10_3 * x - 2 / 231 * math.sqrt(8211) * sh_10_5 * x ) sh_11_18 = ( (1 / 154) * math.sqrt(16422) * sh_10_16 * z + (2 / 77) * math.sqrt(966) * sh_10_17 * y - 1 / 77 * math.sqrt(161) * sh_10_18 * z - 1 / 77 * math.sqrt(161) * sh_10_2 * x - 1 / 154 * math.sqrt(16422) * sh_10_4 * x ) sh_11_19 = ( -1 / 154 * math.sqrt(322) * sh_10_1 * x + (1 / 154) * math.sqrt(18354) * sh_10_17 * z + (1 / 77) * math.sqrt(3059) * sh_10_18 * y - 1 / 154 * math.sqrt(322) * sh_10_19 * z - 1 / 154 * math.sqrt(18354) * sh_10_3 * x ) sh_11_20 = ( -1 / 462 * math.sqrt(966) * sh_10_0 * x + (1 / 231) * math.sqrt(45885) * sh_10_18 * z + (2 / 231) * math.sqrt(4830) * sh_10_19 * y - 1 / 231 * math.sqrt(45885) * sh_10_2 * x - 1 / 462 * math.sqrt(966) * sh_10_20 * z ) sh_11_21 = ( -1 / 11 * math.sqrt(115) * sh_10_1 * x + (1 / 11) * math.sqrt(115) * sh_10_19 * z + (1 / 11) * math.sqrt(23) * sh_10_20 * y ) sh_11_22 = (1 / 22) * math.sqrt(506) * (-sh_10_0 * x + sh_10_20 * z) if lmax == 11: return torch.stack( [ sh_0_0, sh_1_0, sh_1_1, sh_1_2, sh_2_0, sh_2_1, sh_2_2, sh_2_3, sh_2_4, sh_3_0, sh_3_1, sh_3_2, sh_3_3, sh_3_4, sh_3_5, sh_3_6, sh_4_0, sh_4_1, sh_4_2, sh_4_3, sh_4_4, sh_4_5, sh_4_6, sh_4_7, sh_4_8, sh_5_0, sh_5_1, sh_5_2, sh_5_3, sh_5_4, sh_5_5, sh_5_6, sh_5_7, sh_5_8, sh_5_9, sh_5_10, sh_6_0, sh_6_1, sh_6_2, sh_6_3, sh_6_4, sh_6_5, sh_6_6, sh_6_7, sh_6_8, sh_6_9, sh_6_10, sh_6_11, sh_6_12, sh_7_0, sh_7_1, sh_7_2, sh_7_3, sh_7_4, sh_7_5, sh_7_6, sh_7_7, sh_7_8, sh_7_9, sh_7_10, sh_7_11, sh_7_12, sh_7_13, sh_7_14, sh_8_0, sh_8_1, sh_8_2, sh_8_3, sh_8_4, sh_8_5, sh_8_6, sh_8_7, sh_8_8, sh_8_9, sh_8_10, sh_8_11, sh_8_12, sh_8_13, sh_8_14, sh_8_15, sh_8_16, sh_9_0, sh_9_1, sh_9_2, sh_9_3, sh_9_4, sh_9_5, sh_9_6, sh_9_7, sh_9_8, sh_9_9, sh_9_10, sh_9_11, sh_9_12, sh_9_13, sh_9_14, sh_9_15, sh_9_16, sh_9_17, sh_9_18, sh_10_0, sh_10_1, sh_10_2, sh_10_3, sh_10_4, sh_10_5, sh_10_6, sh_10_7, sh_10_8, sh_10_9, sh_10_10, sh_10_11, sh_10_12, sh_10_13, sh_10_14, sh_10_15, sh_10_16, sh_10_17, sh_10_18, sh_10_19, sh_10_20, sh_11_0, sh_11_1, sh_11_2, sh_11_3, sh_11_4, sh_11_5, sh_11_6, sh_11_7, sh_11_8, sh_11_9, sh_11_10, sh_11_11, sh_11_12, sh_11_13, sh_11_14, sh_11_15, sh_11_16, sh_11_17, sh_11_18, sh_11_19, sh_11_20, sh_11_21, sh_11_22, ], dim=-1, ) sh_12_0 = (5 / 12) * math.sqrt(6) * (sh_11_0 * z + sh_11_22 * x) sh_12_1 = (5 / 12) * sh_11_0 * y + (5 / 24) * math.sqrt(22) * sh_11_1 * z + (5 / 24) * math.sqrt(22) * sh_11_21 * x sh_12_2 = ( -5 / 552 * math.sqrt(46) * sh_11_0 * z + (5 / 138) * math.sqrt(253) * sh_11_1 * y + (5 / 552) * math.sqrt(10626) * sh_11_2 * z + (5 / 552) * math.sqrt(10626) * sh_11_20 * x + (5 / 552) * math.sqrt(46) * sh_11_22 * x ) sh_12_3 = ( -5 / 552 * math.sqrt(138) * sh_11_1 * z + (5 / 276) * math.sqrt(2415) * sh_11_19 * x + (5 / 92) * math.sqrt(161) * sh_11_2 * y + (5 / 552) * math.sqrt(138) * sh_11_21 * x + (5 / 276) * math.sqrt(2415) * sh_11_3 * z ) sh_12_4 = ( (5 / 276) * math.sqrt(2185) * sh_11_18 * x - 5 / 276 * math.sqrt(69) * sh_11_2 * z + (5 / 276) * math.sqrt(69) * sh_11_20 * x + (5 / 69) * math.sqrt(115) * sh_11_3 * y + (5 / 276) * math.sqrt(2185) * sh_11_4 * z ) sh_12_5 = ( (5 / 184) * math.sqrt(874) * sh_11_17 * x + (5 / 276) * math.sqrt(115) * sh_11_19 * x - 5 / 276 * math.sqrt(115) * sh_11_3 * z + (5 / 276) * math.sqrt(2185) * sh_11_4 * y + (5 / 184) * math.sqrt(874) * sh_11_5 * z ) sh_12_6 = ( (5 / 552) * math.sqrt(3) * ( math.sqrt(2346) * sh_11_16 * x + math.sqrt(230) * sh_11_18 * x - math.sqrt(230) * sh_11_4 * z + 12 * math.sqrt(23) * sh_11_5 * y + math.sqrt(2346) * sh_11_6 * z ) ) sh_12_7 = ( (5 / 138) * math.sqrt(391) * sh_11_15 * x + (5 / 552) * math.sqrt(966) * sh_11_17 * x - 5 / 552 * math.sqrt(966) * sh_11_5 * z + (5 / 276) * math.sqrt(2737) * sh_11_6 * y + (5 / 138) * math.sqrt(391) * sh_11_7 * z ) sh_12_8 = ( (5 / 138) * math.sqrt(345) * sh_11_14 * x + (5 / 276) * math.sqrt(322) * sh_11_16 * x - 5 / 276 * math.sqrt(322) * sh_11_6 * z + (10 / 69) * math.sqrt(46) * sh_11_7 * y + (5 / 138) * math.sqrt(345) * sh_11_8 * z ) sh_12_9 = ( (5 / 552) * math.sqrt(4830) * sh_11_13 * x + (5 / 92) * math.sqrt(46) * sh_11_15 * x - 5 / 92 * math.sqrt(46) * sh_11_7 * z + (5 / 92) * math.sqrt(345) * sh_11_8 * y + (5 / 552) * math.sqrt(4830) * sh_11_9 * z ) sh_12_10 = ( (5 / 552) * math.sqrt(4186) * sh_11_10 * z + (5 / 552) * math.sqrt(4186) * sh_11_12 * x + (5 / 184) * math.sqrt(230) * sh_11_14 * x - 5 / 184 * math.sqrt(230) * sh_11_8 * z + (5 / 138) * math.sqrt(805) * sh_11_9 * y ) sh_12_11 = ( (5 / 276) * math.sqrt(3289) * sh_11_10 * y + (5 / 276) * math.sqrt(1794) * sh_11_11 * x + (5 / 552) * math.sqrt(2530) * sh_11_13 * x - 5 / 552 * math.sqrt(2530) * sh_11_9 * z ) sh_12_12 = ( -5 / 276 * math.sqrt(1518) * sh_11_10 * x + (5 / 23) * math.sqrt(23) * sh_11_11 * y - 5 / 276 * math.sqrt(1518) * sh_11_12 * z ) sh_12_13 = ( (5 / 276) * math.sqrt(1794) * sh_11_11 * z + (5 / 276) * math.sqrt(3289) * sh_11_12 * y - 5 / 552 * math.sqrt(2530) * sh_11_13 * z - 5 / 552 * math.sqrt(2530) * sh_11_9 * x ) sh_12_14 = ( -5 / 552 * math.sqrt(4186) * sh_11_10 * x + (5 / 552) * math.sqrt(4186) * sh_11_12 * z + (5 / 138) * math.sqrt(805) * sh_11_13 * y - 5 / 184 * math.sqrt(230) * sh_11_14 * z - 5 / 184 * math.sqrt(230) * sh_11_8 * x ) sh_12_15 = ( (5 / 552) * math.sqrt(4830) * sh_11_13 * z + (5 / 92) * math.sqrt(345) * sh_11_14 * y - 5 / 92 * math.sqrt(46) * sh_11_15 * z - 5 / 92 * math.sqrt(46) * sh_11_7 * x - 5 / 552 * math.sqrt(4830) * sh_11_9 * x ) sh_12_16 = ( (5 / 138) * math.sqrt(345) * sh_11_14 * z + (10 / 69) * math.sqrt(46) * sh_11_15 * y - 5 / 276 * math.sqrt(322) * sh_11_16 * z - 5 / 276 * math.sqrt(322) * sh_11_6 * x - 5 / 138 * math.sqrt(345) * sh_11_8 * x ) sh_12_17 = ( (5 / 138) * math.sqrt(391) * sh_11_15 * z + (5 / 276) * math.sqrt(2737) * sh_11_16 * y - 5 / 552 * math.sqrt(966) * sh_11_17 * z - 5 / 552 * math.sqrt(966) * sh_11_5 * x - 5 / 138 * math.sqrt(391) * sh_11_7 * x ) sh_12_18 = ( (5 / 552) * math.sqrt(3) * ( math.sqrt(2346) * sh_11_16 * z + 12 * math.sqrt(23) * sh_11_17 * y - math.sqrt(230) * sh_11_18 * z - math.sqrt(230) * sh_11_4 * x - math.sqrt(2346) * sh_11_6 * x ) ) sh_12_19 = ( (5 / 184) * math.sqrt(874) * sh_11_17 * z + (5 / 276) * math.sqrt(2185) * sh_11_18 * y - 5 / 276 * math.sqrt(115) * sh_11_19 * z - 5 / 276 * math.sqrt(115) * sh_11_3 * x - 5 / 184 * math.sqrt(874) * sh_11_5 * x ) sh_12_20 = ( (5 / 276) * math.sqrt(2185) * sh_11_18 * z + (5 / 69) * math.sqrt(115) * sh_11_19 * y - 5 / 276 * math.sqrt(69) * sh_11_2 * x - 5 / 276 * math.sqrt(69) * sh_11_20 * z - 5 / 276 * math.sqrt(2185) * sh_11_4 * x ) sh_12_21 = ( -5 / 552 * math.sqrt(138) * sh_11_1 * x + (5 / 276) * math.sqrt(2415) * sh_11_19 * z + (5 / 92) * math.sqrt(161) * sh_11_20 * y - 5 / 552 * math.sqrt(138) * sh_11_21 * z - 5 / 276 * math.sqrt(2415) * sh_11_3 * x ) sh_12_22 = ( -5 / 552 * math.sqrt(46) * sh_11_0 * x - 5 / 552 * math.sqrt(10626) * sh_11_2 * x + (5 / 552) * math.sqrt(10626) * sh_11_20 * z + (5 / 138) * math.sqrt(253) * sh_11_21 * y - 5 / 552 * math.sqrt(46) * sh_11_22 * z ) sh_12_23 = -5 / 24 * math.sqrt(22) * sh_11_1 * x + (5 / 24) * math.sqrt(22) * sh_11_21 * z + (5 / 12) * sh_11_22 * y sh_12_24 = (5 / 12) * math.sqrt(6) * (-sh_11_0 * x + sh_11_22 * z) return torch.stack( [ sh_0_0, sh_1_0, sh_1_1, sh_1_2, sh_2_0, sh_2_1, sh_2_2, sh_2_3, sh_2_4, sh_3_0, sh_3_1, sh_3_2, sh_3_3, sh_3_4, sh_3_5, sh_3_6, sh_4_0, sh_4_1, sh_4_2, sh_4_3, sh_4_4, sh_4_5, sh_4_6, sh_4_7, sh_4_8, sh_5_0, sh_5_1, sh_5_2, sh_5_3, sh_5_4, sh_5_5, sh_5_6, sh_5_7, sh_5_8, sh_5_9, sh_5_10, sh_6_0, sh_6_1, sh_6_2, sh_6_3, sh_6_4, sh_6_5, sh_6_6, sh_6_7, sh_6_8, sh_6_9, sh_6_10, sh_6_11, sh_6_12, sh_7_0, sh_7_1, sh_7_2, sh_7_3, sh_7_4, sh_7_5, sh_7_6, sh_7_7, sh_7_8, sh_7_9, sh_7_10, sh_7_11, sh_7_12, sh_7_13, sh_7_14, sh_8_0, sh_8_1, sh_8_2, sh_8_3, sh_8_4, sh_8_5, sh_8_6, sh_8_7, sh_8_8, sh_8_9, sh_8_10, sh_8_11, sh_8_12, sh_8_13, sh_8_14, sh_8_15, sh_8_16, sh_9_0, sh_9_1, sh_9_2, sh_9_3, sh_9_4, sh_9_5, sh_9_6, sh_9_7, sh_9_8, sh_9_9, sh_9_10, sh_9_11, sh_9_12, sh_9_13, sh_9_14, sh_9_15, sh_9_16, sh_9_17, sh_9_18, sh_10_0, sh_10_1, sh_10_2, sh_10_3, sh_10_4, sh_10_5, sh_10_6, sh_10_7, sh_10_8, sh_10_9, sh_10_10, sh_10_11, sh_10_12, sh_10_13, sh_10_14, sh_10_15, sh_10_16, sh_10_17, sh_10_18, sh_10_19, sh_10_20, sh_11_0, sh_11_1, sh_11_2, sh_11_3, sh_11_4, sh_11_5, sh_11_6, sh_11_7, sh_11_8, sh_11_9, sh_11_10, sh_11_11, sh_11_12, sh_11_13, sh_11_14, sh_11_15, sh_11_16, sh_11_17, sh_11_18, sh_11_19, sh_11_20, sh_11_21, sh_11_22, sh_12_0, sh_12_1, sh_12_2, sh_12_3, sh_12_4, sh_12_5, sh_12_6, sh_12_7, sh_12_8, sh_12_9, sh_12_10, sh_12_11, sh_12_12, sh_12_13, sh_12_14, sh_12_15, sh_12_16, sh_12_17, sh_12_18, sh_12_19, sh_12_20, sh_12_21, sh_12_22, sh_12_23, sh_12_24, ], dim=-1, ) e3nn-0.6.0/e3nn/o3/_spherical_harmonics_generator.py000066400000000000000000000041651514371756200223610ustar00rootroot00000000000000import torch import sympy from sympy.printing.pycode import pycode from e3nn import o3 def _generate_spherical_harmonics(lmax, device=None) -> None: # pragma: no cover r"""code used to generate the code above based on `wigner_3j` """ torch.set_default_dtype(torch.float64) def to_frac(x: float): from fractions import Fraction s = 1 if x >= 0 else -1 x = x**2 x = Fraction(x).limit_denominator() x = s * sympy.sqrt(x) x = sympy.simplify(x) return x print("sh_0_0 = torch.ones_like(x)") print("if lmax == 0:") print(" return torch.stack([") print(" sh_0_0,") print(" ], dim=-1)") print() x_var, y_var, z_var = sympy.symbols("x y z") polynomials = [sympy.sqrt(3) * x_var, sympy.sqrt(3) * y_var, sympy.sqrt(3) * z_var] def sub_z1(p, names, polynormz): p = p.subs(x_var, 0).subs(y_var, 1).subs(z_var, 0) for n, c in zip(names, polynormz): p = p.subs(n, c) return p poly_evalz = [sub_z1(p, [], []) for p in polynomials] for l in range(1, lmax + 1): sh_variables = sympy.symbols(" ".join(f"sh_{l}_{m}" for m in range(2 * l + 1))) for n, p in zip(sh_variables, polynomials): print(f"{n} = {pycode(p)}") print(f"if lmax == {l}:") u = ",\n ".join(", ".join(f"sh_{j}_{m}" for m in range(2 * j + 1)) for j in range(l + 1)) print(f" return torch.stack([\n {u}\n ], dim=-1)") print() if l == lmax: break polynomials = [ sum(to_frac(c.item()) * v * sh for cj, v in zip(cij, [x_var, y_var, z_var]) for c, sh in zip(cj, sh_variables)) for cij in o3.wigner_3j(l + 1, 1, l, device=device) ] poly_evalz = [sub_z1(p, sh_variables, poly_evalz) for p in polynomials] norm = sympy.sqrt(sum(p**2 for p in poly_evalz)) polynomials = [sympy.sqrt(2 * l + 3) * p / norm for p in polynomials] poly_evalz = [sympy.sqrt(2 * l + 3) * p / norm for p in poly_evalz] polynomials = [sympy.simplify(p, full=True) for p in polynomials] e3nn-0.6.0/e3nn/o3/_tensor_product/000077500000000000000000000000001514371756200167705ustar00rootroot00000000000000e3nn-0.6.0/e3nn/o3/_tensor_product/__init__.py000066400000000000000000000005411514371756200211010ustar00rootroot00000000000000from ._instruction import Instruction from ._tensor_product import TensorProduct from ._sub import ElementwiseTensorProduct, FullTensorProduct, FullyConnectedTensorProduct, TensorSquare __all__ = [ "Instruction", "TensorProduct", "FullyConnectedTensorProduct", "ElementwiseTensorProduct", "FullTensorProduct", "TensorSquare", ] e3nn-0.6.0/e3nn/o3/_tensor_product/_codegen.py000066400000000000000000001022341514371756200211070ustar00rootroot00000000000000from collections import OrderedDict from math import sqrt from typing import List import torch from e3nn.o3._irreps import Irreps from e3nn.o3._wigner import wigner_3j from e3nn.util import prod from opt_einsum_fx import optimize_einsums_full from torch import fx from ._instruction import Instruction def _sum_tensors(xs: List[torch.Tensor], shape: torch.Size, like: torch.Tensor) -> torch.Tensor: if len(xs) > 0: out = xs[0] for x in xs[1:]: out = out + x return out return like.new_zeros(shape) def codegen_tensor_product_left_right( irreps_in1: Irreps, irreps_in2: Irreps, irreps_out: Irreps, instructions: List[Instruction], shared_weights: bool = False, specialized_code: bool = True, optimize_einsums: bool = True, ) -> fx.GraphModule: graph = fx.Graph() # = Function definitions = tracer = fx.proxy.GraphAppendingTracer(graph) constants = OrderedDict() x1s = fx.Proxy(graph.placeholder("x1", torch.Tensor), tracer=tracer) x2s = fx.Proxy(graph.placeholder("x2", torch.Tensor), tracer=tracer) weights = fx.Proxy(graph.placeholder("w", torch.Tensor), tracer=tracer) if shared_weights: # by broadcasting all but the final irrep dim, we broadcast any and all batch dimensions # the use of `:1` is important, rather than `0`, to ensure the case with an empty irrep dimension doesn't error out output_shape = torch.broadcast_tensors(x1s[..., :1], x2s[..., :1])[0].shape[:-1] else: output_shape = torch.broadcast_tensors(x1s[..., :1], x2s[..., :1], weights[..., :1])[0].shape[:-1] # We produce no code for empty instructions instructions = [ins for ins in instructions if 0 not in ins.path_shape] if len(instructions) == 0: outputs = x1s.new_zeros(output_shape + (irreps_out.dim,)) graph.output(outputs.node, torch.Tensor) # Short circut return fx.GraphModule({}, graph, "tp_forward") # = Broadcast inputs = bc_shape = output_shape + (-1,) x1s, x2s = x1s.expand(bc_shape), x2s.expand(bc_shape) if not shared_weights: weights = weights.expand(bc_shape) output_shape = output_shape + (irreps_out.dim,) x1s = x1s.reshape(-1, irreps_in1.dim) x2s = x2s.reshape(-1, irreps_in2.dim) batch_numel = x1s.shape[0] # = Determine number of weights and reshape weights == weight_numel = sum(prod(ins.path_shape) for ins in instructions if ins.has_weight) if weight_numel > 0: weights = weights.reshape(-1, weight_numel) del weight_numel # = extract individual input irreps = # If only one input irrep, can avoid creating a view if len(irreps_in1) == 1: x1_list = [x1s.reshape(batch_numel, irreps_in1[0].mul, irreps_in1[0].ir.dim)] else: x1_list = [ x1s[:, i].reshape(batch_numel, mul_ir.mul, mul_ir.ir.dim) for i, mul_ir in zip(irreps_in1.slices(), irreps_in1) ] x2_list = [] # If only one input irrep, can avoid creating a view if len(irreps_in2) == 1: x2_list.append(x2s.reshape(batch_numel, irreps_in2[0].mul, irreps_in2[0].ir.dim)) else: for i, mul_ir in zip(irreps_in2.slices(), irreps_in2): x2_list.append(x2s[:, i].reshape(batch_numel, mul_ir.mul, mul_ir.ir.dim)) # The einsum string index to prepend to the weights if the weights are not shared and have a batch dimension z = "" if shared_weights else "z" # Cache of input irrep pairs whose outer products (xx) have already been computed xx_dict = dict() # Current index in the flat weight tensor flat_weight_index = 0 outputs = [] for ins in instructions: mul_ir_in1 = irreps_in1[ins.i_in1] mul_ir_in2 = irreps_in2[ins.i_in2] mul_ir_out = irreps_out[ins.i_out] assert mul_ir_in1.ir.p * mul_ir_in2.ir.p == mul_ir_out.ir.p assert abs(mul_ir_in1.ir.l - mul_ir_in2.ir.l) <= mul_ir_out.ir.l <= mul_ir_in1.ir.l + mul_ir_in2.ir.l if mul_ir_in1.dim == 0 or mul_ir_in2.dim == 0 or mul_ir_out.dim == 0: continue x1 = x1_list[ins.i_in1] x2 = x2_list[ins.i_in2] assert ins.connection_mode in ["uvw", "uvu", "uvv", "uuw", "uuu", "uvuv", "uvuzuij", x1, x2) else: xx_dict[key] = torch.einsum("zui,zvj->zuvij", x1, x2) xx = xx_dict[key] del key # Create a proxy & request for the relevant wigner w3j # If not used (because of specialized code), will get removed later. w3j_name = f"_w3j_{mul_ir_in1.ir.l}_{mul_ir_in2.ir.l}_{mul_ir_out.ir.l}" w3j = fx.Proxy(graph.get_attr(w3j_name), tracer=tracer) l1l2l3 = (mul_ir_in1.ir.l, mul_ir_in2.ir.l, mul_ir_out.ir.l) if ins.connection_mode == "uvw": assert ins.has_weight if specialized_code and l1l2l3 == (0, 0, 0): result = torch.einsum( f"{z}uvw,zu,zv->zw", w, x1.reshape(batch_numel, mul_ir_in1.dim), x2.reshape(batch_numel, mul_ir_in2.dim) ) elif specialized_code and mul_ir_in1.ir.l == 0: result = torch.einsum(f"{z}uvw,zu,zvj->zwj", w, x1.reshape(batch_numel, mul_ir_in1.dim), x2) / sqrt( mul_ir_out.ir.dim ) elif specialized_code and mul_ir_in2.ir.l == 0: result = torch.einsum(f"{z}uvw,zui,zv->zwi", w, x1, x2.reshape(batch_numel, mul_ir_in2.dim)) / sqrt( mul_ir_out.ir.dim ) elif specialized_code and mul_ir_out.ir.l == 0: result = torch.einsum(f"{z}uvw,zui,zvi->zw", w, x1, x2) / sqrt(mul_ir_in1.ir.dim) else: result = torch.einsum(f"{z}uvw,ijk,zuvij->zwk", w, w3j, xx) if ins.connection_mode == "uvu": assert mul_ir_in1.mul == mul_ir_out.mul if ins.has_weight: if specialized_code and l1l2l3 == (0, 0, 0): result = torch.einsum( f"{z}uv,zu,zv->zu", w, x1.reshape(batch_numel, mul_ir_in1.dim), x2.reshape(batch_numel, mul_ir_in2.dim) ) elif specialized_code and mul_ir_in1.ir.l == 0: result = torch.einsum(f"{z}uv,zu,zvj->zuj", w, x1.reshape(batch_numel, mul_ir_in1.dim), x2) / sqrt( mul_ir_out.ir.dim ) elif specialized_code and mul_ir_in2.ir.l == 0: result = torch.einsum(f"{z}uv,zui,zv->zui", w, x1, x2.reshape(batch_numel, mul_ir_in2.dim)) / sqrt( mul_ir_out.ir.dim ) elif specialized_code and mul_ir_out.ir.l == 0: result = torch.einsum(f"{z}uv,zui,zvi->zu", w, x1, x2) / sqrt(mul_ir_in1.ir.dim) else: result = torch.einsum(f"{z}uv,ijk,zuvij->zuk", w, w3j, xx) else: # not so useful operation because v is summed result = torch.einsum("ijk,zuvij->zuk", w3j, xx) if ins.connection_mode == "uvv": assert mul_ir_in2.mul == mul_ir_out.mul if ins.has_weight: if specialized_code and l1l2l3 == (0, 0, 0): result = torch.einsum( f"{z}uv,zu,zv->zv", w, x1.reshape(batch_numel, mul_ir_in1.dim), x2.reshape(batch_numel, mul_ir_in2.dim) ) elif specialized_code and mul_ir_in1.ir.l == 0: result = torch.einsum(f"{z}uv,zu,zvj->zvj", w, x1.reshape(batch_numel, mul_ir_in1.dim), x2) / sqrt( mul_ir_out.ir.dim ) elif specialized_code and mul_ir_in2.ir.l == 0: result = torch.einsum(f"{z}uv,zui,zv->zvi", w, x1, x2.reshape(batch_numel, mul_ir_in2.dim)) / sqrt( mul_ir_out.ir.dim ) elif specialized_code and mul_ir_out.ir.l == 0: result = torch.einsum(f"{z}uv,zui,zvi->zv", w, x1, x2) / sqrt(mul_ir_in1.ir.dim) else: result = torch.einsum(f"{z}uv,ijk,zuvij->zvk", w, w3j, xx) else: # not so useful operation because u is summed # only specialize out for this path if specialized_code and l1l2l3 == (0, 0, 0): result = torch.einsum( "zu,zv->zv", x1.reshape(batch_numel, mul_ir_in1.dim), x2.reshape(batch_numel, mul_ir_in2.dim) ) elif specialized_code and mul_ir_in1.ir.l == 0: result = torch.einsum("zu,zvj->zvj", x1.reshape(batch_numel, mul_ir_in1.dim), x2) / sqrt(mul_ir_out.ir.dim) elif specialized_code and mul_ir_in2.ir.l == 0: result = torch.einsum("zui,zv->zvi", x1, x2.reshape(batch_numel, mul_ir_in2.dim)) / sqrt(mul_ir_out.ir.dim) elif specialized_code and mul_ir_out.ir.l == 0: result = torch.einsum("zui,zvi->zv", x1, x2) / sqrt(mul_ir_in1.ir.dim) else: result = torch.einsum("ijk,zuvij->zvk", w3j, xx) if ins.connection_mode == "uuw": assert mul_ir_in1.mul == mul_ir_in2.mul if ins.has_weight: if specialized_code and l1l2l3 == (0, 0, 0): result = torch.einsum( f"{z}uw,zu,zu->zw", w, x1.reshape(batch_numel, mul_ir_in1.dim), x2.reshape(batch_numel, mul_ir_in2.dim) ) elif specialized_code and mul_ir_in1.ir.l == 0: result = torch.einsum(f"{z}uw,zu,zuj->zwj", w, x1.reshape(batch_numel, mul_ir_in1.dim), x2) / sqrt( mul_ir_out.ir.dim ) elif specialized_code and mul_ir_in2.ir.l == 0: result = torch.einsum(f"{z}uw,zui,zu->zwi", w, x1, x2.reshape(batch_numel, mul_ir_in2.dim)) / sqrt( mul_ir_out.ir.dim ) elif specialized_code and mul_ir_out.ir.l == 0: result = torch.einsum(f"{z}uw,zui,zui->zw", w, x1, x2) / sqrt(mul_ir_in1.ir.dim) else: result = torch.einsum(f"{z}uw,ijk,zuij->zwk", w, w3j, xx) else: # equivalent to tp(x, y, 'uuu').sum('u') assert mul_ir_out.mul == 1 result = torch.einsum("ijk,zuij->zk", w3j, xx) if ins.connection_mode == "uuu": assert mul_ir_in1.mul == mul_ir_in2.mul == mul_ir_out.mul if ins.has_weight: if specialized_code and l1l2l3 == (0, 0, 0): result = torch.einsum( f"{z}u,zu,zu->zu", w, x1.reshape(batch_numel, mul_ir_in1.dim), x2.reshape(batch_numel, mul_ir_in2.dim) ) elif specialized_code and l1l2l3 == (1, 1, 1): result = torch.einsum(f"{z}u,zui->zui", w, torch.cross(x1, x2, dim=2)) / sqrt(2 * 3) elif specialized_code and mul_ir_in1.ir.l == 0: result = torch.einsum(f"{z}u,zu,zuj->zuj", w, x1.reshape(batch_numel, mul_ir_in1.dim), x2) / sqrt( mul_ir_out.ir.dim ) elif specialized_code and mul_ir_in2.ir.l == 0: result = torch.einsum(f"{z}u,zui,zu->zui", w, x1, x2.reshape(batch_numel, mul_ir_in2.dim)) / sqrt( mul_ir_out.ir.dim ) elif specialized_code and mul_ir_out.ir.l == 0: result = torch.einsum(f"{z}u,zui,zui->zu", w, x1, x2) / sqrt(mul_ir_in1.ir.dim) else: result = torch.einsum(f"{z}u,ijk,zuij->zuk", w, w3j, xx) else: if specialized_code and l1l2l3 == (0, 0, 0): result = torch.einsum( "zu,zu->zu", x1.reshape(batch_numel, mul_ir_in1.dim), x2.reshape(batch_numel, mul_ir_in2.dim) ) elif specialized_code and l1l2l3 == (1, 1, 1): result = torch.cross(x1, x2, dim=2) * (1.0 / sqrt(2 * 3)) elif specialized_code and mul_ir_in1.ir.l == 0: result = torch.einsum("zu,zuj->zuj", x1.reshape(batch_numel, mul_ir_in1.dim), x2) / sqrt(mul_ir_out.ir.dim) elif specialized_code and mul_ir_in2.ir.l == 0: result = torch.einsum("zui,zu->zui", x1, x2.reshape(batch_numel, mul_ir_in2.dim)) / sqrt(mul_ir_out.ir.dim) elif specialized_code and mul_ir_out.ir.l == 0: result = torch.einsum("zui,zui->zu", x1, x2) / sqrt(mul_ir_in1.ir.dim) else: result = torch.einsum("ijk,zuij->zuk", w3j, xx) if ins.connection_mode == "uvuv": assert mul_ir_in1.mul * mul_ir_in2.mul == mul_ir_out.mul if ins.has_weight: # TODO implement specialized code result = torch.einsum(f"{z}uv,ijk,zuvij->zuvk", w, w3j, xx) else: # TODO implement specialized code result = torch.einsum("ijk,zuvij->zuvk", w3j, xx) if ins.connection_mode == "uvu zwij if ins.has_weight: # TODO implement specialized code result = torch.einsum(f"{z}w,ijk,zwij->zwk", w, w3j, xx) else: # TODO implement specialized code result = torch.einsum("ijk,zwij->zwk", w3j, xx) if ins.connection_mode == "u zqij # TODO implement specialized code result = torch.einsum(f"{z}qw,ijk,zqij->zwk", w, w3j, xx) result = ins.path_weight * result outputs += [result.reshape(batch_numel, mul_ir_out.dim)] # Remove unused w3js: if len(w3j.node.users) == 0: # The w3j nodes are reshapes, so we have to remove them from the graph # Although they are dead code, they try to reshape to dimensions that don't exist # (since the corresponding w3js are not in w3j) # so they screw up the shape propagation, even though they would be removed later as dead code by TorchScript. graph.erase_node(w3j.node) else: if w3j_name not in constants: constants[w3j_name] = wigner_3j(mul_ir_in1.ir.l, mul_ir_in2.ir.l, mul_ir_out.ir.l) # = Return the result = outputs = [ _sum_tensors( [out for ins, out in zip(instructions, outputs) if ins.i_out == i_out], shape=(batch_numel, mul_ir_out.dim), like=x1s, ) for i_out, mul_ir_out in enumerate(irreps_out) if mul_ir_out.mul > 0 ] if len(outputs) > 1: outputs = torch.cat(outputs, dim=1) else: # Avoid an unnecessary copy in a size one torch.cat outputs = outputs[0] outputs = outputs.reshape(output_shape) graph.output(outputs.node, torch.Tensor) # check graphs graph.lint() # Make GraphModules # By putting the constants in a Module rather than a dict, # we force FX to copy them as buffers instead of as attributes. # # FX seems to have resolved this issue for dicts in 1.9, but we support all the way back to 1.8.0. constants_root = torch.nn.Module() for key, value in constants.items(): constants_root.register_buffer(key, value) graphmod = fx.GraphModule(constants_root, graph, class_name="tp_forward") # == Optimize == # TODO: when eliminate_dead_code() is in PyTorch stable, use that if optimize_einsums: # Note that for our einsums, we can optimize _once_ for _any_ batch dimension # and still get the right path for _all_ batch dimensions. # This is because our einsums are essentially of the form: # zuvw,ijk,zuvij->zwk OR uvw,ijk,zuvij->zwk # In the first case, all but one operands have the batch dimension # => The first contraction gains the batch dimension # => All following contractions have batch dimension # => All possible contraction paths have cost that scales linearly in batch size # => The optimal path is the same for all batch sizes # For the second case, this logic follows as long as the first contraction is not between the first two operands. # Since those two operands do not share any indexes, contracting them first is a rare pathological case. See # https://github.com/dgasmith/opt_einsum/issues/158 # for more details. # # TODO: consider the impact maximum intermediate result size on this logic # \- this is the `memory_limit` option in opt_einsum # TODO: allow user to choose opt_einsum parameters? # # We use float32 and zeros to save memory and time, since opt_einsum_fx looks only at traced shapes, not values or # dtypes. batchdim = 4 example_inputs = ( torch.zeros((batchdim, irreps_in1.dim)), torch.zeros((batchdim, irreps_in2.dim)), torch.zeros( 1 if shared_weights else batchdim, flat_weight_index, ), ) graphmod = optimize_einsums_full(graphmod, example_inputs) return graphmod def codegen_tensor_product_right( irreps_in1: Irreps, irreps_in2: Irreps, irreps_out: Irreps, instructions: List[Instruction], shared_weights: bool = False, specialized_code: bool = True, optimize_einsums: bool = True, ) -> fx.GraphModule: graph = fx.Graph() # = Function definitions = tracer = fx.proxy.GraphAppendingTracer(graph) constants = OrderedDict() x2s = fx.Proxy(graph.placeholder("x2", torch.Tensor), tracer=tracer) weights = fx.Proxy(graph.placeholder("w", torch.Tensor), tracer=tracer) if shared_weights: output_shape = x2s.shape[:-1] else: output_shape = torch.broadcast_tensors(x2s[..., 0], weights[..., 0])[0].shape # = Short-circut for zero dimensional = # We produce no code for empty instructions instructions = [ins for ins in instructions if 0 not in ins.path_shape] if len(instructions) == 0: outputs = x2s.new_zeros( output_shape + ( irreps_in1.dim, irreps_out.dim, ) ) graph.output(outputs.node, torch.Tensor) # Short circut return fx.GraphModule({}, graph, "tp_right") # = Broadcast inputs = if not shared_weights: x2s, weights = x2s.broadcast_to(output_shape + (-1,)), weights.broadcast_to(output_shape + (-1,)) output_shape = output_shape + ( irreps_in1.dim, irreps_out.dim, ) x2s = x2s.reshape(-1, irreps_in2.dim) batch_numel = x2s.shape[0] # = Determine number of weights and reshape weights == weight_numel = sum(prod(ins.path_shape) for ins in instructions if ins.has_weight) if weight_numel > 0: weights = weights.reshape(-1, weight_numel) del weight_numel # = book-keeping for wigners = # = extract individual input irreps = # If only one input irrep, can avoid creating a view x2_list = [] # If only one input irrep, can avoid creating a view if len(irreps_in2) == 1: x2_list.append(x2s.reshape(batch_numel, irreps_in2[0].mul, irreps_in2[0].ir.dim)) else: for i, mul_ir in zip(irreps_in2.slices(), irreps_in2): x2_list.append(x2s[:, i].reshape(batch_numel, mul_ir.mul, mul_ir.ir.dim)) # The einsum string index to prepend to the weights if the weights are not shared and have a batch dimension z = "" if shared_weights else "z" # Current index in the flat weight tensor flat_weight_index = 0 outputs = [] for ins in instructions: mul_ir_in1 = irreps_in1[ins.i_in1] mul_ir_in2 = irreps_in2[ins.i_in2] mul_ir_out = irreps_out[ins.i_out] assert mul_ir_in1.ir.p * mul_ir_in2.ir.p == mul_ir_out.ir.p assert abs(mul_ir_in1.ir.l - mul_ir_in2.ir.l) <= mul_ir_out.ir.l <= mul_ir_in1.ir.l + mul_ir_in2.ir.l if mul_ir_in1.dim == 0 or mul_ir_in2.dim == 0 or mul_ir_out.dim == 0: continue x2 = x2_list[ins.i_in2] e1 = fx.Proxy( graph.call_function(torch.eye, (mul_ir_in1.mul,), dict(dtype=x2s.dtype.node, device=x2s.device.node)), tracer=tracer, ) e2 = fx.Proxy( graph.call_function(torch.eye, (mul_ir_in2.mul,), dict(dtype=x2s.dtype.node, device=x2s.device.node)), tracer=tracer, ) i1 = fx.Proxy( graph.call_function(torch.eye, (mul_ir_in1.ir.dim,), dict(dtype=x2s.dtype.node, device=x2s.device.node)), tracer=tracer, ) assert ins.connection_mode in ["uvw", "uvu", "uvv", "uuw", "uuu", "uvuv", "uvuzuw", w, x2.reshape(batch_numel, mul_ir_in2.dim)) elif specialized_code and mul_ir_in1.ir.l == 0: result = torch.einsum(f"{z}uvw,zvi->zuwi", w, x2) / sqrt(mul_ir_out.ir.dim) elif specialized_code and mul_ir_in2.ir.l == 0: result = torch.einsum(f"{z}uvw,ij,zv->zuiwj", w, i1, x2.reshape(batch_numel, mul_ir_in2.dim)) / sqrt( mul_ir_out.ir.dim ) elif specialized_code and mul_ir_out.ir.l == 0: result = torch.einsum(f"{z}uvw,zvi->zuiw", w, x2) / sqrt(mul_ir_in1.ir.dim) else: result = torch.einsum(f"{z}uvw,ijk,zvj->zuiwk", w, w3j, x2) if ins.connection_mode == "uvu": assert mul_ir_in1.mul == mul_ir_out.mul if ins.has_weight: if specialized_code and (mul_ir_in1.ir.l, mul_ir_in2.ir.l, mul_ir_out.ir.l) == (0, 0, 0): result = torch.einsum(f"{z}uv,uw,zv->zuw", w, e1, x2.reshape(batch_numel, mul_ir_in2.dim)) elif specialized_code and mul_ir_in1.ir.l == 0: result = torch.einsum(f"{z}uv,uw,zvi->zuwi", w, e1, x2) / sqrt(mul_ir_out.ir.dim) elif specialized_code and mul_ir_in2.ir.l == 0: result = torch.einsum(f"{z}uv,ij,uw,zv->zuiwj", w, i1, e1, x2.reshape(batch_numel, mul_ir_in2.dim)) / sqrt( mul_ir_out.ir.dim ) elif specialized_code and mul_ir_out.ir.l == 0: result = torch.einsum(f"{z}uv,uw,zvi->zuiw", w, e1, x2) / sqrt(mul_ir_in1.ir.dim) else: result = torch.einsum(f"{z}uv,ijk,uw,zvj->zuiwk", w, w3j, e1, x2) else: # not so useful operation because v is summed result = torch.einsum("ijk,uw,zvj->zuiwk", w3j, e1, x2) if ins.connection_mode == "uvv": assert mul_ir_in2.mul == mul_ir_out.mul if ins.has_weight: if specialized_code and (mul_ir_in1.ir.l, mul_ir_in2.ir.l, mul_ir_out.ir.l) == (0, 0, 0): result = torch.einsum(f"{z}uv,vw,zv->zuw", w, e2, x2.reshape(batch_numel, mul_ir_in2.dim)) elif specialized_code and mul_ir_in1.ir.l == 0: result = torch.einsum(f"{z}uv,vw,zvi->zuwi", w, e2, x2) / sqrt(mul_ir_out.ir.dim) elif specialized_code and mul_ir_in2.ir.l == 0: result = torch.einsum(f"{z}uv,ij,vw,zv->zuiwj", w, i1, e2, x2.reshape(batch_numel, mul_ir_in2.dim)) / sqrt( mul_ir_out.ir.dim ) elif specialized_code and mul_ir_out.ir.l == 0: result = torch.einsum(f"{z}uv,vw,zvi->zuiw", w, e2, x2) / sqrt(mul_ir_in1.ir.dim) else: result = torch.einsum(f"{z}uv,ijk,zvj->zuivk", w, w3j, x2) else: # not so useful operation because u is summed # only specialize out for this path s2ones = fx.Proxy( graph.call_function(torch.ones, (mul_ir_in1.mul,), dict(device=x2.device.node, dtype=x2.dtype.node)), tracer=tracer, ) result = torch.einsum("u,ijk,zvj->zuivk", s2ones, w3j, x2) if ins.connection_mode == "uuw": assert mul_ir_in1.mul == mul_ir_in2.mul if ins.has_weight: # TODO: specialize right() result = torch.einsum(f"{z}uw,ijk,zuj->zuiwk", w, w3j, x2) else: # equivalent to tp(x, y, 'uuu').sum('u') assert mul_ir_out.mul == 1 result = torch.einsum("ijk,zuj->zuik", w3j, x2) if ins.connection_mode == "uuu": assert mul_ir_in1.mul == mul_ir_in2.mul == mul_ir_out.mul if ins.has_weight: if specialized_code and (mul_ir_in1.ir.l, mul_ir_in2.ir.l, mul_ir_out.ir.l) == (0, 0, 0): result = torch.einsum(f"{z}u,uw,zu->zuw", w, e2, x2.reshape(batch_numel, mul_ir_in2.dim)) elif specialized_code and (mul_ir_in1.ir.l, mul_ir_in2.ir.l, mul_ir_out.ir.l) == (1, 1, 1): # For cross product, use the general case right() result = torch.einsum(f"{z}u,ijk,uw,zuj->zuiwk", w, w3j, e1, x2) elif specialized_code and mul_ir_in1.ir.l == 0: result = torch.einsum(f"{z}u,uw,zui->zuwi", w, e2, x2) / sqrt(mul_ir_out.ir.dim) elif specialized_code and mul_ir_in2.ir.l == 0: result = torch.einsum(f"{z}u,ij,uw,zu->zuiwj", w, i1, e2, x2.reshape(batch_numel, mul_ir_in2.dim)) / sqrt( mul_ir_out.ir.dim ) elif specialized_code and mul_ir_out.ir.l == 0: result = torch.einsum(f"{z}u,uw,zui->zuiw", w, e2, x2) / sqrt(mul_ir_in1.ir.dim) else: result = torch.einsum(f"{z}u,ijk,uw,zuj->zuiwk", w, w3j, e1, x2) else: if specialized_code and (mul_ir_in1.ir.l, mul_ir_in2.ir.l, mul_ir_out.ir.l) == (0, 0, 0): result = torch.einsum("uw,zu->zuw", e2, x2.reshape(batch_numel, mul_ir_in2.dim)) elif specialized_code and (mul_ir_in1.ir.l, mul_ir_in2.ir.l, mul_ir_out.ir.l) == (1, 1, 1): # For cross product, use the general case right() result = torch.einsum("ijk,uw,zuj->zuiwk", w3j, e1, x2) elif specialized_code and mul_ir_in1.ir.l == 0: result = torch.einsum("uw,zui->zuwi", e2, x2) / sqrt(mul_ir_out.ir.dim) elif specialized_code and mul_ir_in2.ir.l == 0: result = torch.einsum("ij,uw,zu->zuiwj", i1, e2, x2.reshape(batch_numel, mul_ir_in2.dim)) / sqrt( mul_ir_out.ir.dim ) elif specialized_code and mul_ir_out.ir.l == 0: result = torch.einsum("uw,zui->zuiw", e2, x2) / sqrt(mul_ir_in1.ir.dim) else: result = torch.einsum("ijk,uw,zuj->zuiwk", w3j, e1, x2) if ins.connection_mode == "uvuv": assert mul_ir_in1.mul * mul_ir_in2.mul == mul_ir_out.mul if ins.has_weight: # TODO implement specialized code result = torch.einsum(f"{z}uv,ijk,uw,zvj->zuiwvk", w, w3j, e1, x2) else: # TODO implement specialized code result = torch.einsum("ijk,uw,zvj->zuiwvk", w3j, e1, x2) if ins.connection_mode == "uvu 0 ], dim=2, ) for i_in1, mul_ir_in1 in enumerate(irreps_in1) if mul_ir_in1.mul > 0 ] if len(outputs) > 1: outputs = torch.cat(outputs, dim=1) else: outputs = outputs[0] outputs = outputs.reshape(output_shape) graph.output(outputs.node, torch.Tensor) # check graphs graph.lint() # Make GraphModules # By putting the constants in a Module rather than a dict, # we force FX to copy them as buffers instead of as attributes. # # FX seems to have resolved this issue for dicts in 1.9, but we support all the way back to 1.8.0. constants_root = torch.nn.Module() for key, value in constants.items(): constants_root.register_buffer(key, value) graphmod = fx.GraphModule(constants_root, graph, class_name="tp_right") # == Optimize == # TODO: when eliminate_dead_code() is in PyTorch stable, use that if optimize_einsums: # Note that for our einsums, we can optimize _once_ for _any_ batch dimension # and still get the right path for _all_ batch dimensions. # This is because our einsums are essentially of the form: # zuvw,ijk,zuvij->zwk OR uvw,ijk,zuvij->zwk # In the first case, all but one operands have the batch dimension # => The first contraction gains the batch dimension # => All following contractions have batch dimension # => All possible contraction paths have cost that scales linearly in batch size # => The optimal path is the same for all batch sizes # For the second case, this logic follows as long as the first contraction is not between the first two operands. # Since those two operands do not share any indexes, contracting them first is a rare pathological case. See # https://github.com/dgasmith/opt_einsum/issues/158 # for more details. # # TODO: consider the impact maximum intermediate result size on this logic # \- this is the `memory_limit` option in opt_einsum # TODO: allow user to choose opt_einsum parameters? # # We use float32 and zeros to save memory and time, since opt_einsum_fx looks only at traced shapes, not values or # dtypes. batchdim = 4 example_inputs = ( torch.zeros((batchdim, irreps_in1.dim)), torch.zeros((batchdim, irreps_in2.dim)), torch.zeros( 1 if shared_weights else batchdim, flat_weight_index, ), ) graphmod = optimize_einsums_full(graphmod, example_inputs[1:]) return graphmod e3nn-0.6.0/e3nn/o3/_tensor_product/_instruction.py000066400000000000000000000003071514371756200220620ustar00rootroot00000000000000from typing import NamedTuple class Instruction(NamedTuple): i_in1: int i_in2: int i_out: int connection_mode: str has_weight: bool path_weight: float path_shape: tuple e3nn-0.6.0/e3nn/o3/_tensor_product/_sub.py000066400000000000000000000326141514371756200203000ustar00rootroot00000000000000from typing import Iterator, Optional import torch from e3nn.o3._irreps import Irrep, Irreps from e3nn.util import prod from ._tensor_product import TensorProduct class FullyConnectedTensorProduct(TensorProduct): r"""Fully-connected weighted tensor product All the possible path allowed by :math:`|l_1 - l_2| \leq l_{out} \leq l_1 + l_2` are made. The output is a sum on different paths: .. math:: z_w = \sum_{u,v} w_{uvw} x_u \otimes y_v + \cdots \text{other paths} where :math:`u,v,w` are the indices of the multiplicities. Parameters ---------- irreps_in1 : `e3nn.o3.Irreps` representation of the first input irreps_in2 : `e3nn.o3.Irreps` representation of the second input irreps_out : `e3nn.o3.Irreps` representation of the output irrep_normalization : {'component', 'norm'} see `e3nn.o3.TensorProduct` path_normalization : {'element', 'path'} see `e3nn.o3.TensorProduct` internal_weights : bool see `e3nn.o3.TensorProduct` shared_weights : bool see `e3nn.o3.TensorProduct` """ def __init__( self, irreps_in1, irreps_in2, irreps_out, irrep_normalization: str = None, path_normalization: str = None, **kwargs ) -> None: irreps_in1 = Irreps(irreps_in1) irreps_in2 = Irreps(irreps_in2) irreps_out = Irreps(irreps_out) instr = [ (i_1, i_2, i_out, "uvw", True, 1.0) for i_1, (_, ir_1) in enumerate(irreps_in1) for i_2, (_, ir_2) in enumerate(irreps_in2) for i_out, (_, ir_out) in enumerate(irreps_out) if ir_out in ir_1 * ir_2 ] super().__init__( irreps_in1, irreps_in2, irreps_out, instr, irrep_normalization=irrep_normalization, path_normalization=path_normalization, **kwargs, ) class ElementwiseTensorProduct(TensorProduct): r"""Elementwise connected tensor product. .. math:: z_u = x_u \otimes y_u where :math:`u` runs over the irreps. Note that there are no weights. The output representation is determined by the two input representations. Parameters ---------- irreps_in1 : `e3nn.o3.Irreps` representation of the first input irreps_in2 : `e3nn.o3.Irreps` representation of the second input filter_ir_out : iterator of `e3nn.o3.Irrep`, optional filter to select only specific `e3nn.o3.Irrep` of the output irrep_normalization : {'component', 'norm'} see `e3nn.o3.TensorProduct` Examples -------- Elementwise scalar product >>> ElementwiseTensorProduct("5x1o + 5x1e", "10x1e", ["0e", "0o"]) ElementwiseTensorProduct(5x1o+5x1e x 10x1e -> 5x0o+5x0e | 10 paths | 0 weights) """ def __init__(self, irreps_in1, irreps_in2, filter_ir_out=None, irrep_normalization: str = None, **kwargs) -> None: irreps_in1 = Irreps(irreps_in1).simplify() irreps_in2 = Irreps(irreps_in2).simplify() if filter_ir_out is not None: try: filter_ir_out = [Irrep(ir) for ir in filter_ir_out] except ValueError: raise ValueError(f"filter_ir_out (={filter_ir_out}) must be an iterable of e3nn.o3.Irrep") assert irreps_in1.num_irreps == irreps_in2.num_irreps irreps_in1 = list(irreps_in1) irreps_in2 = list(irreps_in2) i = 0 while i < len(irreps_in1): mul_1, ir_1 = irreps_in1[i] mul_2, ir_2 = irreps_in2[i] if mul_1 < mul_2: irreps_in2[i] = (mul_1, ir_2) irreps_in2.insert(i + 1, (mul_2 - mul_1, ir_2)) if mul_2 < mul_1: irreps_in1[i] = (mul_2, ir_1) irreps_in1.insert(i + 1, (mul_1 - mul_2, ir_1)) i += 1 out = [] instr = [] for i, ((mul, ir_1), (mul_2, ir_2)) in enumerate(zip(irreps_in1, irreps_in2)): assert mul == mul_2 for ir in ir_1 * ir_2: if filter_ir_out is not None and ir not in filter_ir_out: continue i_out = len(out) out.append((mul, ir)) instr += [(i, i, i_out, "uuu", False)] super().__init__(irreps_in1, irreps_in2, out, instr, irrep_normalization=irrep_normalization, **kwargs) class FullTensorProduct(TensorProduct): r"""Full tensor product between two irreps. .. math:: z_{uv} = x_u \otimes y_v where :math:`u` and :math:`v` run over the irreps. Note that there are no weights. The output representation is determined by the two input representations. Parameters ---------- irreps_in1 : `e3nn.o3.Irreps` representation of the first input irreps_in2 : `e3nn.o3.Irreps` representation of the second input filter_ir_out : iterator of `e3nn.o3.Irrep`, optional filter to select only specific `e3nn.o3.Irrep` of the output irrep_normalization : {'component', 'norm'} see `e3nn.o3.TensorProduct` """ def __init__( self, irreps_in1: Irreps, irreps_in2: Irreps, filter_ir_out: Iterator[Irrep] = None, irrep_normalization: str = None, **kwargs, ) -> None: irreps_in1 = Irreps(irreps_in1).simplify() irreps_in2 = Irreps(irreps_in2).simplify() if filter_ir_out is not None: try: filter_ir_out = [Irrep(ir) for ir in filter_ir_out] except ValueError: raise ValueError(f"filter_ir_out (={filter_ir_out}) must be an iterable of e3nn.o3.Irrep") out = [] instr = [] for i_1, (mul_1, ir_1) in enumerate(irreps_in1): for i_2, (mul_2, ir_2) in enumerate(irreps_in2): for ir_out in ir_1 * ir_2: if filter_ir_out is not None and ir_out not in filter_ir_out: continue i_out = len(out) out.append((mul_1 * mul_2, ir_out)) instr += [(i_1, i_2, i_out, "uvuv", False)] out = Irreps(out) out, p, _ = out.sort() instr = [(i_1, i_2, p[i_out], mode, train) for i_1, i_2, i_out, mode, train in instr] super().__init__(irreps_in1, irreps_in2, out, instr, irrep_normalization=irrep_normalization, **kwargs) def _square_instructions_full(irreps_in, filter_ir_out=None, irrep_normalization=None): """Generate instructions for square tensor product. Parameters ---------- irreps_in : `e3nn.o3.Irreps` representation of the input filter_ir_out : iterator of `e3nn.o3.Irrep`, optional filter to select only specific `e3nn.o3.Irrep` of the output irrep_normalization : {'component', 'norm', 'none'} see `e3nn.o3.TensorProduct` Returns ------- irreps_out : `e3nn.o3.Irreps` representation of the output instr : list of tuple list of instructions """ # pylint: disable=too-many-nested-blocks irreps_out = [] instr = [] for i_1, (mul_1, ir_1) in enumerate(irreps_in): for i_2, (mul_2, ir_2) in enumerate(irreps_in): for ir_out in ir_1 * ir_2: if filter_ir_out is not None and ir_out not in filter_ir_out: continue if irrep_normalization == "component": alpha = ir_out.dim if irrep_normalization == "norm": alpha = ir_1.dim * ir_2.dim if irrep_normalization == "none": alpha = 1 if i_1 < i_2: i_out = len(irreps_out) irreps_out.append((mul_1 * mul_2, ir_out)) instr += [(i_1, i_2, i_out, "uvuv", False, alpha)] elif i_1 == i_2: i = i_1 mul = mul_1 if mul > 1: i_out = len(irreps_out) irreps_out.append((mul * (mul - 1) // 2, ir_out)) instr += [(i, i, i_out, "uvu 1: instr += [(i, i, i_out, "u None: if irrep_normalization is None: irrep_normalization = "component" assert irrep_normalization in ["component", "norm", "none"] irreps_in = Irreps(irreps_in).simplify() if filter_ir_out is not None: try: filter_ir_out = [Irrep(ir) for ir in filter_ir_out] except ValueError as exc: raise ValueError(f"Error constructing filter_ir_out irrep: {exc}") from exc if irreps_out is None: irreps_out, instr = _square_instructions_full(irreps_in, filter_ir_out, irrep_normalization) else: if filter_ir_out is not None: raise ValueError("Both `irreps_out` and `filter_ir_out` are not None, this is ambiguous.") irreps_out = Irreps(irreps_out).simplify() instr = _square_instructions_fully_connected(irreps_in, irreps_out, irrep_normalization) self.irreps_in = irreps_in super().__init__(irreps_in, irreps_in, irreps_out, instr, irrep_normalization="none", **kwargs) def __repr__(self) -> str: npath = sum(prod(i.path_shape) for i in self.instructions) return ( f"{self.__class__.__name__}" f"({self.irreps_in} " f"-> {self.irreps_out.simplify()} | {npath} paths | {self.weight_numel} weights)" ) def forward(self, x, weight: Optional[torch.Tensor] = None): # pylint: disable=arguments-differ return super().forward(x, x, weight) e3nn-0.6.0/e3nn/o3/_tensor_product/_tensor_product.py000066400000000000000000000737131514371756200225660ustar00rootroot00000000000000from math import sqrt from typing import List, Optional, Union, Any, Callable import warnings import torch from torch import fx import e3nn from e3nn.o3._irreps import Irreps from e3nn.util import prod from e3nn.util.codegen import CodeGenMixin from e3nn.util.jit import compile_mode from ._codegen import codegen_tensor_product_left_right, codegen_tensor_product_right from ._instruction import Instruction # A list, in order of priority, of codegen providers for the tensor product. # If a provider does not support the parameters it is given, it should # return `None`, in which case the next provider in the list will be tried. _CODEGEN_PROVIDERS_LEFT_RIGHT: List[Callable] = [codegen_tensor_product_left_right] _CODEGEN_PROVIDERS_RIGHT: List[Callable] = [codegen_tensor_product_right] @compile_mode("script") class TensorProduct(CodeGenMixin, torch.nn.Module): r"""Tensor product with parametrized paths. Parameters ---------- irreps_in1 : `e3nn.o3.Irreps` Irreps for the first input. irreps_in2 : `e3nn.o3.Irreps` Irreps for the second input. irreps_out : `e3nn.o3.Irreps` Irreps for the output. instructions : list of tuple List of instructions ``(i_1, i_2, i_out, mode, train[, path_weight])``. Each instruction puts ``in1[i_1]`` :math:`\otimes` ``in2[i_2]`` into ``out[i_out]``. * ``mode``: `str`. Determines the way the multiplicities are treated, ``"uvw"`` is fully connected. Other valid options are: ``'uvw'``, ``'uvu'``, ``'uvv'``, ``'uuw'``, ``'uuu'``, and ``'uvuv'``. * ``train``: `bool`. `True` if this path should have learnable weights, otherwise `False`. * ``path_weight``: `float`. A fixed multiplicative weight to apply to the output of this path. Defaults to 1. Note that setting ``path_weight`` breaks the normalization derived from ``in1_var``/``in2_var``/``out_var``. in1_var : list of float, Tensor, or None Variance for each irrep in ``irreps_in1``. If ``None``, all default to ``1.0``. in2_var : list of float, Tensor, or None Variance for each irrep in ``irreps_in2``. If ``None``, all default to ``1.0``. out_var : list of float, Tensor, or None Variance for each irrep in ``irreps_out``. If ``None``, all default to ``1.0``. irrep_normalization : {'component', 'norm'} The assumed normalization of the input and output representations. If it is set to "norm": .. math:: \| x \| = \| y \| = 1 \Longrightarrow \| x \otimes y \| = 1 path_normalization : {'element', 'path'} If set to ``element``, each output is normalized by the total number of elements (independently of their paths). If it is set to ``path``, each path is normalized by the total number of elements in the path, then each output is normalized by the number of paths. internal_weights : bool whether the `e3nn.o3.TensorProduct` contains its learnable weights as a parameter shared_weights : bool whether the learnable weights are shared among the input's extra dimensions * `True` :math:`z_i = w x_i \otimes y_i` * `False` :math:`z_i = w_i x_i \otimes y_i` where here :math:`i` denotes a *batch-like* index. ``shared_weights`` cannot be `False` if ``internal_weights`` is `True`. compile_left_right : bool whether to compile the forward function, true by default compile_right : bool whether to compile the ``.right`` function, false by default Examples -------- Create a module that computes elementwise the cross-product of 16 vectors with 16 vectors :math:`z_u = x_u \wedge y_u` >>> module = TensorProduct( ... "16x1o", "16x1o", "16x1e", ... [ ... (0, 0, 0, "uuu", False) ... ] ... ) Now mix all 16 vectors with all 16 vectors to makes 16 pseudo-vectors :math:`z_w = \sum_{u,v} w_{uvw} x_u \wedge y_v` >>> module = TensorProduct( ... [(16, (1, -1))], ... [(16, (1, -1))], ... [(16, (1, 1))], ... [ ... (0, 0, 0, "uvw", True) ... ] ... ) With custom input variance and custom path weights: >>> module = TensorProduct( ... "8x0o + 8x1o", ... "16x1o", ... "16x1e", ... [ ... (0, 0, 0, "uvw", True, 3), ... (1, 0, 0, "uvw", True, 1), ... ], ... in2_var=[1/16] ... ) Example of a dot product: >>> irreps = o3.Irreps("3x0e + 4x0o + 1e + 2o + 3o") >>> module = TensorProduct(irreps, irreps, "0e", [ ... (i, i, 0, 'uuw', False) ... for i, (mul, ir) in enumerate(irreps) ... ]) Implement :math:`z_u = x_u \otimes (\sum_v w_{uv} y_v)` >>> module = TensorProduct( ... "8x0o + 7x1o + 3x2e", ... "10x0e + 10x1e + 10x2e", ... "8x0o + 7x1o + 3x2e", ... [ ... # paths for the l=0: ... (0, 0, 0, "uvu", True), # 0x0->0 ... # paths for the l=1: ... (1, 0, 1, "uvu", True), # 1x0->1 ... (1, 1, 1, "uvu", True), # 1x1->1 ... (1, 2, 1, "uvu", True), # 1x2->1 ... # paths for the l=2: ... (2, 0, 2, "uvu", True), # 2x0->2 ... (2, 1, 2, "uvu", True), # 2x1->2 ... (2, 2, 2, "uvu", True), # 2x2->2 ... ] ... ) Tensor Product using the xavier uniform initialization: >>> irreps_1 = o3.Irreps("5x0e + 10x1o + 1x2e") >>> irreps_2 = o3.Irreps("5x0e + 10x1o + 1x2e") >>> irreps_out = o3.Irreps("5x0e + 10x1o + 1x2e") >>> # create a Fully Connected Tensor Product >>> module = o3.TensorProduct( ... irreps_1, ... irreps_2, ... irreps_out, ... [ ... (i_1, i_2, i_out, "uvw", True, mul_1 * mul_2) ... for i_1, (mul_1, ir_1) in enumerate(irreps_1) ... for i_2, (mul_2, ir_2) in enumerate(irreps_2) ... for i_out, (mul_out, ir_out) in enumerate(irreps_out) ... if ir_out in ir_1 * ir_2 ... ] ... ) >>> with torch.no_grad(): ... for weight in module.weight_views(): ... mul_1, mul_2, mul_out = weight.shape ... # formula from torch.nn.init.xavier_uniform_ ... a = (6 / (mul_1 * mul_2 + mul_out))**0.5 ... new_weight = torch.empty_like(weight) ... new_weight.uniform_(-a, a) ... weight[:] = new_weight tensor(...) >>> n = 1_000 >>> vars = module(irreps_1.randn(n, -1), irreps_2.randn(n, -1)).var(0) >>> assert vars.min() > 1 / 3 >>> assert vars.max() < 3 """ instructions: List[Any] shared_weights: bool internal_weights: bool weight_numel: int _did_compile_right: bool _specialized_code: bool _optimize_einsums: bool _profiling_str: str _in1_dim: int _in2_dim: int def __init__( self, irreps_in1: Irreps, irreps_in2: Irreps, irreps_out: Irreps, instructions: List[tuple], in1_var: Optional[Union[List[float], torch.Tensor]] = None, in2_var: Optional[Union[List[float], torch.Tensor]] = None, out_var: Optional[Union[List[float], torch.Tensor]] = None, irrep_normalization: str = None, path_normalization: str = None, internal_weights: Optional[bool] = None, shared_weights: Optional[bool] = None, compile_left_right: bool = True, compile_right: bool = False, normalization=None, # for backward compatibility _specialized_code: Optional[bool] = None, _optimize_einsums: Optional[bool] = None, ) -> None: # === Setup === super().__init__() if normalization is not None: warnings.warn("`normalization` is deprecated. Use `irrep_normalization` instead.", DeprecationWarning) irrep_normalization = normalization if irrep_normalization is None: irrep_normalization = "component" if path_normalization is None: path_normalization = "element" assert irrep_normalization in ["component", "norm", "none"] assert path_normalization in ["element", "path", "none"] self.irreps_in1 = Irreps(irreps_in1) self.irreps_in2 = Irreps(irreps_in2) self.irreps_out = Irreps(irreps_out) del irreps_in1, irreps_in2, irreps_out instructions = [x if len(x) == 6 else x + (1.0,) for x in instructions] instructions = [ Instruction( i_in1=i_in1, i_in2=i_in2, i_out=i_out, connection_mode=connection_mode, has_weight=has_weight, path_weight=path_weight, path_shape={ "uvw": (self.irreps_in1[i_in1].mul, self.irreps_in2[i_in2].mul, self.irreps_out[i_out].mul), "uvu": (self.irreps_in1[i_in1].mul, self.irreps_in2[i_in2].mul), "uvv": (self.irreps_in1[i_in1].mul, self.irreps_in2[i_in2].mul), "uuw": (self.irreps_in1[i_in1].mul, self.irreps_out[i_out].mul), "uuu": (self.irreps_in1[i_in1].mul,), "uvuv": (self.irreps_in1[i_in1].mul, self.irreps_in2[i_in2].mul), "uvu 0.0: alpha /= x alpha *= out_var[ins.i_out] alpha *= ins.path_weight normalization_coefficients += [sqrt(alpha)] self.instructions = [ Instruction(ins.i_in1, ins.i_in2, ins.i_out, ins.connection_mode, ins.has_weight, alpha, ins.path_shape) for ins, alpha in zip(instructions, normalization_coefficients) ] self._in1_dim = self.irreps_in1.dim self._in2_dim = self.irreps_in2.dim if shared_weights is False and internal_weights is None: internal_weights = False if shared_weights is None: shared_weights = True if internal_weights is None: internal_weights = shared_weights and any(i.has_weight for i in self.instructions) assert shared_weights or not internal_weights self.internal_weights = internal_weights self.shared_weights = shared_weights opt_defaults = e3nn.get_optimization_defaults() self._specialized_code = _specialized_code if _specialized_code is not None else opt_defaults["specialized_code"] self._optimize_einsums = _optimize_einsums if _optimize_einsums is not None else opt_defaults["optimize_einsums"] del opt_defaults # Generate the actual tensor product code if compile_left_right: for codegen in _CODEGEN_PROVIDERS_LEFT_RIGHT: graphmod_left_right = codegen( self.irreps_in1, self.irreps_in2, self.irreps_out, self.instructions, self.shared_weights, self._specialized_code, self._optimize_einsums, ) if graphmod_left_right is not None: break assert graphmod_left_right is not None else: graphmod_left_right = fx.Graph() graphmod_left_right.placeholder("x1", torch.Tensor) graphmod_left_right.placeholder("x2", torch.Tensor) graphmod_left_right.placeholder("w", torch.Tensor) graphmod_left_right.call_function( torch._assert, args=( False, "`left_right` method is not compiled, set `compile_left_right` to True when creating the TensorProduct", ), ) graphmod_left_right = fx.GraphModule(torch.nn.Module(), graphmod_left_right, class_name="tp_forward") if compile_right: for codegen in _CODEGEN_PROVIDERS_RIGHT: graphmod_right = codegen( self.irreps_in1, self.irreps_in2, self.irreps_out, self.instructions, self.shared_weights, self._specialized_code, self._optimize_einsums, ) if graphmod_right is not None: break assert graphmod_right is not None else: graphmod_right = fx.Graph() tmp = graphmod_right.placeholder("x2", torch.Tensor) # Make a dummy no-op graph, it can't be empty or causes IndentationError on unpickle graphmod_right.placeholder("w", torch.Tensor) graphmod_right.output(tmp) del tmp graphmod_right = fx.GraphModule(torch.nn.Module(), graphmod_right, class_name="tp_forward") self._did_compile_right = compile_right self._codegen_register({"_compiled_main_left_right": graphmod_left_right, "_compiled_main_right": graphmod_right}) # === Determine weights === self.weight_numel = sum(prod(ins.path_shape) for ins in self.instructions if ins.has_weight) if internal_weights and self.weight_numel > 0: assert self.shared_weights, "Having internal weights impose shared weights" self.weight = torch.nn.Parameter(torch.randn(self.weight_numel)) else: # For TorchScript, there always has to be some kind of defined .weight self.register_buffer("weight", torch.Tensor()) if self.irreps_out.dim > 0: output_mask = torch.cat( [ ( torch.ones(mul * ir.dim) if any( (ins.i_out == i_out) and (ins.path_weight != 0) and (0 not in ins.path_shape) for ins in self.instructions ) else torch.zeros(mul * ir.dim) ) for i_out, (mul, ir) in enumerate(self.irreps_out) ] ) else: output_mask = torch.ones(0) self.register_buffer("output_mask", output_mask) # For TorchScript, this needs to be done in advance: self._profiling_str = str(self) def __repr__(self) -> str: npath = sum(prod(i.path_shape) for i in self.instructions) return ( f"{self.__class__.__name__}" f"({self.irreps_in1.simplify()} x {self.irreps_in2.simplify()} " f"-> {self.irreps_out.simplify()} | {npath} paths | {self.weight_numel} weights)" ) @torch.jit.unused def _prep_weights_python(self, weight: Optional[Union[torch.Tensor, List[torch.Tensor]]]) -> Optional[torch.Tensor]: if isinstance(weight, list): weight_shapes = [ins.path_shape for ins in self.instructions if ins.has_weight] if not self.shared_weights: weight = [w.reshape(-1, prod(shape)) for w, shape in zip(weight, weight_shapes)] else: weight = [w.reshape(prod(shape)) for w, shape in zip(weight, weight_shapes)] return torch.cat(weight, dim=-1) else: return weight def _get_weights(self, weight: Optional[torch.Tensor]) -> torch.Tensor: if not torch.jit.is_scripting(): # If we're not scripting, then we're in Python and `weight` could be a List[Tensor] # deal with that: weight = self._prep_weights_python(weight) if weight is None: if self.weight_numel > 0 and not self.internal_weights: raise RuntimeError("Weights must be provided when the TensorProduct does not have `internal_weights`") return self.weight else: if self.shared_weights: torch._assert(weight.shape == (self.weight_numel,), "Invalid weight shape") else: torch._assert(weight.shape[-1] == self.weight_numel, "Invalid weight shape") torch._assert(weight.ndim > 1, "When shared weights is false, weights must have batch dimension") return weight @torch.jit.export def right(self, y, weight: Optional[torch.Tensor] = None): r"""Partially evaluate :math:`w x \otimes y`. It returns an operator in the form of a tensor that can act on an arbitrary :math:`x`. For example, if the tensor product above is expressed as .. math:: w_{ijk} x_i y_j \rightarrow z_k then the right method returns a tensor :math:`b_{ik}` such that .. math:: w_{ijk} y_j \rightarrow b_{ik} .. math:: x_i b_{ik} \rightarrow z_k The result of this method can be applied with a tensor contraction: .. code-block:: python torch.einsum("...ik,...i->...k", right, input) Parameters ---------- y : `torch.Tensor` tensor of shape ``(..., irreps_in2.dim)`` weight : `torch.Tensor` or list of `torch.Tensor`, optional required if ``internal_weights`` is ``False`` tensor of shape ``(self.weight_numel,)`` if ``shared_weights`` is ``True`` tensor of shape ``(..., self.weight_numel)`` if ``shared_weights`` is ``False`` or list of tensors of shapes ``weight_shape`` / ``(...) + weight_shape``. Use ``self.instructions`` to know what are the weights used for. Returns ------- `torch.Tensor` tensor of shape ``(..., irreps_in1.dim, irreps_out.dim)`` """ torch._assert( self._did_compile_right, "`right` method is not compiled, set `compile_right` to True when creating the TensorProduct", ) torch._assert(y.shape[-1] == self._in2_dim, "Incorrect last dimension for y") # - PROFILER - with torch.autograd.profiler.record_function(self._profiling_str): real_weight = self._get_weights(weight) return self._compiled_main_right(y, real_weight) def forward(self, x, y, weight: Optional[torch.Tensor] = None): r"""Evaluate :math:`w x \otimes y`. Parameters ---------- x : `torch.Tensor` tensor of shape ``(..., irreps_in1.dim)`` y : `torch.Tensor` tensor of shape ``(..., irreps_in2.dim)`` weight : `torch.Tensor` or list of `torch.Tensor`, optional required if ``internal_weights`` is ``False`` tensor of shape ``(self.weight_numel,)`` if ``shared_weights`` is ``True`` tensor of shape ``(..., self.weight_numel)`` if ``shared_weights`` is ``False`` or list of tensors of shapes ``weight_shape`` / ``(...) + weight_shape``. Use ``self.instructions`` to know what are the weights used for. Returns ------- `torch.Tensor` tensor of shape ``(..., irreps_out.dim)`` """ torch._assert(x.shape[-1] == self._in1_dim, "Incorrect last dimension for x") torch._assert(y.shape[-1] == self._in2_dim, "Incorrect last dimension for y") # - PROFILER - with torch.autograd.profiler.record_function(self._profiling_str): real_weight = self._get_weights(weight) return self._compiled_main_left_right(x, y, real_weight) def weight_view_for_instruction(self, instruction: int, weight: Optional[torch.Tensor] = None) -> torch.Tensor: r"""View of weights corresponding to ``instruction``. Parameters ---------- instruction : int The index of the instruction to get a view on the weights for. ``self.instructions[instruction].has_weight`` must be ``True``. weight : `torch.Tensor`, optional like ``weight`` argument to ``forward()`` Returns ------- `torch.Tensor` A view on ``weight`` or this object's internal weights for the weights corresponding to the ``instruction`` th instruction. """ if not self.instructions[instruction].has_weight: raise ValueError(f"Instruction {instruction} has no weights.") offset = sum(prod(ins.path_shape) for ins in self.instructions[:instruction]) ins = self.instructions[instruction] weight = self._get_weights(weight) batchshape = weight.shape[:-1] return weight.narrow(-1, offset, prod(ins.path_shape)).view(batchshape + ins.path_shape) def weight_views(self, weight: Optional[torch.Tensor] = None, yield_instruction: bool = False): r"""Iterator over weight views for each weighted instruction. Parameters ---------- weight : `torch.Tensor`, optional like ``weight`` argument to ``forward()`` yield_instruction : `bool`, default False Whether to also yield the corresponding instruction. Yields ------ If ``yield_instruction`` is ``True``, yields ``(instruction_index, instruction, weight_view)``. Otherwise, yields ``weight_view``. """ weight = self._get_weights(weight) batchshape = weight.shape[:-1] offset = 0 for ins_i, ins in enumerate(self.instructions): if ins.has_weight: flatsize = prod(ins.path_shape) this_weight = weight.narrow(-1, offset, flatsize).view(batchshape + ins.path_shape) offset += flatsize if yield_instruction: yield ins_i, ins, this_weight else: yield this_weight def visualize( self, weight: Optional[torch.Tensor] = None, plot_weight: bool = True, aspect_ratio=1, ax=None ): # pragma: no cover r"""Visualize the connectivity of this `e3nn.o3.TensorProduct` Parameters ---------- weight : `torch.Tensor`, optional like ``weight`` argument to ``forward()`` plot_weight : `bool`, default True Whether to color paths by the sum of their weights. ax : ``matplotlib.Axes``, default None The axes to plot on. If ``None``, a new figure will be created. Returns ------- (fig, ax) The figure and axes on which the plot was drawn. """ import numpy as np def _intersection(x, u, y, v): u2 = np.sum(u**2) v2 = np.sum(v**2) uv = np.sum(u * v) det = u2 * v2 - uv**2 mu = np.sum((u * uv - v * u2) * (y - x)) / det return y + mu * v import matplotlib import matplotlib.pyplot as plt from matplotlib import patches from matplotlib.path import Path if ax is None: ax = plt.gca() fig = ax.get_figure() # hexagon verts = [np.array([np.cos(a * 2 * np.pi / 6), np.sin(a * 2 * np.pi / 6)]) for a in range(6)] verts = np.asarray(verts) # scale it if not (aspect_ratio in ["auto"] or isinstance(aspect_ratio, (float, int))): raise ValueError(f"aspect_ratio must be 'auto' or a float or int, got {aspect_ratio}") if aspect_ratio == "auto": factor = 0.2 / 2 min_aspect = 1 / 2 h_factor = max(len(self.irreps_in2), len(self.irreps_in1)) w_factor = len(self.irreps_out) if h_factor / w_factor < min_aspect: h_factor = min_aspect * w_factor verts[:, 1] *= h_factor * factor verts[:, 0] *= w_factor * factor if isinstance(aspect_ratio, (float, int)): factor = 0.1 * max(len(self.irreps_in2), len(self.irreps_in1), len(self.irreps_out)) verts[:, 1] *= factor verts[:, 0] *= aspect_ratio * factor codes = [ Path.MOVETO, Path.LINETO, Path.MOVETO, Path.LINETO, Path.MOVETO, Path.LINETO, ] path = Path(verts, codes) patch = patches.PathPatch(path, facecolor="none", lw=1, zorder=2) ax.add_patch(patch) n = len(self.irreps_in1) b, a = verts[2:4] c_in1 = (a + b) / 2 s_in1 = [a + (i + 1) / (n + 1) * (b - a) for i in range(n)] n = len(self.irreps_in2) b, a = verts[:2] c_in2 = (a + b) / 2 s_in2 = [a + (i + 1) / (n + 1) * (b - a) for i in range(n)] n = len(self.irreps_out) a, b = verts[4:6] s_out = [a + (i + 1) / (n + 1) * (b - a) for i in range(n)] # get weights if weight is None and not self.internal_weights: plot_weight = False elif plot_weight: with torch.no_grad(): path_weight = [] for ins_i, ins in enumerate(self.instructions): if ins.has_weight: this_weight = self.weight_view_for_instruction(ins_i, weight=weight).cpu() path_weight.append(this_weight.pow(2).mean()) else: path_weight.append(0) path_weight = np.asarray(path_weight) path_weight /= np.abs(path_weight).max() cmap = matplotlib.colormaps["Blues"] for ins_index, ins in enumerate(self.instructions): y = _intersection(s_in1[ins.i_in1], c_in1, s_in2[ins.i_in2], c_in2) verts = [] codes = [] verts += [s_out[ins.i_out], y] codes += [Path.MOVETO, Path.LINETO] verts += [s_in1[ins.i_in1], y] codes += [Path.MOVETO, Path.LINETO] verts += [s_in2[ins.i_in2], y] codes += [Path.MOVETO, Path.LINETO] if plot_weight: color = cmap(0.5 + 0.5 * path_weight[ins_index]) if ins.has_weight else "black" else: color = "green" if ins.has_weight else "black" ax.add_patch( patches.PathPatch( Path(verts, codes), facecolor="none", edgecolor=color, alpha=0.5, ls="-", lw=1.5, ) ) # add labels padding = 3 fontsize = 10 def format_ir(mul_ir) -> str: if mul_ir.mul == 1: return f"${mul_ir.ir}$" return f"${mul_ir.mul} \\times {mul_ir.ir}$" for i, mul_ir in enumerate(self.irreps_in1): ax.annotate( format_ir(mul_ir), s_in1[i], horizontalalignment="right", textcoords="offset points", xytext=(-padding, 0), fontsize=fontsize, ) for i, mul_ir in enumerate(self.irreps_in2): ax.annotate( format_ir(mul_ir), s_in2[i], horizontalalignment="left", textcoords="offset points", xytext=(padding, 0), fontsize=fontsize, ) for i, mul_ir in enumerate(self.irreps_out): ax.annotate( format_ir(mul_ir), s_out[i], horizontalalignment="center", verticalalignment="top", rotation=90, textcoords="offset points", xytext=(0, -padding), fontsize=fontsize, ) ax.set_xlim(-2, 2) ax.set_ylim(-2, 2) ax.axis("equal") ax.axis("off") return fig, ax e3nn-0.6.0/e3nn/o3/_wigner.py000066400000000000000000000233661514371756200155750ustar00rootroot00000000000000r"""Core functions of :math:`SO(3)`""" import functools import math from typing import Union import torch from e3nn.util import explicit_default_types def su2_generators(j: int) -> torch.Tensor: m = torch.arange(-j, j) raising = torch.diag(-torch.sqrt(j * (j + 1) - m * (m + 1)), diagonal=-1) m = torch.arange(-j + 1, j + 1) lowering = torch.diag(torch.sqrt(j * (j + 1) - m * (m - 1)), diagonal=1) m = torch.arange(-j, j + 1) return torch.stack( [ 0.5 * (raising + lowering), # x (usually) torch.diag(1j * m), # z (usually) -0.5j * (raising - lowering), # -y (usually) ], dim=0, ) # Need to do a graph break since Dynamo # cannot handle power of complex numbers and SymInt in L41 @torch.compiler.disable def change_basis_real_to_complex(l: int, dtype=None, device=None) -> torch.Tensor: # https://en.wikipedia.org/wiki/Spherical_harmonics#Real_form q = torch.zeros((2 * l + 1, 2 * l + 1), dtype=torch.complex128) for m in range(-l, 0): q[l + m, l + abs(m)] = 1 / 2**0.5 q[l + m, l - abs(m)] = -1j / 2**0.5 q[l, l] = 1 for m in range(1, l + 1): q[l + m, l + abs(m)] = (-1) ** m / 2**0.5 q[l + m, l - abs(m)] = 1j * (-1) ** m / 2**0.5 q = (-1j) ** l * q # Added factor of 1j**l to make the Clebsch-Gordan coefficients real dtype, device = explicit_default_types(dtype, device) dtype = { torch.float32: torch.complex64, torch.float64: torch.complex128, }[dtype] # make sure we always get: # 1. a copy so mutation doesn't ruin the stored tensors # 2. a contiguous tensor, regardless of what transpositions happened above return q.to(dtype=dtype, device=device, copy=True, memory_format=torch.contiguous_format) def so3_generators(l) -> torch.Tensor: X = su2_generators(l) Q = change_basis_real_to_complex(l) X = torch.conj(Q.T) @ X @ Q assert torch.all(torch.abs(torch.imag(X)) < 1e-5) return torch.real(X) def wigner_D(l: int, alpha: torch.Tensor, beta: torch.Tensor, gamma: torch.Tensor) -> torch.Tensor: r"""Wigner D matrix representation of :math:`SO(3)`. It satisfies the following properties: * :math:`D(\text{identity rotation}) = \text{identity matrix}` * :math:`D(R_1 \circ R_2) = D(R_1) \circ D(R_2)` * :math:`D(R^{-1}) = D(R)^{-1} = D(R)^T` * :math:`D(\text{rotation around Y axis})` has some property that allows us to use FFT in `ToS2Grid` Parameters ---------- l : int :math:`l` alpha : `torch.Tensor` tensor of shape :math:`(...)` Rotation :math:`\alpha` around Y axis, applied third. beta : `torch.Tensor` tensor of shape :math:`(...)` Rotation :math:`\beta` around X axis, applied second. gamma : `torch.Tensor` tensor of shape :math:`(...)` Rotation :math:`\gamma` around Y axis, applied first. Returns ------- `torch.Tensor` tensor :math:`D^l(\alpha, \beta, \gamma)` of shape :math:`(2l+1, 2l+1)` """ alpha, beta, gamma = torch.broadcast_tensors(alpha, beta, gamma) alpha = alpha[..., None, None] % (2 * math.pi) beta = beta[..., None, None] % (2 * math.pi) gamma = gamma[..., None, None] % (2 * math.pi) X = so3_generators(l) return torch.matrix_exp(alpha * X[1]) @ torch.matrix_exp(beta * X[0]) @ torch.matrix_exp(gamma * X[1]) def wigner_3j(l1: int, l2: int, l3: int, dtype=None, device=None) -> torch.Tensor: r"""Wigner 3j symbols :math:`C_{lmn}`. It satisfies the following two properties: .. math:: C_{lmn} = C_{ijk} D_{il}(g) D_{jm}(g) D_{kn}(g) \qquad \forall g \in SO(3) where :math:`D` are given by `wigner_D`. .. math:: C_{ijk} C_{ijk} = 1 Parameters ---------- l1 : int :math:`l_1` l2 : int :math:`l_2` l3 : int :math:`l_3` dtype : torch.dtype or None ``dtype`` of the returned tensor. If ``None`` then set to ``torch.get_default_dtype()``. device : torch.device or None ``device`` of the returned tensor. If ``None`` then set to the default device of the current context. Returns ------- `torch.Tensor` tensor :math:`C` of shape :math:`(2l_1+1, 2l_2+1, 2l_3+1)` """ assert abs(l2 - l3) <= l1 <= l2 + l3 assert isinstance(l1, int) and isinstance(l2, int) and isinstance(l3, int) C = _so3_clebsch_gordan(l1, l2, l3) dtype, device = explicit_default_types(dtype, device) # make sure we always get: # 1. a copy so mutation doesn't ruin the stored tensors # 2. a contiguous tensor, regardless of what transpositions happened above return C.to(dtype=dtype, device=device, copy=True, memory_format=torch.contiguous_format) @functools.lru_cache(maxsize=None) def _so3_clebsch_gordan(l1: int, l2: int, l3: int) -> torch.Tensor: Q1 = change_basis_real_to_complex(l1, dtype=torch.float64) Q2 = change_basis_real_to_complex(l2, dtype=torch.float64) Q3 = change_basis_real_to_complex(l3, dtype=torch.float64) C = _su2_clebsch_gordan(l1, l2, l3).to(dtype=torch.complex128) C = torch.einsum("ij,kl,mn,ikn->jlm", Q1, Q2, torch.conj(Q3.T), C) # make it real assert torch.all(torch.abs(torch.imag(C)) < 1e-5) C = torch.real(C) # normalization C = C / torch.norm(C) return C # Taken from http://qutip.org/docs/3.1.0/modules/qutip/utilities.html # This file is part of QuTiP: Quantum Toolbox in Python. # # Copyright (c) 2011 and later, Paul D. Nation and Robert J. Johansson. # All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are # met: # # 1. Redistributions of source code must retain the above copyright notice, # this list of conditions and the following disclaimer. # # 2. Redistributions in binary form must reproduce the above copyright # notice, this list of conditions and the following disclaimer in the # documentation and/or other materials provided with the distribution. # # 3. Neither the name of the QuTiP: Quantum Toolbox in Python nor the names # of its contributors may be used to endorse or promote products derived # from this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT # HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ############################################################################### @functools.lru_cache(maxsize=None) def _su2_clebsch_gordan(j1: Union[int, float], j2: Union[int, float], j3: Union[int, float]) -> torch.Tensor: """Calculates the Clebsch-Gordon matrix for SU(2) coupling j1 and j2 to give j3. Parameters ---------- j1 : float Total angular momentum 1. j2 : float Total angular momentum 2. j3 : float Total angular momentum 3. Returns ------- cg_matrix : numpy.array Requested Clebsch-Gordan matrix. """ assert isinstance(j1, (int, float)) assert isinstance(j2, (int, float)) assert isinstance(j3, (int, float)) mat = torch.zeros((int(2 * j1 + 1), int(2 * j2 + 1), int(2 * j3 + 1)), dtype=torch.float64) if int(2 * j3) in range(int(2 * abs(j1 - j2)), int(2 * (j1 + j2)) + 1, 2): for m1 in (x / 2 for x in range(-int(2 * j1), int(2 * j1) + 1, 2)): for m2 in (x / 2 for x in range(-int(2 * j2), int(2 * j2) + 1, 2)): if abs(m1 + m2) <= j3: mat[int(j1 + m1), int(j2 + m2), int(j3 + m1 + m2)] = _su2_clebsch_gordan_coeff( (j1, m1), (j2, m2), (j3, m1 + m2) ) return mat def _su2_clebsch_gordan_coeff(idx1, idx2, idx3): """Calculates the Clebsch-Gordon coefficient for SU(2) coupling (j1,m1) and (j2,m2) to give (j3,m3). Parameters ---------- j1 : float Total angular momentum 1. j2 : float Total angular momentum 2. j3 : float Total angular momentum 3. m1 : float z-component of angular momentum 1. m2 : float z-component of angular momentum 2. m3 : float z-component of angular momentum 3. Returns ------- cg_coeff : float Requested Clebsch-Gordan coefficient. """ from fractions import Fraction from math import factorial j1, m1 = idx1 j2, m2 = idx2 j3, m3 = idx3 if m3 != m1 + m2: return 0 vmin = int(max([-j1 + j2 + m3, -j1 + m1, 0])) vmax = int(min([j2 + j3 + m1, j3 - j1 + j2, j3 + m3])) def f(n: int) -> int: assert n == round(n) return factorial(round(n)) C = ( (2.0 * j3 + 1.0) * Fraction( f(j3 + j1 - j2) * f(j3 - j1 + j2) * f(j1 + j2 - j3) * f(j3 + m3) * f(j3 - m3), f(j1 + j2 + j3 + 1) * f(j1 - m1) * f(j1 + m1) * f(j2 - m2) * f(j2 + m2), ) ) ** 0.5 S = 0 for v in range(vmin, vmax + 1): S += (-1) ** int(v + j2 + m2) * Fraction( f(j2 + j3 + m1 - v) * f(j1 - m1 + v), f(v) * f(j3 - j1 + j2 - v) * f(j3 + m3 - v) * f(v + j1 - j2 - m3) ) C = C * S return C e3nn-0.6.0/e3nn/o3/experimental/000077500000000000000000000000001514371756200162545ustar00rootroot00000000000000e3nn-0.6.0/e3nn/o3/experimental/__init__.py000066400000000000000000000003201514371756200203600ustar00rootroot00000000000000from ._full_tp import FullTensorProduct as FullTensorProductv2 from ._elementwise_tp import ElementwiseTensorProduct as ElementwiseTensorProductv2 __all__ = [FullTensorProductv2, ElementwiseTensorProductv2] e3nn-0.6.0/e3nn/o3/experimental/_elementwise_tp.py000066400000000000000000000077701514371756200220240ustar00rootroot00000000000000# flake8: noqa from typing import Tuple from e3nn.util.datatypes import Path, Chunk from e3nn import o3 import torch from torch import nn import numpy as np from ._full_tp import _prepare_inputs def _align_two_irreps(irreps1: o3.Irreps, irreps2: o3.Irreps) -> Tuple[o3.Irreps, o3.Irreps]: assert irreps1.num_irreps == irreps2.num_irreps irreps1 = list(irreps1) irreps2 = list(irreps2) i = 0 while i < min(len(irreps1), len(irreps2)): mul_1, ir_1 = irreps1[i] mul_2, ir_2 = irreps2[i] if mul_1 < mul_2: irreps2[i] = (mul_1, ir_2) irreps2.insert(i + 1, (mul_2 - mul_1, ir_2)) if mul_2 < mul_1: irreps1[i] = (mul_2, ir_1) irreps1.insert(i + 1, (mul_1 - mul_2, ir_1)) i += 1 assert [mul for mul, _ in irreps1] == [mul for mul, _ in irreps2] return o3.Irreps(irreps1), o3.Irreps(irreps2) class ElementwiseTensorProduct(nn.Module): def __init__( self, irreps_in1: o3.Irreps, irreps_in2: o3.Irreps, *, filter_ir_out: o3.Irreps = None, irrep_normalization: str = "component", ): """Tensor Product adapted from https://github.com/e3nn/e3nn-jax/blob/cf37f3e95264b34587b3a202ea4c3eb82597307e/e3nn_jax/_src/tensor_products.py#L139-L213""" super(ElementwiseTensorProduct, self).__init__() if irreps_in1.num_irreps != irreps_in2.num_irreps: raise ValueError( "o3.ElementwiseTensorProductv2: inputs must have the same number of irreps, " f"got {irreps_in1.num_irreps} and {irreps_in2.num_irreps}" ) irreps_in1, irreps_in2 = _align_two_irreps(irreps_in1, irreps_in2) paths = {} irreps_out = [] for (mul_1, ir_1), slice_1, (_, ir_2), slice_2 in zip( irreps_in1, irreps_in1.slices(), irreps_in2, irreps_in2.slices() ): for ir_out in ir_1 * ir_2: if filter_ir_out is not None and ir_out not in filter_ir_out: continue cg = o3.wigner_3j(ir_1.l, ir_2.l, ir_out.l) if irrep_normalization == "component": cg *= np.sqrt(ir_out.dim) elif irrep_normalization == "norm": cg *= np.sqrt(ir_1.dim * ir_2.dim) else: raise ValueError(f"irrep_normalization={irrep_normalization} not supported") self.register_buffer(f"cg_{ir_1.l}_{ir_2.l}_{ir_out.l}", cg) paths[(ir_1.l, ir_1.p, ir_2.l, ir_2.p, ir_out.l, ir_out.p)] = Path( Chunk(mul_1, ir_1.dim, slice_1), Chunk(mul_1, ir_2.dim, slice_2), Chunk(mul_1, ir_out.dim) ) irreps_out.append((mul_1, ir_out)) self.paths = paths irreps_out = o3.Irreps(irreps_out) self.irreps_out, _, self.inv = irreps_out.sort() self.irreps_in1 = irreps_in1 self.irreps_in2 = irreps_in2 def forward( self, input1: torch.Tensor, input2: torch.Tensor, ) -> torch.Tensor: input1, input2, leading_shape = _prepare_inputs(input1, input2) chunks = [] for (l1, _, l2, _, l3, _), ( (mul_1, input_dim1, slice_1), (mul_2, input_dim2, slice_2), (output_mul, output_dim, _), ) in self.paths.items(): x1 = input1[..., slice_1].reshape( leading_shape + ( mul_1, input_dim1, ) ) x2 = input2[..., slice_2].reshape( leading_shape + ( mul_2, input_dim2, ) ) cg = getattr(self, f"cg_{l1}_{l2}_{l3}") chunk = torch.einsum("...ui, ...uj, ijk -> ...uk", x1, x2, cg) chunk = torch.reshape(chunk, chunk.shape[:-2] + (output_mul * output_dim,)) chunks.append(chunk) return torch.cat([chunks[i] for i in self.inv], dim=-1) e3nn-0.6.0/e3nn/o3/experimental/_full_tp.py000066400000000000000000000070401514371756200204330ustar00rootroot00000000000000# flake8: noqa from e3nn.util.datatypes import Path, Chunk from e3nn import o3 import torch from torch import nn import numpy as np def _prepare_inputs(input1, input2): dtype = torch.promote_types(input1.dtype, input2.dtype) input1 = input1.to(dtype=dtype) input2 = input2.to(dtype=dtype) leading_shape = torch.broadcast_shapes(input1.shape[:-1], input2.shape[:-1]) input1 = input1.broadcast_to(leading_shape + (-1,)) input2 = input2.broadcast_to(leading_shape + (-1,)) return input1, input2, leading_shape class FullTensorProduct(nn.Module): def __init__( self, irreps_in1: o3.Irreps, irreps_in2: o3.Irreps, *, filter_ir_out: o3.Irreps = None, irrep_normalization: str = "component", regroup_output: bool = True, ): """Tensor Product adapted from https://github.com/e3nn/e3nn-jax/blob/cf37f3e95264b34587b3a202ea4c3eb82597307e/e3nn_jax/_src/tensor_products.py#L40-L135""" super(FullTensorProduct, self).__init__() if regroup_output: irreps_in1 = o3.Irreps(irreps_in1).regroup() irreps_in2 = o3.Irreps(irreps_in2).regroup() paths = {} irreps_out = [] for (mul_1, ir_1), slice_1 in zip(irreps_in1, irreps_in1.slices()): for (mul_2, ir_2), slice_2 in zip(irreps_in2, irreps_in2.slices()): for ir_out in ir_1 * ir_2: if filter_ir_out is not None and ir_out not in filter_ir_out: continue cg = o3.wigner_3j(ir_1.l, ir_2.l, ir_out.l) if irrep_normalization == "component": cg *= np.sqrt(ir_out.dim) elif irrep_normalization == "norm": cg *= np.sqrt(ir_1.dim * ir_2.dim) else: raise ValueError(f"irrep_normalization={irrep_normalization} not supported") self.register_buffer(f"cg_{ir_1.l}_{ir_2.l}_{ir_out.l}", cg) paths[(ir_1.l, ir_1.p, ir_2.l, ir_2.p, ir_out.l, ir_out.p)] = Path( Chunk(mul_1, ir_1.dim, slice_1), Chunk(mul_2, ir_2.dim, slice_2), Chunk(mul_1 * mul_2, ir_out.dim) ) irreps_out.append((mul_1 * mul_2, ir_out)) self.paths = paths irreps_out = o3.Irreps(irreps_out) self.irreps_out, _, self.inv = irreps_out.sort() self.irreps_in1 = irreps_in1 self.irreps_in2 = irreps_in2 def forward( self, input1: torch.Tensor, input2: torch.Tensor, ) -> torch.Tensor: input1, input2, leading_shape = _prepare_inputs(input1, input2) chunks = [] for (l1, _, l2, _, l3, _), ( (mul_1, input_dim1, slice_1), (mul_2, input_dim2, slice_2), (output_mul, output_dim, _), ) in self.paths.items(): x1 = input1[..., slice_1].reshape( leading_shape + ( mul_1, input_dim1, ) ) x2 = input2[..., slice_2].reshape( leading_shape + ( mul_2, input_dim2, ) ) cg = getattr(self, f"cg_{l1}_{l2}_{l3}") chunk = torch.einsum("...ui, ...vj, ijk -> ...uvk", x1, x2, cg) chunk = torch.reshape(chunk, chunk.shape[:-3] + (output_mul * output_dim,)) chunks.append(chunk) return torch.cat([chunks[i] for i in self.inv], dim=-1) e3nn-0.6.0/e3nn/o3/irrep/000077500000000000000000000000001514371756200147005ustar00rootroot00000000000000e3nn-0.6.0/e3nn/o3/irrep/__init__.py000066400000000000000000000016741514371756200170210ustar00rootroot00000000000000r"""Allows for clean lookup of Irreducible representations of :math:`O(3)` Examples -------- Create a scalar representation (:math:`l=0`) of even parity. >>> from e3nn.o3 import irrep >>> irrep.l0e == Irrep("0e") True >>> from e3nn.o3.irrep import l1o, l2o >>> l1o + l2o == Irrep("1o") + Irrep("2o") True """ from .._irreps import Irrep def __getattr__(name: str) -> Irrep: r"""Creates an Irreps obeject by reflection Parameters ---------- name : string the o3 object name prefixed by l. Example: l1o == Irrep("1o") Returns ------- `e3nn.o3.Irrep` irreducible representation of :math:`O(3)` """ prefix, *ir = name if prefix != "l" or not ir: raise AttributeError(f"'e3nn.o3.irrep' module has no attribute '{name}'") try: return Irrep("".join(ir)) except (ValueError, AssertionError): raise AttributeError(f"'e3nn.o3.irrep' module has no attribute '{name}'") e3nn-0.6.0/e3nn/util/000077500000000000000000000000001514371756200142135ustar00rootroot00000000000000e3nn-0.6.0/e3nn/util/__init__.py000066400000000000000000000005621514371756200163270ustar00rootroot00000000000000from .default_type import ( torch_get_default_tensor_type, torch_get_default_device, explicit_default_types, ) def prod(x): """Compute the product of a sequence.""" out = 1 for a in x: out *= a return out __all__ = [ "torch_get_default_tensor_type", "torch_get_default_device", "explicit_default_types", "prod", ] e3nn-0.6.0/e3nn/util/_argtools.py000066400000000000000000000147021514371756200165620ustar00rootroot00000000000000from typing import Optional import random import warnings import torch from e3nn.o3._irreps import Irreps def _is_irreps(obj): """Check if obj is an Irreps instance, even across different class definitions. This uses a marker attribute instead of isinstance() to handle cases where the Irreps class from packaged code differs from the environment's Irreps class. """ return hasattr(type(obj), '_e3nn_irreps_marker') def _transform(dat, irreps_dat, rot_mat, translation: float = 0.0, output_transform_dtype: bool = False): """Transform ``dat`` by ``rot_mat`` and ``translation`` according to ``irreps_dat``.""" out = [] transform_dtype = rot_mat.dtype translation = torch.as_tensor(translation, dtype=transform_dtype) for irreps, a in zip(irreps_dat, dat): if output_transform_dtype: out_dtype = transform_dtype else: out_dtype = a.dtype if irreps is None: out.append(a.clone()) elif irreps == "cartesian_points": translation = torch.as_tensor(translation, device=a.device) out.append(((a.to(transform_dtype) @ rot_mat.T.to(a.device)) + translation).to(out_dtype)) else: # For o3.Irreps out.append((a.to(transform_dtype) @ irreps.D_from_matrix(rot_mat).T.to(a.device)).to(out_dtype)) return out def _get_io_irreps(func, irreps_in=None, irreps_out=None): """Preprocess or, if not given, try to infer the I/O irreps for ``func``.""" SPECIAL_VALS = ["cartesian_points", None] if (irreps_in is None or irreps_out is None) and isinstance(func, torch.jit.ScriptModule): warnings.warn( "Asking to infer irreps in/out of a compiled TorchScript module. This is unreliable, please provide `irreps_in`" "and `irreps_out` explicitly." ) if irreps_in is None: if hasattr(func, "irreps_in"): irreps_in = func.irreps_in # gets checked for type later elif hasattr(func, "irreps_in1"): irreps_in = [func.irreps_in1, func.irreps_in2] else: raise ValueError("Cannot infer irreps_in for %r; provide them explicitly" % func) if irreps_out is None: if hasattr(func, "irreps_out"): irreps_out = func.irreps_out # gets checked for type later else: raise ValueError("Cannot infer irreps_out for %r; provide them explicitly" % func) if _is_irreps(irreps_in) or irreps_in in SPECIAL_VALS: irreps_in = [irreps_in] elif isinstance(irreps_in, list): irreps_in = [i if i in SPECIAL_VALS else Irreps(i) for i in irreps_in] else: if isinstance(irreps_in, tuple) and not _is_irreps(irreps_in): warnings.warn( f"Module {func} had irreps_in of type tuple but not Irreps; ambiguous whether the tuple should be interpreted " f"as a tuple representing a single Irreps or a tuple of objects each to be converted to Irreps. Assuming the " f"former. If the latter, use a list." ) irreps_in = [Irreps(irreps_in)] if _is_irreps(irreps_out) or irreps_out in SPECIAL_VALS: irreps_out = [irreps_out] elif isinstance(irreps_out, list): irreps_out = [i if i in SPECIAL_VALS else Irreps(i) for i in irreps_out] else: if isinstance(irreps_out, tuple) and not _is_irreps(irreps_out): warnings.warn( f"Module {func} had irreps_out of type tuple but not Irreps; ambiguous whether the tuple should be " f"interpreted as a tuple representing a single Irreps or a tuple of objects each to be converted to Irreps. " f"Assuming the former. If the latter, use a list." ) irreps_out = [Irreps(irreps_out)] return irreps_in, irreps_out def _get_args_in(func, args_in=None, irreps_in=None, irreps_out=None): irreps_in, irreps_out = _get_io_irreps(func, irreps_in=irreps_in, irreps_out=irreps_out) if args_in is None: args_in = _rand_args(irreps_in) assert len(args_in) == len(irreps_in), "irreps_in and args_in don't match in length" return args_in, irreps_in, irreps_out def _rand_args(irreps_in, batch_size: Optional[int] = None): if not all((_is_irreps(i) or i == "cartesian_points") for i in irreps_in): raise ValueError( "Random arguments cannot be generated when argument types besides Irreps and `'cartesian_points'` are specified; " "provide explicit ``args_in``" ) if batch_size is None: # Generate random args with random size batch dim between 1 and 4: batch_size = random.randint(1, 4) args_in = [ torch.randn(batch_size, 3) if (irreps == "cartesian_points") else irreps.randn(batch_size, -1) for irreps in irreps_in ] return args_in def _get_device(mod: torch.nn.Module) -> torch.device: # Try to a get a parameter a_buf = next(mod.parameters(), None) if a_buf is None: # If there isn't one, try to get a buffer a_buf = next(mod.buffers(), None) return a_buf.device if a_buf is not None else "cpu" def _get_floating_dtype(mod: torch.nn.Module) -> torch.dtype: """Guess floating dtype for module. Assumes no mixed precision. """ # Try to a get a parameter a_buf = None for buf in mod.parameters(): if buf.is_floating_point(): a_buf = buf break if a_buf is None: # If there isn't one, try to get a buffer for buf in mod.buffers(): if buf.is_floating_point(): a_buf = buf break return a_buf.dtype if a_buf is not None else torch.get_default_dtype() def _to_device_dtype(args, device=None, dtype=None): kwargs = {} if device is not None: kwargs["device"] = device if dtype is not None: kwargs["dtype"] = dtype if isinstance(args, torch.Tensor): if args.is_floating_point(): # Only convert dtypes of floating tensors return args.to(device=device, dtype=dtype) else: return args.to(device=device) elif isinstance(args, tuple): return tuple(_to_device_dtype(e, **kwargs) for e in args) elif isinstance(args, list): return [_to_device_dtype(e, **kwargs) for e in args] elif isinstance(args, dict): return {k: _to_device_dtype(v, **kwargs) for k, v in args.items()} else: raise TypeError("Only (nested) dict/tuple/lists of Tensors can be moved to a device/dtype.") e3nn-0.6.0/e3nn/util/_context.py000066400000000000000000000013001514371756200164020ustar00rootroot00000000000000# Please see PR #203 for these commented out code: https://github.com/e3nn/e3nn/pull/203 # from abc import ABCMeta, abstractmethod # from contextlib import AbstractContextManager # from functools import wraps # class AbstractContextDecoratorManager(AbstractContextManager, metaclass=ABCMeta): # def __init__(self) -> None: # super().__init__() # @abstractmethod # def __enter__(self): # pass # @abstractmethod # def __exit__(self, exc_type, exc_value, traceback): # pass # def __call__(self, f): # @wraps(f) # def wrapper(*args, **kwargs): # with self: # return f(*args, **kwargs) # return wrapper e3nn-0.6.0/e3nn/util/codegen/000077500000000000000000000000001514371756200156175ustar00rootroot00000000000000e3nn-0.6.0/e3nn/util/codegen/__init__.py000066400000000000000000000001051514371756200177240ustar00rootroot00000000000000from ._mixin import CodeGenMixin __all__ = [ "CodeGenMixin", ] e3nn-0.6.0/e3nn/util/codegen/_mixin.py000066400000000000000000000126601514371756200174610ustar00rootroot00000000000000import io import pickle from typing import Dict import e3nn.util.jit import torch from torch import fx class CodeGenMixin: """Mixin for classes that dynamically generate TorchScript code using FX. This class manages evaluating and compiling generated code for subclasses while remaining pickle/deepcopy compatible. If subclasses need to override ``__getstate__``/``__setstate__``, they should be sure to call CodeGenMixin's implimentation first and use its output. """ # pylint: disable=super-with-arguments def _codegen_register( self, funcs: Dict[str, fx.GraphModule], ) -> None: """Register ``fx.GraphModule``s as TorchScript submodules. Parameters ---------- funcs : Dict[str, fx.GraphModule] Dictionary mapping submodule names to graph modules. """ if not hasattr(self, "__codegen__"): # list of submodule names that are managed by this object self.__codegen__ = [] self.__codegen__.extend(funcs.keys()) opt_defaults = e3nn.get_optimization_defaults() for fname, graphmod in funcs.items(): assert isinstance(graphmod, fx.GraphModule) if opt_defaults["jit_mode"] == "script": # With recurse=False, this more or less is equivalent to # torch.jit.script(jitable(graphmod)) scriptmod = e3nn.util.jit.compile(graphmod, recurse=False) assert isinstance(scriptmod, torch.jit.ScriptModule) else: scriptmod = graphmod # Add the ScriptModule as a submodule so it can be called self.add_module(fname, scriptmod) # In order to support copy.deepcopy and pickling, we need to not save the compiled TorchScript functions: # See pickle docs: https://docs.python.org/3/library/pickle.html#pickling-class-instances def __getstate__(self): # - Get a state to work with - # We need to check if other parent classes of self define __getstate__ # torch.nn.Module does not currently impliment __get/setstate__ but # may in the future, which is why we have these hasattr checks for # other superclasses. if hasattr(super(CodeGenMixin, self), "__getstate__"): out = super(CodeGenMixin, self).__getstate__() else: out = self.__dict__ out = out.copy() # We need a copy of the _modules OrderedDict # Otherwise, modifying the returned state will modify the current module itself out["_modules"] = out["_modules"].copy() # - Add saved versions of the ScriptModules to the state - codegen_state = {} if hasattr(self, "__codegen__"): for fname in self.__codegen__: # Get the module smod = getattr(self, fname) buffer_type: str buffer: bytes if isinstance(smod, (fx.GraphModule, torch._dynamo.OptimizedModule)): buffer_type = "fx" # pickle the fx.GraphModule normally buffer = pickle.dumps(smod) elif isinstance(smod, torch.jit.ScriptModule): buffer_type = "torchscript" # Save the compiled code as TorchScript IR buffer_io = io.BytesIO() torch.jit.save(smod, buffer_io) # Serialize that IR (just some `bytes`) instead of # the ScriptModule buffer = buffer_io.getvalue() else: assert False # Save the buffer and a note on what it is so we know how to load it codegen_state[fname] = (buffer_type, buffer) # Remove the compiled submodule from being a submodule # of the saved module del out["_modules"][fname] out["__codegen__"] = codegen_state return out def __setstate__(self, d) -> None: d = d.copy() # We don't want to add this to the object when we call super's __setstate__ codegen_state = d.pop("__codegen__", None) # We need to initialize self first so that we can add submodules # We need to check if other parent classes of self define __getstate__ if hasattr(super(CodeGenMixin, self), "__setstate__"): super(CodeGenMixin, self).__setstate__(d) else: self.__dict__.update(d) if codegen_state is not None: for fname, (buffer_type, buffer) in codegen_state.items(): assert isinstance(fname, str) assert isinstance(buffer_type, str) # Make sure bytes, not ScriptModules, got made assert isinstance(buffer, bytes) if buffer_type == "fx": smod = pickle.loads(buffer) assert isinstance(smod, (fx.GraphModule, torch._dynamo.OptimizedModule)) elif buffer_type == "torchscript": # Load the TorchScript IR buffer buffer = io.BytesIO(buffer) smod = torch.jit.load(buffer) assert isinstance(smod, torch.jit.ScriptModule) else: raise NotImplementedError # Add the ScriptModule as a submodule setattr(self, fname, smod) self.__codegen__ = list(codegen_state.keys()) e3nn-0.6.0/e3nn/util/datatypes.py000066400000000000000000000003431514371756200165630ustar00rootroot00000000000000from typing import NamedTuple, Optional class Chunk(NamedTuple): mul: int dim: int slice: Optional[slice] = None class Path(NamedTuple): input_1_slice: Chunk input_2_slice: Chunk output_slice: Chunk e3nn-0.6.0/e3nn/util/default_type.py000066400000000000000000000062761514371756200172650ustar00rootroot00000000000000from typing import Optional, Tuple import torch import torch.jit # Please see PR #203 for these commented out code: https://github.com/e3nn/e3nn/pull/203 # from functools import wraps # from ._context import AbstractContextDecoratorManager # class torch_default_tensor_type(AbstractContextDecoratorManager): # def __init__(self, dtype, device) -> None: # super().__init__() # self.saved_ttype = None # self.dtype = dtype # self.device = device # # def __enter__(self): # if self.dtype is not None or self.device is not None: # self.saved_ttype = torch_get_default_tensor_type() # torch.set_default_tensor_type(self.ttype) # # def __exit__(self, exc_type, exc_value, traceback): # if self.saved_ttype is not None: # torch.set_default_tensor_type(self.saved_ttype) # self.saved_ttype = None # # @property # def ttype(self): # return torch.empty(0, dtype=self.dtype, device=self.device).type() # class torch_default_dtype(AbstractContextDecoratorManager): # def __init__(self, dtype) -> None: # super().__init__() # self.saved_dtype = None # self.dtype = dtype # # def __enter__(self): # if self.dtype is not None: # self.saved_dtype = torch.get_default_dtype() # torch.set_default_dtype(self.dtype) # # def __exit__(self, exc_type, exc_value, traceback): # if self.saved_dtype is not None: # torch.set_default_dtype(self.saved_dtype) # self.saved_dtype = None # class torch_default_device(torch_default_tensor_type): # def __init__(self, device) -> None: # super().__init__(None, device) # class add_type_kwargs(object): # _DOC_NOTE = r""" # - dtype and device keyword arguments will be passed to torch_default_tensor_type() # """ # # def __init__(self, dtype=None, device=None) -> None: # super().__init__() # self.dtype = dtype # self.device = device # # def __call__(self, f): # @wraps(f) # def wrapper(*args, dtype=self.dtype, device=self.device, **kwargs): # with torch_default_tensor_type(dtype, device): # return f(*args, **kwargs) # # if wrapper.__doc__ is not None: # if not wrapper.__doc__.endswith("\n"): # wrapper.__doc__ += "\n" # wrapper.__doc__ += self._DOC_NOTE # # return wrapper def torch_get_default_tensor_type() -> str: return torch.empty(0).type() def _torch_get_default_dtype() -> torch.dtype: """A torchscript-compatible version of torch.get_default_dtype()""" return torch.empty(0).dtype def torch_get_default_device() -> torch.device: return torch.empty(0).device def explicit_default_types(dtype: Optional[torch.dtype], device: Optional[torch.device]) -> Tuple[torch.dtype, torch.device]: """A torchscript-compatible type resolver""" if dtype is None: dtype = _torch_get_default_dtype() if device is None: device = torch_get_default_device() return dtype, device # def torch_set_default_device(device): # ttype = torch_default_device(device).ttype # torch.set_default_tensor_type(ttype) e3nn-0.6.0/e3nn/util/jit.py000066400000000000000000000301001514371756200153450ustar00rootroot00000000000000import copy import inspect import warnings import re from typing import Optional, Callable, Tuple from contextlib import contextmanager from e3nn import get_optimization_defaults, set_optimization_defaults import torch from torch import nn from torch import fx from opt_einsum_fx import jitable ModuleFactory = Callable[..., nn.Module] TypeTuple = Tuple[type, ...] _E3NN_COMPILE_MODE = "__e3nn_compile_mode__" _VALID_MODES = ("trace", "script", "unsupported", None) _MAKE_TRACING_INPUTS = "_make_tracing_inputs" def compile_mode(mode: str): """Decorator to set the compile mode of a module. Parameters ---------- mode : str 'script', 'trace', or None """ if mode not in _VALID_MODES: raise ValueError("Invalid compile mode") def decorator(obj): if not (inspect.isclass(obj) and issubclass(obj, torch.nn.Module)): raise TypeError("@e3nn.util.jit.compile_mode can only decorate classes derived from torch.nn.Module") setattr(obj, _E3NN_COMPILE_MODE, mode) return obj return decorator def get_compile_mode(mod: torch.nn.Module) -> str: """Get the compilation mode of a module. Parameters ---------- mod : torch.nn.Module Returns ------- 'script', 'trace', or None if the module was not decorated with @compile_mode """ if hasattr(mod, _E3NN_COMPILE_MODE): mode = getattr(mod, _E3NN_COMPILE_MODE) else: mode = getattr(type(mod), _E3NN_COMPILE_MODE, None) if mode is None and isinstance(mod, fx.GraphModule): mode = "script" assert mode in _VALID_MODES, "Invalid compile mode `%r`" % mode return mode def compile( mod: torch.nn.Module, n_trace_checks: int = 1, script_options: dict = None, trace_options: dict = None, in_place: bool = True, recurse: bool = True, ): """Recursively compile a module and all submodules according to their decorators. (Sub)modules without decorators will be unaffected. Parameters ---------- mod : torch.nn.Module The module to compile. The module will have its submodules compiled replaced in-place. n_trace_checks : int, default = 1 How many random example inputs to generate when tracing a module. Must be at least one in order to have a tracing input. Extra example inputs will be pased to ``torch.jit.trace`` to confirm that the traced copmute graph doesn't change. script_options : dict, default = {} Extra kwargs for ``torch.jit.script``. trace_options : dict, default = {} Extra kwargs for ``torch.jit.trace``. in_place : bool, default True Whether to insert the recursively compiled submodules in-place, or do a deepcopy first. recurse : bool, default True Whether to recurse through the module's children before passing the parent to TorchScript Returns ------- Returns the compiled module. """ script_options = script_options or {} trace_options = trace_options or {} mode = get_compile_mode(mod) if mode == "unsupported": raise NotImplementedError(f"{type(mod).__name__} does not support TorchScript compilation") if not in_place: mod = copy.deepcopy(mod) # TODO: debug logging assert n_trace_checks >= 1 if recurse: # == recurse to children == # This allows us to trace compile submodules of modules we are going to script for submod_name, submod in mod.named_children(): setattr( mod, submod_name, compile( submod, n_trace_checks=n_trace_checks, script_options=script_options, trace_options=trace_options, in_place=True, # since we deepcopied the module above, we can do inplace recurse=recurse, # always true in this branch ), ) # == Compile this module now == if mode == "script": if isinstance(mod, fx.GraphModule): mod = jitable(mod) # In recent PyTorch versions (probably >1.12, definitely >=2.0), PyTorch's implementation of fx.GraphModule # causes a warning to be raised when fx.GraphModules are compiled to TorchScript with `torch.jit.script`: # # torch/jit/_check.py:177: UserWarning: The TorchScript type system doesn't support instance-level # annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the # class body, or 2) wrap the type in `torch.jit.Attribute`. # # Using the debugger traces this back to the following line in PyTorch: # https://github.com/pytorch/pytorch/blob/v2.3.1/torch/fx/graph_module.py#L446 # Because the metadata stored by GraphModule is not relevant to the compiled TorchScript module # we need, it should be safe to ignore this warning. The below code suppresses this warning as # narrowly as possible to ensure it can still be raised from user code. # See also: https://github.com/pytorch/pytorch/issues/89064 # Note: In PyTorch 2.10.0+, this warning is raised from ast.py instead of torch/jit/_check.py, # so we don't filter by module to catch both cases. with warnings.catch_warnings(): warnings.filterwarnings( "ignore", # warnings treats this argument as a regex, but we want to match a string literal exactly, so escape it: message=re.escape( "The TorchScript type system doesn't support instance-level annotations on empty non-base types " "in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type " "in `torch.jit.Attribute`." ), category=UserWarning, # don't filter by module since in PyTorch 2.10.0+ the warning comes from ast.py instead of torch ) mod = torch.jit.script(mod, **script_options) else: mod = torch.jit.script(mod, **script_options) elif mode == "trace": # These are always modules, so we're always using trace_module # We need tracing inputs: check_inputs = get_tracing_inputs( mod, n_trace_checks, ) assert len(check_inputs) >= 1, "Must have at least one tracing input." # Do the actual trace mod = torch.jit.trace_module(mod, inputs=check_inputs[0], check_inputs=check_inputs, **trace_options) return mod def get_tracing_inputs( mod: torch.nn.Module, n: int = 1, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None ): """Get random tracing inputs for ``mod``. First checks if ``mod`` has a ``_make_tracing_inputs`` method. If so, calls it with ``n`` as the single argument and returns its results. Otherwise, attempts to infer the input signature of the module using ``e3nn.util._argtools._get_io_irreps``. Parameters ---------- mod : torch.nn.Module n : int, default = 1 A hint for how many inputs are wanted. Usually n will be returned, but modules don't necessarily have to. device : torch.device The device to do tracing on. If `None` (default), will be guessed. dtype : torch.dtype The dtype to trace with. If `None` (default), will be guessed. Returns ------- list of dict Tracing inputs in the format of ``torch.jit.trace_module``: dicts mapping method names like ``'forward'`` to tuples of arguments. """ # Avoid circular imports from ._argtools import _get_device, _get_floating_dtype, _get_io_irreps, _rand_args, _to_device_dtype # - Get inputs - if hasattr(mod, _MAKE_TRACING_INPUTS): # This returns a trace_module style dict of method names to test inputs trace_inputs = mod._make_tracing_inputs(n) assert isinstance(trace_inputs, list) for d in trace_inputs: assert isinstance(d, dict), "_make_tracing_inputs must return a list of dict[str, tuple]" assert all( isinstance(k, str) and isinstance(v, tuple) for k, v in d.items() ), "_make_tracing_inputs must return a list of dict[str, tuple]" else: # Try to infer. This will throw if it can't. irreps_in, _ = _get_io_irreps(mod, irreps_out=[None]) # we're only trying to infer inputs trace_inputs = [{"forward": _rand_args(irreps_in)} for _ in range(n)] # - Put them on the right device - if device is None: device = _get_device(mod) if dtype is None: dtype = _get_floating_dtype(mod) # Move them trace_inputs = _to_device_dtype(trace_inputs, device, dtype) return trace_inputs def trace_module(mod: torch.nn.Module, inputs: dict = None, check_inputs: list = None, in_place: bool = True): """Trace a module. Identical signature to ``torch.jit.trace_module``, but first recursively compiles ``mod`` using ``compile``. Parameters ---------- mod : torch.nn.Module inputs : dict check_inputs : list of dict Returns ------- Traced module. """ check_inputs = check_inputs or [] # Set the compile mode for mod, temporarily old_mode = getattr(mod, _E3NN_COMPILE_MODE, None) if old_mode is not None and old_mode != "trace": warnings.warn( f"Trying to trace a module of type {type(mod).__name__} marked with @compile_mode != 'trace', expect errors!" ) setattr(mod, _E3NN_COMPILE_MODE, "trace") # If inputs are provided, set make_tracing_input temporarily old_make_tracing_input = None if inputs is not None: old_make_tracing_input = getattr(mod, _MAKE_TRACING_INPUTS, None) setattr(mod, _MAKE_TRACING_INPUTS, lambda num: ([inputs] + check_inputs)) # Compile out = compile(mod, in_place=in_place) # Restore old values, if we had them if old_mode is not None: setattr(mod, _E3NN_COMPILE_MODE, old_mode) if old_make_tracing_input is not None: setattr(mod, _MAKE_TRACING_INPUTS, old_make_tracing_input) return out def trace(mod: torch.nn.Module, example_inputs: tuple = None, check_inputs: list = None, in_place: bool = True): """Trace a module. Identical signature to ``torch.jit.trace``, but first recursively compiles ``mod`` using :func:``compile``. Parameters ---------- mod : torch.nn.Module example_inputs : tuple check_inputs : list of tuple Returns ------- Traced module. """ check_inputs = check_inputs or [] return trace_module( mod=mod, inputs=({"forward": example_inputs} if example_inputs is not None else None), check_inputs=[{"forward": c} for c in check_inputs], in_place=in_place, ) def script(mod: torch.nn.Module, in_place: bool = True): """Script a module. Like ``torch.jit.script``, but first recursively compiles ``mod`` using :func:``compile``. Parameters ---------- mod : torch.nn.Module Returns ------- Scripted module. """ # Set the compile mode for mod, temporarily old_mode = getattr(mod, _E3NN_COMPILE_MODE, None) if old_mode is not None and old_mode != "script": warnings.warn( f"Trying to script a module of type {type(mod).__name__} marked with @compile_mode != 'script', expect errors!" ) setattr(mod, _E3NN_COMPILE_MODE, "script") # Compile out = compile(mod, in_place=in_place) # Restore old values, if we had them if old_mode is not None: setattr(mod, _E3NN_COMPILE_MODE, old_mode) return out @contextmanager def disable_e3nn_codegen(): """Context manager that disables the legacy PyTorch code generation used in e3nn.""" init_val = get_optimization_defaults()["jit_script_fx"] set_optimization_defaults(jit_script_fx=False) yield set_optimization_defaults(jit_script_fx=init_val) e3nn-0.6.0/e3nn/util/test.py000066400000000000000000000463501514371756200155540ustar00rootroot00000000000000import random import math import inspect import itertools import logging from typing import Iterable, Optional, Callable import warnings import numpy as np import torch from e3nn import o3 from e3nn.util.jit import ( compile, get_tracing_inputs, get_compile_mode, _MAKE_TRACING_INPUTS, get_optimization_defaults, set_optimization_defaults, ) from ._argtools import _get_args_in, _get_io_irreps, _transform, _rand_args # pylint: disable=unused-variable # Make a logger for reporting error statistics logger = logging.getLogger(__name__) def _logging_name(func) -> str: """Get a decent string representation of ``func`` for logging""" if inspect.isfunction(func): return func.__name__ else: return repr(func) # The default float tolerance FLOAT_TOLERANCE = {t: torch.as_tensor(v, dtype=t) for t, v in {torch.float32: 1e-3, torch.float64: 1e-9}.items()} try: # If pytest is available, define an e3nn pytest plugin # See https://docs.pytest.org/en/stable/fixture.html#using-fixtures-from-other-projects import pytest @pytest.fixture(scope="session", autouse=True, params=["float32", "float64"]) def float_tolerance(request): """Run all tests with various PyTorch default dtypes. This is a session-wide, autouse fixture — you only need to request it explicitly if a test needs to know the tolerance for the current default dtype. Returns -------- A precision threshold to use for closeness tests. """ old_dtype = torch.get_default_dtype() dtype = {"float32": torch.float32, "float64": torch.float64}[request.param] torch.set_default_dtype(dtype) yield FLOAT_TOLERANCE[dtype] torch.set_default_dtype(old_dtype) except ImportError: pass def random_irreps( n: int = 1, lmax: int = 4, mul_min: int = 0, mul_max: int = 5, len_min: int = 0, len_max: int = 4, clean: bool = False, allow_empty: bool = True, ): r"""Generate random irreps parameters for testing. Parameters ---------- n : int, optional How many to generate; defaults to 1. lmax : int, optional The maximum L to generate (inclusive); defaults to 4. mul_min : int, optional The smallest multiplicity to generate, defaults to 0. mul_max : int, optional The largest multiplicity to generate, defaults to 5. len_min : int, optional The smallest number of irreps to generate, defaults to 0. len_max : int, optional The largest number of irreps to generate, defaults to 4. clean : bool, optional If ``True``, only ``o3.Irreps`` objects will be returned. If ``False`` (the default), ``e3nn.o3.Irreps``-like objects like strings and lists of tuples can be returned. allow_empty : bool, optional Whether to allow generating empty ``e3nn.o3.Irreps``. Returns ------- An irreps-like object if ``n == 1`` or a list of them if ``n > 1`` """ assert n >= 1 assert lmax >= 0 assert mul_min >= 0 assert mul_max >= mul_min if not allow_empty and len_min == 0: len_min = 1 assert len_min >= 0 assert len_max >= len_min out = [] for _ in range(n): this_irreps = [] for _ in range(random.randint(len_min, len_max)): this_irreps.append((random.randint(mul_min, mul_max), (random.randint(0, lmax), random.choice((1, -1))))) if not allow_empty and all(m == 0 for m, _ in this_irreps): this_irreps[-1] = (random.randint(1, mul_max), this_irreps[-1][1]) this_irreps = o3.Irreps(this_irreps) if clean: outtype = "irreps" else: outtype = random.choice(("irreps", "str", "list")) if outtype == "irreps": out.append(this_irreps) elif outtype == "str": out.append(repr(this_irreps)) elif outtype == "list": out.append([(mul_ir.mul, (mul_ir.ir.l, mul_ir.ir.p)) for mul_ir in this_irreps]) if n == 1: return out[0] else: return out def format_equivariance_error(errors: dict) -> str: """Format the dictionary returned by ``equivariance_error`` into a readable string. Parameters ---------- errors : dict A dictionary of errors returned by ``equivariance_error``. Returns ------- A string. """ return "\n".join( "(parity_k={:d}, did_translate={}) -> max error={:.3e} in argument {}".format( int(k[0]), bool(k[1]), float(v.max()), int(v.argmax()) ) for k, v in errors.items() ) def assert_equivariant(func, args_in=None, irreps_in=None, irreps_out=None, tolerance=None, **kwargs) -> dict: r"""Assert that ``func`` is equivariant. Parameters ---------- args_in : list or None the original input arguments for the function. If ``None`` and the function has ``irreps_in`` consisting only of ``o3.Irreps`` and ``'cartesian'``, random test inputs will be generated. irreps_in : object see ``equivariance_error`` irreps_out : object see ``equivariance_error`` tolerance : float or None the threshold below which the equivariance error must fall. If ``None``, (the default), ``FLOAT_TOLERANCE[torch.get_default_dtype()]`` is used. **kwargs : kwargs passed through to ``equivariance_error``. Returns ------- The same as ``equivariance_error``: a dictionary mapping tuples ``(parity_k, did_translate)`` to errors """ # Prevent pytest from showing this function in the traceback __tracebackhide__ = True args_in, irreps_in, irreps_out = _get_args_in(func, args_in=args_in, irreps_in=irreps_in, irreps_out=irreps_out) # Get error errors = equivariance_error(func, args_in=args_in, irreps_in=irreps_in, irreps_out=irreps_out, **kwargs) logger.info( "Tested equivariance of `%s` -- max componentwise errors:\n%s", _logging_name(func), format_equivariance_error(errors), ) # Check it if tolerance is None: tolerance = FLOAT_TOLERANCE[torch.get_default_dtype()] problems = {case: err for case, err in errors.items() if err.max() > tolerance} if len(problems) != 0: errstr = "Largest componentwise equivariance error was too large for: " errstr += format_equivariance_error(problems) assert len(problems) == 0, errstr return errors def equivariance_error( func, args_in, irreps_in=None, irreps_out=None, ntrials: int = 1, do_parity: bool = True, do_translation: bool = True, transform_dtype=torch.float64, ): r"""Get the maximum equivariance error for ``func`` over ``ntrials`` Each trial randomizes the equivariant transformation tested. Parameters ---------- func : callable the function to test args_in : list the original inputs to pass to ``func``. irreps_in : list of `e3nn.o3.Irreps` or `e3nn.o3.Irreps` the input irreps for each of the arguments in ``args_in``. If left as the default of ``None``, ``get_io_irreps`` will be used to try to infer them. If a sequence is provided, valid elements are also the string ``'cartesian'``, which denotes that the corresponding input should be dealt with as cartesian points in 3D, and ``None``, which indicates that the argument should not be transformed. irreps_out : list of `e3nn.o3.Irreps` or `e3nn.o3.Irreps` the out irreps for each of the return values of ``func``. Accepts similar values to ``irreps_in``. ntrials : int run this many trials with random transforms do_parity : bool whether to test parity do_translation : bool whether to test translation for ``'cartesian'`` inputs Returns ------- dictionary mapping tuples ``(parity_k, did_translate)`` to an array of errors, each entry the biggest over all trials for that output, in order. """ irreps_in, irreps_out = _get_io_irreps(func, irreps_in=irreps_in, irreps_out=irreps_out) if do_parity: parity_ks = [0, 1] else: parity_ks = [0] if "cartesian_points" not in irreps_in: # There's nothing to translate do_translation = False if do_translation: do_translation = [False, True] else: do_translation = [False] tests = list(itertools.product(parity_ks, do_translation)) neg_inf = -float("Inf") device = next(t.device for t in args_in if isinstance(t, torch.Tensor)) biggest_errs = {test: torch.full((len(irreps_out),), neg_inf, dtype=transform_dtype, device=device) for test in tests} for trial in range(ntrials): for this_test in tests: parity_k, this_do_translate = this_test # Build a rotation matrix for point data rot_mat = o3.rand_matrix(dtype=transform_dtype) # add parity rot_mat *= (-1) ** parity_k # build translation translation = 10 * torch.randn(1, 3, dtype=rot_mat.dtype) if this_do_translate else 0.0 # Evaluate the function on rotated arguments: rot_args = _transform(args_in, irreps_in, rot_mat, translation) x1 = func(*rot_args) if isinstance(x1, torch.Tensor): x1 = [x1] elif isinstance(x1, (list, tuple)): x1 = list(x1) else: raise TypeError(f"equivariance_error cannot handle output type {type(x1)}") # if `func` was a model, the outputs might be attached in the autograd graph # convert into the transform dtype for computing the difference x1 = [t.detach().to(transform_dtype) for t in x1] # Evaluate the function on the arguments, then apply group action: x2 = func(*args_in) if isinstance(x2, torch.Tensor): x2 = [x2] elif isinstance(x2, (list, tuple)): x2 = list(x2) else: raise TypeError(f"equivariance_error cannot handle output type {type(x2)}") x2 = [t.detach() for t in x2] # confirm sanity assert len(x1) == len(x2) assert len(x1) == len(irreps_out) # apply the group action to x2 # get this in the transform dtype x2 = _transform(x2, irreps_out, rot_mat, translation, output_transform_dtype=True) # compute errors in the transform dtype, # then convert back to default later errors = torch.stack([(a - b).abs().max() for a, b in zip(x1, x2)]) biggest_errs[this_test] = torch.where(errors > biggest_errs[this_test], errors, biggest_errs[this_test]) # convert errors back to default dtype to return: return {k: v.to(torch.get_default_dtype()) for k, v in biggest_errs.items()} # TODO: this is only for things marked with @compile_mode. # Make something else for general script/traceability def assert_auto_jitable( func, error_on_warnings: bool = True, n_trace_checks: int = 2, strict_shapes: bool = True, ): r"""Assert that submodule ``func`` is automatically JITable. Parameters ---------- func : Callable The function to trace. error_on_warnings : bool If True (default), TracerWarnings emitted by ``torch.jit.trace`` will be treated as errors. n_random_tests : int If ``args_in`` is ``None`` and arguments are being automatically generated, this many random arguments will be generated as test inputs for ``torch.jit.trace``. strict_shapes : bool Test that the traced function errors on inputs with feature dimensions that don't match the input irreps. Returns ------- The traced TorchScript function. """ # Prevent pytest from showing this function in the traceback __tracebackhide__ = True if get_compile_mode(func) is None: raise ValueError("assert_auto_jitable is only for modules marked with @compile_mode") # Test tracing with warnings.catch_warnings(): if error_on_warnings: warnings.filterwarnings("error", category=torch.jit.TracerWarning) func_jit = compile(func, n_trace_checks=n_trace_checks) # Confirm that it rejects incorrect shapes # This check only makes sense if all inputs are Tensors with irreps; otherwise we can't know how to modify the arguments # or that our modifications make them wrong. if strict_shapes and not hasattr(func, _MAKE_TRACING_INPUTS): try: all_bad_args = get_tracing_inputs(func, n=1)[0] except ValueError: # couldn't infer, don't check pass else: for method, bad_args in all_bad_args.items(): # Since _rand_args is OK, they're all Irreps style args where changing the feature dimension is wrong bad_which = random.randint(0, len(bad_args) - 1) bad_args = list(bad_args) bad_args[bad_which] = bad_args[bad_which][..., : -random.randint(1, 3)] # make bad shape try: if method == "forward": func_jit(*bad_args) else: getattr(func_jit, method)(*bad_args) except (torch.jit.Error, RuntimeError): # type: ignore # As far as I can tell, there's no good way to introspect TorchScript exceptions. pass else: raise AssertionError("Traced function didn't error on bad input shape") return func_jit def assert_torch_compile( compile_mode: str, func: Callable, *args, **kwargs, ) -> None: r"""Assert that func is torch.compile(fullgraph=True) Parameters ---------- func: Callable thats a functools.partial(torch.nn.Module) *args: func's forward arguments **kwargs: func's forward positional arguments """ # Turning off the torch.jit.script in CodeGenMix to enable torch.compile. jit_mode_before = get_optimization_defaults()["jit_mode"] try: set_optimization_defaults(jit_mode=compile_mode) m = func() torch._dynamo.reset() # Clear cache from previous runs m_pt2 = torch.compile(m, fullgraph=True) m_pt2(*args, **kwargs) finally: set_optimization_defaults(jit_mode=jit_mode_before) return m_pt2 # TODO: custom in_vars, out_vars support def assert_normalized( func: torch.nn.Module, irreps_in=None, irreps_out=None, normalization: str = "component", n_input: int = 10_000, n_weight: Optional[int] = None, weights: Optional[Iterable[torch.nn.Parameter]] = None, atol: float = 0.1, ) -> None: r"""Assert that ``func`` is normalized. See https://docs.e3nn.org/en/stable/guide/normalization.html for more information on the normalization scheme. ``atol``, ``n_input``, and ``n_weight`` may need to be significantly higher in order to converge the statistics to pass the test. Parameters ---------- func : torch.nn.Module the module to test irreps_in : object see ``equivariance_error`` irreps_out : object see ``equivariance_error`` normalization : str, default "component" one of "component" or "norm". Note that this is defined for both the inputs and the outputs; if you need seperate normalizations for input and output please file a feature request. n_input : int, default 10_000 the number of input samples to use for each weight init n_weight : int, default 20 the number of weight initializations to sample weights : optional iterable of parameters the weights to reinitialize ``n_weight`` times. If ``None`` (default), ``func.parameters()`` will be used. atol : float, default 0.1 tolerance for checking moments. Higher values for this prevent explosive computational costs for this test. """ # Prevent pytest from showing this function in the traceback __tracebackhide__ = True if normalization not in ("component", "norm"): raise ValueError(f"invalid normalization `{normalization}`") irreps_in, irreps_out = _get_io_irreps(func, irreps_in=irreps_in, irreps_out=irreps_out) if all(i.num_irreps == 0 for i in irreps_in) or all(i.num_irreps == 0 for i in irreps_out): # Short-circut return if weights is None: if isinstance(func, torch.nn.Module): weights = func.parameters() else: weights = [] weights = list(weights) if len(weights) == 0: assert n_weight is None or n_weight == 1, "Without weights to re-init, n_weight must be 1 or None" n_weight = 1 else: n_weight = 20 if n_weight is None else n_weight with torch.no_grad(): expected_squares = [torch.zeros(irreps.dim) for irreps in irreps_out] n_samples = 0 for weight_init in range(n_weight): # generate weight sample for param in weights: param.normal_() # generate input sample args_in = _rand_args(irreps_in, batch_size=n_input) # args_in gives component normalized irreps if normalization == "norm": for i, irreps in enumerate(irreps_in): for mul_ir, ir_slice in zip(irreps, irreps.slices()): args_in[i][:, ir_slice].div_(math.sqrt(mul_ir.ir.dim)) # run func this_outs = func(*args_in) if not isinstance(this_outs, list) or isinstance(this_outs, tuple): this_outs = (this_outs,) assert len(this_outs) == len(irreps_out) # square this_outs = [e.square() for e in this_outs] # update running average for i in range(len(irreps_out)): assert this_outs[i].shape[0] == n_input update = this_outs[i].sum(dim=0) - n_input * expected_squares[i] update.div_(n_input + n_samples) expected_squares[i].add_(update) n_samples += n_input # check them for expected_square, irreps in zip(expected_squares, irreps_out): if irreps == "cartesian_points" or irreps is None: continue if normalization == "component": targets = [1.0] * len(irreps) elif normalization == "norm": targets = [1.0 / math.sqrt(ir.dim) for _, ir in irreps] for i, (target, ir_slice) in enumerate(zip(targets, irreps.slices())): if ir_slice.start == ir_slice.stop: continue max_componentwise = (expected_square[ir_slice] - target).abs().max().item() logger.info("Tested normalization of %r: max componentwise error %.6f", _logging_name(func), max_componentwise) assert max_componentwise <= atol, ( f"< x_i^2 > !~= {target:.6f} for output irrep #{i}, {irreps[i]}." f"Max componentwise error: {max_componentwise:.6f}" ) def set_random_seeds() -> None: """Set the random seeds to try to get some reproducibility""" torch.manual_seed(0) random.seed(0) np.random.seed(0) e3nn-0.6.0/examples/000077500000000000000000000000001514371756200142115ustar00rootroot00000000000000e3nn-0.6.0/examples/README.md000066400000000000000000000001671514371756200154740ustar00rootroot00000000000000Examples ======== More examples are available in [the user guides](https://docs.e3nn.org/en/stable/guide/guide.html). e3nn-0.6.0/examples/atom_types.py000066400000000000000000000071331514371756200167530ustar00rootroot00000000000000"""Different paramters for the different atom types based on `tetris_polynomial` idea: if we have num_z types of atoms we have num_z^2 types of edges. Instead of having spherical harmonics for the edge attributes we have num_z^2 times the spherical harmonics, all zero except for the type of the edge. >>> test() """ import torch from torch_cluster import radius_graph from torch_geometric.data import Data, DataLoader from torch_scatter import scatter from e3nn import o3 from e3nn.o3 import FullyConnectedTensorProduct, TensorProduct class InvariantPolynomial(torch.nn.Module): def __init__(self, irreps_out, num_z, lmax) -> None: super().__init__() self.num_z = num_z self.irreps_sh = o3.Irreps.spherical_harmonics(lmax) # to multiply the edge type one-hot with the spherical harmonics to get the edge attributes self.mul = TensorProduct( [(num_z**2, "0e")], self.irreps_sh, [(num_z**2, ir) for _, ir in self.irreps_sh], [(0, l, l, "uvu", False) for l in range(lmax + 1)], ) irreps_attr = self.mul.irreps_out irreps_mid = o3.Irreps("64x0e + 24x1e + 24x1o + 16x2e + 16x2o") irreps_out = o3.Irreps(irreps_out) self.tp1 = FullyConnectedTensorProduct( irreps_in1=self.irreps_sh, irreps_in2=irreps_attr, irreps_out=irreps_mid, ) self.tp2 = FullyConnectedTensorProduct( irreps_in1=irreps_mid, irreps_in2=irreps_attr, irreps_out=irreps_out, ) def forward(self, data) -> torch.Tensor: num_neighbors = 3 # typical number of neighbors num_nodes = 4 # typical number of nodes num_z = self.num_z # number of atom types # graph edge_src, edge_dst = radius_graph(data.pos, 10.0, data.batch) # spherical harmonics edge_vec = data.pos[edge_src] - data.pos[edge_dst] edge_sh = o3.spherical_harmonics(self.irreps_sh, edge_vec, normalize=False, normalization="component") # edge types edge_zz = num_z * data.z[edge_src] + data.z[edge_dst] # from 0 to num_z^2 - 1 edge_zz = torch.nn.functional.one_hot(edge_zz, num_z**2).mul(num_z) edge_zz = edge_zz.to(edge_sh.dtype) # edge attributes edge_attr = self.mul(edge_zz, edge_sh) # For each node, the initial features are the sum of the spherical harmonics of the neighbors node_features = scatter(edge_sh, edge_dst, dim=0).div(num_neighbors**0.5) # For each edge, tensor product the features on the source node with the spherical harmonics edge_features = self.tp1(node_features[edge_src], edge_attr) node_features = scatter(edge_features, edge_dst, dim=0).div(num_neighbors**0.5) edge_features = self.tp2(node_features[edge_src], edge_attr) node_features = scatter(edge_features, edge_dst, dim=0).div(num_neighbors**0.5) # For each graph, all the node's features are summed return scatter(node_features, data.batch, dim=0).div(num_nodes**0.5) def test() -> None: torch.set_default_dtype(torch.float64) pos = torch.tensor( [ [0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.5], ] ) # atom type z = torch.tensor([0, 1, 2, 2]) dataset = [Data(pos=pos @ R.T, z=z) for R in o3.rand_matrix(10)] data = next(iter(DataLoader(dataset, batch_size=len(dataset)))) f = InvariantPolynomial("0e+0o", num_z=3, lmax=3) out = f(data) # expect invariant output assert out.std(0).max() < 1e-5 e3nn-0.6.0/examples/s2cnn/000077500000000000000000000000001514371756200152345ustar00rootroot00000000000000e3nn-0.6.0/examples/s2cnn/mnist/000077500000000000000000000000001514371756200163665ustar00rootroot00000000000000e3nn-0.6.0/examples/s2cnn/mnist/README.md000066400000000000000000000001701514371756200176430ustar00rootroot00000000000000Try to reproduce https://github.com/jonkhler/s2cnn/tree/master/examples/mnist ``` python gendata.py python train.py ```e3nn-0.6.0/examples/s2cnn/mnist/gendata.py000066400000000000000000000166431514371756200203550ustar00rootroot00000000000000"""Module to generate the spherical mnist data set""" import argparse import gzip import pickle import numpy as np from torchvision import datasets from e3nn.o3 import s2_grid NORTHPOLE_EPSILON = 1e-3 def rand_rotation_matrix(deflection: float = 1.0, randnums=None): """ Creates a random rotation matrix. deflection: the magnitude of the rotation. For 0, no rotation; for 1, competely random rotation. Small deflection => small perturbation. randnums: 3 random numbers in the range [0, 1]. If `None`, they will be auto-generated. # http://blog.lostinmyterminal.com/python/2015/05/12/random-rotation-matrix.html """ if randnums is None: randnums = np.random.uniform(size=(3,)) theta, phi, z = randnums theta = theta * 2.0 * deflection * np.pi # Rotation about the pole (Z). phi = phi * 2.0 * np.pi # For direction of pole deflection. z = z * 2.0 * deflection # For magnitude of pole deflection. # Compute a vector V used for distributing points over the sphere # via the reflection I - V Transpose(V). This formulation of V # will guarantee that if x[1] and x[2] are uniformly distributed, # the reflected points will be uniform on the sphere. Note that V # has length sqrt(2) to eliminate the 2 in the Householder matrix. r = np.sqrt(z) V = (np.sin(phi) * r, np.cos(phi) * r, np.sqrt(2.0 - z)) st = np.sin(theta) ct = np.cos(theta) R = np.array(((ct, st, 0), (-st, ct, 0), (0, 0, 1))) # Construct the rotation matrix ( V Transpose(V) - I ) R. M = (np.outer(V, V) - np.eye(3)).dot(R) return M def rotate_grid(rot, grid): x, y, z = grid xyz = np.stack((x, y, z)) x_r, y_r, z_r = np.einsum("ij,jab->iab", rot, xyz) return x_r, y_r, z_r def get_projection_grid(b): """returns the spherical grid in euclidean coordinates, where the sphere's center is moved to (0, 0, 1)""" theta, phi = s2_grid(2 * b, 2 * b) phi, theta = np.meshgrid(phi.numpy(), theta.numpy(), indexing="ij") x_ = np.sin(theta) * np.cos(phi) y_ = np.sin(theta) * np.sin(phi) z_ = np.cos(theta) return x_, y_, z_ def project_sphere_on_xy_plane(grid, projection_origin): """returns xy coordinates on the plane obtained from projecting each point of the spherical grid along the ray from the projection origin through the sphere""" sx, sy, sz = projection_origin x, y, z = grid z = z.copy() + 1 t = -z / (z - sz) qx = t * (x - sx) + x qy = t * (y - sy) + y xmin = 1 / 2 * (-1 - sx) + -1 ymin = 1 / 2 * (-1 - sy) + -1 # ensure that plane projection # ends up on southern hemisphere rx = (qx - xmin) / (2 * np.abs(xmin)) ry = (qy - ymin) / (2 * np.abs(ymin)) return rx, ry def sample_within_bounds(signal, x, y, bounds): xmin, xmax, ymin, ymax = bounds idxs = (xmin <= x) & (x < xmax) & (ymin <= y) & (y < ymax) if len(signal.shape) > 2: sample = np.zeros((signal.shape[0], x.shape[0], x.shape[1])) sample[:, idxs] = signal[:, x[idxs], y[idxs]] else: sample = np.zeros((x.shape[0], x.shape[1])) sample[idxs] = signal[x[idxs], y[idxs]] return sample def sample_bilinear(signal, rx, ry): signal_dim_x = signal.shape[1] signal_dim_y = signal.shape[2] rx *= signal_dim_x ry *= signal_dim_y # discretize sample position ix = rx.astype(int) iy = ry.astype(int) # obtain four sample coordinates ix0 = ix - 1 iy0 = iy - 1 ix1 = ix + 1 iy1 = iy + 1 bounds = (0, signal_dim_x, 0, signal_dim_y) # sample signal at each four positions signal_00 = sample_within_bounds(signal, ix0, iy0, bounds) signal_10 = sample_within_bounds(signal, ix1, iy0, bounds) signal_01 = sample_within_bounds(signal, ix0, iy1, bounds) signal_11 = sample_within_bounds(signal, ix1, iy1, bounds) # linear interpolation in x-direction fx1 = (ix1 - rx) * signal_00 + (rx - ix0) * signal_10 fx2 = (ix1 - rx) * signal_01 + (rx - ix0) * signal_11 # linear interpolation in y-direction return (iy1 - ry) * fx1 + (ry - iy0) * fx2 def project_2d_on_sphere(signal, grid, projection_origin=None): if projection_origin is None: projection_origin = (0, 0, 2 + NORTHPOLE_EPSILON) rx, ry = project_sphere_on_xy_plane(grid, projection_origin) sample = sample_bilinear(signal, rx, ry) # ensure that only south hemisphere gets projected sample *= (grid[2] <= 1).astype(np.float64) # rescale signal to [0,1] sample_min = sample.min(axis=(1, 2)).reshape(-1, 1, 1) sample_max = sample.max(axis=(1, 2)).reshape(-1, 1, 1) sample = (sample - sample_min) / (sample_max - sample_min) sample *= 255 sample = sample.astype(np.uint8) return sample def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--bandwidth", help="the bandwidth of the S2 signal", type=int, default=30, required=False) parser.add_argument("--noise", help="the rotational noise applied on the sphere", type=float, default=1.0, required=False) parser.add_argument("--chunk_size", help="size of image chunk with same rotation", type=int, default=500, required=False) parser.add_argument( "--mnist_data_folder", help="folder for saving the mnist data", type=str, default="MNIST_data", required=False ) parser.add_argument( "--output_file", help="file for saving the data output (.gz file)", type=str, default="s2_mnist.gz", required=False ) parser.add_argument("--no_rotate_train", help="do not rotate train set", dest="no_rotate_train", action="store_true") parser.add_argument("--no_rotate_test", help="do not rotate test set", dest="no_rotate_test", action="store_true") args = parser.parse_args() print("getting mnist data") trainset = datasets.MNIST(root=args.mnist_data_folder, train=True, download=True) testset = datasets.MNIST(root=args.mnist_data_folder, train=False, download=True) mnist_train = {} mnist_train["images"] = trainset.data.numpy() mnist_train["labels"] = trainset.targets.numpy() mnist_test = {} mnist_test["images"] = testset.data.numpy() mnist_test["labels"] = testset.targets.numpy() grid = get_projection_grid(b=args.bandwidth) # result dataset = {} no_rotate = {"train": args.no_rotate_train, "test": args.no_rotate_test} for label, data in zip(["train", "test"], [mnist_train, mnist_test]): print(f"projecting {label} data set") current = 0 signals = data["images"].reshape(-1, 28, 28).astype(np.float64) n_signals = signals.shape[0] projections = np.ndarray((signals.shape[0], 2 * args.bandwidth, 2 * args.bandwidth), dtype=np.uint8) while current < n_signals: if not no_rotate[label]: rot = rand_rotation_matrix(deflection=args.noise) rotated_grid = rotate_grid(rot, grid) else: rotated_grid = grid idxs = np.arange(current, min(n_signals, current + args.chunk_size)) chunk = signals[idxs] projections[idxs] = project_2d_on_sphere(chunk, rotated_grid) current += args.chunk_size print(f"\r{current}/{n_signals}", end="") print("") dataset[label] = {"images": projections, "labels": data["labels"]} print("writing pickle") with gzip.open(args.output_file, "wb") as f: pickle.dump(dataset, f) print("done") if __name__ == "__main__": main() e3nn-0.6.0/examples/s2cnn/mnist/train.py000066400000000000000000000160531514371756200200620ustar00rootroot00000000000000import gzip import math import pickle import numpy as np import torch from e3nn import o3 from e3nn.nn import SO3Activation def s2_near_identity_grid(max_beta: float = math.pi / 8, n_alpha: int = 8, n_beta: int = 3) -> torch.Tensor: """ :return: rings around the north pole size of the kernel = n_alpha * n_beta """ beta = torch.arange(1, n_beta + 1) * max_beta / n_beta alpha = torch.linspace(0, 2 * math.pi, n_alpha + 1)[:-1] a, b = torch.meshgrid(alpha, beta, indexing="ij") b = b.flatten() a = a.flatten() return torch.stack((a, b)) def so3_near_identity_grid( max_beta: float = math.pi / 8, max_gamma: float = 2 * math.pi, n_alpha: int = 8, n_beta: int = 3, n_gamma=None ) -> torch.Tensor: """ :return: rings of rotations around the identity, all points (rotations) in a ring are at the same distance from the identity size of the kernel = n_alpha * n_beta * n_gamma """ if n_gamma is None: n_gamma = n_alpha # similar to regular representations beta = torch.arange(1, n_beta + 1) * max_beta / n_beta alpha = torch.linspace(0, 2 * math.pi, n_alpha)[:-1] pre_gamma = torch.linspace(-max_gamma, max_gamma, n_gamma) A, B, preC = torch.meshgrid(alpha, beta, pre_gamma, indexing="ij") C = preC - A A = A.flatten() B = B.flatten() C = C.flatten() return torch.stack((A, B, C)) def s2_irreps(lmax: int) -> o3.Irreps: return o3.Irreps([(1, (l, 1)) for l in range(lmax + 1)]) def so3_irreps(lmax: int) -> o3.Irreps: return o3.Irreps([(2 * l + 1, (l, 1)) for l in range(lmax + 1)]) def flat_wigner(lmax: int, alpha: torch.Tensor, beta: torch.Tensor, gamma: torch.Tensor) -> torch.Tensor: return torch.cat([(2 * l + 1) ** 0.5 * o3.wigner_D(l, alpha, beta, gamma).flatten(-2) for l in range(lmax + 1)], dim=-1) class S2Convolution(torch.nn.Module): def __init__(self, f_in, f_out, lmax, kernel_grid) -> None: super().__init__() self.register_parameter( "w", torch.nn.Parameter(torch.randn(f_in, f_out, kernel_grid.shape[1])) ) # [f_in, f_out, n_s2_pts] self.register_buffer( "Y", o3.spherical_harmonics_alpha_beta(range(lmax + 1), *kernel_grid, normalization="component") ) # [n_s2_pts, psi] self.lin = o3.Linear(s2_irreps(lmax), so3_irreps(lmax), f_in=f_in, f_out=f_out, internal_weights=False) def forward(self, x): psi = torch.einsum("ni,xyn->xyi", self.Y, self.w) / self.Y.shape[0] ** 0.5 return self.lin(x, weight=psi) class SO3Convolution(torch.nn.Module): def __init__(self, f_in, f_out, lmax, kernel_grid) -> None: super().__init__() self.register_parameter( "w", torch.nn.Parameter(torch.randn(f_in, f_out, kernel_grid.shape[1])) ) # [f_in, f_out, n_so3_pts] self.register_buffer("D", flat_wigner(lmax, *kernel_grid)) # [n_so3_pts, psi] self.lin = o3.Linear(so3_irreps(lmax), so3_irreps(lmax), f_in=f_in, f_out=f_out, internal_weights=False) def forward(self, x): psi = torch.einsum("ni,xyn->xyi", self.D, self.w) / self.D.shape[0] ** 0.5 return self.lin(x, weight=psi) class S2ConvNet_original(torch.nn.Module): def __init__(self) -> None: super().__init__() f1 = 20 f2 = 40 f_output = 10 b_in = 60 lmax1 = 10 b_l1 = 10 lmax2 = 5 b_l2 = 6 grid_s2 = s2_near_identity_grid() grid_so3 = so3_near_identity_grid() self.from_s2 = o3.FromS2Grid((b_in, b_in), lmax1) self.conv1 = S2Convolution(1, f1, lmax1, kernel_grid=grid_s2) self.act1 = SO3Activation(lmax1, lmax2, torch.relu, b_l1) self.conv2 = SO3Convolution(f1, f2, lmax2, kernel_grid=grid_so3) self.act2 = SO3Activation(lmax2, 0, torch.relu, b_l2) self.w_out = torch.nn.Parameter(torch.randn(f2, f_output)) def forward(self, x): x = x.transpose(-1, -2) # [batch, features, alpha, beta] -> [batch, features, beta, alpha] x = self.from_s2(x) # [batch, features, beta, alpha] -> [batch, features, irreps] x = self.conv1(x) # [batch, features, irreps] -> [batch, features, irreps] x = self.act1(x) # [batch, features, irreps] -> [batch, features, irreps] x = self.conv2(x) # [batch, features, irreps] -> [batch, features, irreps] x = self.act2(x) # [batch, features, scalar] x = x.flatten(1) @ self.w_out / self.w_out.shape[0] return x MNIST_PATH = "s2_mnist.gz" DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") NUM_EPOCHS = 20 BATCH_SIZE = 32 LEARNING_RATE = 5e-3 def load_data(path, batch_size): with gzip.open(path, "rb") as f: dataset = pickle.load(f) train_data = torch.from_numpy(dataset["train"]["images"][:, None, :, :].astype(np.float32)) train_labels = torch.from_numpy(dataset["train"]["labels"].astype(np.int64)) # train_data /= 57 This normalization was hurtful, see @dmklee comment in discussions/344 train_dataset = torch.utils.data.TensorDataset(train_data, train_labels) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_data = torch.from_numpy(dataset["test"]["images"][:, None, :, :].astype(np.float32)) test_labels = torch.from_numpy(dataset["test"]["labels"].astype(np.int64)) # test_data /= 57 test_dataset = torch.utils.data.TensorDataset(test_data, test_labels) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True) return train_loader, test_loader, train_dataset, test_dataset def main() -> None: train_loader, test_loader, train_dataset, _ = load_data(MNIST_PATH, BATCH_SIZE) classifier = S2ConvNet_original() classifier.to(DEVICE) print("#params", sum(x.numel() for x in classifier.parameters())) optimizer = torch.optim.Adam(classifier.parameters(), lr=LEARNING_RATE) for epoch in range(NUM_EPOCHS): for i, (images, labels) in enumerate(train_loader): classifier.train() images = images.to(DEVICE) labels = labels.to(DEVICE) optimizer.zero_grad() outputs = classifier(images) loss = torch.nn.functional.cross_entropy(outputs, labels) loss.backward() optimizer.step() print( "\rEpoch [{0}/{1}], Iter [{2}/{3}] Loss: {4:.4f}".format( epoch + 1, NUM_EPOCHS, i + 1, len(train_dataset) // BATCH_SIZE, loss.item() ), end="", ) print("") correct = 0 total = 0 for images, labels in test_loader: classifier.eval() with torch.no_grad(): images = images.to(DEVICE) labels = labels.to(DEVICE) outputs = classifier(images) _, predicted = torch.max(outputs, 1) total += labels.size(0) correct += (predicted == labels).long().sum().item() print(f"Test Accuracy: {100 * correct / total}") if __name__ == "__main__": main() e3nn-0.6.0/examples/tensor_product_benchmark.py000066400000000000000000000065731514371756200216620ustar00rootroot00000000000000import argparse import logging import torch from torch.utils.benchmark import Timer from e3nn.o3 import Irreps, FullyConnectedTensorProduct, ElementwiseTensorProduct from e3nn.util.jit import compile logging.basicConfig(level=logging.DEBUG) # https://stackoverflow.com/a/15008806/1008938 def t_or_f(arg) -> bool: ua = str(arg).upper() if "TRUE".startswith(ua): return True elif "FALSE".startswith(ua): return False else: raise ValueError(str(arg)) def main() -> None: parser = argparse.ArgumentParser(prog="tensor_product_benchmark") parser.add_argument("--jit", type=t_or_f, default=True) parser.add_argument("--irreps", type=str, default="8x0e + 8x1e + 8x2e + 8x3o") parser.add_argument("--irreps-in1", type=str, default=None) parser.add_argument("--irreps-in2", type=str, default=None) parser.add_argument("--irreps-out", type=str, default=None) parser.add_argument("--cuda", type=t_or_f, default=True) parser.add_argument("--backward", type=t_or_f, default=True) parser.add_argument("--opt-ein", type=t_or_f, default=True) parser.add_argument("--specialized-code", type=t_or_f, default=True) parser.add_argument("--elementwise", action="store_true") parser.add_argument("-n", type=int, default=1000) parser.add_argument("--batch", type=int, default=10) args = parser.parse_args() device = "cuda" if (torch.cuda.is_available() and args.cuda) else "cpu" args.cuda = device == "cuda" print("======= Benchmark with settings: ======") for key, val in vars(args).items(): print(f"{key:>18} : {val}") print("=" * 40) irreps_in1 = Irreps(args.irreps_in1 if args.irreps_in1 else args.irreps) irreps_in2 = Irreps(args.irreps_in2 if args.irreps_in2 else args.irreps) irreps_out = Irreps(args.irreps_out if args.irreps_out else args.irreps) if args.elementwise: tp = ElementwiseTensorProduct( irreps_in1, irreps_in2, _specialized_code=args.specialized_code, _optimize_einsums=args.opt_ein ) if args.backward: print("Elementwise TP has no weights, cannot backward. Setting --backward False.") args.backward = False else: tp = FullyConnectedTensorProduct( irreps_in1, irreps_in2, irreps_out, _specialized_code=args.specialized_code, _optimize_einsums=args.opt_ein ) tp = tp.to(device=device) assert len(tp.instructions) > 0, "Bad irreps, no instructions" print(f"Tensor product: {tp}") print("Instructions:") for ins in tp.instructions: print(f" {ins}") # from https://pytorch.org/docs/master/_modules/torch/utils/benchmark/utils/timer.html#Timer.timeit warmup = max(int(args.n // 100), 1) inputs = iter( [ (irreps_in1.randn(args.batch, -1).to(device=device), irreps_in2.randn(args.batch, -1).to(device=device)) for _ in range(args.n + warmup) ] ) # compile if args.jit: tp = compile(tp) print("starting...") # tanh() forces it to realize the grad as a full size matrix rather than expanded (stride 0) ones t = Timer( stmt=("tp.zero_grad()\n" "out = tp(*next(inputs))\n" + ("out.tanh().sum().backward()\n" if args.backward else "")), globals={"tp": tp, "inputs": inputs}, ) perloop = t.timeit(args.n) print() print(perloop) if __name__ == "__main__": main() e3nn-0.6.0/examples/tensor_product_profile.py000066400000000000000000000064161514371756200213640ustar00rootroot00000000000000import argparse import logging import torch from e3nn.o3 import Irreps, FullyConnectedTensorProduct from e3nn.util.jit import compile logging.basicConfig(level=logging.DEBUG) # https://stackoverflow.com/a/15008806/1008938 def t_or_f(arg) -> bool: ua = str(arg).upper() if "TRUE".startswith(ua): return True elif "FALSE".startswith(ua): return False else: raise ValueError(str(arg)) def main() -> None: parser = argparse.ArgumentParser(prog="tensor_product_benchmark") parser.add_argument("--jit", type=t_or_f, default=True) parser.add_argument("--irreps-in1", type=str, default="8x0e + 8x1e + 8x2e + 8x3e") parser.add_argument("--irreps-in2", type=str, default="8x0e + 8x1e + 8x2e + 8x3e") parser.add_argument("--irreps-out", type=str, default="8x0e + 8x1e + 8x2e + 8x3e") parser.add_argument("--cuda", type=t_or_f, default=True) parser.add_argument("--backward", type=t_or_f, default=True) parser.add_argument("--opt-ein", type=t_or_f, default=True) parser.add_argument("--specialized-code", type=t_or_f, default=True) parser.add_argument("-w", type=int, default=10) parser.add_argument("-n", type=int, default=3) parser.add_argument("--batch", type=int, default=10) args = parser.parse_args() device = "cuda" if (torch.cuda.is_available() and args.cuda) else "cpu" args.cuda = device == "cuda" if args.cuda: # Workaround for CUDA driver issues # See https://github.com/pytorch/pytorch/issues/60158#issuecomment-866294291 with torch.profiler.profile() as _: pass print("======= Benchmark with settings: ======") for key, val in vars(args).items(): print(f"{key:>18} : {val}") print("=" * 40) irreps_in1 = Irreps(args.irreps_in1) irreps_in2 = Irreps(args.irreps_in2) irreps_out = Irreps(args.irreps_out) tp = FullyConnectedTensorProduct( irreps_in1, irreps_in2, irreps_out, _specialized_code=args.specialized_code, _optimize_einsums=args.opt_ein ) tp = tp.to(device=device) inputs = [ (irreps_in1.randn(args.batch, -1).to(device=device), irreps_in2.randn(args.batch, -1).to(device=device)) for _ in range(1 + args.w + args.n) ] if args.backward: for tmp in inputs: for t in tmp: t.requires_grad_(True) inputs = iter(inputs) # compile if args.jit: print("JITing...") tp = compile(tp) print("starting...") called_num = [0] def trace_handler(p) -> None: print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) p.export_chrome_trace("test_trace_" + str(called_num[0]) + ".json") called_num[0] += 1 with torch.profiler.profile( activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], schedule=torch.profiler.schedule(wait=1, warmup=args.w, active=args.n), on_trace_ready=trace_handler, ) as p: for _ in range(1 + args.w + args.n): out = tp(*next(inputs)) if args.backward: # tanh() forces it to realize the grad as a full size matrix rather than expanded (stride 0) ones out.tanh().sum().backward() p.step() if __name__ == "__main__": main() e3nn-0.6.0/examples/tetris.py000066400000000000000000000105601514371756200160770ustar00rootroot00000000000000"""Classify tetris using gate activation function Implement a equivariant model using gates to fit the tetris dataset Exact equivariance to :math:`E(3)` >>> test() """ import torch from torch_geometric.data import Data, DataLoader from e3nn import o3 from e3nn.nn.models.v2106.gate_points_networks import SimpleNetwork def tetris() -> None: pos = [ [(0, 0, 0), (0, 0, 1), (1, 0, 0), (1, 1, 0)], # chiral_shape_1 [(0, 0, 0), (0, 0, 1), (1, 0, 0), (1, -1, 0)], # chiral_shape_2 [(0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 0)], # square [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3)], # line [(0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0)], # corner [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 0)], # L [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 1)], # T [(0, 0, 0), (1, 0, 0), (1, 1, 0), (2, 1, 0)], # zigzag ] pos = torch.tensor(pos, dtype=torch.get_default_dtype()) # Since chiral shapes are the mirror of one another we need an *odd* scalar to distinguish them labels = torch.tensor( [ [+1, 0, 0, 0, 0, 0, 0], # chiral_shape_1 [-1, 0, 0, 0, 0, 0, 0], # chiral_shape_2 [0, 1, 0, 0, 0, 0, 0], # square [0, 0, 1, 0, 0, 0, 0], # line [0, 0, 0, 1, 0, 0, 0], # corner [0, 0, 0, 0, 1, 0, 0], # L [0, 0, 0, 0, 0, 1, 0], # T [0, 0, 0, 0, 0, 0, 1], # zigzag ], dtype=torch.get_default_dtype(), ) # apply random rotation pos = torch.einsum("zij,zaj->zai", o3.rand_matrix(len(pos)), pos) return pos, labels def make_batch(pos): # put in torch_geometric format dataset = [Data(pos=pos, x=torch.ones(4, 1)) for pos in pos] return next(iter(DataLoader(dataset, batch_size=len(dataset)))) def Network() -> None: return SimpleNetwork( irreps_in="0e", irreps_out="0o + 6x0e", max_radius=1.5, num_neighbors=2.0, num_nodes=4.0, ) def main() -> None: x, y = tetris() train_x, train_y = make_batch(x[1:]), y[1:] # dont train on both chiral shapes x, y = tetris() test_x, test_y = make_batch(x), y f = Network() print("Built a model:") print(f) optim = torch.optim.Adam(f.parameters(), lr=1e-3) # == Training == for step in range(300): pred = f(train_x) loss = (pred - train_y).pow(2).sum() optim.zero_grad() loss.backward() optim.step() if step % 10 == 0: accuracy = f(test_x).round().eq(test_y).all(dim=1).double().mean(dim=0).item() print(f"epoch {step:5d} | loss {loss:<10.1f} | {100 * accuracy:5.1f}% accuracy") # == Check equivariance == # Because the model outputs (psuedo)scalars, we can easily directly # check its equivariance to the same data with new rotations: print("Testing equivariance directly...") rotated_x, _ = tetris() rotated_x = make_batch(rotated_x) error = f(rotated_x) - f(test_x) print(f"Equivariance error = {error.abs().max().item():.1e}") if __name__ == "__main__": main() def test() -> None: torch.set_default_dtype(torch.float64) data, labels = tetris() data = make_batch(data) f = Network() pred = f(data) loss = (pred - labels).pow(2).sum() loss.backward() rotated_data, _ = tetris() rotated_data = make_batch(rotated_data) error = f(rotated_data) - f(data) assert error.abs().max() < 1e-10 def profile() -> None: data, labels = tetris() data = make_batch(data) data = data.to(device="cuda") labels = labels.to(device="cuda") f = Network() f.to(device="cuda") optim = torch.optim.Adam(f.parameters(), lr=1e-2) called_num = [0] def trace_handler(p) -> None: print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) p.export_chrome_trace("test_trace_" + str(called_num[0]) + ".json") called_num[0] += 1 with torch.profiler.profile( activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], schedule=torch.profiler.schedule(wait=50, warmup=1, active=1), on_trace_ready=trace_handler, ) as p: for _ in range(52): pred = f(data) loss = (pred - labels).pow(2).sum() optim.zero_grad() loss.backward() optim.step() p.step() e3nn-0.6.0/examples/tetris_gate.py000066400000000000000000000175461514371756200171120ustar00rootroot00000000000000"""Classify tetris using gate activation function Implement a equivariant model using gates to fit the tetris dataset Exact equivariance to :math:`E(3)` >>> test() """ import logging import torch from torch_cluster import radius_graph from torch_geometric.data import Data, DataLoader from torch_scatter import scatter from e3nn import o3 from e3nn.nn import FullyConnectedNet, Gate from e3nn.o3 import FullyConnectedTensorProduct from e3nn.math import soft_one_hot_linspace from e3nn.util.test import assert_equivariant def tetris(): pos = [ [(0, 0, 0), (0, 0, 1), (1, 0, 0), (1, 1, 0)], # chiral_shape_1 [(0, 0, 0), (0, 0, 1), (1, 0, 0), (1, -1, 0)], # chiral_shape_2 [(0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 0)], # square [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3)], # line [(0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0)], # corner [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 0)], # L [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 1)], # T [(0, 0, 0), (1, 0, 0), (1, 1, 0), (2, 1, 0)], # zigzag ] pos = torch.tensor(pos, dtype=torch.get_default_dtype()) # Since chiral shapes are the mirror of one another we need an *odd* scalar to distinguish them labels = torch.tensor( [ [+1, 0, 0, 0, 0, 0, 0], # chiral_shape_1 [-1, 0, 0, 0, 0, 0, 0], # chiral_shape_2 [0, 1, 0, 0, 0, 0, 0], # square [0, 0, 1, 0, 0, 0, 0], # line [0, 0, 0, 1, 0, 0, 0], # corner [0, 0, 0, 0, 1, 0, 0], # L [0, 0, 0, 0, 0, 1, 0], # T [0, 0, 0, 0, 0, 0, 1], # zigzag ], dtype=torch.get_default_dtype(), ) # apply random rotation pos = torch.einsum("zij,zaj->zai", o3.rand_matrix(len(pos)), pos) # put in torch_geometric format dataset = [Data(pos=pos) for pos in pos] data = next(iter(DataLoader(dataset, batch_size=len(dataset)))) return data, labels def mean_std(name, x) -> None: print(f"{name} \t{x.mean():.1f} ± ({x.var(0).mean().sqrt():.1f}|{x.std():.1f})") class Convolution(torch.nn.Module): def __init__(self, irreps_in, irreps_sh, irreps_out, num_neighbors) -> None: super().__init__() self.num_neighbors = num_neighbors tp = FullyConnectedTensorProduct( irreps_in1=irreps_in, irreps_in2=irreps_sh, irreps_out=irreps_out, internal_weights=False, shared_weights=False, ) self.fc = FullyConnectedNet([3, 256, tp.weight_numel], torch.relu) self.tp = tp self.irreps_out = self.tp.irreps_out def forward(self, node_features, edge_src, edge_dst, edge_attr, edge_scalars) -> torch.Tensor: weight = self.fc(edge_scalars) edge_features = self.tp(node_features[edge_src], edge_attr, weight) node_features = scatter(edge_features, edge_dst, dim=0).div(self.num_neighbors**0.5) return node_features class Network(torch.nn.Module): def __init__(self) -> None: super().__init__() self.num_neighbors = 3.8 # typical number of neighbors self.irreps_sh = o3.Irreps.spherical_harmonics(3) irreps = self.irreps_sh # First layer with gate gate = Gate( "16x0e + 16x0o", [torch.relu, torch.abs], # scalar "8x0e + 8x0o + 8x0e + 8x0o", [torch.relu, torch.tanh, torch.relu, torch.tanh], # gates (scalars) "16x1o + 16x1e", # gated tensors, num_irreps has to match with gates ) self.conv = Convolution(irreps, self.irreps_sh, gate.irreps_in, self.num_neighbors) self.gate = gate irreps = self.gate.irreps_out # Final layer self.final = Convolution(irreps, self.irreps_sh, "0o + 6x0e", self.num_neighbors) self.irreps_out = self.final.irreps_out def forward(self, data) -> torch.Tensor: num_nodes = 4 # typical number of nodes edge_src, edge_dst = radius_graph(x=data.pos, r=2.5, batch=data.batch) edge_vec = data.pos[edge_src] - data.pos[edge_dst] edge_attr = o3.spherical_harmonics(l=self.irreps_sh, x=edge_vec, normalize=True, normalization="component") edge_length_embedded = ( soft_one_hot_linspace(x=edge_vec.norm(dim=1), start=0.5, end=2.5, number=3, basis="smooth_finite", cutoff=True) * 3**0.5 ) x = scatter(edge_attr, edge_dst, dim=0).div(self.num_neighbors**0.5) x = self.conv(x, edge_src, edge_dst, edge_attr, edge_length_embedded) x = self.gate(x) x = self.final(x, edge_src, edge_dst, edge_attr, edge_length_embedded) return scatter(x, data.batch, dim=0).div(num_nodes**0.5) def main() -> None: data, labels = tetris() f = Network() print("Built a model:") print(f) optim = torch.optim.Adam(f.parameters(), lr=1e-3) # == Training == for step in range(200): pred = f(data) loss = (pred - labels).pow(2).sum() optim.zero_grad() loss.backward() optim.step() if step % 10 == 0: accuracy = pred.round().eq(labels).all(dim=1).double().mean(dim=0).item() print(f"epoch {step:5d} | loss {loss:<10.1f} | {100 * accuracy:5.1f}% accuracy") # == Check equivariance == # Because the model outputs (psuedo)scalars, we can easily directly # check its equivariance to the same data with new rotations: print("Testing equivariance directly...") rotated_data, _ = tetris() error = f(rotated_data) - f(data) print(f"Equivariance error = {error.abs().max().item():.1e}") print("Testing equivariance using `assert_equivariance`...") # We can also use the library's `assert_equivariant` helper # `assert_equivariant` also tests parity and translation, and # can handle non-(psuedo)scalar outputs. # To "interpret" between it and torch_geometric, we use a small wrapper: def wrapper(pos, batch): return f(Data(pos=pos, batch=batch)) # `assert_equivariant` uses logging to print a summary of the equivariance error, # so we enable logging logging.basicConfig(level=logging.INFO) assert_equivariant( wrapper, # We provide the original data that `assert_equivariant` will transform... args_in=[data.pos, data.batch], # ...in accordance with these irreps... irreps_in=[ "cartesian_points", # pos has vector 1o irreps, but is also translation equivariant None, # `None` indicates invariant, possibly non-floating-point data ], # ...and confirm that the outputs transform correspondingly for these irreps: irreps_out=[f.irreps_out], ) if __name__ == "__main__": main() def test() -> None: torch.set_default_dtype(torch.float64) data, labels = tetris() f = Network() pred = f(data) loss = (pred - labels).pow(2).sum() loss.backward() rotated_data, _ = tetris() error = f(rotated_data) - f(data) assert error.abs().max() < 1e-10 def profile() -> None: data, labels = tetris() data = data.to(device="cuda") labels = labels.to(device="cuda") f = Network() f.to(device="cuda") optim = torch.optim.Adam(f.parameters(), lr=1e-2) called_num = [0] def trace_handler(p) -> None: print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) p.export_chrome_trace("test_trace_" + str(called_num[0]) + ".json") called_num[0] += 1 with torch.profiler.profile( activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], schedule=torch.profiler.schedule(wait=50, warmup=1, active=1), on_trace_ready=trace_handler, ) as p: for _ in range(52): pred = f(data) loss = (pred - labels).pow(2).sum() optim.zero_grad() loss.backward() optim.step() p.step() e3nn-0.6.0/examples/tetris_polynomial.py000066400000000000000000000143141514371756200203430ustar00rootroot00000000000000"""Minimal example Implement a equivariant polynomial to fit the tetris dataset Exact equivariance to :math:`E(3)` This example is minimal: * there is dependency on the distance to the neighbors (tetris pieces are made of edges of length 1) * there is no non-linearities except that the tensor product, therefore this model is a polynomial >>> test() """ import logging import torch from torch_cluster import radius_graph from torch_geometric.data import Data, DataLoader from torch_scatter import scatter from e3nn import o3 from e3nn.o3 import FullyConnectedTensorProduct from e3nn.util.test import assert_equivariant def tetris(): pos = [ [(0, 0, 0), (0, 0, 1), (1, 0, 0), (1, 1, 0)], # chiral_shape_1 [(0, 0, 0), (0, 0, 1), (1, 0, 0), (1, -1, 0)], # chiral_shape_2 [(0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 0)], # square [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3)], # line [(0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0)], # corner [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 0)], # L [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 1)], # T [(0, 0, 0), (1, 0, 0), (1, 1, 0), (2, 1, 0)], # zigzag ] pos = torch.tensor(pos, dtype=torch.get_default_dtype()) # Since chiral shapes are the mirror of one another we need an *odd* scalar to distinguish them labels = torch.tensor( [ [+1, 0, 0, 0, 0, 0, 0], # chiral_shape_1 [-1, 0, 0, 0, 0, 0, 0], # chiral_shape_2 [0, 1, 0, 0, 0, 0, 0], # square [0, 0, 1, 0, 0, 0, 0], # line [0, 0, 0, 1, 0, 0, 0], # corner [0, 0, 0, 0, 1, 0, 0], # L [0, 0, 0, 0, 0, 1, 0], # T [0, 0, 0, 0, 0, 0, 1], # zigzag ], dtype=torch.get_default_dtype(), ) # apply random rotation pos = torch.einsum("zij,zaj->zai", o3.rand_matrix(len(pos)), pos) # put in torch_geometric format dataset = [Data(pos=pos) for pos in pos] data = next(iter(DataLoader(dataset, batch_size=len(dataset)))) return data, labels class InvariantPolynomial(torch.nn.Module): def __init__(self) -> None: super().__init__() self.irreps_sh: o3.Irreps = o3.Irreps.spherical_harmonics(3) irreps_mid = o3.Irreps("64x0e + 24x1e + 24x1o + 16x2e + 16x2o") irreps_out = o3.Irreps("0o + 6x0e") self.tp1 = FullyConnectedTensorProduct( irreps_in1=self.irreps_sh, irreps_in2=self.irreps_sh, irreps_out=irreps_mid, ) self.tp2 = FullyConnectedTensorProduct( irreps_in1=irreps_mid, irreps_in2=self.irreps_sh, irreps_out=irreps_out, ) self.irreps_out = self.tp2.irreps_out def forward(self, data) -> torch.Tensor: num_neighbors = 2 # typical number of neighbors num_nodes = 4 # typical number of nodes edge_src, edge_dst = radius_graph(x=data.pos, r=1.1, batch=data.batch) # tensors of indices representing the graph edge_vec = data.pos[edge_src] - data.pos[edge_dst] edge_sh = o3.spherical_harmonics( l=self.irreps_sh, x=edge_vec, normalize=False, # here we don't normalize otherwise it would not be a polynomial normalization="component", ) # For each node, the initial features are the sum of the spherical harmonics of the neighbors node_features = scatter(edge_sh, edge_dst, dim=0).div(num_neighbors**0.5) # For each edge, tensor product the features on the source node with the spherical harmonics edge_features = self.tp1(node_features[edge_src], edge_sh) node_features = scatter(edge_features, edge_dst, dim=0).div(num_neighbors**0.5) edge_features = self.tp2(node_features[edge_src], edge_sh) node_features = scatter(edge_features, edge_dst, dim=0).div(num_neighbors**0.5) # For each graph, all the node's features are summed return scatter(node_features, data.batch, dim=0).div(num_nodes**0.5) def main() -> None: data, labels = tetris() f = InvariantPolynomial() optim = torch.optim.Adam(f.parameters(), lr=1e-2) # == Train == for step in range(200): pred = f(data) loss = (pred - labels).pow(2).sum() optim.zero_grad() loss.backward() optim.step() if step % 10 == 0: accuracy = pred.round().eq(labels).all(dim=1).double().mean(dim=0).item() print(f"epoch {step:5d} | loss {loss:<10.1f} | {100 * accuracy:5.1f}% accuracy") # == Check equivariance == # Because the model outputs (psuedo)scalars, we can easily directly # check its equivariance to the same data with new rotations: print("Testing equivariance directly...") rotated_data, _ = tetris() error = f(rotated_data) - f(data) print(f"Equivariance error = {error.abs().max().item():.1e}") print("Testing equivariance using `assert_equivariance`...") # We can also use the library's `assert_equivariant` helper # `assert_equivariant` also tests parity and translation, and # can handle non-(psuedo)scalar outputs. # To "interpret" between it and torch_geometric, we use a small wrapper: def wrapper(pos, batch): return f(Data(pos=pos, batch=batch)) # `assert_equivariant` uses logging to print a summary of the equivariance error, # so we enable logging logging.basicConfig(level=logging.INFO) assert_equivariant( wrapper, # We provide the original data that `assert_equivariant` will transform... args_in=[data.pos, data.batch], # ...in accordance with these irreps... irreps_in=[ "cartesian_points", # pos has vector 1o irreps, but is also translation equivariant None, # `None` indicates invariant, possibly non-floating-point data ], # ...and confirm that the outputs transform correspondingly for these irreps: irreps_out=[f.irreps_out], ) if __name__ == "__main__": main() def test() -> None: data, labels = tetris() f = InvariantPolynomial() pred = f(data) loss = (pred - labels).pow(2).sum() loss.backward() rotated_data, _ = tetris() error = f(rotated_data) - f(data) assert error.abs().max() < 1e-5 e3nn-0.6.0/pyproject.toml000066400000000000000000000050161514371756200153110ustar00rootroot00000000000000[build-system] requires = ["setuptools", "setuptools-scm"] build-backend = "setuptools.build_meta" [project] name="e3nn" requires-python = ">=3.8" dynamic = ["version", "readme"] license = {text = "MIT"} description = "Equivariant convolutional neural networks for the group E(3) of 3 dimensional rotations, translations, and mirrors." classifiers = [ "Programming Language :: Python", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "License :: OSI Approved :: MIT License", "Operating System :: POSIX", "Operating System :: MacOS", ] dependencies = [ "sympy", "scipy", "torch>=2.2.0", "opt_einsum_fx>=0.1.4" ] [project.optional-dependencies] dev = [ "pytest", "pre-commit", ] [project.urls] homepage = "https://e3nn.org" documentation = "https://docs.e3nn.org/" repository = "https://github.com/e3nn/e3nn.git" changelog = "https://github.com/e3nn/e3nn/blob/main/.github/CHANGELOG.md" [tool.setuptools.packages.find] exclude = [ "tests", "tests.*", ] [tool.setuptools.dynamic] version = {attr = "e3nn.__version__"} readme = {file = ["README.md"], content-type = "text/markdown"} [tool.black] line-length = 127 target-version = ['py38'] include = '\.pyi?$' exclude = ''' /( \.git | \.hg | \.mypy_cache | \.tox | \.venv | _build | buck-out | build | dist )/ ''' [tool.flake8] max-line-length = 127 max-complexity = 21 select = ["B", "C", "E", "F", "W", "T4", "B9"] ignore = [ "E741", "E203", "W503", "C901" ] exclude = [ ".eggs", "*.egg", "build", "dist", "docs/_build", "notebook" ] per-file-ignores = [ "e3nn/o3/cartesian_spherical_harmonics.py: E226", ] [tool.coverage.run] source = ["e3nn"] [tool.coverage.report] exclude_lines = [ "pragma: no cover", "torch.jit.script", "raise", "except", ] [tool.pylint.typecheck] generated-members = "numpy.*,torch.*" [tool.pylint."messages control"] disable = [ "protected-access", "no-else-return", "raise-missing-from", "invalid-name", "duplicate-code", "import-outside-toplevel", "missing-docstring", "bad-continuation", "locally-disabled", "too-few-public-methods", "too-many-arguments", "too-many-instance-attributes", "too-many-local-variables", "too-many-locals", "too-many-branches", "too-many-statements", "too-many-return-statements", "redefined-builtin", "redefined-outer-name", "line-too-long", "fixme", ] e3nn-0.6.0/tests/000077500000000000000000000000001514371756200135355ustar00rootroot00000000000000e3nn-0.6.0/tests/conftest.py000066400000000000000000000011551514371756200157360ustar00rootroot00000000000000import pytest # For good practice, we *should* do this: # See https://docs.pytest.org/en/stable/fixture.html#using-fixtures-from-other-projects # pytest_plugins = ['e3nn.util.test'] # But doing so exposes float_tolerance to doctests, which don't support parametrized, autouse fixtures. # Importing directly somehow only brings in the fixture later, preventing the issue. from e3nn.util import test # Suppress linter errors float_tolerance = test.float_tolerance @pytest.fixture(autouse=True) def set_random_seed() -> None: """Set the random seeds to try to get some reproducibility""" test.set_random_seeds() e3nn-0.6.0/tests/defaults_test.py000066400000000000000000000015431514371756200167600ustar00rootroot00000000000000import e3nn def test_opt_defaults() -> None: a = e3nn.o3.FullyConnectedTensorProduct("4x1o", "4x1o", "4x1o") b = e3nn.o3.Linear("4x1o", "4x1o") assert a._specialized_code assert a._optimize_einsums assert b._optimize_einsums old_defaults = e3nn.get_optimization_defaults() try: e3nn.set_optimization_defaults(optimize_einsums=False) a = e3nn.o3.FullyConnectedTensorProduct("4x1o", "4x1o", "4x1o") b = e3nn.o3.Linear("4x1o", "4x1o") assert a._specialized_code assert not a._optimize_einsums assert not b._optimize_einsums finally: e3nn.set_optimization_defaults(**old_defaults) a = e3nn.o3.FullyConnectedTensorProduct("4x1o", "4x1o", "4x1o") b = e3nn.o3.Linear("4x1o", "3x1o") assert a._specialized_code assert a._optimize_einsums assert b._optimize_einsums e3nn-0.6.0/tests/math/000077500000000000000000000000001514371756200144665ustar00rootroot00000000000000e3nn-0.6.0/tests/math/bessel_test.py000066400000000000000000000003271514371756200173560ustar00rootroot00000000000000import torch import pytest import e3nn @pytest.mark.parametrize("n", [1, 2, 4]) def test_bessel(n: int): x = torch.linspace(0.0, 1.0, 100) y = e3nn.math.bessel(x, n) assert y.shape == (100, n) e3nn-0.6.0/tests/math/normalize_activation_test.py000066400000000000000000000007241514371756200223230ustar00rootroot00000000000000import torch from e3nn.math import normalize2mom def test_device() -> None: act = torch.nn.ReLU() act = normalize2mom(act) def test_identity() -> None: act1 = normalize2mom(torch.relu) act2 = normalize2mom(act1) x = torch.randn(10) assert (act1(x) == act2(x)).all() def test_deterministic() -> None: act1 = normalize2mom(torch.tanh) act2 = normalize2mom(torch.tanh) x = torch.randn(10) assert (act1(x) == act2(x)).all() e3nn-0.6.0/tests/math/perm_test.py000066400000000000000000000056741514371756200170560ustar00rootroot00000000000000import math import pytest import torch from e3nn.math import perm @pytest.mark.parametrize("n", [0, 1, 2, 3, 4, 5]) def test_inverse(n) -> None: for p in perm.group(n): ip = perm.inverse(p) assert perm.compose(p, ip) == perm.identity(n) assert perm.compose(ip, p) == perm.identity(n) @pytest.mark.parametrize("n", [0, 1, 2, 3, 4, 5]) def test_int_inverse(n) -> None: for j in range(math.factorial(n)): p = perm.from_int(j, n) i = perm.to_int(p) assert i == j @pytest.mark.parametrize("n", [0, 1, 2, 3, 4, 5]) def test_int_injection(n) -> None: group = {perm.from_int(j, n) for j in range(math.factorial(n))} assert len(group) == math.factorial(n) def test_germinate() -> None: assert perm.is_group(perm.germinate({(1, 2, 3, 4, 0)})) assert perm.is_group(perm.germinate({(1, 0, 2, 3), (0, 2, 1, 3), (0, 1, 3, 2)})) @pytest.mark.parametrize("n", [0, 1, 2, 3, 4, 5]) def test_rand(n) -> None: perm.is_perm(perm.rand(n)) def test_not_group() -> None: assert not perm.is_group(set()) # empty assert not perm.is_group({(1, 0, 2), (0, 2, 1), (1, 2, 0), (2, 0, 1), (2, 1, 0)}) # missing neutral assert not perm.is_group({(0, 1, 2), (1, 2, 0)}) # missing inverse assert not perm.is_group({(0, 1, 2, 3), (3, 0, 1, 2), (1, 2, 3, 0)}) # g1 . g2 not in G def test_to_cycles() -> None: assert perm.to_cycles((1, 2, 3, 0)) == {(0, 1, 2, 3)} assert perm.to_cycles((2, 3, 0, 1)) == {(0, 2), (1, 3)} def test_sign() -> None: assert perm.sign((1, 0, 3, 2)) == 1 assert perm.sign((1, 0, 3, 2, 5, 6, 7, 4)) == -1 @pytest.mark.parametrize("n", [3, 7, 15]) def test_standard_representation(float_tolerance, n) -> None: # identity e = perm.standard_representation(perm.identity(n)) assert torch.allclose(e, torch.eye(n - 1), atol=float_tolerance) # inverse p = perm.rand(n) a = perm.standard_representation(p) b = perm.standard_representation(perm.inverse(p)) assert torch.allclose(a, torch.inverse(b), atol=float_tolerance) # compose p1, p2 = perm.rand(n), perm.rand(n) a = perm.standard_representation(p1) @ perm.standard_representation(p2) b = perm.standard_representation(perm.compose(p1, p2)) assert torch.allclose(a, b, atol=float_tolerance) # orthogonal a = perm.standard_representation(perm.rand(n)) assert torch.allclose(a @ a.T, torch.eye(n - 1), atol=float_tolerance) @pytest.mark.parametrize("n", [3, 7, 15]) def test_natural_representation(float_tolerance, n) -> None: p = perm.rand(n) a = torch.eye(n)[list(perm.inverse(p))] b = perm.natural_representation(p) assert torch.allclose(a, b, atol=float_tolerance) p = perm.rand(n) a = torch.eye(n)[:, list(p)] b = perm.natural_representation(p) assert torch.allclose(a, b, atol=float_tolerance) # orthogonal a = perm.natural_representation(perm.rand(n)) assert torch.allclose(a @ a.T, torch.eye(n), atol=float_tolerance) e3nn-0.6.0/tests/math/soft_one_hot_test.py000066400000000000000000000027541514371756200205750ustar00rootroot00000000000000import pytest import torch from e3nn.math import soft_one_hot_linspace @pytest.mark.parametrize("basis", ["gaussian", "cosine", "fourier", "bessel", "smooth_finite"]) def test_with_compile(basis) -> None: # torch.compile recompiles for every basis and every dtype torch._dynamo.config.cache_size_limit = 32 x = torch.linspace(-2.0, 3.0, 20) kwargs = dict(start=-1.0, end=2.0, number=5, basis=basis, cutoff=True) y = soft_one_hot_linspace(x, **kwargs) y_compiled = torch.compile(soft_one_hot_linspace, fullgraph=True)(x, **kwargs) assert y.shape == y_compiled.shape assert y.dtype == y_compiled.dtype assert y.device == y_compiled.device assert torch.allclose(y, y_compiled, atol=1e-7) @pytest.mark.parametrize("basis", ["gaussian", "cosine", "fourier", "bessel", "smooth_finite"]) def test_zero_out(basis) -> None: x1 = torch.linspace(-2.0, -1.1, 20) x2 = torch.linspace(2.1, 3.0, 20) x = torch.cat([x1, x2]) y = soft_one_hot_linspace(x, -1.0, 2.0, 5, basis, cutoff=True) if basis == "gaussian": assert y.abs().max() < 0.22 else: assert y.abs().max() == 0.0 @pytest.mark.parametrize("basis", ["gaussian", "cosine", "fourier", "smooth_finite"]) @pytest.mark.parametrize("cutoff", [True, False]) def test_normalized(basis, cutoff) -> None: x = torch.linspace(-14.0, 105.0, 50) y = soft_one_hot_linspace(x, -20.0, 120.0, 12, basis, cutoff) assert 0.4 < y.pow(2).sum(1).min() assert y.pow(2).sum(1).max() < 2.0 e3nn-0.6.0/tests/math/soft_unit_step_test.py000066400000000000000000000013641514371756200211500ustar00rootroot00000000000000import torch from e3nn.math import soft_unit_step def test_grad() -> None: torch.set_default_dtype(torch.float64) x = torch.linspace(-1, 1, 1000, requires_grad=True) def f(x): return soft_unit_step(x).sum() assert torch.autograd.gradcheck(f, (x,), check_undefined_grad=False) def test_grads() -> None: x = torch.linspace(-1, 1, 1000, requires_grad=True) y0 = soft_unit_step(x) assert torch.isfinite(y0).all() (y1,) = torch.autograd.grad(y0.sum(), x, create_graph=True) assert torch.isfinite(y1).all() (y2,) = torch.autograd.grad(y1.sum(), x, create_graph=True) assert torch.isfinite(y2).all() (y3,) = torch.autograd.grad(y2.sum(), x, create_graph=True) assert torch.isfinite(y3).all() e3nn-0.6.0/tests/nn/000077500000000000000000000000001514371756200141505ustar00rootroot00000000000000e3nn-0.6.0/tests/nn/activation_test.py000066400000000000000000000016711514371756200177270ustar00rootroot00000000000000import pytest import functools import torch from e3nn import o3 from e3nn.nn import Activation from e3nn.util.test import assert_equivariant, assert_auto_jitable, assert_normalized, assert_torch_compile @pytest.mark.parametrize( "irreps_in,acts", [("256x0o", [torch.abs]), ("37x0e", [torch.tanh]), ("4x0e + 3x0o", [torch.nn.functional.silu, torch.abs])], ) def test_activation(irreps_in, acts) -> None: irreps_in = o3.Irreps(irreps_in) a = Activation(irreps_in, acts) inp = irreps_in.randn(13, -1) assert_auto_jitable(a) assert_torch_compile("inductor", functools.partial(Activation, irreps_in, acts), inp) assert_equivariant(a) out = a(inp) for ir_slice, act in zip(irreps_in.slices(), acts): this_out = out[:, ir_slice] true_up_to_factor = act(inp[:, ir_slice]) factors = this_out / true_up_to_factor assert torch.allclose(factors, factors[0]) assert_normalized(a) e3nn-0.6.0/tests/nn/batchnorm_test.py000066400000000000000000000041571514371756200175450ustar00rootroot00000000000000import torch from e3nn import o3 from e3nn.nn import BatchNorm from e3nn.util.test import assert_equivariant import pytest def test_equivariant() -> None: irreps = o3.Irreps("3x0e + 3x0o + 4x1e") m = BatchNorm(irreps) m(irreps.randn(16, -1)) m(irreps.randn(16, -1)) m.train() assert_equivariant(m, irreps_in=[irreps], irreps_out=[irreps]) m.eval() assert_equivariant(m, irreps_in=[irreps], irreps_out=[irreps]) @pytest.mark.parametrize("affine", [True, False]) @pytest.mark.parametrize("reduce", ["mean", "max"]) @pytest.mark.parametrize("normalization", ["norm", "component"]) @pytest.mark.parametrize("instance", [True, False]) def test_modes(affine, reduce, normalization, instance) -> None: irreps = o3.Irreps("10x0e + 5x1e") m = BatchNorm(irreps, affine=affine, reduce=reduce, normalization=normalization, instance=instance) repr(m) m.train() m(irreps.randn(20, 20, -1)) m.eval() m(irreps.randn(20, 20, -1)) @pytest.mark.parametrize("instance", [True, False]) def test_normalization(float_tolerance, instance) -> None: sqrt_float_tolerance = torch.sqrt(float_tolerance) batch, n = 20, 20 irreps = o3.Irreps("3x0e + 4x1e") m = BatchNorm(irreps, normalization="norm", instance=instance) x = torch.randn(batch, n, irreps.dim).mul(5.0).add(10.0) x = m(x) a = x[..., :3] # [batch, space, mul] assert a.mean([0, 1]).abs().max() < float_tolerance assert a.pow(2).mean([0, 1]).sub(1).abs().max() < sqrt_float_tolerance a = x[..., 3:].reshape(batch, n, 4, 3) # [batch, space, mul, repr] assert a.pow(2).sum(3).mean([0, 1]).sub(1).abs().max() < sqrt_float_tolerance m = BatchNorm(irreps, normalization="component", instance=instance) x = torch.randn(batch, n, irreps.dim).mul(5.0).add(10.0) x = m(x) a = x[..., :3] # [batch, space, mul] assert a.mean([0, 1]).abs().max() < float_tolerance assert a.pow(2).mean([0, 1]).sub(1).abs().max() < sqrt_float_tolerance a = x[..., 3:].reshape(batch, n, 4, 3) # [batch, space, mul, repr] assert a.pow(2).mean(3).mean([0, 1]).sub(1).abs().max() < sqrt_float_tolerance e3nn-0.6.0/tests/nn/dropout_test.py000066400000000000000000000012421514371756200172540ustar00rootroot00000000000000import copy import torch from e3nn.nn import Dropout from e3nn.util.test import assert_auto_jitable, assert_equivariant def test_dropout() -> None: c = Dropout(irreps="10x1e + 10x0e", p=0.75) x = c.irreps.randn(5, 2, -1) for c in [c, assert_auto_jitable(c)]: c.eval() assert c(x).eq(x).all() c.train() y = c(x) assert (y.eq(x / 0.25) | y.eq(0)).all() def wrap(x): torch.manual_seed(0) return c(x) assert_equivariant(wrap, args_in=[x], irreps_in=[c.irreps], irreps_out=[c.irreps]) def test_copy() -> None: c = Dropout(irreps="0e + 1e", p=0.5) _ = copy.deepcopy(c) e3nn-0.6.0/tests/nn/extract_test.py000066400000000000000000000035121514371756200172340ustar00rootroot00000000000000import pytest import copy import functools import torch from e3nn.nn import Extract, ExtractIr from e3nn.util.test import assert_auto_jitable, assert_equivariant, assert_torch_compile def test_extract() -> None: c = Extract("1e + 0e + 0e", ["0e", "0e"], [(1,), (2,)]) out = c(torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0])) assert out == (torch.Tensor([1.0]), torch.Tensor([2.0])) assert_auto_jitable(c) assert_torch_compile( "inductor", functools.partial(Extract, "1e + 0e + 0e", ["0e", "0e"], [(1,), (2,)]), torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0]), ) assert_equivariant(c, irreps_out=list(c.irreps_outs)) @pytest.mark.parametrize("squeeze", [True, False]) def test_extract_single(squeeze) -> None: c = Extract("1e + 0e + 0e", ["0e"], [(1,)], squeeze_out=squeeze) out = c(torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0])) if squeeze: assert isinstance(out, torch.Tensor) else: assert len(out) == 1 out = out[0] assert out == torch.Tensor([1.0]) assert_auto_jitable(c) assert_torch_compile( "inductor", functools.partial(Extract, "1e + 0e + 0e", ["0e"], [(1,)], squeeze_out=squeeze), torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0]), ) assert_equivariant(c, irreps_out=list(c.irreps_outs)) def test_extract_ir() -> None: c = ExtractIr("1e + 0e + 0e", "0e") out = c(torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0])) assert torch.all(out == torch.Tensor([1.0, 2.0])) assert_auto_jitable(c) assert_torch_compile( "inductor", functools.partial(ExtractIr, "1e + 0e + 0e", "0e"), torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0]) ) assert_equivariant(c) def test_copy() -> None: c = Extract("1e + 0e + 0e", ["0e", "0e"], [(1,), (2,)]) _ = copy.deepcopy(c) c = ExtractIr("1e + 0e + 0e", "0e") _ = copy.deepcopy(c) e3nn-0.6.0/tests/nn/fc_test.py000066400000000000000000000021161514371756200161510ustar00rootroot00000000000000import torch import functools import pytest from e3nn.nn import FullyConnectedNet from e3nn.util.test import assert_auto_jitable, assert_torch_compile @pytest.mark.parametrize("act", [None, torch.tanh]) @pytest.mark.parametrize("var_in, var_out, out_act", [(1, 1, False), (1, 1, True), (0.1, 10.0, False), (0.1, 0.05, True)]) def test_variance(act, var_in, var_out, out_act) -> None: hs = (1000, 500, 1500, 4) f = FullyConnectedNet(hs, act, var_in, var_out, out_act) x = torch.randn(2000, hs[0]) * var_in**0.5 y = f(x) / var_out**0.5 if not out_act: assert y.mean().abs() < 0.5 assert y.pow(2).mean().log10().abs() < torch.tensor(1.5).log10() f = assert_auto_jitable(f) f(x) f_pt2 = assert_torch_compile("inductor", functools.partial(FullyConnectedNet, hs, act, var_in, var_out, out_act), x) f_pt2(x) @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires cuda") def test_data_parallel() -> None: fc = torch.nn.DataParallel(FullyConnectedNet([10, 20, 30]).cuda()) y = fc(torch.randn(32, 10).cuda()) y.sum().backward() e3nn-0.6.0/tests/nn/gate_test.py000066400000000000000000000016341514371756200165050ustar00rootroot00000000000000import functools import torch from e3nn.o3 import Irreps from e3nn.nn import Gate from e3nn.nn._gate import _Sortcut from e3nn.util.test import assert_equivariant, assert_auto_jitable, assert_normalized, assert_torch_compile def test_gate() -> None: irreps_scalars, act_scalars, irreps_gates, act_gates, irreps_gated = ( Irreps("16x0o"), [torch.tanh], Irreps("32x0o"), [torch.tanh], Irreps("16x1e+16x1o"), ) sc = _Sortcut(irreps_scalars, irreps_gates) assert_auto_jitable(sc) g = Gate(irreps_scalars, act_scalars, irreps_gates, act_gates, irreps_gated) irreps = Irreps("16x0o+32x0o+16x1e+16x1o") assert_equivariant(g) assert_auto_jitable(g) assert_torch_compile( "inductor", functools.partial(Gate, irreps_scalars, act_scalars, irreps_gates, act_gates, irreps_gated), irreps.randn(-1), ) assert_normalized(g) e3nn-0.6.0/tests/nn/models/000077500000000000000000000000001514371756200154335ustar00rootroot00000000000000e3nn-0.6.0/tests/nn/models/gate_points_2101_test.py000066400000000000000000000062411514371756200220260ustar00rootroot00000000000000import copy import random import tempfile import pytest import torch from e3nn import o3 from e3nn.nn.models.gate_points_2101 import Network from e3nn.util.test import assert_auto_jitable, assert_equivariant @pytest.fixture def network(): num_nodes = 5 irreps_in = o3.Irreps("3x0e + 2x1o") irreps_attr = o3.Irreps("10x0e") irreps_out = o3.Irreps("2x0o + 2x1o + 2x2e") f = Network( irreps_in, o3.Irreps("5x0e + 5x0o + 5x1e + 5x1o"), irreps_out, irreps_attr, o3.Irreps.spherical_harmonics(3), layers=3, max_radius=2.0, number_of_basis=5, radial_layers=2, radial_neurons=100, num_neighbors=4.0, num_nodes=num_nodes, ) def random_graph(): N = random.randint(3, 7) return {"pos": torch.randn(N, 3), "x": f.irreps_in.randn(N, -1), "z": f.irreps_node_attr.randn(N, -1)} return f, random_graph def test_convolution_jit(network) -> None: f, _ = network # Get a convolution from the network assert_auto_jitable(f.layers[0].first) def test_gate_points_2101_equivariant(network) -> None: f, random_graph = network # -- Test equivariance: -- def wrapper(pos, x, z): data = dict(pos=pos, x=x, z=z, batch=torch.zeros(pos.shape[0], dtype=torch.long)) return f(data) assert_equivariant( wrapper, irreps_in=["cartesian_points", f.irreps_in, f.irreps_node_attr], irreps_out=[f.irreps_out], ) def test_copy(network) -> None: f, random_graph = network fcopy = copy.deepcopy(f) g = random_graph() assert torch.allclose(f(g), fcopy(g)) def test_save(network) -> None: f, random_graph = network # Get a saved, loaded network with tempfile.NamedTemporaryFile(suffix=".pth") as tmp: torch.save(f.state_dict(), tmp.name) # Recreate network with same parameters as fixture irreps_in = o3.Irreps("3x0e + 2x1o") irreps_attr = o3.Irreps("10x0e") irreps_out = o3.Irreps("2x0o + 2x1o + 2x2e") f2 = Network( irreps_in, o3.Irreps("5x0e + 5x0o + 5x1e + 5x1o"), irreps_out, irreps_attr, o3.Irreps.spherical_harmonics(3), layers=3, max_radius=2.0, number_of_basis=5, radial_layers=2, radial_neurons=100, num_neighbors=4.0, num_nodes=5, ) f2.load_state_dict(torch.load(tmp.name, weights_only=False)) x = random_graph() assert torch.all(f(x) == f2(x)) # Get a double-saved network with tempfile.NamedTemporaryFile(suffix=".pth") as tmp: torch.save(f2.state_dict(), tmp.name) f3 = Network( irreps_in, o3.Irreps("5x0e + 5x0o + 5x1e + 5x1o"), irreps_out, irreps_attr, o3.Irreps.spherical_harmonics(3), layers=3, max_radius=2.0, number_of_basis=5, radial_layers=2, radial_neurons=100, num_neighbors=4.0, num_nodes=5, ) f3.load_state_dict(torch.load(tmp.name, weights_only=False)) assert torch.all(f(x) == f3(x)) e3nn-0.6.0/tests/nn/models/gate_points_2102_test.py000066400000000000000000000031541514371756200220270ustar00rootroot00000000000000import copy import random import pytest import torch from e3nn import o3 from e3nn.nn.models.gate_points_2102 import Network from e3nn.util.test import assert_auto_jitable, assert_equivariant @pytest.fixture def network(): num_nodes = 5 irreps_in = o3.Irreps("3x0e + 2x1o") irreps_attr = o3.Irreps("10x0e") irreps_out = o3.Irreps("2x0o + 2x1o + 2x2e") f = Network( irreps_in, o3.Irreps("5x0e + 5x0o + 5x1e + 5x1o"), irreps_out, irreps_attr, o3.Irreps.spherical_harmonics(3), layers=3, max_radius=2.0, number_of_basis=5, radial_layers=2, radial_neurons=100, num_neighbors=4.0, num_nodes=num_nodes, ) def random_graph(): N = random.randint(3, 7) return {"pos": torch.randn(N, 3), "x": f.irreps_in.randn(N, -1), "z": f.irreps_node_attr.randn(N, -1)} return f, random_graph def test_convolution_jit(network) -> None: f, _ = network # Get a convolution from the network assert_auto_jitable(f.layers[0].first) def test_gate_points_2102_equivariant(network) -> None: f, random_graph = network # -- Test equivariance: -- def wrapper(pos, x, z): data = dict(pos=pos, x=x, z=z, batch=torch.zeros(pos.shape[0], dtype=torch.long)) return f(data) assert_equivariant( wrapper, irreps_in=["cartesian_points", f.irreps_in, f.irreps_node_attr], irreps_out=[f.irreps_out], ) def test_copy(network) -> None: f, random_graph = network fcopy = copy.deepcopy(f) g = random_graph() assert torch.allclose(f(g), fcopy(g)) e3nn-0.6.0/tests/nn/models/v2203/000077500000000000000000000000001514371756200162075ustar00rootroot00000000000000e3nn-0.6.0/tests/nn/models/v2203/sparse_voxel_convolution_test.py000066400000000000000000000040331514371756200247710ustar00rootroot00000000000000import pytest import torch from e3nn.o3 import Irreps import math rotations = [ (0.0, 0.0, 0.0), (0.0, 0.0, math.pi / 2), (0.0, 0.0, math.pi), (0.0, math.pi / 2, 0.0), (0.0, math.pi / 2, math.pi / 2), (0.0, math.pi / 2, math.pi), (0.0, math.pi, 0.0), (math.pi / 2, 0.0, 0.0), (math.pi / 2, 0.0, math.pi / 2), (math.pi / 2, 0.0, math.pi), (math.pi / 2, math.pi / 2, 0.0), ] def rotate_sparse_tensor(x, irreps, abc): """Perform a rotation of angles abc to a sparse tensor""" from MinkowskiEngine import SparseTensor # rotate the coordinates (like vectors l=1) coordinates = x.C[:, 1:].to(x.F.dtype) coordinates = torch.einsum("ij,bj->bi", Irreps("1e").D_from_angles(*abc), coordinates) assert (coordinates - coordinates.round()).abs().max() < 1e-6 coordinates = coordinates.round().to(torch.int32) coordinates = torch.cat([x.C[:, :1], coordinates], dim=1) # rotate the features (according to `irreps`) features = x.F features = torch.einsum("ij,bj->bi", irreps.D_from_angles(*abc), features) return SparseTensor(coordinates=coordinates, features=features) @pytest.mark.parametrize("abc", rotations) def test_equivariance(abc) -> None: pytest.importorskip("MinkowskiEngine") from MinkowskiEngine import SparseTensor from e3nn.nn.models.v2203.sparse_voxel_convolution import Convolution abc = torch.tensor(abc) irreps_in = Irreps("1e") irreps_out = Irreps("0e + 1e + 2e") conv = Convolution(irreps_in, irreps_out, irreps_sh="0e + 1e + 2e", diameter=7, num_radial_basis=3, steps=(1.0, 1.0, 1.0)) x1 = SparseTensor( coordinates=torch.tensor([[0, 0, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0], [0, 1, 1, 0]], dtype=torch.int32), features=irreps_in.randn(4, -1), ) x2 = rotate_sparse_tensor(x1, irreps_in, abc) y2 = conv(x2) y1 = conv(x1) y1 = rotate_sparse_tensor(y1, irreps_out, abc) # check equivariance assert (y1.C - y2.C).abs().max() == 0 assert (y1.F - y2.F).abs().max() < 1e-7 * y1.F.abs().max() e3nn-0.6.0/tests/nn/normact_test.py000066400000000000000000000102601514371756200172230ustar00rootroot00000000000000import pytest import functools import torch import e3nn from e3nn.nn import NormActivation from e3nn.util.test import assert_equivariant, assert_auto_jitable, assert_torch_compile @pytest.mark.parametrize("do_bias", [True, False]) @pytest.mark.parametrize("nonlin", [torch.tanh, torch.sigmoid]) def test_norm_activation(float_tolerance, do_bias, nonlin) -> None: irreps_in = e3nn.o3.Irreps("4x0e + 5x1o") N_batch = 3 in_features = torch.randn(N_batch, irreps_in.dim) # Set some features to zero to test avoiding divide by zero in_features[0, 0] = 0 # batch 0, scalar 0 in_features[1, 4 : 4 + 3] = 0 # batch 0, vector 1 norm_act = NormActivation(irreps_in=irreps_in, scalar_nonlinearity=nonlin, normalize=True, bias=do_bias) if do_bias: assert len(list(norm_act.parameters())) == 1 with torch.no_grad(): norm_act.biases[:] = torch.randn(norm_act.biases.shape) else: # Assert that there are no biases assert len(list(norm_act.parameters())) == 0 out = norm_act(in_features) if do_bias: assert out.requires_grad for batch in range(N_batch): # scalars should be the nonlin of their abs with the same sign. scalar_in = in_features[batch, :4] if do_bias: true_nonlin_arg = scalar_in.abs() + norm_act.biases[:4] else: true_nonlin_arg = scalar_in.abs() assert torch.allclose(torch.sign(scalar_in) * nonlin(true_nonlin_arg), out[batch, :4], atol=float_tolerance) # vectors # first, check norms: vector_in = in_features[batch, 4:].reshape(5, 3) in_norms = vector_in.norm(dim=-1) vector_out = out[batch, 4:].reshape(5, 3) out_norms = vector_out.norm(dim=-1) # Can only check direction on vectors that have one: mask = (in_norms > 0) & (out_norms > 0) if do_bias: true_nonlin_arg = in_norms + norm_act.biases[4:] else: true_nonlin_arg = in_norms # Check norms for nonzero vectors assert torch.allclose(nonlin(true_nonlin_arg).abs()[mask], out_norms[mask], atol=float_tolerance) # Check that zeros maintained for zero inputs assert torch.allclose(in_norms[~mask], out_norms[~mask], atol=float_tolerance) # then that directions are unchanged up to sign: assert torch.allclose( torch.einsum( # dot products "ni,ni->n", vector_in[mask] / in_norms[mask, None], vector_out[mask] / out_norms[mask, None], ).abs(), torch.ones(mask.sum()), atol=float_tolerance, ) @pytest.mark.parametrize("do_bias", [True, False]) @pytest.mark.parametrize("nonlin", [torch.tanh, torch.sigmoid]) def test_norm_activation_equivariant(do_bias, nonlin) -> None: irreps_in = e3nn.o3.Irreps( # test lots of different irreps "2x0e + 3x0o + 5x1o + 1x1e + 2x2e + 1x2o + 1x3e + 1x3o + 1x5e + 1x6o" ) norm_act = NormActivation(irreps_in=irreps_in, scalar_nonlinearity=nonlin, bias=do_bias) if do_bias: # Set up some nonzero biases assert len(list(norm_act.parameters())) == 1 with torch.no_grad(): norm_act.biases[:] = torch.randn(norm_act.biases.shape) assert_equivariant(norm_act) assert_torch_compile( "inductor", functools.partial(NormActivation, irreps_in=irreps_in, scalar_nonlinearity=nonlin, bias=do_bias), irreps_in.randn(-1), ) assert_auto_jitable(norm_act) @pytest.mark.parametrize("do_bias", [True, False]) @pytest.mark.parametrize("nonlin", [torch.tanh, torch.sigmoid]) def test_zeros(do_bias, nonlin) -> None: """Confirm that `epsilon` gives non-NaN grads""" irreps_in = e3nn.o3.Irreps("2x0e + 3x0o") norm_act = NormActivation( irreps_in=irreps_in, scalar_nonlinearity=nonlin, bias=do_bias, normalize=True, ) with torch.autograd.set_detect_anomaly(True): inp = torch.zeros(norm_act.irreps_in.dim, requires_grad=True) out = norm_act(inp) grads = torch.autograd.grad( outputs=out.sum(), inputs=inp, )[0] assert torch.all(torch.isfinite(grads)) e3nn-0.6.0/tests/nn/s2act_test.py000066400000000000000000000012351514371756200165760ustar00rootroot00000000000000import itertools import pytest import torch from e3nn import io from e3nn.nn import S2Activation from e3nn.util.test import assert_equivariant @pytest.mark.parametrize( "act, normalization, p_val, p_arg", itertools.product([torch.tanh, lambda x: x**2], ["norm", "component"], [-1, 1], [-1, 1]), ) def test_equivariance(float_tolerance, act, normalization, p_val, p_arg) -> None: irreps = io.SphericalTensor(3, p_val, p_arg) # TODO: torch.compile(fullgraph=True) not working m = S2Activation(irreps, act, 120, normalization=normalization, lmax_out=6, random_rot=True) assert_equivariant(m, ntrials=10, tolerance=torch.sqrt(float_tolerance)) e3nn-0.6.0/tests/nn/so3act_test.py000066400000000000000000000020761514371756200167620ustar00rootroot00000000000000import pytest import torch from e3nn import o3 from e3nn.nn import SO3Activation from e3nn.util.test import assert_equivariant from e3nn.util.jit import compile def so3_irreps(lmax: int) -> o3.Irreps: return o3.Irreps([(2 * l + 1, (l, 1)) for l in range(lmax + 1)]) @pytest.mark.parametrize("lmax", [1, 2, 3, 4]) @pytest.mark.parametrize("act", [torch.tanh, lambda x: x**2]) def test_equivariance(act, lmax: int) -> None: # TODO: torch.compile(fullgraph=True) not working m = SO3Activation(lmax, lmax, act, 6) assert_equivariant(m, ntrials=10, tolerance=0.04, irreps_in=so3_irreps(lmax), irreps_out=so3_irreps(lmax)) @pytest.mark.parametrize("aspect_ratio", [1, 2, 3, 4]) def test_identity(aspect_ratio) -> None: irreps = o3.Irreps([(2 * l + 1, (l, 1)) for l in range(5 + 1)]) m = SO3Activation(5, 5, lambda x: x, 6, aspect_ratio=aspect_ratio) m = compile(m) x = irreps.randn(-1) y = m(x) m_pt2 = torch.compile(m, fullgraph=True) y2 = m_pt2(x) torch.allclose(y, y2) mse = (x - y).pow(2).mean() assert mse < 1e-5, mse e3nn-0.6.0/tests/o3/000077500000000000000000000000001514371756200140565ustar00rootroot00000000000000e3nn-0.6.0/tests/o3/angular_spherical_harmonics_test.py000066400000000000000000000033321514371756200232160ustar00rootroot00000000000000import math import functools import torch from e3nn import o3 from e3nn.util.test import assert_auto_jitable, assert_torch_compile def test_jit(float_tolerance) -> None: sh = o3.SphericalHarmonicsAlphaBeta([0, 1, 2]) a = torch.randn(5, 4) b = torch.randn(5, 4) jited = assert_auto_jitable(sh) assert (sh(a, b) - jited(a, b)).abs().max() < float_tolerance pt2 = assert_torch_compile("inductor", functools.partial(o3.SphericalHarmonicsAlphaBeta, [0, 1, 2]), a, b) assert (sh(a, b) - pt2(a, b)).abs().max() < float_tolerance def test_sh_equivariance1(float_tolerance) -> None: r"""test - compose - spherical_harmonics_alpha_beta - irrep """ for l in range(7 + 1): a, b, _ = o3.rand_angles() alpha, beta, gamma = o3.rand_angles() ra, rb, _ = o3.compose_angles(alpha, beta, gamma, a, b, torch.tensor(0.0)) Yrx = o3.spherical_harmonics_alpha_beta(l, ra, rb) Y = o3.spherical_harmonics_alpha_beta(l, a, b) DrY = o3.wigner_D(l, alpha, beta, gamma) @ Y assert (Yrx - DrY).abs().max() < float_tolerance * Y.abs().max() def test_sh_is_in_irrep(float_tolerance) -> None: for l in range(4 + 1): a, b, _ = o3.rand_angles() Y = o3.spherical_harmonics_alpha_beta(l, a, b) * math.sqrt(4 * math.pi) / math.sqrt(2 * l + 1) D = o3.wigner_D(l, a, b, torch.zeros(())) assert (Y - D[:, l]).abs().max() < float_tolerance def test_sh_same(float_tolerance) -> None: for l in range(4 + 1): x = torch.randn(10, 3) a, b = o3.xyz_to_angles(x) y1 = o3.spherical_harmonics(l, x, True) y2 = o3.spherical_harmonics_alpha_beta(l, a, b) assert (y1 - y2).abs().max() < float_tolerance e3nn-0.6.0/tests/o3/cartesian_spherical_harmonics_test.py000066400000000000000000000137251514371756200235450ustar00rootroot00000000000000import math import io import functools import pytest import torch from e3nn import o3 from e3nn import set_optimization_defaults, get_optimization_defaults from e3nn.util.test import assert_auto_jitable, assert_equivariant, assert_torch_compile def test_weird_call() -> None: o3.spherical_harmonics([4, 1, 2, 3, 3, 1, 0], torch.randn(2, 1, 2, 3), False) def test_weird_irreps() -> None: # string input o3.spherical_harmonics("0e + 1o", torch.randn(1, 3), False) # Weird multipliciteis irreps = o3.Irreps("1x0e + 4x1o + 3x2e") out = o3.spherical_harmonics(irreps, torch.randn(7, 3), True) assert out.shape[-1] == irreps.dim # Bad parity with pytest.raises(ValueError): # L = 1 shouldn't be even for a vector input o3.SphericalHarmonics( irreps_out="1x0e + 4x1e + 3x2e", normalize=True, normalization="integral", irreps_in="1o", ) # Good parity but psuedovector input _ = o3.SphericalHarmonics(irreps_in="1e", irreps_out="1x0e + 4x1e + 3x2e", normalize=True) # Invalid input with pytest.raises(ValueError): _ = o3.SphericalHarmonics(irreps_in="1e + 3o", irreps_out="1x0e + 4x1e + 3x2e", normalize=True) # invalid def test_zeros() -> None: assert torch.allclose( o3.spherical_harmonics([0, 1], torch.zeros(1, 3), False, normalization="norm"), torch.tensor([[1, 0, 0, 0.0]]) ) def test_equivariance(float_tolerance) -> None: lmax = 5 irreps = o3.Irreps.spherical_harmonics(lmax) x = torch.randn(2, 3) abc = o3.rand_angles() y1 = o3.spherical_harmonics(irreps, x @ o3.angles_to_matrix(*abc).T, False) y2 = o3.spherical_harmonics(irreps, x, False) @ irreps.D_from_angles(*abc).T assert (y1 - y2).abs().max() < 10 * float_tolerance def test_backwardable() -> None: lmax = 3 ls = list(range(lmax + 1)) xyz = torch.tensor( [ [0.0, 0.0, 1.0], [1.0, 0, 0], [0.0, 10.0, 0], [0.435, 0.7644, 0.023], ], requires_grad=True, dtype=torch.float64, ) def func(pos): return o3.spherical_harmonics(ls, pos, False) assert torch.autograd.gradcheck(func, (xyz,), check_undefined_grad=False) @pytest.mark.parametrize("l", range(10 + 1)) def test_normalization(float_tolerance, l) -> None: n = o3.spherical_harmonics(l, torch.randn(3), normalize=True, normalization="integral").pow(2).mean() assert abs(n - 1 / (4 * math.pi)) < float_tolerance n = o3.spherical_harmonics(l, torch.randn(3), normalize=True, normalization="norm").norm() assert abs(n - 1) < float_tolerance n = o3.spherical_harmonics(l, torch.randn(3), normalize=True, normalization="component").pow(2).mean() assert abs(n - 1) < float_tolerance def test_closure() -> None: r""" integral of Ylm * Yjn = delta_lj delta_mn integral of 1 over the unit sphere = 4 pi """ x = torch.randn(1_000_000, 3) Ys = [o3.spherical_harmonics(l, x, True) for l in range(0, 3 + 1)] for l1, Y1 in enumerate(Ys): for l2, Y2 in enumerate(Ys): m = Y1[:, :, None] * Y2[:, None, :] m = m.mean(0) * 4 * math.pi if l1 == l2: i = torch.eye(2 * l1 + 1) assert (m - i).abs().max() < 0.01 else: assert m.abs().max() < 0.01 @pytest.mark.parametrize("l", range(11 + 1)) def test_parity(float_tolerance, l) -> None: r""" (-1)^l Y(x) = Y(-x) """ x = torch.randn(3) Y1 = (-1) ** l * o3.spherical_harmonics(l, x, False) Y2 = o3.spherical_harmonics(l, -x, False) assert (Y1 - Y2).abs().max() < float_tolerance @pytest.mark.parametrize("l", range(9 + 1)) def test_recurrence_relation(float_tolerance, l) -> None: if torch.get_default_dtype() != torch.float64 and l > 6: pytest.xfail("we expect this to fail for high l and single precision") x = torch.randn(3, requires_grad=True) a = o3.spherical_harmonics(l + 1, x, False) b = torch.einsum("ijk,j,k->i", o3.wigner_3j(l + 1, l, 1), o3.spherical_harmonics(l, x, False), x) alpha = b.norm() / a.norm() assert (a / a.norm() - b / b.norm()).abs().max() < 10 * float_tolerance def f(x): return o3.spherical_harmonics(l + 1, x, False) a = torch.autograd.functional.jacobian(f, x) b = (l + 1) / alpha * torch.einsum("ijk,j->ik", o3.wigner_3j(l + 1, l, 1), o3.spherical_harmonics(l, x, False)) assert (a - b).abs().max() < 100 * float_tolerance @pytest.mark.parametrize("normalization", ["integral", "component", "norm"]) @pytest.mark.parametrize("normalize", [True, False]) def test_module(normalization, normalize) -> None: l = o3.Irreps("0e + 1o + 3o") sp = o3.SphericalHarmonics(l, normalize, normalization) xyz = torch.randn(11, 3) sp_jit = assert_auto_jitable(sp) assert torch.allclose(sp_jit(xyz), o3.spherical_harmonics(l, xyz, normalize, normalization)) assert_equivariant(sp) sp_pt2 = assert_torch_compile("inductor", functools.partial(o3.SphericalHarmonics, l, normalize, normalization), xyz) assert torch.allclose(sp_pt2(xyz), o3.spherical_harmonics(l, xyz, normalize, normalization)) @pytest.mark.parametrize("jit_mode", ["inductor", "eager"]) def test_pickle(jit_mode): l = o3.Irreps("0e + 1o + 3o") # Turning off the torch.jit.script in CodeGenMix to enable torch.compile. jit_mode_before = get_optimization_defaults()["jit_mode"] try: # Cannot pickle with compiled submodules set_optimization_defaults(jit_mode=jit_mode) sp = o3.SphericalHarmonics(l, normalization="integral", normalize=True) buffer = io.BytesIO() torch.save(sp.state_dict(), buffer) buffer.seek(0) sp2 = o3.SphericalHarmonics(l, normalization="integral", normalize=True) sp2.load_state_dict(torch.load(buffer, weights_only=False)) xyz = torch.randn(11, 3) assert torch.allclose(sp(xyz), sp2(xyz)) finally: set_optimization_defaults(jit_mode=jit_mode_before) e3nn-0.6.0/tests/o3/experimental/000077500000000000000000000000001514371756200165535ustar00rootroot00000000000000e3nn-0.6.0/tests/o3/experimental/benchmark_pt2.py000066400000000000000000000041171514371756200216470ustar00rootroot00000000000000# flake8: noqa def main(): import torch from torch._inductor.utils import print_performance # Borrowed from https://github.com/pytorch-labs/gpt-fast/blob/db7b273ab86b75358bd3b014f1f022a19aba4797/generate.py#L16-L18 torch.set_float32_matmul_precision("high") import torch._dynamo.config import torch._inductor.config torch._inductor.config.coordinate_descent_tuning = True torch._inductor.config.triton.unique_kernel_names = True device = "cuda" compile_mode = ( "max-autotune" # Bringing out all of the tricks that Torch 2.0 has but "reduce-overhead" should work as well ) from e3nn import o3, util import numpy as np from torch import nn import time LMAX = 8 CHANNEL = 128 BATCH = 100 for lmax in range(1, LMAX + 1): irreps = o3.Irreps.spherical_harmonics(lmax) irreps_x = (CHANNEL * irreps).regroup() x = irreps_x.randn(BATCH, -1).to(device=device) irreps_y = irreps y = irreps_y.randn(BATCH, -1).to(device=device) print(f"{irreps_x} \otimes {irreps_y}") tp = o3.FullTensorProduct(irreps_x, irreps_y) # Doesnt work with fullgraph=True tp_jit_compile = util.jit.compile(tp).to(device=device) tp_compile = torch.compile(tp, mode=compile_mode).to(device=device) print( f"TP JIT lmax {lmax} channel {CHANNEL} batch {BATCH}: {print_performance(lambda: tp_jit_compile(x, y), times=100, repeat=10)*1000:.3f}ms" ) print( f"TP Torch 2.0 lmax {lmax} channel {CHANNEL} batch {BATCH}: {print_performance(lambda: tp_compile(x, y), times=100, repeat=10)*1000:.3f}ms" ) tp_experimental = o3.experimental.FullTensorProductv2(irreps_x, irreps_y) tp_experimental_compile = torch.compile(tp_experimental, mode=compile_mode, fullgraph=True).to(device=device) print( f"TP Experimental Torch 2.0 lmax {lmax} channel {CHANNEL} batch {BATCH}: {print_performance(lambda: tp_experimental_compile(x, y), times=100, repeat=10)*1000:.3f}ms" ) if __name__ == "__main__": main() e3nn-0.6.0/tests/o3/experimental/test_elementwise_tp.py000066400000000000000000000012661514371756200232150ustar00rootroot00000000000000import torch from e3nn import o3 import pytest @pytest.mark.parametrize("irreps_in1, irreps_in2", [("15x0e", "5x0e + 5x1o + 5x1e"), ("2x0e + 1x1e", "2x0o + 1x1e")]) def test_elementwise_tp(irreps_in1, irreps_in2): irreps_in1 = o3.Irreps(irreps_in1) irreps_in2 = o3.Irreps(irreps_in2) x1 = irreps_in1.randn(5, -1) x2 = irreps_in2.randn(5, -1) tp = o3.ElementwiseTensorProduct(irreps_in1, irreps_in2) tp_pt2 = torch.compile(o3.experimental.ElementwiseTensorProductv2(irreps_in1, irreps_in2), fullgraph=True) result_tp = tp(x1, x2) result_tp2 = tp_pt2(x1, x2) assert tp.irreps_out == tp_pt2.irreps_out torch.testing.assert_close(result_tp, result_tp2) e3nn-0.6.0/tests/o3/experimental/test_fulltp.py000066400000000000000000000010611514371756200214700ustar00rootroot00000000000000import torch from e3nn import o3 import pytest @pytest.mark.parametrize("irreps_in1", ["0e", "0e + 1e"]) @pytest.mark.parametrize("irreps_in2", ["2x0e", "2x0e + 3x1e"]) def test_fulltp(irreps_in1, irreps_in2): x = o3.Irreps(irreps_in1).randn(10, -1) y = o3.Irreps(irreps_in2).randn(10, -1) tp_pt2 = torch.compile(o3.experimental.FullTensorProductv2(irreps_in1, irreps_in2), fullgraph=True) tp = o3.FullTensorProduct(irreps_in1, irreps_in2) assert tp_pt2.irreps_out == tp.irreps_out torch.testing.assert_close(tp_pt2(x, y), tp(x, y)) e3nn-0.6.0/tests/o3/irreps_test.py000066400000000000000000000070551514371756200170020ustar00rootroot00000000000000import pytest from e3nn import o3 from e3nn.o3 import irrep def test_creation() -> None: o3.Irrep(3, 1) ir = o3.Irrep("3e") o3.Irrep(ir) assert o3.Irrep("10o") == o3.Irrep(10, -1) assert o3.Irrep("1y") == o3.Irrep("1o") irreps = o3.Irreps(ir) o3.Irreps(irreps) o3.Irreps([(32, (4, -1))]) o3.Irreps("11e") assert o3.Irreps("16x1e + 32 x 2o") == o3.Irreps([(16, (1, 1)), (32, (2, -1))]) o3.Irreps(["1e", "2o"]) o3.Irreps([(16, "3e"), "1e"]) o3.Irreps([(16, "3e"), "1e", (256, (1, -1))]) assert irrep.l0e == o3.Irrep("0e") from e3nn.o3.irrep import l1y assert l1y == o3.Irrep("1y") def test_properties() -> None: irrep = o3.Irrep("3e") assert irrep.l == 3 assert irrep.p == 1 assert irrep.dim == 7 assert o3.Irrep(repr(irrep)) == irrep l, p = o3.Irrep("5o") assert l == 5 assert p == -1 iterator = o3.Irrep.iterator(5) assert len(list(iterator)) == 12 iterator = o3.Irrep.iterator() for x in range(100): irrep = next(iterator) assert irrep.l == x // 2 assert irrep.p in (-1, 1) assert irrep.dim == 2 * (x // 2) + 1 irreps = o3.Irreps("4x1e + 6x2e + 12x2o") assert o3.Irreps(repr(irreps)) == irreps def test_arithmetic() -> None: assert 3 * o3.Irrep("6o") == o3.Irreps("3x6o") products = list(o3.Irrep("1o") * o3.Irrep("2e")) assert products == [o3.Irrep("1o"), o3.Irrep("2o"), o3.Irrep("3o")] assert o3.Irrep("4o") + o3.Irrep("7e") == o3.Irreps("4o + 7e") assert 2 * o3.Irreps("2x2e + 4x1o") == o3.Irreps("2x2e + 4x1o + 2x2e + 4x1o") assert o3.Irreps("2x2e + 4x1o") * 2 == o3.Irreps("2x2e + 4x1o + 2x2e + 4x1o") assert o3.Irreps("1o + 4o") + o3.Irreps("1o + 7e") == o3.Irreps("1o + 4o + 1o + 7e") def test_empty_irreps() -> None: assert o3.Irreps() == o3.Irreps("") == o3.Irreps([]) assert len(o3.Irreps()) == 0 assert o3.Irreps().dim == 0 assert o3.Irreps().ls == [] assert o3.Irreps().num_irreps == 0 def test_getitem() -> None: irreps = o3.Irreps("16x1e + 3e + 2e + 5o") assert irreps[0] == (16, o3.Irrep("1e")) assert irreps[3] == (1, o3.Irrep("5o")) assert irreps[-1] == (1, o3.Irrep("5o")) sliced = irreps[2:] assert isinstance(sliced, o3.Irreps) assert sliced == o3.Irreps("2e + 5o") def test_cat() -> None: irreps = o3.Irreps("4x1e + 6x2e + 12x2o") + o3.Irreps("1x1e + 2x2e + 12x4o") assert len(irreps) == 6 assert irreps.ls == [1] * 4 + [2] * 6 + [2] * 12 + [1] * 1 + [2] * 2 + [4] * 12 assert irreps.lmax == 4 assert irreps.num_irreps == 4 + 6 + 12 + 1 + 2 + 12 def test_contains() -> None: assert o3.Irrep("2e") in o3.Irreps("3x0e + 2x2e + 1x3o") assert o3.Irrep("2o") not in o3.Irreps("3x0e + 2x2e + 1x3o") def test_errors() -> None: """Test invalid irrep specifications""" # Irrep with pytest.raises(ValueError): o3.Irrep(-1) with pytest.raises(ValueError): o3.Irrep(1, p=2) with pytest.raises(ValueError): o3.Irrep("-1e") # Irreps with pytest.raises(ValueError): o3.Irreps("-1x1e") with pytest.raises(ValueError): o3.Irreps("1x-1e") with pytest.raises(ValueError): o3.Irreps("bla") @pytest.mark.xfail() def test_fail1() -> None: o3.Irreps([(32, 1)]) def test_slice_by_mul(): assert o3.Irreps("10x0e").slice_by_mul[1:4] == o3.Irreps("3x0e") assert o3.Irreps("10x0e + 10x1e").slice_by_mul[5:15] == o3.Irreps("5x0e + 5x1e") assert o3.Irreps("10x0e + 2e + 10x1e").slice_by_mul[5:15] == o3.Irreps( "5x0e + 2e + 4x1e" )e3nn-0.6.0/tests/o3/linear_test.py000066400000000000000000000156211514371756200167460ustar00rootroot00000000000000import pytest import functools from typing import Optional import torch from e3nn import o3 from e3nn.util.test import assert_equivariant, assert_auto_jitable, random_irreps, assert_normalized, assert_torch_compile class SlowLinear(torch.nn.Module): r"""Inefficient implimentation of Linear relying on TensorProduct.""" def __init__( self, irreps_in, irreps_out, internal_weights=None, shared_weights=None, ) -> None: super().__init__() irreps_in = o3.Irreps(irreps_in) irreps_out = o3.Irreps(irreps_out) instr = [ (i_in, 0, i_out, "uvw", True, 1.0) for i_in, (_, ir_in) in enumerate(irreps_in) for i_out, (_, ir_out) in enumerate(irreps_out) if ir_in == ir_out ] self.tp = o3.TensorProduct( irreps_in, "0e", irreps_out, instr, internal_weights=internal_weights, shared_weights=shared_weights, ) self.output_mask = self.tp.output_mask self.irreps_in = irreps_in self.irreps_out = irreps_out def forward(self, features, weight: Optional[torch.Tensor] = None): ones = torch.ones(features.shape[:-1] + (1,), dtype=features.dtype, device=features.device) return self.tp(features, ones, weight) def test_linear() -> None: irreps_in = o3.Irreps("1e + 2e + 3x3o") irreps_out = o3.Irreps("1e + 2e + 3x3o") m = o3.Linear(irreps_in, irreps_out) m(torch.randn(irreps_in.dim)) assert_equivariant(m) assert_auto_jitable(m) assert_torch_compile( "inductor", functools.partial(o3.Linear, irreps_in, irreps_out), torch.randn(irreps_in.dim), ) assert_normalized(m, n_weight=100, n_input=10_000, atol=0.5) def test_bias() -> None: irreps_in = o3.Irreps("2x0e + 1e + 2x0e + 0o") irreps_out = o3.Irreps("3x0e + 1e + 3x0e + 5x0e + 0o") m = o3.Linear(irreps_in, irreps_out, biases=[True, False, False, True, False]) with torch.no_grad(): m.bias[:].fill_(1.0) x = m(torch.zeros(irreps_in.dim)) assert torch.allclose(x, torch.tensor([1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0])) assert_equivariant(m) assert_auto_jitable(m) m = o3.Linear("0e + 0o + 1e + 1o", "10x0e + 0o + 1e + 1o", biases=True) assert_equivariant(m) assert_auto_jitable(m) assert_normalized(m, n_weight=100, n_input=10_000, atol=0.5, weights=[m.weight]) def test_single_out() -> None: l1 = o3.Linear("5x0e", "5x0e") l2 = o3.Linear("5x0e", "5x0e + 3x0o") with torch.no_grad(): l1.weight[:] = l2.weight x = torch.randn(3, 5) out1 = l1(x) out2 = l2(x) assert out1.shape == (3, 5) assert out2.shape == (3, 8) assert torch.allclose(out1, out2[:, :5]) assert torch.all(out2[:, 5:] == 0) # We want to be sure to test a multiple-same L case, a single irrep case, and an empty irrep case @pytest.mark.parametrize("irreps_in", ["5x0e", "1e + 2e + 4x1e + 3x3o", "2x1o + 0x3e"] + random_irreps(n=4)) @pytest.mark.parametrize("irreps_out", ["5x0e", "1e + 2e + 3x3o + 3x1e", "2x1o + 0x3e"] + random_irreps(n=4)) def test_linear_like_tp(irreps_in, irreps_out) -> None: """Test that Linear gives the same results as the corresponding TensorProduct.""" m = o3.Linear(irreps_in, irreps_out) m_true = SlowLinear(irreps_in, irreps_out) with torch.no_grad(): m_true.tp.weight[:] = m.weight inp = torch.randn(4, m.irreps_in.dim) assert torch.allclose( m(inp), m_true(inp), atol={torch.float32: 1e-6, torch.float64: 1e-10}[torch.get_default_dtype()], ) def test_output_mask() -> None: irreps_in = o3.Irreps("1e + 2e") irreps_out = o3.Irreps("3e + 5x2o") m = o3.Linear(irreps_in, irreps_out) assert torch.all(m.output_mask == torch.zeros(m.irreps_out.dim, dtype=torch.bool)) def test_instructions_parameter() -> None: m = o3.Linear("4x0e + 3x4o", "1x2e + 4x0o") assert len(m.instructions) == 0 assert not torch.any(m.output_mask) with pytest.raises(ValueError): m = o3.Linear( "4x0e + 3x4o", "1x2e + 4x0e", # invalid mixture of 0e and 2e instructions=[(0, 0)], ) with pytest.raises(IndexError): m = o3.Linear("4x0e + 3x4o", "1x2e + 4x0e", instructions=[(4, 0)]) def test_empty_instructions() -> None: m = o3.Linear(o3.Irreps.spherical_harmonics(3), o3.Irreps.spherical_harmonics(3), instructions=[]) assert len(m.instructions) == 0 assert not torch.any(m.output_mask) inp = m.irreps_in.randn(3, -1) out = m(inp) assert torch.all(out == 0.0) def test_default_instructions() -> None: m = o3.Linear( "4x0e + 3x1o + 2x0e", "2x1o + 8x0e", ) assert len(m.instructions) == 3 assert torch.all(m.output_mask) ins_set = set((ins.i_in, ins.i_out) for ins in m.instructions) assert ins_set == {(0, 1), (1, 0), (2, 1)} assert set(ins.path_shape for ins in m.instructions) == {(4, 8), (2, 8), (3, 2)} def test_instructions() -> None: m = o3.Linear("4x0e + 3x1o + 2x0e", "2x1o + 8x0e", instructions=[(0, 1), (1, 0)]) inp = m.irreps_in.randn(3, -1) inp[:, : m.irreps_in[:2].dim] = 0.0 out = m(inp) assert torch.allclose(out, torch.zeros(1)) def test_weight_view() -> None: m = o3.Linear("4x0e + 3x1o + 2x0e", "2x1o + 8x0e", instructions=[(0, 1), (1, 0)]) inp = m.irreps_in.randn(3, -1) assert m.weight_view_for_instruction(0).shape == (4, 8) assert m.weight_view_for_instruction(1).shape == (3, 2) # Make weights going to output 0 all zeros with torch.no_grad(): m.weight_view_for_instruction(1).fill_(0.0) out = m(inp) assert torch.allclose(out[:, :6], torch.zeros(1)) for w in m.weight_views(): with torch.no_grad(): w.fill_(2.0) for i, ins, w in m.weight_views(yield_instruction=True): assert (w - 2.0).norm() == 0.0 def test_weight_view_unshared() -> None: m = o3.Linear("4x0e + 3x1o + 2x0e", "2x1o + 8x0e", instructions=[(0, 1), (1, 0)], shared_weights=False) batchdim = 7 inp = m.irreps_in.randn(batchdim, -1) weights = torch.randn(batchdim, m.weight_numel) assert m.weight_view_for_instruction(0, weights).shape == (batchdim, 4, 8) assert m.weight_view_for_instruction(1, weights).shape == (batchdim, 3, 2) # Make weights going to output 0 all zeros with torch.no_grad(): m.weight_view_for_instruction(1, weights).fill_(0.0) out = m(inp, weights) assert torch.allclose(out[:, :6], torch.zeros(1)) def test_f() -> None: m = o3.Linear("0e + 1e + 2e", "0e + 2x1e + 2e", f_in=44, f_out=25, _optimize_einsums=False) assert_equivariant(m, args_in=[torch.randn(10, 44, 9)]) m = assert_auto_jitable(m) y = m(torch.randn(10, 44, 9)) assert m.weight_numel == 4 assert m.weight.numel() == 44 * 25 * 4 assert 0.7 < y.pow(2).mean() < 1.4 e3nn-0.6.0/tests/o3/norm_test.py000066400000000000000000000031141514371756200164410ustar00rootroot00000000000000import pytest import functools import torch from e3nn import o3 from e3nn.util.test import assert_equivariant, assert_auto_jitable, random_irreps, assert_torch_compile @pytest.mark.parametrize("irreps_in", ["", "5x0e", "1e + 2e + 4x1e + 3x3o"] + random_irreps(n=4)) @pytest.mark.parametrize("squared", [True, False]) def test_norm(irreps_in, squared) -> None: m = o3.Norm(irreps_in, squared=squared) m(torch.randn(m.irreps_in.dim)) if m.irreps_in.dim == 0: return assert_equivariant(m) assert_torch_compile("inductor", functools.partial(o3.Norm, irreps_in, squared=squared), torch.randn(m.irreps_in.dim)) assert_auto_jitable(m) @pytest.mark.parametrize("squared", [True, False]) def test_grad(squared) -> None: """Confirm has zero grad at zero""" irreps_in = o3.Irreps("2x0e + 3x0o") norm = o3.Norm(irreps_in, squared=squared) with torch.autograd.set_detect_anomaly(True): inp = torch.zeros(norm.irreps_in.dim, requires_grad=True) out = norm(inp) grads = torch.autograd.grad( outputs=out.sum(), inputs=inp, )[0] assert torch.allclose(grads, torch.zeros(1)) @pytest.mark.parametrize("squared", [True, False]) def test_vector_norm(squared) -> None: n = 10 batch = 3 irreps_in = o3.Irreps([(n, (1, -1))]) vecs = torch.randn(batch, n, 3) norm_mod = o3.Norm(irreps_in, squared=squared) norms = norm_mod(vecs.reshape(batch, -1)) norms_true = vecs.norm(dim=-1) if squared: norms_true.square_() assert torch.allclose(norms_true, norms.reshape(batch, n)) e3nn-0.6.0/tests/o3/reduce_tensor_test.py000066400000000000000000000071371514371756200203400ustar00rootroot00000000000000import tempfile import functools import torch from e3nn import o3 from e3nn.util.test import assert_auto_jitable, assert_equivariant, assert_torch_compile def test_save_load() -> None: tp1 = o3.ReducedTensorProducts("ij=-ji", i="5x0e + 1e") with tempfile.NamedTemporaryFile(suffix=".pth") as tmp: torch.save(tp1.state_dict(), tmp.name) tp2 = o3.ReducedTensorProducts("ij=-ji", i="5x0e + 1e") tp2.load_state_dict(torch.load(tmp.name, weights_only=False)) xs = (torch.randn(2, 5 + 3), torch.randn(2, 5 + 3)) assert torch.allclose(tp1(*xs), tp2(*xs)) assert torch.allclose(tp1.change_of_basis, tp2.change_of_basis) def test_antisymmetric_matrix(float_tolerance) -> None: tp = o3.ReducedTensorProducts("ij=-ji", i="5x0e + 1e") Q = tp.change_of_basis x = torch.randn(2, 5 + 3) assert_equivariant(tp, irreps_in=tp.irreps_in, irreps_out=tp.irreps_out) assert_torch_compile("inductor", functools.partial(o3.ReducedTensorProducts, "ij=-ji", i="5x0e + 1e"), *x) assert_auto_jitable(tp) assert (tp(*x) - torch.einsum("xij,i,j", Q, *x)).abs().max() < float_tolerance assert (Q + torch.einsum("xij->xji", Q)).abs().max() < float_tolerance def test_reduce_tensor_Levi_Civita_symbol(float_tolerance) -> None: tp = o3.ReducedTensorProducts("ijk=-ikj=-jik", i="1e") assert tp.irreps_out == o3.Irreps("0e") assert_equivariant(tp, irreps_in=tp.irreps_in, irreps_out=tp.irreps_out) assert_auto_jitable(tp) Q = tp.change_of_basis x = torch.randn(3, 3) assert (tp(*x) - torch.einsum("xijk,i,j,k", Q, *x)).abs().max() < float_tolerance assert (Q + torch.einsum("xijk->xikj", Q)).abs().max() < float_tolerance assert (Q + torch.einsum("xijk->xjik", Q)).abs().max() < float_tolerance def test_reduce_tensor_antisymmetric_L2(float_tolerance) -> None: tp = o3.ReducedTensorProducts("ijk=-ikj=-jik", i="2e") assert_equivariant(tp, irreps_in=tp.irreps_in, irreps_out=tp.irreps_out) assert_auto_jitable(tp) Q = tp.change_of_basis x = torch.randn(3, 5) assert (tp(*x) - torch.einsum("xijk,i,j,k", Q, *x)).abs().max() < float_tolerance assert (Q + torch.einsum("xijk->xikj", Q)).abs().max() < float_tolerance assert (Q + torch.einsum("xijk->xjik", Q)).abs().max() < float_tolerance def test_reduce_tensor_elasticity_tensor(float_tolerance) -> None: tp = o3.ReducedTensorProducts("ijkl=jikl=klij", i="1e") assert tp.irreps_out.dim == 21 assert_equivariant(tp, irreps_in=tp.irreps_in, irreps_out=tp.irreps_out) assert_auto_jitable(tp) Q = tp.change_of_basis x = torch.randn(4, 3) assert (tp(*x) - torch.einsum("xijkl,i,j,k,l", Q, *x)).abs().max() < float_tolerance assert (Q - torch.einsum("xijkl->xjikl", Q)).abs().max() < float_tolerance assert (Q - torch.einsum("xijkl->xijlk", Q)).abs().max() < float_tolerance assert (Q - torch.einsum("xijkl->xklij", Q)).abs().max() < float_tolerance def test_reduce_tensor_elasticity_tensor_parity(float_tolerance) -> None: tp = o3.ReducedTensorProducts("ijkl=jikl=klij", i="1o") assert tp.irreps_out.dim == 21 assert all(ir.p == 1 for _, ir in tp.irreps_out) assert_equivariant(tp, irreps_in=tp.irreps_in, irreps_out=tp.irreps_out) assert_auto_jitable(tp) Q = tp.change_of_basis x = torch.randn(4, 3) assert (tp(*x) - torch.einsum("xijkl,i,j,k,l", Q, *x)).abs().max() < float_tolerance assert (Q - torch.einsum("xijkl->xjikl", Q)).abs().max() < float_tolerance assert (Q - torch.einsum("xijkl->xijlk", Q)).abs().max() < float_tolerance assert (Q - torch.einsum("xijkl->xklij", Q)).abs().max() < float_tolerance e3nn-0.6.0/tests/o3/rotation_test.py000066400000000000000000000065031514371756200173320ustar00rootroot00000000000000import torch from e3nn import o3 def test_xyz(float_tolerance) -> None: R = o3.rand_matrix(10) assert (R @ R.transpose(-1, -2) - torch.eye(3)).abs().max() < float_tolerance a, b, c = o3.matrix_to_angles(R) pos1 = o3.angles_to_xyz(a, b) pos2 = R @ torch.tensor([0, 1.0, 0]) assert torch.allclose(pos1, pos2, atol=float_tolerance) a2, b2 = o3.xyz_to_angles(pos2) assert (a - a2).abs().max() < float_tolerance assert (b - b2).abs().max() < float_tolerance def test_conversions(float_tolerance) -> None: def wrap(f): def g(x): if isinstance(x, tuple): return f(*x) else: return f(x) return g def identity(x): return x conv = [ [identity, wrap(o3.angles_to_matrix), wrap(o3.angles_to_axis_angle), wrap(o3.angles_to_quaternion)], [wrap(o3.matrix_to_angles), identity, wrap(o3.matrix_to_axis_angle), wrap(o3.matrix_to_quaternion)], [wrap(o3.axis_angle_to_angles), wrap(o3.axis_angle_to_matrix), identity, wrap(o3.axis_angle_to_quaternion)], [wrap(o3.quaternion_to_angles), wrap(o3.quaternion_to_matrix), wrap(o3.quaternion_to_axis_angle), identity], ] R1 = o3.rand_matrix(100) path = [1, 2, 3, 0, 2, 0, 3, 1, 3, 2, 1, 0, 1] g = R1 for i, j in zip(path, path[1:]): g = conv[i][j](g) R2 = g assert (R1 - R2).abs().median() < float_tolerance def test_compose(float_tolerance) -> None: q1 = o3.rand_quaternion(10) q2 = o3.rand_quaternion(10) q = o3.compose_quaternion(q1, q2) R1 = o3.quaternion_to_matrix(q1) R2 = o3.quaternion_to_matrix(q2) R = R1 @ R2 abc1 = o3.quaternion_to_angles(q1) abc2 = o3.quaternion_to_angles(q2) abc = o3.compose_angles(*abc1, *abc2) ax1, a1 = o3.quaternion_to_axis_angle(q1) ax2, a2 = o3.quaternion_to_axis_angle(q2) ax, a = o3.compose_axis_angle(ax1, a1, ax2, a2) R1 = o3.quaternion_to_matrix(q) R2 = R R3 = o3.angles_to_matrix(*abc) R4 = o3.axis_angle_to_matrix(ax, a) assert (R1 - R2).abs().max().median() < float_tolerance assert (R1 - R3).abs().max().median() < float_tolerance assert (R1 - R4).abs().max().median() < float_tolerance def test_inverse_angles(float_tolerance) -> None: a = o3.rand_angles() b = o3.inverse_angles(*a) c = o3.compose_angles(*a, *b) e = o3.identity_angles(requires_grad=True) rc = o3.angles_to_matrix(*c) re = o3.angles_to_matrix(*e) assert (rc - re).abs().max() < float_tolerance # test `requires_grad` re.sum().backward() assert e[0].grad is not None def test_rand_axis_angle() -> None: axis, angle = o3.rand_axis_angle(1_000_000) x = o3.axis_angle_to_matrix(axis, angle) @ torch.tensor([0.2, 0.5, 0.3]) assert x[:, 0].mean().max() < 0.005 assert x[:, 1].mean().max() < 0.005 assert x[:, 2].mean().max() < 0.005 def test_matrix_xyz(float_tolerance) -> None: x = torch.randn(100, 3) y = torch.einsum("zij,zj->zi", o3.matrix_x(torch.randn(100)), x) assert (x[:, 0] - y[:, 0]).abs().max() < float_tolerance y = torch.einsum("zij,zj->zi", o3.matrix_y(torch.randn(100)), x) assert (x[:, 1] - y[:, 1]).abs().max() < float_tolerance y = torch.einsum("zij,zj->zi", o3.matrix_z(torch.randn(100)), x) assert (x[:, 2] - y[:, 2]).abs().max() < float_tolerance e3nn-0.6.0/tests/o3/s2_test.py000066400000000000000000000031671514371756200160220ustar00rootroot00000000000000import torch import pytest from e3nn.o3 import ToS2Grid, FromS2Grid, Irreps from e3nn.util.test import assert_equivariant @pytest.mark.parametrize("res_a", [11, 12, 13, 14, 15, 16, None]) @pytest.mark.parametrize("res_b", [12, 14, 16, None]) @pytest.mark.parametrize("lmax", [0, 1, 5, None]) def test_inverse1(float_tolerance, lmax, res_b, res_a) -> None: if lmax is None and res_b is None and res_a is None: return m = FromS2Grid((res_b, res_a), lmax) k = ToS2Grid(lmax, (res_b, res_a)) res_b, res_a = m.res_beta, m.res_alpha x = torch.randn(res_b, res_a) x = k(m(x)) # remove high frequencies y = k(m(x)) assert (x - y).abs().max().item() < float_tolerance @pytest.mark.parametrize("res_a", [11, 12, 13, 14, 15, 16, None]) @pytest.mark.parametrize("res_b", [12, 14, 16, None]) @pytest.mark.parametrize("lmax", [0, 1, 5, None]) def test_inverse2(float_tolerance, lmax, res_b, res_a) -> None: if lmax is None and res_b is None and res_a is None: return m = FromS2Grid((res_b, res_a), lmax) k = ToS2Grid(lmax, (res_b, res_a)) lmax = m.lmax x = torch.randn((lmax + 1) ** 2) y = m(k(x)) assert (x - y).abs().max().item() < float_tolerance @pytest.mark.parametrize("res_a", [100, 101]) @pytest.mark.parametrize("res_b", [98, 100]) @pytest.mark.parametrize("lmax", [1, 5]) def test_equivariance(lmax, res_b, res_a) -> None: m = FromS2Grid((res_b, res_a), lmax) k = ToS2Grid(lmax, (res_b, res_a)) def f(x): y = k(x) y = y.exp() return m(y) f.irreps_in = f.irreps_out = Irreps.spherical_harmonics(lmax) assert_equivariant(f) e3nn-0.6.0/tests/o3/tensor_product_sub_test.py000066400000000000000000000074051514371756200214200ustar00rootroot00000000000000import torch import functools from e3nn import o3 from e3nn.nn import Identity from e3nn.o3 import FullyConnectedTensorProduct, FullTensorProduct, Norm, TensorSquare from e3nn.util.test import assert_equivariant, assert_auto_jitable, assert_torch_compile def test_fully_connected() -> None: irreps_in1 = o3.Irreps("1e + 2e + 3x3o") irreps_in2 = o3.Irreps("1e + 2e + 3x3o") irreps_out = o3.Irreps("1e + 2e + 3x3o") m = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_out) print(m) m(torch.randn(irreps_in1.dim), torch.randn(irreps_in2.dim)) assert_equivariant(m) assert_auto_jitable(m) assert_torch_compile( "inductor", functools.partial(FullyConnectedTensorProduct, irreps_in1, irreps_in2, irreps_out), torch.randn(irreps_in1.dim), torch.randn(irreps_in2.dim), ) def test_fully_connected_normalization() -> None: m = FullyConnectedTensorProduct("10x0e", "10x0e", "0e") for p in m.parameters(): p.data.fill_(1.0) n = FullyConnectedTensorProduct("3x0e + 7x0e", "3x0e + 7x0e", "0e") for p in n.parameters(): p.data.fill_(1.0) x1, x2 = torch.randn(2, 3, 10) assert torch.allclose(m(x1, x2), n(x1, x2)) def test_id() -> None: irreps_in = o3.Irreps("1e + 2e + 3x3o") irreps_out = o3.Irreps("1e + 2e + 3x3o") m = Identity(irreps_in, irreps_out) print(m) m(torch.randn(irreps_in.dim)) assert_equivariant(m) assert_auto_jitable(m, strict_shapes=False) assert_torch_compile("inductor", functools.partial(Identity, irreps_in, irreps_out), torch.randn(irreps_in.dim)) def test_full() -> None: irreps_in1 = o3.Irreps("1e + 2e + 3x3o") irreps_in2 = o3.Irreps("1e + 2x2e + 2x3o") m = FullTensorProduct(irreps_in1, irreps_in2) print(m) assert_equivariant(m) assert_auto_jitable(m) assert_torch_compile( "inductor", functools.partial(FullTensorProduct, irreps_in1, irreps_in2), irreps_in1.randn(-1), irreps_in2.randn(-1) ) def test_norm() -> None: irreps_in = o3.Irreps("3x0e + 5x1o") scalars = torch.randn(3) vecs = torch.randn(5, 3) norm = Norm(irreps_in=irreps_in) out_norms = norm(torch.cat((scalars.reshape(1, -1), vecs.reshape(1, -1)), dim=-1)) true_scalar_norms = torch.abs(scalars) true_vec_norms = torch.linalg.norm(vecs, dim=-1) assert torch.allclose(out_norms[0, :3], true_scalar_norms) assert torch.allclose(out_norms[0, 3:], true_vec_norms) assert_equivariant(norm) assert_auto_jitable(norm) assert_torch_compile( "inductor", functools.partial(Norm, irreps_in=irreps_in), torch.cat((scalars.reshape(1, -1), vecs.reshape(1, -1)), dim=-1), ) def test_square_normalization() -> None: irreps = o3.Irreps("0e + 1e + 2e") tp = TensorSquare(irreps, irrep_normalization="norm") x = irreps.randn(1_000_000, -1, normalization="norm") y = tp(x) n = Norm(tp.irreps_out, squared=True)(y) assert (n.mean(0).log().abs().exp() < 1.1).all() irreps = o3.Irreps("0e + 3x1e + 3e") tp = o3.TensorSquare(irreps, irrep_normalization="component") x = irreps.randn(1_000_000, -1, normalization="component") y = tp(x) assert (y.pow(2).mean(0).log().abs().exp() < 1.1).all() tp = TensorSquare(irreps, irrep_normalization="none") y = tp(x) assert not (y.pow(2).mean(0).log().abs().exp() < 1.1).all() # with weights tp = TensorSquare(irreps, irreps) n = 2_000 y = torch.stack([tp(tp.irreps_in.randn(n, -1), torch.randn(tp.weight_numel)) for _ in range(n)]) assert (y.pow(2).mean([0, 1]).log().abs().exp() < 1.1).all() def test_square_elasticity_tensor() -> None: tp = TensorSquare("1o") tp = TensorSquare(tp.irreps_out) assert tp.irreps_out.simplify() == o3.Irreps("2x0e + 2x2e + 4e") e3nn-0.6.0/tests/o3/tensor_product_test.py000066400000000000000000000422631514371756200205500ustar00rootroot00000000000000import random import copy import tempfile import functools import pytest import torch from e3nn.o3 import TensorProduct, FullyConnectedTensorProduct, Irreps from e3nn.util.test import assert_equivariant, assert_auto_jitable, assert_normalized, assert_torch_compile def make_tp(l1, p1, l2, p2, lo, po, mode, weight, mul: int = 25, path_weights: bool = True, **kwargs): def mul_out(mul): if mode == "uvuv": return mul**2 if mode == "uvu None: eps = float_tolerance n = 1_500 tol = 3.0 m = make_tp(l1, p1, l2, p2, lo, po, mode, weight) # bilinear x1 = torch.randn(2, m.irreps_in1.dim) x2 = torch.randn(2, m.irreps_in1.dim) y1 = torch.randn(2, m.irreps_in2.dim) y2 = torch.randn(2, m.irreps_in2.dim) z1 = m(x1 + 1.7 * x2, y1 - y2) z2 = m(x1, y1 - y2) + 1.7 * m(x2, y1 - y2) z3 = m(x1 + 1.7 * x2, y1) - m(x1 + 1.7 * x2, y2) assert (z1 - z2).abs().max() < eps assert (z1 - z3).abs().max() < eps # right z1 = m(x1, y1) z2 = torch.einsum("zi,zij->zj", x1, m.right(y1)) assert (z1 - z2).abs().max() < eps # variance x1 = torch.randn(n, m.irreps_in1.dim) y1 = torch.randn(n, m.irreps_in2.dim) z1 = m(x1, y1).var(0) assert z1.mean().log10().abs() < torch.tensor(tol).log10() # equivariance assert_equivariant(m, irreps_in=[m.irreps_in1, m.irreps_in2], irreps_out=m.irreps_out) if weight: # linear in weights w1 = m.weight.clone().normal_() w2 = m.weight.clone().normal_() z1 = m(x1, y1, weight=w1) + 1.5 * m(x1, y1, weight=w2) z2 = m(x1, y1, weight=w1 + 1.5 * w2) assert (z1 - z2).abs().max() < eps # This is a fairly expensive test, so we don't run too many configs @pytest.mark.parametrize("path_normalization", ["element", "path"]) @pytest.mark.parametrize("l1, p1, l2, p2, lo, po, mode, weight", random_params(n=8)) def test_normalized(l1, p1, l2, p2, lo, po, mode, weight, path_normalization) -> None: if torch.get_default_dtype() != torch.float32: pytest.skip("No reason to run expensive normalization tests again at float64 expense.") # Explicit fixed path weights screw with the output normalization, # so don't use them m = make_tp(l1, p1, l2, p2, lo, po, mode, weight, mul=5, path_weights=False, path_normalization=path_normalization) # normalization # n_weight, n_input has to be decently high to ensure statistical convergence # especially for uvuv assert_normalized(m, n_weight=100, n_input=10_000, atol=0.5) def test_empty() -> None: m = TensorProduct( "0x0e + 1o + 2e", "0e + 1o + 2e", "0x0e + 1o", [ (0, 0, 0, "uvw", True), (1, 1, 0, "uvw", True), ], compile_right=True, ) x1, x2 = m.irreps_in1.randn(4, -1), m.irreps_in2.randn(4, -1) out = m(x1, x2) assert out.shape == (4, m.irreps_out.dim) assert torch.all(out == 0.0) # no instruction leads to the 1o output m.right(x2) @pytest.mark.parametrize("normalization", ["component", "norm"]) @pytest.mark.parametrize( "mode,weighted", [ ("uvw", True), ("uvu", True), ("uvu", False), ("uvv", True), ("uvv", False), ("uuu", True), ("uuu", False), ("uuw", True), ("uuw", False), ], ) def test_specialized_code(normalization, mode, weighted, float_tolerance) -> None: irreps_in1 = Irreps("4x0e + 4x1e + 4x2e") irreps_in2 = Irreps("5x0e + 5x1e + 5x2e") irreps_out = Irreps("6x0e + 6x1e + 6x2e") if mode == "uvu": irreps_out = irreps_in1 elif mode == "uvv": irreps_out = irreps_in2 elif mode == "uuu": irreps_in2 = irreps_in1 irreps_out = irreps_in1 elif mode == "uuw": irreps_in2 = irreps_in1 # When unweighted, uuw is a plain sum over u and requires an output mul of 1 if not weighted: irreps_out = Irreps([(1, ir) for _, ir in irreps_out]) ins = [ (0, 0, 0, mode, weighted, 1.0), (0, 1, 1, mode, weighted, 1.0), (1, 0, 1, mode, weighted, 1.0), (1, 1, 0, mode, weighted, 1.0), (1, 1, 1, mode, weighted, 1.0), (0, 2, 2, mode, weighted, 1.0), (2, 0, 2, mode, weighted, 1.0), (2, 2, 0, mode, weighted, 1.0), (2, 1, 1, mode, weighted, 1.0), ] tp1 = TensorProduct( irreps_in1, irreps_in2, irreps_out, ins, irrep_normalization=normalization, compile_right=True, _specialized_code=False, ) tp2 = TensorProduct( irreps_in1, irreps_in2, irreps_out, ins, irrep_normalization=normalization, compile_right=True, _specialized_code=True, ) with torch.no_grad(): tp2.weight[:] = tp1.weight x = irreps_in1.randn(3, -1) y = irreps_in2.randn(3, -1) assert (tp1(x, y) - tp2(x, y)).abs().max() < float_tolerance assert (tp1.right(y) - tp2.right(y)).abs().max() < float_tolerance def test_empty_irreps() -> None: tp = FullyConnectedTensorProduct("0e + 1e", Irreps([]), "0e + 1e") out = tp(torch.randn(1, 2, 4), torch.randn(2, 1, 0)) assert out.shape == (2, 2, 4) def test_single_out() -> None: tp1 = TensorProduct("5x0e", "5x0e", "5x0e", [(0, 0, 0, "uvw", True, 1.0)]) tp2 = TensorProduct("5x0e", "5x0e", "5x0e + 3x0o", [(0, 0, 0, "uvw", True, 1.0)]) with torch.no_grad(): tp2.weight[:] = tp1.weight x1, x2 = torch.randn(3, 5), torch.randn(3, 5) out1 = tp1(x1, x2) out2 = tp2(x1, x2) assert out1.shape == (3, 5) assert out2.shape == (3, 8) assert torch.allclose(out1, out2[:, :5]) assert torch.all(out2[:, 5:] == 0) def test_empty_inputs() -> None: tp = FullyConnectedTensorProduct("0e + 1e", "0e + 1e", "0e + 1e", compile_right=True) out = tp(torch.randn(2, 1, 0, 1, 4), torch.randn(1, 2, 0, 3, 4)) assert out.shape == (2, 2, 0, 3, 4) out = tp.right(torch.randn(1, 2, 0, 3, 4)) assert out.shape == (1, 2, 0, 3, 4, 4) @pytest.mark.parametrize("l1, p1, l2, p2, lo, po, mode, weight", random_params(n=2)) @pytest.mark.parametrize("special_code", [True, False]) @pytest.mark.parametrize("opt_ein", [True, False]) def test_jit(l1, p1, l2, p2, lo, po, mode, weight, special_code, opt_ein) -> None: """Test the JIT. This test is seperate from test_optimizations to ensure that just jitting a model has minimal error if any. """ orig_tp = make_tp(l1, p1, l2, p2, lo, po, mode, weight, _specialized_code=special_code, _optimize_einsums=opt_ein) opt_tp = assert_auto_jitable(orig_tp) # Confirm equivariance of optimized model assert_equivariant(opt_tp, irreps_in=[orig_tp.irreps_in1, orig_tp.irreps_in2], irreps_out=orig_tp.irreps_out) # Confirm that it gives same results x1 = orig_tp.irreps_in1.randn(2, -1) x2 = orig_tp.irreps_in2.randn(2, -1) # TorchScript should casue very little if any numerical error assert torch.allclose( orig_tp(x1, x2), opt_tp(x1, x2), ) assert torch.allclose( orig_tp.right(x2), opt_tp.right(x2), ) @pytest.mark.parametrize("l1, p1, l2, p2, lo, po, mode, weight", random_params(n=4)) @pytest.mark.parametrize("special_code", [True, False]) @pytest.mark.parametrize("opt_ein", [True, False]) @pytest.mark.parametrize("jit", [True, False]) def test_optimizations(l1, p1, l2, p2, lo, po, mode, weight, special_code, opt_ein, jit, float_tolerance) -> None: orig_tp = make_tp(l1, p1, l2, p2, lo, po, mode, weight, _specialized_code=False, _optimize_einsums=False) opt_tp = make_tp(l1, p1, l2, p2, lo, po, mode, weight, _specialized_code=special_code, _optimize_einsums=opt_ein) # We don't use state_dict here since that contains things like wigners that # can differ between optimized and unoptimized TPs with torch.no_grad(): opt_tp.weight[:] = orig_tp.weight assert opt_tp._specialized_code == special_code assert opt_tp._optimize_einsums == opt_ein if jit: opt_tp = assert_auto_jitable(opt_tp) # Confirm equivariance of optimized model assert_equivariant(opt_tp, irreps_in=[orig_tp.irreps_in1, orig_tp.irreps_in2], irreps_out=orig_tp.irreps_out) # Confirm that it gives same results x1 = orig_tp.irreps_in1.randn(2, -1) x2 = orig_tp.irreps_in2.randn(2, -1) assert torch.allclose( orig_tp(x1, x2), opt_tp(x1, x2), atol=float_tolerance, # numerical optimizations can cause meaningful numerical error by changing operations ) assert torch.allclose(orig_tp.right(x2), opt_tp.right(x2), atol=float_tolerance) # We also test .to(), even if only with a dtype, to ensure that various optimizations still # always store constants in correct ways other_dtype = next(d for d in [torch.float32, torch.float64] if d != torch.get_default_dtype()) x1, x2 = x1.to(other_dtype), x2.to(other_dtype) opt_tp = opt_tp.to(other_dtype) assert opt_tp(x1, x2).dtype == other_dtype def test_input_weights_python() -> None: irreps_in1 = Irreps("1e + 2e + 3x3o") irreps_in2 = Irreps("1e + 2e + 3x3o") irreps_out = Irreps("1e + 2e + 3x3o") # - shared_weights = False - m = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_out, internal_weights=False, shared_weights=False) bdim = random.randint(1, 3) x1 = irreps_in1.randn(bdim, -1) x2 = irreps_in2.randn(bdim, -1) w = [torch.randn((bdim,) + ins.path_shape) for ins in m.instructions if ins.has_weight] m(x1, x2, w) # - shared_weights = True - m = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_out, internal_weights=False, shared_weights=True) bdim = random.randint(1, 3) x1 = irreps_in1.randn(bdim, -1) x2 = irreps_in2.randn(bdim, -1) w = [torch.randn(ins.path_shape) for ins in m.instructions if ins.has_weight] m(x1, x2, w) def test_input_weights_jit() -> None: irreps_in1 = Irreps("1e + 2e + 3x3o") irreps_in2 = Irreps("1e + 2e + 3x3o") irreps_out = Irreps("1e + 2e + 3x3o") # - shared_weights = False - m = FullyConnectedTensorProduct( irreps_in1, irreps_in2, irreps_out, internal_weights=False, shared_weights=False, compile_right=True, ) traced = assert_auto_jitable(m) x1 = irreps_in1.randn(2, -1) x2 = irreps_in2.randn(2, -1) w = torch.randn(2, m.weight_numel) with pytest.raises((RuntimeError, torch.jit.Error)): m(x1, x2) # it should require weights with pytest.raises((RuntimeError, torch.jit.Error)): traced(x1, x2) # it should also require weights with pytest.raises((RuntimeError, torch.jit.Error)): traced(x1, x2, w[0]) # it should reject insufficient weights # Does the trace give right results? assert torch.allclose(m(x1, x2, w), traced(x1, x2, w)) # Confirm that weird batch dimensions give the same results for f in (m, traced): x1 = irreps_in1.randn(2, 1, 4, -1) x2 = irreps_in2.randn(2, 3, 1, -1) w = torch.randn(3, 4, f.weight_numel) assert torch.allclose( f(x1, x2, w).reshape(24, -1), f( x1.expand(2, 3, 4, -1).reshape(24, -1), x2.expand(2, 3, 4, -1).reshape(24, -1), w[None].expand(2, 3, 4, -1).reshape(24, -1), ), ) assert torch.allclose( f.right(x2, w).reshape(24, -1), f.right(x2.expand(2, 3, 4, -1).reshape(24, -1), w[None].expand(2, 3, 4, -1).reshape(24, -1)).reshape(24, -1), ) # - shared_weights = True - m = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_out, internal_weights=False, shared_weights=True) w = torch.randn(m.weight_numel) traced = assert_auto_jitable(m) assert_torch_compile( "inductor", functools.partial( FullyConnectedTensorProduct, irreps_in1, irreps_in2, irreps_out, internal_weights=False, shared_weights=True ), x1, x2, w, ) with pytest.raises((RuntimeError, torch.jit.Error)): m(x1, x2) # it should require weights with pytest.raises((RuntimeError, torch.jit.Error)): traced(x1, x2) # it should also require weights with pytest.raises((RuntimeError, torch.jit.Error)): traced(x1, x2, torch.randn(2, m.weight_numel)) # it should reject too many weights # Does the trace give right results? assert torch.allclose(m(x1, x2, w), traced(x1, x2, w)) def test_weight_view_for_instruction() -> None: irreps_in1 = Irreps("1e + 2e + 3x3o") irreps_in2 = Irreps("1e + 2e + 3x3o") irreps_out = Irreps("1e + 2e + 3x3o") x1 = irreps_in1.randn(2, -1) x2 = irreps_in2.randn(2, -1) m = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_out) # Find all paths to the first output ins_idexes = [i for i, ins in enumerate(m.instructions) if ins.i_out == 0] with torch.no_grad(): for i in ins_idexes: m.weight_view_for_instruction(i).zero_() out = m(x1, x2) assert torch.all(out[:, :1] == 0.0) assert torch.any(out[:, 1:] > 0.0) def test_weight_views() -> None: irreps_in1 = Irreps("1e + 2e + 3x3o") irreps_in2 = Irreps("1e + 2e + 3x3o") irreps_out = Irreps("1e + 2e + 3x3o") batchdim = 3 x1 = irreps_in1.randn(batchdim, -1) x2 = irreps_in2.randn(batchdim, -1) # shared weights m = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_out) with torch.no_grad(): for w in m.weight_views(): w.zero_() assert torch.all(m(x1, x2) == 0.0) # unshared weights m = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_out, shared_weights=False) weights = torch.randn(batchdim, m.weight_numel) with torch.no_grad(): for w in m.weight_views(weights): w.zero_() assert torch.all(m(x1, x2, weights) == 0.0) @pytest.mark.parametrize("l1, p1, l2, p2, lo, po, mode, weight", random_params(n=1)) def test_deepcopy(l1, p1, l2, p2, lo, po, mode, weight) -> None: tp = make_tp(l1, p1, l2, p2, lo, po, mode, weight) assert_auto_jitable(tp) x1 = torch.randn(2, tp.irreps_in1.dim) x2 = torch.randn(2, tp.irreps_in2.dim) res1 = tp(x1, x2) tp_copy = copy.deepcopy(tp) res2 = tp_copy(x1, x2) assert torch.allclose(res1, res2) @pytest.mark.parametrize("l1, p1, l2, p2, lo, po, mode, weight", random_params(n=1)) def test_save(l1, p1, l2, p2, lo, po, mode, weight) -> None: tp = make_tp(l1, p1, l2, p2, lo, po, mode, weight) assert_auto_jitable(tp) # Saved TP with tempfile.NamedTemporaryFile(suffix=".pth") as tmp: torch.save(tp.state_dict(), tmp.name) tp2 = make_tp(l1, p1, l2, p2, lo, po, mode, weight) tp2.load_state_dict(torch.load(tmp.name, weights_only=False)) # JITed, saved TP with tempfile.NamedTemporaryFile(suffix=".pth") as tmp: tp_jit = assert_auto_jitable(tp) tp_jit.save(tmp.name) tp3 = torch.jit.load(tmp.name) # Double-saved TP with tempfile.NamedTemporaryFile(suffix=".pth") as tmp: torch.save(tp2.state_dict(), tmp.name) tp4 = make_tp(l1, p1, l2, p2, lo, po, mode, weight) tp4.load_state_dict(torch.load(tmp.name, weights_only=False)) x1 = torch.randn(2, tp.irreps_in1.dim) x2 = torch.randn(2, tp.irreps_in2.dim) res1 = tp(x1, x2) res2 = tp2(x1, x2) res3 = tp3(x1, x2) res4 = tp4(x1, x2) assert torch.allclose(res1, res2) assert torch.allclose(res1, res3) assert torch.allclose(res1, res4) def test_triu_mode() -> None: tp = TensorProduct("10x0e", "10x0e", "45x0e", [(0, 0, 0, "uvu None: assert torch.allclose(o3.wigner_3j(1, 2, 3), o3.wigner_3j(1, 3, 2).transpose(1, 2)) assert torch.allclose(o3.wigner_3j(1, 2, 3), o3.wigner_3j(2, 1, 3).transpose(0, 1)) assert torch.allclose(o3.wigner_3j(1, 2, 3), o3.wigner_3j(3, 2, 1).transpose(0, 2)) assert torch.allclose(o3.wigner_3j(1, 2, 3), o3.wigner_3j(3, 1, 2).transpose(0, 1).transpose(1, 2)) assert torch.allclose(o3.wigner_3j(1, 2, 3), o3.wigner_3j(2, 3, 1).transpose(0, 2).transpose(1, 2)) @pytest.mark.parametrize("l1,l2,l3", [(1, 2, 3), (2, 3, 4), (3, 4, 5), (1, 1, 1), (1, 1, 0), (1, 0, 1), (0, 1, 1), (2, 2, 2)]) def test_wigner_3j(l1, l2, l3, float_tolerance) -> None: abc = o3.rand_angles(10) C = o3.wigner_3j(l1, l2, l3) D1 = o3.Irrep(l1, 1).D_from_angles(*abc) D2 = o3.Irrep(l2, 1).D_from_angles(*abc) D3 = o3.Irrep(l3, 1).D_from_angles(*abc) C2 = torch.einsum("ijk,zil,zjm,zkn->zlmn", C, D1, D2, D3) assert (C - C2).abs().max() < float_tolerance def test_cartesian(float_tolerance) -> None: abc = o3.rand_angles(10) R = o3.angles_to_matrix(*abc) D = o3.wigner_D(1, *abc) assert (R - D).abs().max() < float_tolerance def commutator(A, B): return A @ B - B @ A @pytest.mark.parametrize("j", [0, 1 / 2, 1, 3 / 2, 2, 5 / 2]) def test_su2_algebra(j, float_tolerance) -> None: X = o3.su2_generators(j) assert torch.allclose(commutator(X[0], X[1]), X[2], atol=float_tolerance) assert torch.allclose(commutator(X[1], X[2]), X[0], atol=float_tolerance) e3nn-0.6.0/tests/util/000077500000000000000000000000001514371756200145125ustar00rootroot00000000000000e3nn-0.6.0/tests/util/test_jit.py000066400000000000000000000105411514371756200167120ustar00rootroot00000000000000import pytest import warnings import torch from e3nn.o3 import Linear, Irreps from e3nn.nn import FullyConnectedNet from e3nn.util.jit import script, trace_module, compile_mode, compile from e3nn.util.test import assert_equivariant, assert_auto_jitable def test_submod_tracing() -> None: """Check that tracing actually occurs""" @compile_mode("trace") class BadTrace(torch.nn.Module): def forward(self, x): if x.shape[0] == 7: return x.new_ones(8) else: return x # This class has no irreps_in, so we need this to allow trace compilation def make_tracing_input(self): return {"forward": torch.randn(8, 3)} class ParentMod(torch.nn.Module): def __init__(self) -> None: super().__init__() self.child = BadTrace() def forward(self, x): return torch.as_tensor(0.5585) * self.child(x) parent = ParentMod() with pytest.raises(Exception): with warnings.catch_warnings(): warnings.filterwarnings("error", category=torch.jit.TracerWarning) script(parent) def test_submod_scripting() -> None: """Check that scripting actually occurs""" @compile_mode("script") class ScriptSubmod(torch.nn.Module): def forward(self, x): if x.shape[0] == 7: return x.new_zeros(8) else: return x class ParentMod(torch.nn.Module): def __init__(self) -> None: super().__init__() self.child = ScriptSubmod() def forward(self, x): return self.child(x) parent = ParentMod() assert parent(torch.randn(7, 4)).shape == (8,) parent_trace = trace_module(parent, inputs={"forward": (torch.randn(7, 4),)}) # get the conditional behaviour # Does it get the behaviour it was traced for? assert parent_trace(torch.randn(7, 4)).shape == (8,) # Does it get the conditional that should have been scripted? x = torch.randn(5, 7) assert torch.allclose(parent(x), x) assert torch.allclose(parent_trace(x), x) def test_compilation() -> None: class Supermod(torch.nn.Module): def forward(self, x): return x * 2.0 @compile_mode("trace") class ChildMod(Supermod): def forward(self, x): return super().forward(x) * 3.0 def _make_tracing_inputs(self, n: int): return [{"forward": (torch.randn(2, 3),)} for _ in range(n)] # This module can't be compiled directly by TorchScript, since ChildMod is a subclass and calls super() in forward() class ContainerMod(torch.nn.Module): def __init__(self) -> None: super().__init__() self.submod = ChildMod() self.alpha = torch.randn(1).squeeze() def forward(self, x): return self.submod(x) + self.alpha * self.submod(x) mod = ContainerMod() # Try and xfail with torch.jit.script with pytest.raises((RuntimeError, torch.jit.Error)): mod_script = torch.jit.script(mod) # Compile with our compiler mod_script = script(mod) x = torch.randn(3, 2) assert torch.allclose(mod(x), mod_script(x)) def test_equivariant() -> None: # Confirm that a compiled tensorproduct is still equivariant irreps_in = Irreps("1e + 2e + 3x3o") irreps_out = Irreps("1e + 2e + 3x3o") mod = Linear(irreps_in, irreps_out) mod_script = compile(mod) assert_equivariant( mod_script, # we provide explicit irreps because infering on a script module is not reliable irreps_in=irreps_in, irreps_out=irreps_out, ) def test_unsupported() -> None: @compile_mode("unsupported") class ChildMod(torch.nn.Module): pass class Supermod(torch.nn.Module): def __init__(self) -> None: super().__init__() self.child = ChildMod() mod = Supermod() with pytest.raises(NotImplementedError): mod = script(mod) def test_trace_dtypes() -> None: # FullyConnectedNet is traced fc = FullyConnectedNet([8, 16, 8]) # compile in a dtype other than the default target_dtype = {torch.float32: torch.float64, torch.float64: torch.float32}[torch.get_default_dtype()] fc = fc.to(dtype=target_dtype) for weight in fc.parameters(): assert weight.dtype == target_dtype assert_auto_jitable(fc) e3nn-0.6.0/tests/util/test_test.py000066400000000000000000000033701514371756200171050ustar00rootroot00000000000000import pytest import torch from e3nn import o3 from e3nn.util.jit import compile_mode from e3nn.util.test import assert_equivariant, assert_auto_jitable, assert_normalized, random_irreps def test_assert_equivariant() -> None: def not_equivariant(x1, x2): return x1 * x2 not_equivariant.irreps_in1 = o3.Irreps("2x0e + 1x1e + 3x2o + 1x4e") not_equivariant.irreps_in2 = o3.Irreps("2x0o + 3x0o + 3x2e + 1x4o") not_equivariant.irreps_out = o3.Irreps("1x1e + 2x0o + 3x2e + 1x4o") assert not_equivariant.irreps_in1.dim == not_equivariant.irreps_in2.dim assert not_equivariant.irreps_in1.dim == not_equivariant.irreps_out.dim with pytest.raises(AssertionError): assert_equivariant(not_equivariant) def test_jit_trace() -> None: @compile_mode("trace") class NotTracable(torch.nn.Module): def forward(self, param): if param.shape[0] == 7: return torch.ones(8) else: return torch.randn(8, 3) not_tracable = NotTracable() not_tracable.irreps_in = o3.Irreps("2x0e") not_tracable.irreps_out = o3.Irreps("1x1o") # TorchScript returns some weird exceptions... with pytest.raises(Exception): assert_auto_jitable(not_tracable) def test_bad_normalize() -> None: def not_normal(x1) -> float: return 870.0 * x1.square().relu() not_normal.irreps_in = random_irreps(clean=True, allow_empty=False) not_normal.irreps_out = not_normal.irreps_in with pytest.raises(AssertionError): assert_normalized(not_normal) def test_normalized_ident() -> None: def ident(x1): return x1 ident.irreps_in = random_irreps(clean=True, allow_empty=False) ident.irreps_out = ident.irreps_in assert_normalized(ident)