pax_global_header00006660000000000000000000000064151524077370014525gustar00rootroot0000000000000052 comment=db56c466d69decb5f7f2311adf9e60f132f076ca tee-ar-ex-trx-python-a304ac2/000077500000000000000000000000001515240773700160455ustar00rootroot00000000000000tee-ar-ex-trx-python-a304ac2/.flake8000066400000000000000000000002121515240773700172130ustar00rootroot00000000000000[flake8] max-line-length = 88 max-complexity = 10 exclude = .git, __pycache__, .tox, .eggs, *.egg, build, disttee-ar-ex-trx-python-a304ac2/.github/000077500000000000000000000000001515240773700174055ustar00rootroot00000000000000tee-ar-ex-trx-python-a304ac2/.github/workflows/000077500000000000000000000000001515240773700214425ustar00rootroot00000000000000tee-ar-ex-trx-python-a304ac2/.github/workflows/codeformat.yml000066400000000000000000000007421515240773700243130ustar00rootroot00000000000000name: Code Format on: push: branches: [master] pull_request: branches: [master] permissions: contents: read jobs: pre-commit: name: Pre-commit checks runs-on: ubuntu-latest steps: - name: Check out repository uses: actions/checkout@v4 - name: Set up Python uses: actions/setup-python@v5 with: python-version: "3.12" - name: Install and run pre-commit hooks uses: pre-commit/action@v3.0.1 tee-ar-ex-trx-python-a304ac2/.github/workflows/coverage.yml000066400000000000000000000020641515240773700237620ustar00rootroot00000000000000name: Coverage on: push: branches: [master] pull_request: branches: [master] permissions: contents: read jobs: coverage: name: Code Coverage runs-on: ubuntu-latest steps: - name: Check out repository uses: actions/checkout@v4 with: fetch-depth: 0 - name: Set up Python uses: actions/setup-python@v5 with: python-version: "3.12" cache: pip cache-dependency-path: pyproject.toml - name: Install dependencies run: | python -m pip install --upgrade pip python -m pip install -e .[test] - name: Run tests with coverage run: | pytest trx/tests --cov=trx --cov-report=xml --cov-report=term-missing - name: Upload coverage to Codecov uses: codecov/codecov-action@v4 with: files: ./coverage.xml flags: unittests name: codecov-trx fail_ci_if_error: false verbose: true env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} tee-ar-ex-trx-python-a304ac2/.github/workflows/docbuild.yml000066400000000000000000000101621515240773700237520ustar00rootroot00000000000000name: Documentation build on: push: branches: [ master ] tags: - '*' pull_request: branches: [ master ] permissions: contents: write jobs: build: runs-on: ubuntu-latest strategy: fail-fast: false matrix: python-version: ["3.13"] steps: - uses: actions/checkout@v4 with: fetch-depth: 0 # Fetch all history and tags for setuptools_scm - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install run: | python -m pip install cython python -m pip install --upgrade pip pip install .[doc] - name: Build docs run: | cd docs SPHINXOPTS="-W" make html - name: Upload docs uses: actions/upload-artifact@v4 with: name: docs path: docs/_build/html deploy-dev: needs: build runs-on: ubuntu-latest if: github.event_name == 'push' && github.ref == 'refs/heads/master' && github.repository == 'tee-ar-ex/trx-python' steps: - uses: actions/checkout@v4 - uses: actions/download-artifact@v4 with: name: docs path: docs/_build/html - name: Publish dev docs to Github Pages uses: JamesIves/github-pages-deploy-action@v4 with: branch: gh-pages folder: docs/_build/html target-folder: dev deploy-release: needs: build runs-on: ubuntu-latest if: startsWith(github.ref, 'refs/tags/') && github.repository == 'tee-ar-ex/trx-python' steps: - uses: actions/checkout@v4 - uses: actions/download-artifact@v4 with: name: docs path: docs/_build/html - name: Get version from tag id: get_version run: echo "VERSION=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT - name: Fetch existing switcher.json from gh-pages run: | curl -sSL https://tee-ar-ex.github.io/trx-python/switcher.json -o switcher.json || echo '[]' > switcher.json - name: Update switcher.json with new version run: python tools/update_switcher.py switcher.json --version ${{ steps.get_version.outputs.VERSION }} - name: Create root files (redirect + switcher.json) run: | mkdir -p root_files cp switcher.json root_files/ cat > root_files/index.html << 'EOF' trx-python - TRX File Format for Tractography

trx-python Documentation

Python implementation of the TRX file format for tractography data.

If you are not redirected automatically, visit the stable documentation.

EOF - name: Publish root files (redirect + switcher.json) uses: JamesIves/github-pages-deploy-action@v4 with: branch: gh-pages folder: root_files target-folder: . clean: false - name: Publish release docs to Github Pages uses: JamesIves/github-pages-deploy-action@v4 with: branch: gh-pages folder: docs/_build/html target-folder: ${{ steps.get_version.outputs.VERSION }} - name: Publish stable docs to Github Pages uses: JamesIves/github-pages-deploy-action@v4 with: branch: gh-pages folder: docs/_build/html target-folder: stable tee-ar-ex-trx-python-a304ac2/.github/workflows/publish-to-test-pypi.yml000066400000000000000000000046471515240773700262220ustar00rootroot00000000000000name: Publish to TestPyPI and PyPI on: push: branches: - master tags: - "*" concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true jobs: build-sdist: name: Build sdist runs-on: ubuntu-latest steps: - uses: actions/checkout@v6 with: fetch-depth: 0 - name: Build sdist run: pipx run build - uses: actions/upload-artifact@v6 with: name: source-dist path: ./dist/* test-sdist: name: Test sdist needs: [build-sdist] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@v7 with: name: source-dist path: ./dist - uses: actions/setup-python@v6 with: python-version: "3.11" - name: Display Python version run: python -c "import sys; print(sys.version)" - name: Install sdist without optional dependencies run: pip install dist/*.tar.gz - run: python -c 'import trx; print(trx.__version__)' - name: Install pytest run: pip install pytest psutil pytest-console-scripts pytest-cov - name: Run tests run: pytest -v --pyargs trx pre-publish: runs-on: ubuntu-latest needs: [test-sdist] steps: - uses: actions/download-artifact@v7 with: path: dist/ pattern: '*-dist' merge-multiple: true - run: ls -lR dist/ - run: pipx run twine check dist/* test-pypi-publish: runs-on: ubuntu-latest needs: [pre-publish] steps: - uses: actions/download-artifact@v7 with: path: dist/ pattern: '*-dist' merge-multiple: true - run: ls -lR dist/ - uses: pypa/gh-action-pypi-publish@release/v1 with: user: __token__ password: ${{ secrets.TEST_PYPI_API_TOKEN }} repository-url: https://test.pypi.org/legacy/ verbose: true pypi-publish: runs-on: ubuntu-latest environment: "Package deployment" needs: [pre-publish] if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') steps: - uses: actions/download-artifact@v7 with: path: dist/ pattern: '*-dist' merge-multiple: true - run: ls -lR dist/ - uses: pypa/gh-action-pypi-publish@release/v1 with: user: __token__ password: ${{ secrets.PYPI_API_TOKEN }} tee-ar-ex-trx-python-a304ac2/.github/workflows/test.yml000066400000000000000000000017451515240773700231530ustar00rootroot00000000000000name: Tests on: push: branches: [master] pull_request: branches: [master] permissions: contents: read jobs: test: name: Python ${{ matrix.python-version }} • ${{ matrix.os }} runs-on: ${{ matrix.os }} strategy: fail-fast: false matrix: python-version: ["3.11", "3.12", "3.13"] os: [ubuntu-latest, windows-latest, macos-latest] steps: - uses: actions/checkout@v4 with: fetch-depth: 0 # needed for setuptools_scm version detection - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} cache: pip cache-dependency-path: pyproject.toml - name: Install dependencies run: | python -m pip install --upgrade pip python -m pip install -e .[dev,test] python -c "import trx; print(trx.__version__)" - name: Test run: spin test tee-ar-ex-trx-python-a304ac2/.gitignore000066400000000000000000000036161515240773700200430ustar00rootroot00000000000000trx/version.py .DS_Store *_version.py # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ pip-wheel-metadata/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py, cover .hypothesis/ .pytest_cache/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. # Pipfile.lock # PEP 582; used by e.g. github.com/David-OConnor/pyflow __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ .vscode/ tmp/ auto_examples/ CLAUDE.md claude.md agents.md AGENTS.md sg_execution_times.rst tee-ar-ex-trx-python-a304ac2/.pre-commit-config.yaml000066400000000000000000000012411515240773700223240ustar00rootroot00000000000000default_language_version: python: python3 repos: - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.14.14 hooks: # Run the linter - id: ruff args: [ --fix ] # Run the formatter - id: ruff-format - repo: https://github.com/codespell-project/codespell rev: v2.4.1 hooks: - id: codespell args: [--skip, "pyproject.toml,docs/_build/*,*.egg-info"] additional_dependencies: - tomli - repo: https://github.com/numpy/numpydoc rev: v1.10.0 hooks: - id: numpydoc-validation name: numpydoc validate description: Run numpydoc validation across trx package tee-ar-ex-trx-python-a304ac2/.spin/000077500000000000000000000000001515240773700170745ustar00rootroot00000000000000tee-ar-ex-trx-python-a304ac2/.spin/__init__.py000066400000000000000000000000351515240773700212030ustar00rootroot00000000000000"""Spin commands package.""" tee-ar-ex-trx-python-a304ac2/.spin/cmds.py000066400000000000000000000206131515240773700203760ustar00rootroot00000000000000"""Custom spin commands for trx-python development.""" import glob import os import shutil import subprocess import sys import tempfile import click UPSTREAM_URL = "https://github.com/tee-ar-ex/trx-python.git" UPSTREAM_NAME = "upstream" def run(cmd, check=True, capture=True): """Run a shell command. Parameters ---------- cmd : list of str Command and arguments to execute. check : bool, optional If True, check the return code and report errors. capture : bool, optional If True, capture stdout and stderr. Returns ------- str or int or None Captured stdout string, return code, or None on error. """ result = subprocess.run(cmd, capture_output=capture, text=True, check=False) if check and result.returncode != 0: if capture: click.echo(f"Error: {result.stderr}", err=True) return None return result.stdout.strip() if capture else result.returncode def get_remotes(): """Get dict of remote names to URLs. Returns ------- dict Mapping of remote names to their fetch URLs. """ output = run(["git", "remote", "-v"]) if not output: return {} remotes = {} for line in output.split("\n"): if "(fetch)" in line: parts = line.split() remotes[parts[0]] = parts[1] return remotes @click.command() def setup(): """Set up development environment (fetch tags from upstream). This command configures your fork for development by: 1. Adding the upstream remote if not present 2. Fetching tags from upstream (required for correct version detection) Run this once after cloning your fork. """ click.echo("Setting up trx-python development environment...\n") # Check if in git repo if run(["git", "rev-parse", "--git-dir"], check=False) is None: click.echo("Error: Not in a git repository", err=True) sys.exit(1) # Check/add upstream remote remotes = get_remotes() upstream_remote = None for name, url in remotes.items(): if UPSTREAM_URL.rstrip(".git") in url.rstrip(".git"): upstream_remote = name click.echo(f"Found upstream remote: {name}") break if upstream_remote is None: click.echo(f"Adding upstream remote: {UPSTREAM_URL}") run(["git", "remote", "add", UPSTREAM_NAME, UPSTREAM_URL]) upstream_remote = UPSTREAM_NAME # Fetch tags click.echo(f"\nFetching tags from {upstream_remote}...") run(["git", "fetch", upstream_remote, "--tags"], capture=False) # Verify version click.echo("\nVerifying version detection...") try: from setuptools_scm import get_version version = get_version() click.echo(f"Detected version: {version}") # Check for suspicious version patterns if version.startswith("0.0"): click.echo( "\nWarning: Version starts with 0.0 - tags may not be fetched.", err=True, ) sys.exit(1) except ImportError: click.echo("Note: Install setuptools_scm to verify version detection") click.echo("\nSetup complete! You can now run:") click.echo(" spin install # Install in development mode") click.echo(" spin test # Run tests") @click.command() @click.option( "-m", "--match", "pattern", default=None, help="Only run tests matching this pattern (passed to pytest -k)", ) @click.option("-v", "--verbose", is_flag=True, default=False, help="Verbose output") @click.argument("pytest_args", nargs=-1) def test(pattern, verbose, pytest_args): """Run tests using pytest. Additional arguments are passed directly to pytest. Parameters ---------- pattern : str or None Only run tests matching this pattern (passed to pytest -k). verbose : bool If True, enable verbose output. pytest_args : tuple Additional arguments passed directly to pytest. """ cmd = ["pytest", "trx/tests"] if pattern: cmd.extend(["-k", pattern]) if verbose: cmd.append("-v") if pytest_args: cmd.extend(pytest_args) click.echo(f"Running: {' '.join(cmd)}\n") sys.exit(run(cmd, capture=False, check=False)) @click.command() @click.option( "--fix", is_flag=True, default=False, help="Automatically fix issues where possible" ) def lint(fix): """Run linting checks using ruff and codespell. Parameters ---------- fix : bool If True, automatically fix issues where possible. """ click.echo("Running ruff linter...") cmd = ["ruff", "check", "."] if fix: cmd.append("--fix") result = run(cmd, capture=False, check=False) if result != 0: click.echo("\nLinting issues found!", err=True) sys.exit(1) click.echo("\nRunning ruff formatter check...") cmd_format = ["ruff", "format", "--check", "."] result = run(cmd_format, capture=False, check=False) if result != 0: click.echo("\nFormatting issues found!", err=True) sys.exit(1) click.echo("\nRunning codespell...") cmd_spell = [ "codespell", "--skip", "*.pyc,.git,pyproject.toml,./docs/_build/*,*.egg-info,./build/*,./dist/*,./tmp/*", "trx", "docs/source", ".spin", ] result = run(cmd_spell, capture=False, check=False) if result != 0: click.echo("\nSpelling issues found!", err=True) sys.exit(1) click.echo("\nAll checks passed!") @click.command() @click.option( "--clean", is_flag=True, default=False, help="Clean build directory before building" ) @click.option( "--open", "open_browser", is_flag=True, default=False, help="Open documentation in browser after building", ) def docs(clean, open_browser): """Build documentation using Sphinx. Parameters ---------- clean : bool If True, clean build directory before building. open_browser : bool If True, open documentation in browser after building. """ import os docs_dir = "docs" if clean: click.echo("Cleaning build directory...") build_dir = os.path.join(docs_dir, "_build") if os.path.exists(build_dir): shutil.rmtree(build_dir) # Clean sphinx-gallery generated files gallery_dir = os.path.join(docs_dir, "source", "auto_examples") if os.path.exists(gallery_dir): click.echo("Cleaning sphinx-gallery generated files...") shutil.rmtree(gallery_dir) # Clean sphinx-gallery execution times file sg_times = os.path.join(docs_dir, "source", "sg_execution_times.rst") if os.path.exists(sg_times): os.remove(sg_times) click.echo("Building documentation...") cmd = ["make", "-C", docs_dir, "html"] result = run(cmd, capture=False, check=False) if result == 0: index_path = os.path.abspath( os.path.join(docs_dir, "_build", "html", "index.html") ) click.echo("\nDocs built successfully!") click.echo(f"Open: {index_path}") if open_browser: import webbrowser webbrowser.open(f"file://{index_path}") sys.exit(result) @click.command() def clean(): # noqa: C901 """Clean up temporary files and build artifacts.""" click.echo("Cleaning up temporary files...") # Clean TRX temp directory trx_tmp_dir = os.getenv("TRX_TMPDIR", tempfile.gettempdir()) if os.path.exists(trx_tmp_dir): temp_files = glob.glob(os.path.join(trx_tmp_dir, "trx_*")) for temp_dir in temp_files: if os.path.isdir(temp_dir): click.echo(f"Removing temporary directory: {temp_dir}") shutil.rmtree(temp_dir) # Clean build artifacts for build_pattern in ["build", "dist", "*.egg-info"]: for path in glob.glob(build_pattern): if os.path.isdir(path): click.echo(f"Removing build directory: {path}") shutil.rmtree(path) elif os.path.isfile(path): click.echo(f"Removing build file: {path}") os.remove(path) # Clean Python cache for cache_dir in ["**/__pycache__", "**/.pytest_cache"]: for path in glob.glob(cache_dir, recursive=True): if os.path.isdir(path): click.echo(f"Removing cache directory: {path}") shutil.rmtree(path) click.echo("Cleanup complete!") tee-ar-ex-trx-python-a304ac2/.zenodo.json000066400000000000000000000053351515240773700203220ustar00rootroot00000000000000{ "upload_type": "software", "title": "trx-python: A Python implementation of the TRX tractography file format", "description": "trx-python is a Python implementation of the TRX file format for brain tractography data. TRX is a community-driven tractography file format designed to facilitate dataset exchange, interoperability, and state-of-the-art analyses, acting as a replacement for the myriad of existing formats. The library uses memory-mapped files to efficiently handle large neuroimaging datasets and provides a unified I/O interface supporting TRX, TRK, TCK, FIB, VTK, and DPY formats.", "access_right": "open", "license": "bsd-3-clause", "language": "eng", "keywords": [ "tractography", "neuroimaging", "diffusion MRI", "white matter", "brain connectivity", "streamlines", "connectome", "file format", "memory-mapped", "nibabel", "dipy" ], "creators": [ { "name": "Rheault, François", "affiliation": "Université de Sherbrooke", "orcid": "0000-0002-0097-8004" } ], "contributors": [ { "name": "Rokem, Ariel", "affiliation": "University of Washington", "orcid": "0000-0003-0679-1985", "type": "ProjectMember" }, { "name": "Koudoro, Serge", "affiliation": "Indiana University Bloomington", "orcid": "0000-0002-9819-9884", "type": "ProjectMember" }, { "name": "Hayot-Sasson, Valérie", "affiliation": "ÉTS Montréal", "orcid": "0000-0002-4830-4535", "type": "ProjectMember" }, { "name": "Sólon Heinsfeld, Anibal", "affiliation": "University of Texas at Austin", "orcid": "0000-0002-2050-0614", "type": "ProjectMember" }, { "name": "Beasley, Benjamin A.", "type": "Other" }, { "name": "Mollier, Étienne", "affiliation": "Debian Project", "type": "Other" } ], "grants": [ { "id": "10.13039/100000025::1R01MH126699" } ], "related_identifiers": [ { "identifier": "https://github.com/tee-ar-ex/trx-python", "relation": "isSupplementTo", "resource_type": "software" }, { "identifier": "https://tee-ar-ex.github.io/trx-python/", "relation": "isDocumentedBy", "resource_type": "publication-other" } ], "communities": [ { "identifier": "neuroscience" } ] } tee-ar-ex-trx-python-a304ac2/LICENSE000066400000000000000000000024421515240773700170540ustar00rootroot00000000000000Copyright (c) 2021 -- , Francois Rheault and others All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. tee-ar-ex-trx-python-a304ac2/README.md000066400000000000000000000077421515240773700173360ustar00rootroot00000000000000# trx-python [![Tests](https://github.com/tee-ar-ex/trx-python/actions/workflows/test.yml/badge.svg)](https://github.com/tee-ar-ex/trx-python/actions/workflows/test.yml) [![Code Format](https://github.com/tee-ar-ex/trx-python/actions/workflows/codeformat.yml/badge.svg)](https://github.com/tee-ar-ex/trx-python/actions/workflows/codeformat.yml) [![codecov](https://codecov.io/gh/tee-ar-ex/trx-python/branch/master/graph/badge.svg)](https://codecov.io/gh/tee-ar-ex/trx-python) [![PyPI version](https://badge.fury.io/py/trx-python.svg)](https://badge.fury.io/py/trx-python) A Python implementation of the TRX file format for tractography data. For details, please visit the [documentation](https://tee-ar-ex.github.io/trx-python/). ## Installation ### From PyPI ```bash pip install trx-python ``` ### From Source ```bash git clone https://github.com/tee-ar-ex/trx-python.git cd trx-python pip install . ``` ## Quick Start ### Loading and Saving Tractograms ```python from trx.io import load, save # Load a tractogram (supports .trx, .trk, .tck, .vtk, .fib, .dpy) trx = load("tractogram.trx") # Save to a different format save(trx, "output.trk") ``` ### Command-Line Interface TRX-Python provides a unified CLI (`trx`) for common operations: ```bash # Show all available commands trx --help # Display TRX file information (header, groups, data keys, archive contents) trx info data.trx # Convert between formats trx convert input.trk output.trx # Concatenate tractograms trx concatenate tract1.trx tract2.trx merged.trx # Validate a TRX file trx validate data.trx ``` Individual commands are also available for backward compatibility: ```bash trx_info data.trx trx_convert_tractogram input.trk output.trx trx_concatenate_tractograms tract1.trx tract2.trx merged.trx trx_validate data.trx ``` ## Development We use [spin](https://github.com/scientific-python/spin) for development workflow. ### First-Time Setup ```bash # Clone the repository (or your fork) git clone https://github.com/tee-ar-ex/trx-python.git cd trx-python # Install with all dependencies pip install -e ".[all]" # Set up development environment (fetches upstream tags) spin setup ``` ### Common Commands ```bash spin setup # Set up development environment spin install # Install in editable mode spin test # Run all tests spin test -m memmap # Run tests matching pattern spin lint # Run linting (ruff) spin lint --fix # Auto-fix linting issues spin docs # Build documentation spin clean # Clean temporary files ``` Run `spin` without arguments to see all available commands. ### Code Quality We use [ruff](https://docs.astral.sh/ruff/) for linting and formatting: ```bash # Check for issues spin lint # Auto-fix issues spin lint --fix # Format code ruff format . ``` ### Pre-commit Hooks ```bash # Install hooks pre-commit install # Run on all files pre-commit run --all-files ``` ## Temporary Directory The TRX file format uses memory-mapped files to limit RAM usage. When dealing with large files, several gigabytes may be required on disk. By default, temporary files are stored in: - Linux/macOS: `/tmp` - Windows: `C:\WINDOWS\Temp` To change the directory: ```bash # Use a specific directory (must exist) export TRX_TMPDIR=/path/to/tmp # Use current working directory export TRX_TMPDIR=use_working_dir ``` Temporary folders are automatically cleaned, but if the code crashes unexpectedly, ensure folders are deleted manually. ## Troubleshooting If the `trx` command is not working as expected, run `trx --debug` to print diagnostic information about the Python interpreter, package location, and whether all required and optional dependencies are installed. ## Documentation Full documentation is available at https://tee-ar-ex.github.io/trx-python/ To build locally: ```bash spin docs --open ``` ## Contributing We welcome contributions! Please see our [Contributing Guide](https://tee-ar-ex.github.io/trx-python/contributing.html) for details. ## License BSD License - see [LICENSE](LICENSE) for details. tee-ar-ex-trx-python-a304ac2/docs/000077500000000000000000000000001515240773700167755ustar00rootroot00000000000000tee-ar-ex-trx-python-a304ac2/docs/Makefile000066400000000000000000000011771515240773700204430ustar00rootroot00000000000000# Minimal makefile for Sphinx documentation # # You can set these variables from the command line, and also # from the environment for the first two. SPHINXOPTS ?= SPHINXBUILD ?= sphinx-build SOURCEDIR = source BUILDDIR = _build # Put it first so that "make" without argument is like "make help". help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) .PHONY: help Makefile # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) tee-ar-ex-trx-python-a304ac2/docs/_static/000077500000000000000000000000001515240773700204235ustar00rootroot00000000000000tee-ar-ex-trx-python-a304ac2/docs/_static/switcher.json000066400000000000000000000004351515240773700231500ustar00rootroot00000000000000[ { "name": "dev", "version": "dev", "url": "https://tee-ar-ex.github.io/trx-python/dev/" }, { "name": "stable", "version": "stable", "url": "https://tee-ar-ex.github.io/trx-python/stable/", "preferred": true } ] tee-ar-ex-trx-python-a304ac2/docs/_static/trx_logo.png000066400000000000000000001140361515240773700227730ustar00rootroot00000000000000PNG  IHDR":9IDATx]Wu/"zd˖d7`[q( 8$@ $! ixy{!0L T 7\d[Vh[9{>wdC$c)?y<9e*[w-__=~_7/ B?0}@_O$D/̌}t1%JJ @QK `N%G@_O3QVLnMQO1~˭d/[ nIł[ Cbn2"+H[ 'ZCFIo'Y[/چH2H,DE1Njxj>ݐA#Ƞ @hИ$#eJȉ $+X%, "to_~)X Mj,rID;]k,ND,0;Ua9zIgtP;̄Qgfo+ fXN=Y'h-&&o EMiO<:5cL=i4rUh jm8⴦F1`0[@Fh/W(B= %&fdK :3qI&|^6:yISťBI7N Sl†!n n0t-3 Nb`@Id Say1~]̓*2 9crG}sj8N䅞:~˙@_+VXYde! P=v xQT|g CEA `S$Z2 c_,E9Q(X+PDc ց/qR' ӪNLCO>eVYRXU)6[RH@F&Ke庣`2?G%^E)Hid؜׵ k_lX, q4Isw 'F^aԆ!2Il {DEz@z:H:$M%IOHѢ2$0I+ e+ N#2#tpr󝍯~\:' 9#}f7j;iG5^{y2 (>'ٯVۈ:v= hj4&n!H1ji"aY'[NUH si2hX`!;s־S֫0=-fZcIʉ ;.iڅ9VM-%Q >s5E,) [FcMK6P|]cb! '(3KW^ʓhV +(mό%YIt,yrq1^R^ ;šL_xJηt^9dD*?+??}RtT*m|%sm|%z#'P&ʒBki&rDQNiYʲw?6 jD˧3V-[zltΜ!Ff1LKZDd[6ܟ$K%^P| Va}䖻bR&t(QcUj!&&P^xB tӞ'Ln#qY0 LxX?IjQ#_6r"F0v{,6δX6FIM5ash O G[qt̳Lǒ1Uۺ sP.@t;{JY"% b~wŋ[p>'aw %D 6$j7DR{hќ\b mSSr{OIDB+5#!-E7~rZ( ]x)G1] }R v D[o];guҷr{̻տ6,'8T0TœQ^dOD>@_$<]FNMuRgiVxSzbMXZAU (ڌXt疩i+tr ]\fLBp)r:A5Zlrvso9}%0*U[̹9Ӏ9RlHe6YV3L X!X4W<}tW1paNT)vu~'G&&cB$NÚ'|_SN^@X(-%ǿnV<:# 9SMxamqsmIYr@mK}.끌.fX!i!5o>魿o84?ˮYg4zvW<]I/'DmH= >}WDpFuL0Sym `fChT,{PNl}}3^ !+cb4(P B~rp##pNOdA7Pj՝ͽÍD7nD xꚫr}mPԣ1L$ ءGw,`C` f/0s[KZiLbP݄VfiLZ'bpWx WwxTA۾W߲]u&$Rʗ;bIaI͘ظ(P PHΛ`i*')3)=aBVA  [9&mxL^:nNe640a *pcMIoP-⸿3YsjٹHfEo8L+/5]@0q-f3lAQTi!+IoZoL Ns#z*~;4MaY:W%gZ [W}qp#,[Vx“jN`I;'ݎs*X4M׫{vPض\9 %Nq<ȼ8CGd H!@3i063& VNXp2lwtoYv a-yOfj䴿~ډqr:R֓?wOa\šyu |K/LfyteHOaۮhrzQgl.[^0T^bYyp*W@i^((9ԓO|FݛE-.=ꍰt4 m0f&ERE7t i|hlxP&kєoI3τ|V +5C`1R4gӓ;u`2w2t&&''M?'iֳM3~(V ToҒe6Ƨ8쓼Ԍ| w[T6s'l7tߑ4j9W_uҪU9ZG0,nOB R>=-:&w6[/|3SSXh=[A.,u~#MUDl I(Z4NL>=s:=7]^x/=v1 y-B~-QzkU5_86 _|5|u84Mr6UqR1Nu;);J<񍖀Hf_ʿP^6ϧBJdt/$1cwGHP|`wnBk, =q%B10 ( ?*|ZB{0 }.dvܷ5gXV~NYrJ؃ȓA0 @Z֒*3hY`fИ??z]'5"V##gG[8>uBx7?pŋƟ|2F,Zg5M5il2h6.NG]$Y(=s׭?|˥TQzij|1mY [9y装 )Uvx;:OZ)%"m4jϩI.8iiu}[dQ\ßykU@ִth$A'jGgr,hEzo6D_7>['RWԪE^SxQy3L|j߂fRZAݘy)'/?{N-RB񅱝(>Mh4Z}9[Dp3AUS .CWZn2+9&`bem^lz]CgM=PW;jv1Qe塡'k'ASt8Y`b4 qTGcP6fĝU,?Th8x@r^/|sϺ 5C0^_.D\K ˣb3[V5LAD#fBRVCCﯪυ0-Ę8WV^ϵɾcg<|[ͨ=mO \ut*w`wxkG1XDRA~Pu2S3!3gEcЂxDf$C& [qG3ܒEP74Sw>$ eznB:n'PPi0\jLgqZiL~&lEےB8~I`"hZGEjҀ"+\? zGك+W6V6+kO=zHi tgGQC׎`5Q$Bi)H (-u}^~g @e FM.8 3 dF\ӽѩ#㤴N`0Mck.Y7߹ w@0%qE$|) SCk!KVKg)ԶL).v%Τ>^I49n}>> ӭ_y"|T7 L Q @8P>?mр6.U 08>IcV$hnJ_q[GgbԶ<\&3tk*Mk6->jevE%sa̟sѢI0-`&ĿE7L- qMFIYE hDp)4n,{gq:?&͕`(d><'Umr &Jh(X+|)}SQdh?w L{xbrGbMQv d(ZD4gZ䳾'+d er!GC °lƚʾ$1v&uM~P+  赯͟|TqC2?_O;)6^75:ErHy&(#Q3+@ʩA>c8}~pUgqiקnh|F3I"=мIbDIl*Ȱʈ'# CmT賎M-y.A/ڶ'?{X JqQX;\W*J| y>E2`ĭKZk͍f)*O6sKUr]vN~8o'UBZ8vm|\4VJ'i&8KQ7W7(Mk7g_g]0@c%FSL-`8V:n[΁:ˉ>+Ƨ:6s;sj_/_޿ڜ7)LSVMG|e@HCZ{՜}vwM_Vsկ0I"\>ddLc=n:7ܸ1$m^" ~K09 c:YQ,y@UNja݌?0@glT-|X1"i.Ani=l⩫/Q\5PwX'omgT{+Q|en+׮ \tqejM[E 7dOG[`Ά 5w_|ܹ'ϪD%=}Sk{V"f${N0 pXH)&o~$I>s+yۇ׬شn%7glY4'IoT& t{TxEOYc%+(R|\b -@+cb"j\xv뗇dB?7T!4uzrdTڂc TyCU'wTDF? _,XPTrvh(h6~r9u%ZkX߳+aK >D})1yP ;0_/K|x_ݐ2͕'J8fw ws}ݠs(֚=\iT L{LIJqWVU+%z?)-YRT0TJ0dk15{~rqͪx&PJ3)K*VkJ )$2NY|%1F֐k\oY$.AܪU]glV>7>ꈃHaWo2f1Xq[-?OO4l2sЙpLb`ֱ%â}-yc * Hrvn\3eDB&-[Eˁ[|EQù`wzAY _?m[:;p˗ZS(]|nazcfVTofEQk_1O~mdow~3O̟y՛z-̅Y'4*Cc7$l[J4Iv[(|_ߣĻ5PwF3ԨBESCY&5'E`׶:$7t KOhnnԕO76cɊs'nʑ$FjXYnm\)_ΘcPZJ% Oz$G֘[kݱZƘ_%~;VOB,تFL䘓y]sաK_A90x*943-#zݐꏢsB`B)dnu_0{GF[ (:Ao>!3"Ƨ6׼5DYFfX&hQ{M=a m]k;0@H֟P_#[›4Sۼv~Uk)vvv*b`hm/ng֬\z+ןZbɷxtbrx?}Ff^{on<3:8ɧY=8 |`|ڵp?:IyeCE# `؜??G^k^ǧә(&8SUb= &ʚEcI =}xlDGJL6ˢ7jM(%|_ش mZ!@BD;m|2)W%\:y'zW8=UZjɨ9* D1ϿuWzg癧sU,YdUzm0~-HڝA~:˄ϸ,f} w`;x1)hf{{M4@:V*{*[i,#IV-7U{@Lv}/V4l>-VsDY"NXDg91IS6+ |>7;0 ?p% +VJTRN\kb/J_5;B JY) |{'zN>V#5ĭ;|ڴ>ͧe2ltPg926D;Ê5YBr>0TA1*,}f6j]IL:NYc`/l+vGQ:wvv̙ omwuuVRQM1XHEH8m . c}")_Y~\"OX\?m0 h^:UnLm%Iȸ Uf%̨C(ڜY'y'Mi:Ea9 PS8{Cn?k[pY\.U-֗3W)׼꧔V?|&7D grf9?2,̋jgAdH kvJhjjQE`Ji}'b)'jn&작}@໣4&&./Pıd2k3gNZd2 O"1[ 16d^V-Zw>-\fsSѳ' KV /(d V哟`vxiLba8I%xzÚz3_n} cBNBi.IL%v>Np`%?r&+6DWJ?=<׮-Jm~xXsL̞[o쁙iLbJO7٪LNslrJ'v?eQ3? (ƞn l"Į7ol0瞙]>/R ?w-wIOV夶lqt-uShuCPʐ|K-:ácg1 02xԯ ʽw[?csׯ,ơ߿o+/=ee&ߖ  L9zXb`Lz~;`kȧg% &ܭwN\ήJZ7SlbB=3.X!kvԋQ JLgEIηG$ p3, A$tQ$鑑oyK`P5KJ9Mo^G)8$"Ct ׯ˗ Fq<ʔ 1 6E>fWK35zƧ{>fpIpmQI*l%+yï24ze}AjP'(KA2y( 'Gg맮.-B+dNVP"m2 xjT=B &g '̂99YЇ==SzՒ 876O͞]_wG}M,8}sӠ)$Nus IvJ#0!>- u]؟0XfPs.>pbQ(i=>_u屰܃MSw¾[k_g߶EYs@~X,saڽrل-0" [sn!b:7Sx*{va]ՊG?~wiuuuQ_DiU-Z[[۶۶#94AQdr*BQ'6&V؊ge(ă:P^&6>ȳ"XɲZD>%HjLOVf]u'+zwU{`bI(ƣ08O҇^@krɑzV\$4VF⏆9q0͉qĒvw.VKw]|?ZQ>O"zɘS`b{y);cg YfeLO&/߽wq΃'ur hWWZK\Wlf̨3Vnyzcgm˿Gsgs,* Ye8;I3M*TFT٘Ә$SO@ lQ)HHõ{J4UN$<m|4B\=Elj|O5E7wpщC{ܳ#_n}ECNUC6`ҳfҳZg. ֜UirMDN Gdz*yjx;Q" r)0oFPϱrGB+ 'wn9~H#W\:ۦw( ~?$:tyƘOH"˜$YGTrO0& ցh)oLIn1 0@l$Qd˗->#=[َk S ZloBKaԇCc`?vT+CskgP"it;w}襚fcNL%L=uOij>/?X5! c d'qJ'7؊V(uohfB`mr \p :% J7Ƶӊ:cͬnm)hašXL~Oitofpg8~%U(l^{U^:pՋճ0?DMNՙ9ю)i?PMVٞ39}cO 7o>/?u`$+tUg9+IOW\a6fI,'&lMעihx41ikQ5]Sq,&lP }C#GmU(4BLZ鵸1 [<#{|9h\x^߷nb~Aғ`lv]~hxݓhŦ3iI폣X8V ]fcB|p=*>DXR26Z۹v5iJˈ+ؕgClL&,S='3'rEp@ڻ-gQ,w ~&>qܷjB6FJ$ݠ,&0mD!V00e(Ţ CWr'>}]45>Ym߹78% 6$I|0wh!Ѱib)Ñ.)b0KHMCŞբDcA1Dq]|*Kz3o{0h ,BȼB˄(d$ģ^Zb]hL#LD!ZXծsπ2a}ro<]͑?*J-0ƔJ0 > Owx#(|fsoLpg"JlUyG=PqF]>{}i|#8`Z}g91[cI!Mham5lzϻrgnl : 3ۚ+9NB~t\0>l7MБ$mGLT!VAlv-b"5O}(YLP>ORhkxxIkr66[2IB̛dџ%3̷c;m E['O36 *aA4ib]T ~֙iIKEKju޼y .r# ъ??\,T|Ivjlȏ{߇)4wB ,''`0nmguFwnSLeѱmUۥ z0Q*!.zF~mąmN னǙU+?]cƄH o/n`IS"x?Ռ"#'GicZKNkHwoOtNt k=sYFi,$3Z4'څXȌvGŜ6bCwݷ/淫3E&~elǎ⼹*?7{oMOUFwY4K"fx|1#x{^3}!T"n:!3].hJ,VL+G) =OCD´[Y)VE[۶㵒c -D(#ڎV 168wȘ#}>= y4:(`Ki W^?enU{drVvEkZZge1  8~N;x=ғ;M@ 0npzowW7s$}ΞX7C1FG9Fv6~ν8c%Z0VRwX p#io4X(!=}o89d"&~#ZJO=,]b #}-Y-`z킅Lh4ab=aksl:EqiiW~ˏ[u/TG[CitZZe758p;xo]2Q- yi3 A!uyO~oŕ{3Bt3Ƙm H y&E2nsAR1qmPZ$sb}؟'BsxXӊعU̐?}/[+&,M\ȼ`8PST|!Qw4)ob;R*LoZSVsMT;{W茎?H3T򮽰"D r77)oonnCh𞎡M=b-D.?˕j0aMuq𶻩*vYdU?~ιJ]{04$A0"a ?PĴkZ5πaMJ<009t޻{4 ==^~5Jhخ&ˍ CF')L /AV%i뭪ME &en|V4zCyusIB ^t6ljpSU%X=G,@p* )QK0 sùU`8 JrF"EDhH/Vu}/|n.Ѣ,EѴ^؏ ]DQ\w\W""uѧ|sKkAo[䵱- D=Ďa#R!YBix.M(L@䌙C/r- ge~R[wc:] ض;+(vV:{k2B dqFG ?z)pEX錟)&AQ2S7j"]v;O9tQG @ CI[ r[uܿ}b^;:Y̙Lj쳏+MnMLX:;-a"dg/q%o,L~ɉk;$_o¥$1cEdxd+E6ҬUZgά?C`oaJc[Kۗ:bf>%xbf tauZ#" oŀ`Ň|t_wˍֹeg{ykVya:64Z:d˥mb LtP#l*PBYL'דfE 4 40 2r !8iOEiba=ؑYm֚I 9/. (P*e7-h\wnvr{:zC B&pbeB(=˟sΚ/{*C^|ڜӢwVL0:yk\4m "t(uGbÛ?o]IfJ@"6n.;%]/بOVnWu.>֌ǟa=eVI2>q<ש>gU63쌱BHc w GF(Gw4SK:*I)|%!h*<"s̚=b+MoEoIQ 3.~ϻN'r-t4<ݱq%m-~gooՓ%LP`&7X#䲅>>z/|32Rg[DzҸ]b_7we=[}{/_$ >A2%Ixmu Nu25&'ؽ iP>v|Ju6_NJűisH2YYwu|ﰃeO\ 3}æ0Ds ZѾtמ/ 9-SAI5Z79>flóIH_f}YY9'6FiB3^*O'+;FCV@18?6ĬP(143DM{.Y#; H\s660mv6 +1$iw%RV3U֞=NF~sA00JMƠ~:sk~)RY+ aynYƄ%#UnpD1z08j7G v$3LLh?ԟꑠrKw_^㗯#/zh( $MN!OzQk&ͶNa*uƎ~CUEX883eڝ"`"K.Mf]Nf9&Im#CiwU}`YU8n-.f{Ǵt7Iзگm+/zâs_,ۺmT'r?ߒaJuH?0P];6|: [XJO<>vPfsIS~X2Y>a0HiCciVD9CN)#^{yG t)-ɳvdLB|oNJ|}_;DQ,WOVr%4J"bu QQu|Gڎ@ C,eA9}&~' (Zi׬0 FA_6-=XjH3>GZYV)Mbi93'l=`nq {Gے׼}/9o"R%&`V JYl!PMņR;1A=ӽY1( )@@Ede9ޔ>N_ɮ &CNK Bj%@ѪyuePLCiEV\qgiXliK)\&LJQ8&I?DZt FV?Gxg֮~]G[$m!+ljiMM-?Z<8똂k. "޵+G͉>}2p<)fԂF[7V٦5#  ;VlӧbT2ULPL)3n@ze*@;hR@"3poζa IJycDS9g¬ oMa zdlwDI'/>c@Da;8(YAmSNBh7I4O)\ kNz˯If?fK]cR"P x#O2MN;I̖; )"\<(Xt& ycGa&}Om˰އ6p6v|nHƃ$f֑[{ޓo;|enyj7`ff@[h[&qGnc7Vg6L\U[s0!!;$#vyۤe|\TcQ *%Vbd>0QN V(me\f7lzܢ&{pXgzfV-V/kGsk{MBj&>`DZ [7W7m]lN,_<qXaS{7 0F޳d"$ TeCB֨ GZ_T#䪕v")lk^9i572x[1F8$VT5ieY!eKGA.jBI#je <hJsxcN2zנ0:E?E .mBE#C?+4%Vh43`EH&c<4^{ݾLۢ>CX)kP!S(aG>s"H㽰KDu#U6n Gz7AgrcC7C}r 0CI~OY b6(Cϗhrc)3Sv1,ǸipGydC [wNsқ$,xٖN7D "}( y g?9b: ^[X0oܾCj2qb, !!#*jMi23qA-1Jz=m&md6 UE`p֢k-. ԱgBBC, >&5{2dANZ2a!Io=K% "ȎΡKNV, Y҆0ql\޽dQ6,]WЖ$WKq0}= '+k<f!>FԌ,L붝v?AP c[9HT&kkv>v]ZmQ Y3*i@ j&YT(Z+4MYzFۛL7{DW@iHɮ: tl$i_ @OJ^~ዖ孛J+1"0QW+C' |W㻣7 j>;\n]'Z_^ c]raNKf8B١JL~ylN|6!ⱒ-(ij,Q\ֶ  !cS!HZG&)E,wiH'OSLOբkg4^v04] <,ʁ8]4IPPr=2\e(dxAK% ?#$ڟڹ=g[v砆6bgAR7YsDm´Uc`aZ:sඇ/~Wvj4]ym8ʴ4+[siN9_~Ǽca@aSN6o#=:k벓N`Bd| {TߴлDAfWmI씎+H%@(ZضUB$f l&%`tb0imQtbl2&J[<͚̈Q%My{ ?D܁~TӦꇊKmƙI2LhM٦h 4shƮ") " #}r8Ga_j'zz'~tp)Ͳ̼=0Un+psNU&ouF* 4FKR2A?MmΧEmC'#liet}R|P֌9I6-ibv?e\Fp4QHiϐהJX%y08#Vâ?weijw..R :DiKV% BI&6.cw#:)kn<0XZNʽW\}wٌ\}jT2V%KW]tS/xcP%+uTiz xA?>vmM=ZX}[6* `M i8}7,lR#iםx61!& }U\;ALX>?QvBУ;o'#@ K YۦТ~^0uwP0Q2k8r8Rɪ`Krc1jJb o9pҐѠuÝɺo~O|P9giMTt3gׇnq?_׃+cM]8\W[w^xoݰk!X(jvc7 3 k$J4&Xu” [ "&|rFKxF6L Ɔ@ԵDR,,M_$''?/r^m49#e"(DaFg( ⼔>a 1aKmH(0jϮ/7_ H x1_~';48tZ ˉ+KSvt- mã`\GFK7(XVa_@BrH"lFqɲtEnѱ2x}ѯ5s_~6ApCexh+69nh Q(&,ѷ渭ٶ޷S@En(Uho#H56grʁOr0Obs j0uTRwis=g a*> Jg I @ֈƍ4ftv* L~c'<zl;0BGQ{ HP Ž؅!wd5h (<.46O$ ~qvqړK5MF0^TY#UԈK [qkҦS<`QgxvwSNd U6;j?^@dsL[x'Gl²mz!'J~%=?T0u(ۃ_ϥ'Xڱ@eeU+7cdNg:vnyUu0<5ѿI}&Ų4vM@X$}[ـ85)Bz̸z*+Qdh4$±.a8IJN2.GȊ 5k|>d,,S7{o]wH,eF$hΓb %c8%kkyh/vwwvtmZ 1yd-6M\3&IN-)'7l^dmsR$lhţ~SOS_u-e)lJ_S;11#i2.DU`RN=f.&# 9*!F.PXDs>%J|LV0Er;Ze]Wh.v%(2u"[Գ3.w:+L~HcD9Quƪ~o#'^og2٬v;{;{ڋ==\.}/ϓ#dxrtxBk`'>D$T$CA- iΙOOXs1b*?g*>;WBp+LD[`{}qI1ބ 0~w udQZc|Vuv[K>*/GGths,wg $"S2R5 ayMHG(4b"RNN8ZxZ*&:#C\Wyt|[[{.d|/z8"u|=?繎?8p-,U۽ %#U 69S˼F$27&v:<@eD 5eI#GB*,eZ|V޻gIoIX";JҜy,^BDFʝjϹe܆V zL ,JlyQg#%94${S )ԌLV<&|>$B"5j}?yhIO! LYIjVXPvWVc#v_mNzynm4h,k%\TZ <)D"=]~@QrTV\@?,@t[/h 3~i/KS {Oyl XĞѣeE&"!@}L6򸭷mf`s}-?fIvfa)-$ӖI;lk\^o5g9><0-J`B4(D(Gc ;l_oeEvXwgo|/My^{[[Wg-mFع'w#e^CD lشy@#?ly{~T?V+B ꈨ d$34ZUJ$]&R W\uw?~ lCd܆1F?4RG$GaE@Y+۳~dqQ~]HՕVZ~?ŎbG,^d7_3?s~G%cRw퓍p ߿= jPH\~{Gz$Ao+(!!P j|'>2OY~8Z;ݹkۺzȶ; j躟|?W]/^qߛO=+c_|=<26,鍞sWvE:RV*]vޏWcfHH&(#boT2&>2 ].1tE#&9~1&Rpqu75Ǒp_o(mqu LY| th`+RHXTV0<&!yʺir>?FFF/㖀+lk[ܵkoM?s3{z&I׾=?wz`>̎_gvܹh¿z?tyHv}/ ѱ}œ&_V,{ F ua`iA}h!?3ceो{ ٬i=]τwu"?-[?w-vt~#L{_k?>wtwys/җ?|-Xpɛc__χ}W~_r^[[[\O~{πlS/.A#cю 1,'C{j|Y,hKeA<K )YMHR-e+i ˜}%-v(9$!Ap_ݭ%ʕ :׾qÍ>l&sÍ@ėEg /x!+W0e֞p| o~M05޽$ ZyHC+GJdU:HϜ/fdZ),,G滛gĤ@H?{_?g/Z8N-h|/mٺ5\ro 7×^[ȣ5\^~Ս0͗J(x8)o]50bv(GMAw'ہ!/W@p< *8"zT؄ds`N2u)œOTj5pKzG>)W*==g]yd`bc_K/=ҥ#mh}ԁn6q$=]m)3 8Yҕ Ii1x0ZM'Kkm8P: 8C_ٳ;wN7^$*?0'j#)eTo瞾|Awu_Fϔ$ǮYscW\18g3sȪ(Sb6&4U\Dka(%klkFiŠӀQN90 s"$YpfBD-]~wtɒ_6l|ߥyvO&z> o۾[n}߼W\dɥ{w/_$"P=Cw|@J, [WsWUwCK:Euz %Ei) pT#sFJexx䮻~?yC~ 0#exk^ǿpg9 \k*x_:aiSSHX"4:;6 Y0%2DEvX@{:Y B9Ќ53둭_O|sʫn\.{[|q喇-d3;8⍯kqlyg<\}EV_k.M+W,کB EWP1=~ݣrݣ.f|io_Jzd #lӁee5[ፆro zi\*s:n2zbtWWkO2hv!;v4'\@RNZ{dŋ6xݏ`O9~jz駽7X/`l|(Y0S.f1Vz( #Eɀdbp,ށjB =%:6˱{OkSccc?甓3y_u w~}ŗٽ;^z#~_k^ƁlY%߁@]x%5I3aVE}A<,p$Lk+\A+N(>o}oy<<甓 3!+W>؆v;vI6ʫJW}~v㍻v|cY~ҿ PN zlZ LtZriV&=#L+圛ʤhs\wh(6ư _7կK.ݾcG4 Շ9_ 'xb_1q!}_c/Z#; =l/jX)DSbwZt:тMhA =݋hs]ϭ0VFHшmokٹk8gv_q M{"]vE=,F&/۸i3+^к#^v>C`mYJQ=M9,岱`=HH=|{6g"<,X{ >^XܹkNcU|uM-t=cW}Zu_o]FFG>yޢ510v訕BrG0 2LO@w Ӷ+yI2sz{;,:Eg<9|0|U*Վv)e^eJN;Kcn-]3K_?'7r1=f͚~Su_blMolc86)8hxh 4Ќ g}/+¾=-<y^b+^vI'sŒR7xwv̼]7Wկ| D q  _?{ ]ׯyg%3/bH-;H`(38 jZN}"-̈́~!Oi{JHR4jd)YcLr$V*= Rz_RRd|?~>ѡvl|ke )DT,s)'=j)*FސdNB/-j#ޑӚBQ&č5![:ew<42cF^ hp|b2)N?+6M \R A:8Q*?a aݎv^L6xLk-:BT '&;Ȩar?3[CrV.y/gYZ,dy+ /jƗz:@^hvdv5ڵgh)DtXgd׬E2뇌\B&L h(s("0ꆘ`q){wB8KE#'A{TptEӇOc *JPω?&9=2S0fs"ui ZDM-_UcۿZ=䛧'E{ma%d,+Yz4B=N &T2"MsXLS)}ֻYPkV^Qh3.H^Lj6%10 m;rvL.Z($B *w+`i2 !Ià9R)?ҐYj& y@]P`<Oof%g 9 $ ,ˑ`B[})V ImRph`R*2Se(_nI"0A!">qaU12EiÚiN @q `0PdӰσmPl+EB($I2 1t};u$ ?>&PsN9y@`;m6 ]df%)p@3i͸Q-VYCQSԊcBʃ *Wѹ(,S0 B`C2'( "Ũ- 2^{$Y Y +m8QS=ǐQSX YHvIs,U!HN Z)Vu`%&dYV4]CCR2f 4!a)3M٨-"̂G2tE:#Y/Ǝfkx#jKfB>*/C I?Qf[ x7aòi5: nQuwb{s::8KAѫivC=ԡ(HMmΩ-k(Wk#'P i˦*ti s8D4G:2䵄+P4 ;&}I({R#7aIEb *^Y0"^ uD!EV[ۄ@Z`PxlCFɉ;6|LHPPĥPC:WFIDAT#hx\9ٱ k09T/qTˎ@]@8M[&BV䇍&  ŰiPq8Yѡ¦&re"2tv!h-Iwu %m|;sx ̀ATxa6۩#S녾.h<2hױI|mFxxHY&X~ښ|-ձI!oixN)t #:vbY7 k'麮%;wƋ"<UI(2YȺ? Z7 $VJ /B1*4jhM2 Q`klCzwmĨ# h^n8hXGVhR!iȭ8$p8b;M ˟_848*IPg4zk"9.]D D1q.d&c6yV|ؕU} ]/.J$zK.99-lDAʌ6Ac8PxAp@r@06 _ZsޮѶmpmxT+Zwe =|ש-gt"TId[hroZ!/Vl =JI/XnخZ7 +F-_4MȺAG\pjJPю'uc ިCV;?۳~=c33#F߮\:JU(5[i5 B&x 륮:T5 a8U\MwI8Ƿ F_qzZEb^"uG QrnGG1H"^mSLQ#*SܭB)=5:{w< Nf%Ih͈brk_ܾz+o jvWcpL&  X7a=Eqv51.TeqSΕӷm6hH`Oc4"8gǹ01P;V1P^hyvA"qA`\WEW.牗 o~F=hLPge\^3<;K6ZU[qct Vol]X_bow_,jm}EZ%Luak+_l~r)wv\o@''葙CLoUK\[X>#*vֶ[ѿ>Fx>tfVh|dqw}=kI @[gFg.0nJ~ VHmy2񽉙s٦/.crպ*VͿc}pe-&y,}]}< NI0ۃ! 7\#Vx$઄K`*@X ⩚ئ&*"0z˱1|K{ ŒDV)b@k(.|1|~iDRcTt~my 2\$LT( G*P+ w=xAJ3RCpXxKt lWpi0]^[vO&:27(@i;g_,^x9k9NEY(*:VD#h@,*b꠆ *JQ1[աTokhR ɏ1`fk߃XxwbśΓ3!>cŁ׿Pj p pee[HR\in[GL: Vi*hɢ"%lFcۣ+W sErWYH)ͭ6ʹ`qgҢ^+̴ S=\XYxP X.o^_*Ͽ72aÀ XU ^\7wӖ |y݄k5O?՗$T,\Z!s.ĺqP;1VǙT4YFU0y]ecWJkXhiYצa,d Lt%8=g{p8Jyl9ʣfKnbUҋQ/G͔{i?t_{c+jnN*B4e+G4`ILrϥH渜'((E՚*Cu_#)dxXpֽاDAX:)FM7Ɩ7&q8ʙRzCLAZ c+fe6lX3i6Re)C]}UYqccPՁAUU5Xi0$Ģ,Wñ~bk+[?Qa\["))GQ–@ץ7jX׃A\l 8 FO@)&ڢgV{c3GL17 k&M4ɮ@'fp8٭[[F~uCtLg#`L LY9 G`tׁ&$Hh G` 9 HL周j}jșHHI}|t0c[=0cc?cd&U4ARx:c 'F 6.(Y1ftl)c#O,M RCĊElz9Q&9b(JM! #8!DcZ 8u]BOĀ (NUL 2>NUL if errorlevel 9009 ( echo. echo.The 'sphinx-build' command was not found. Make sure you have Sphinx echo.installed, then set the SPHINXBUILD environment variable to point echo.to the full path of the 'sphinx-build' executable. Alternatively you echo.may add the Sphinx directory to PATH. echo. echo.If you don't have Sphinx installed, grab it from echo.https://www.sphinx-doc.org/ exit /b 1 ) if "%1" == "" goto help %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% goto end :help %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% :end popd tee-ar-ex-trx-python-a304ac2/docs/source/000077500000000000000000000000001515240773700202755ustar00rootroot00000000000000tee-ar-ex-trx-python-a304ac2/docs/source/conf.py000066400000000000000000000137641515240773700216070ustar00rootroot00000000000000"""Sphinx configuration for trx-python documentation.""" # Configuration file for the Sphinx documentation builder. # # This file only contains a selection of the most common options. For a full # list see the documentation: # https://www.sphinx-doc.org/en/master/usage/configuration.html # -- Path setup -------------------------------------------------------------- # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # # import os # import sys # sys.path.insert(0, os.path.abspath('.')) import os from datetime import datetime as dt # -- Version information ----------------------------------------------------- # Get version from environment variable (set by CI) or package version = os.environ.get('TRX_VERSION', None) if version is None: try: from trx import __version__ version = __version__ except ImportError: version = "dev" # Normalize version for switcher matching # Remove .devX suffix for matching against switcher.json version_match = version.split('.dev')[0] if '.dev' in version else version if version_match == version and 'dev' not in version: # This is a release version pass else: # Development version - match against "dev" version_match = "dev" # -- Project information ----------------------------------------------------- project = 'trx-python' copyright = copyright = f'2021-{dt.now().year}, The TRX developers' author = 'The TRX developers' release = version # -- General configuration --------------------------------------------------- # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ 'sphinx.ext.autodoc', 'sphinx.ext.todo', 'sphinx.ext.doctest', 'sphinx.ext.intersphinx', 'sphinx.ext.viewcode', 'sphinx.ext.githubpages', 'sphinx.ext.autosummary', 'autoapi.extension', 'numpydoc', 'sphinx_gallery.gen_gallery', 'sphinx_design', ] # Suppress known deprecation warnings from dependencies # astroid 4.x deprecation - will be fixed when sphinx-autoapi updates for astroid 5.x import warnings warnings.filterwarnings( 'ignore', message="importing .* from 'astroid' is deprecated", category=DeprecationWarning ) # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. exclude_patterns = [] # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # html_theme = "pydata_sphinx_theme" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ['../_static'] html_logo = "../_static/trx_logo.png" html_sidebars = { "scripts": [], "trx_specifications": [], } html_theme_options = { "icon_links": [ { # Label for this link "name": "GitHub", # URL where the link will redirect "url": "https://github.com/tee-ar-ex", # required # Icon class (if "type": "fontawesome"), or path to local image # (if "type": "local") "icon": "fab fa-github-square", # The type of image to be used (see below for details) "type": "fontawesome", } ], # Version switcher configuration "switcher": { "json_url": "https://tee-ar-ex.github.io/trx-python/switcher.json", "version_match": version_match, }, "navbar_start": ["navbar-logo", "version-switcher"], "show_version_warning_banner": True, # Show table of contents on each page (section navigation) "secondary_sidebar_items": ["page-toc", "edit-this-page", "sourcelink"], "show_toc_level": 2, } autoapi_type = 'python' autoapi_dirs = ['../../trx'] autoapi_ignore = ['*test*', '*version*'] def _validate_reference_urls(urls, timeout=5): """Validate reference URLs and return only reachable ones. Checks if the objects.inv file (used by sphinx for intersphinx) is accessible at each URL. Parameters ---------- urls : dict Dictionary of package names to documentation URLs. timeout : int Connection timeout in seconds. Returns ------- dict Dictionary containing only URLs that are reachable. """ import urllib.request import urllib.error valid_urls = {} for name, url in urls.items(): objects_inv_url = url.rstrip('/') + '/objects.inv' try: req = urllib.request.Request( objects_inv_url, headers={'User-Agent': 'Sphinx-doc-builder'} ) urllib.request.urlopen(req, timeout=timeout) valid_urls[name] = url except urllib.error.URLError as e: reason = getattr(e, 'reason', str(e)) print(f"WARNING: Skipping '{name}' reference URL ({url}): {reason}") except Exception as e: print(f"WARNING: Skipping '{name}' reference URL ({url}): {e}") return valid_urls # Reference URLs for sphinx-gallery hyperlinks _reference_urls = { 'numpy': 'https://numpy.org/doc/stable/', 'nibabel': 'https://nipy.org/nibabel/', } # Sphinx gallery configuration sphinx_gallery_conf = { 'examples_dirs': '../../examples', 'gallery_dirs': 'auto_examples', 'within_subsection_order': 'NumberOfCodeLinesSortKey', 'reference_url': _validate_reference_urls(_reference_urls), 'default_thumb_file': os.path.join(os.path.dirname(__file__), '..', '_static', 'trx_logo.png'), } tee-ar-ex-trx-python-a304ac2/docs/source/contributing.rst000066400000000000000000000105511515240773700235400ustar00rootroot00000000000000Contributing to TRX-Python ========================== We welcome contributions from the community! This guide will help you get started with contributing to the TRX-Python project. Ways to Contribute ------------------ There are many ways to contribute to TRX-Python: - **Report bugs**: If you find a bug, please open an issue on GitHub - **Suggest features**: Have an idea? Open an issue to discuss it - **Fix bugs**: Look for issues labeled "good first issue" or "help wanted" - **Write documentation**: Help improve our docs or add examples - **Write tests**: Increase test coverage - **Code review**: Review pull requests from other contributors Getting Started --------------- 1. **Fork the repository** on GitHub 2. **Clone your fork**: .. code-block:: bash git clone https://github.com/YOUR_USERNAME/trx-python.git cd trx-python 3. **Set up development environment**: .. code-block:: bash pip install -e ".[all]" spin setup The ``spin setup`` command fetches version tags from upstream, which is required for correct version detection. 4. **Create a branch** for your changes: .. code-block:: bash git checkout -b my-feature-branch Making Changes -------------- Development Workflow ~~~~~~~~~~~~~~~~~~~~ We use `spin `_ for development workflow: .. code-block:: bash spin install # Install in editable mode spin test # Run all tests spin lint # Run linting (ruff) spin docs # Build documentation Before Submitting ~~~~~~~~~~~~~~~~~ 1. **Run tests** to ensure your changes don't break existing functionality: .. code-block:: bash spin test 2. **Run linting** to ensure code style compliance: .. code-block:: bash spin lint You can auto-fix many issues with: .. code-block:: bash spin lint --fix 3. **Format your code** using ruff: .. code-block:: bash ruff format . 4. **Write tests** for any new functionality 5. **Update documentation** if needed Submitting a Pull Request ------------------------- 1. **Push your changes** to your fork: .. code-block:: bash git push origin my-feature-branch 2. **Open a Pull Request** on GitHub against the ``master`` branch 3. **Describe your changes** in the PR description: - What does this PR do? - Why is this change needed? - How was it tested? 4. **Wait for CI checks** to pass 5. **Address review feedback** if requested Code Style ---------- We follow these conventions: - **PEP 8** style guide - **Line length**: 88 characters maximum - **Docstrings**: NumPy style format - **Type hints**: Encouraged but not required Example docstring: .. code-block:: python def my_function(param1, param2): """Short description of the function. Parameters ---------- param1 : int Description of param1. param2 : str Description of param2. Returns ------- result : bool Description of return value. Examples -------- >>> my_function(1, "test") True """ pass We use `ruff `_ for linting and formatting. Configuration is in ``ruff.toml``. Testing ------- Tests are located in ``trx/tests/``. We use pytest for testing. Running Tests ~~~~~~~~~~~~~ .. code-block:: bash # Run all tests spin test # Run tests matching a pattern spin test -m memmap # Run with verbose output spin test -v # Run a specific test file pytest trx/tests/test_memmap.py Writing Tests ~~~~~~~~~~~~~ - Place tests in ``trx/tests/`` - Name test files ``test_*.py`` - Name test functions ``test_*`` - Use pytest fixtures for common setup Documentation ------------- Documentation is built with Sphinx and hosted on GitHub Pages. Building Docs ~~~~~~~~~~~~~ .. code-block:: bash spin docs # Build documentation spin docs --clean # Clean build spin docs --open # Build and open in browser Writing Documentation ~~~~~~~~~~~~~~~~~~~~~ - Documentation source is in ``docs/source/`` - Use reStructuredText format - API documentation is auto-generated from docstrings Getting Help ------------ - **GitHub Issues**: For bugs and feature requests - **GitHub Discussions**: For questions and discussions Thank you for contributing to TRX-Python! tee-ar-ex-trx-python-a304ac2/docs/source/dev.rst000066400000000000000000000173311515240773700216120ustar00rootroot00000000000000Developer Guide =============== This guide provides detailed information for developers working on TRX-Python. .. toctree:: :maxdepth: 1 contributing Installation for Development ---------------------------- Prerequisites ~~~~~~~~~~~~~ - Python 3.11 or later (Python 3.12+ recommended) - Git - pip Setting Up Your Environment ~~~~~~~~~~~~~~~~~~~~~~~~~~~ 1. **Clone the repository**: .. code-block:: bash # If you're a contributor, fork first then clone your fork git clone https://github.com/YOUR_USERNAME/trx-python.git cd trx-python 2. **Install with all development dependencies**: .. code-block:: bash pip install -e ".[all]" This installs: - Core dependencies (numpy, nibabel, deepdiff, typer) - Development tools (spin, setuptools_scm) - Documentation tools (sphinx, numpydoc) - Style tools (ruff, pre-commit) - Testing tools (pytest, pytest-cov) 3. **Set up the development environment**: .. code-block:: bash spin setup This command: - Adds upstream remote if missing - Fetches version tags for correct ``setuptools_scm`` version detection Using Spin ---------- We use `spin `_ for development workflow. Spin provides a consistent interface for common development tasks. Available Commands ~~~~~~~~~~~~~~~~~~ Run ``spin`` without arguments to see all available commands: .. code-block:: bash spin **Setup Commands:** .. code-block:: bash spin setup # Configure development environment **Build Commands:** .. code-block:: bash spin install # Install package in editable mode **Test Commands:** .. code-block:: bash spin test # Run all tests spin test -m NAME # Run tests matching pattern spin test -v # Verbose output spin lint # Run ruff linting spin lint --fix # Auto-fix linting issues **Documentation Commands:** .. code-block:: bash spin docs # Build documentation spin docs --clean # Clean and rebuild spin docs --open # Build and open in browser **Cleanup Commands:** .. code-block:: bash spin clean # Remove temporary files and build artifacts Code Quality ------------ Linting with Ruff ~~~~~~~~~~~~~~~~~ We use `ruff `_ for linting and formatting. Configuration is in ``ruff.toml``. .. code-block:: bash # Check for issues spin lint # Auto-fix issues spin lint --fix # Format code ruff format . # Check formatting without changes ruff format --check . Pre-commit Hooks ~~~~~~~~~~~~~~~~ We recommend using pre-commit hooks to catch issues before committing: .. code-block:: bash # Install pre-commit hooks pre-commit install # Run hooks manually on all files pre-commit run --all-files The hooks run: - ``ruff`` - Linting with auto-fix - ``ruff-format`` - Code formatting - ``codespell`` - Spell checking Testing ------- Running Tests ~~~~~~~~~~~~~ .. code-block:: bash # Run all tests spin test # Run tests matching a pattern spin test -m memmap # Run with pytest directly pytest trx/tests # Run with coverage pytest trx/tests --cov=trx --cov-report=term-missing Test Data ~~~~~~~~~ Test data is automatically downloaded from Figshare on first run. Data is cached in ``~/.tee_ar_ex/``. You can manually fetch test data: .. code-block:: python from trx.fetcher import fetch_data, get_testing_files_dict fetch_data(get_testing_files_dict()) Writing Tests ~~~~~~~~~~~~~ - Tests go in ``trx/tests/`` - Use pytest fixtures for setup/teardown - Use ``pytest.mark.skipif`` for conditional tests Example: .. code-block:: python import pytest import numpy as np from numpy.testing import assert_array_equal def test_my_function(): result = my_function(input_data) expected = np.array([1, 2, 3]) assert_array_equal(result, expected) @pytest.mark.skipif(not dipy_available, reason="Dipy required") def test_with_dipy(): # Test that requires dipy pass Documentation ------------- Building Documentation ~~~~~~~~~~~~~~~~~~~~~~ .. code-block:: bash # Build docs spin docs # Clean build spin docs --clean # Build and open in browser spin docs --open Documentation is built with Sphinx and uses: - ``pydata-sphinx-theme`` for styling - ``sphinx-autoapi`` for API documentation - ``numpydoc`` for NumPy-style docstrings Writing Documentation ~~~~~~~~~~~~~~~~~~~~~ - Source files are in ``docs/source/`` - Use reStructuredText format - API docs are auto-generated from docstrings NumPy Docstring Format ~~~~~~~~~~~~~~~~~~~~~~ All functions and classes should be documented using NumPy-style docstrings: .. code-block:: python def load(filename, reference=None): """Load a tractogram file. Parameters ---------- filename : str Path to the tractogram file. reference : str, optional Path to reference anatomy for formats that require it. Returns ------- tractogram : TrxFile or StatefulTractogram The loaded tractogram. Raises ------ ValueError If the file format is not supported. See Also -------- save : Save a tractogram to file. Examples -------- >>> from trx.io import load >>> trx = load("tractogram.trx") """ pass Project Structure ----------------- .. code-block:: text trx-python/ ├── trx/ # Main package │ ├── __init__.py │ ├── cli.py # Command-line interface (Typer) │ ├── fetcher.py # Test data fetching │ ├── io.py # Unified I/O interface │ ├── streamlines_ops.py # Streamline operations │ ├── trx_file_memmap.py # Core TrxFile class │ ├── utils.py # Utility functions │ ├── viz.py # Visualization (optional) │ ├── workflows.py # High-level workflows │ └── tests/ # Test suite ├── docs/ # Documentation │ └── source/ ├── .github/ # GitHub Actions workflows │ └── workflows/ ├── .spin/ # Spin configuration │ └── cmds.py ├── pyproject.toml # Project configuration ├── ruff.toml # Ruff configuration └── .pre-commit-config.yaml # Pre-commit hooks Continuous Integration ---------------------- GitHub Actions runs on every push and pull request: - **test.yml**: Runs tests on Python 3.11-3.13 across Linux, macOS, Windows - **codeformat.yml**: Checks code formatting with pre-commit/ruff - **coverage.yml**: Generates code coverage reports - **docbuild.yml**: Builds and deploys documentation Environment Variables --------------------- TRX_TMPDIR ~~~~~~~~~~ Controls where temporary files are stored during memory-mapped operations. .. code-block:: bash # Use a specific directory export TRX_TMPDIR=/path/to/tmp # Use current working directory export TRX_TMPDIR=use_working_dir Default: System temp directory (``/tmp`` on Linux/macOS, ``C:\WINDOWS\Temp`` on Windows) Release Process --------------- Releases are managed via GitHub: 1. Update version in ``pyproject.toml`` if needed 2. Create a GitHub release with appropriate tag 3. CI automatically publishes to PyPI Version Detection ~~~~~~~~~~~~~~~~~ We use ``setuptools_scm`` for automatic version detection from git tags. This requires: - Proper git tags from upstream - Running ``spin setup`` after cloning a fork tee-ar-ex-trx-python-a304ac2/docs/source/index.rst000066400000000000000000000065111515240773700221410ustar00rootroot00000000000000.. trx-python documentation master file, created by sphinx-quickstart on Fri Jun 24 23:14:56 2022. You can adapt this file completely to your liking, but it should at least contain the root `toctree` directive. TRX: A community-oriented tractography file format =================================================== We propose **TRX**, a tractography file format designed to facilitate dataset exchange, interoperability, and state-of-the-art analyses, acting as a community-driven replacement for the myriad existing file formats. Getting Started ~~~~~~~~~~~~~~~ New to TRX? Start here: 1. **Understand the format**: Read the :doc:`trx_specifications` to understand the TRX file structure 2. **Learn by example**: Follow our :doc:`auto_examples/index` to learn how to read, write, and manipulate TRX files 3. **Use the CLI tools**: Check out the :doc:`scripts` documentation for command-line operations .. grid:: 2 .. grid-item-card:: Tutorials :link: auto_examples/index :link-type: doc Learn how to work with TRX files through hands-on tutorials covering reading/writing files, working with groups, and using metadata. .. grid-item-card:: TRX Specifications :link: trx_specifications :link-type: doc Complete technical specifications of the TRX file format including header fields, array structures, and naming conventions. Why TRX? ~~~~~~~~ File formats that store the results of computational tractography were typically developed within specific software packages. This approach has facilitated a myriad of applications, but this development approach has also generated insularity within software packages, and has limited standardization. Moreover, because tractography file formats were developed to solve immediate challenges, only a limited breadth of applications within a single software package was envisioned, sometimes also neglecting computational performance. Given the growing interest in tractography methods and applications, and the increasing size and complexity of datasets, a community-driven standardization of tractography have become a priority. To address these challenges, our community initiated a discussion to design a new file format and agreed to participate in its conception, development, and, if successful, its adoption. The goal of TRX is to become the first, community-driven, standard amongst tractography file formats. As with other file formats like NiFTI, we believe that TRX will serve the community well and the growing computational needs of our field. We encourage community members to consider early contributions to our proposal so as to ensure the new standard will cover the needs of the wider audience of software developers, toolboxes, and scientists. Our long-term plan is to integrate TRX within the `Brain Imaging Data Structure (BIDS) `_ ecosystem. Acknowledgments ~~~~~~~~~~~~~~~~ Development of TRX is supported by `NIMH grant 1R01MH126699 `_. .. toctree:: :maxdepth: 2 :caption: User Guide: trx_specifications scripts .. toctree:: :maxdepth: 2 :caption: Tutorials: auto_examples/index .. toctree:: :maxdepth: 2 :caption: Development: dev .. toctree:: :maxdepth: 2 :caption: API Reference: autoapi/index tee-ar-ex-trx-python-a304ac2/docs/source/scripts.rst000066400000000000000000000145431515240773700225250ustar00rootroot00000000000000:html_theme.sidebar_secondary.remove: Command-line Interface ====================== The TRX toolkit provides a unified command-line interface ``trx`` as well as individual standalone commands for backward compatibility. All commands become available on your ``PATH`` after installing ``trx-python``. Each command supports ``--help`` for full options. Unified CLI: ``trx`` -------------------- The recommended way to use TRX commands is through the unified ``trx`` CLI: .. code-block:: bash trx --help # Show all available commands trx --help # Show help for a specific command Available subcommands: - ``trx info`` - Display detailed TRX file information - ``trx concatenate`` - Concatenate multiple tractograms - ``trx convert`` - Convert between tractography formats - ``trx convert-dsi`` - Fix DSI-Studio TRK files - ``trx generate`` - Generate TRX from raw data files - ``trx manipulate-dtype`` - Change array data types - ``trx compare`` - Simple tractogram comparison - ``trx validate`` - Validate and clean TRX files - ``trx verify-header`` - Check header compatibility - ``trx visualize`` - Visualize tractogram overlap Standalone Commands ------------------- For backward compatibility, standalone commands are also available: trx_info ~~~~~~~~ Display detailed information about a TRX file, including file size, compression status, header metadata (affine, dimensions, voxel sizes), streamline/vertex counts, data keys (dpv, dps, dpg), groups, and archive contents. - Only ``.trx`` files are supported. .. code-block:: bash # Using unified CLI trx info tractogram.trx # Using standalone command trx_info tractogram.trx trx_concatenate_tractograms ~~~~~~~~~~~~~~~~~~~~~~~~~~~ Concatenate multiple tractograms into a single output. - Supports ``trk``, ``tck``, ``vtk``, ``fib``, ``dpy``, and ``trx`` inputs. - Flags: ``--delete-dpv``, ``--delete-dps``, ``--delete-groups`` to drop mismatched metadata; ``--reference`` for formats needing an anatomy reference; ``-f`` to overwrite. .. code-block:: bash # Using unified CLI trx concatenate in1.trk in2.trk merged.trx # Using standalone command trx_concatenate_tractograms in1.trk in2.trk merged.trx trx_convert_dsi_studio ~~~~~~~~~~~~~~~~~~~~~~ Convert a DSI Studio ``.trk`` with accompanying ``.nii.gz`` reference into a cleaned ``.trk`` or TRX. .. code-block:: bash # Using unified CLI trx convert-dsi input.trk reference.nii.gz cleaned.trk # Using standalone command trx_convert_dsi_studio input.trk reference.nii.gz cleaned.trk trx_convert_tractogram ~~~~~~~~~~~~~~~~~~~~~~ General-purpose converter between ``trk``, ``tck``, ``vtk``, ``fib``, ``dpy``, and ``trx``. - Flags: ``--reference`` for formats needing a NIfTI, ``--positions-dtype``, ``--offsets-dtype``, ``-f`` to overwrite. .. code-block:: bash # Using unified CLI trx convert input.trk output.trx --positions-dtype float32 --offsets-dtype uint64 # Using standalone command trx_convert_tractogram input.trk output.trx --positions-dtype float32 --offsets-dtype uint64 trx_generate_from_scratch ~~~~~~~~~~~~~~~~~~~~~~~~~ Build a TRX file from raw NumPy arrays or CSV streamline coordinates. - Flags: ``--positions``, ``--offsets``, ``--positions-dtype``, ``--offsets-dtype``, spatial options (``--space``, ``--origin``), and metadata loaders for dpv/dps/groups/dpg. .. code-block:: bash # Using unified CLI trx generate fa.nii.gz output.trx --positions positions.npy --offsets offsets.npy # Using standalone command trx_generate_from_scratch fa.nii.gz output.trx --positions positions.npy --offsets offsets.npy trx_manipulate_datatype ~~~~~~~~~~~~~~~~~~~~~~~ Rewrite TRX datasets with new dtypes for positions/offsets/dpv/dps/dpg/groups. - Accepts per-field dtype arguments and overwrites with ``-f``. .. code-block:: bash # Using unified CLI trx manipulate-dtype input.trx output.trx --positions-dtype float16 --dpv color,uint8 # Using standalone command trx_manipulate_datatype input.trx output.trx --positions-dtype float16 --dpv color,uint8 trx_simple_compare ~~~~~~~~~~~~~~~~~~ Compare two tractograms for quick difference checks. .. code-block:: bash # Using unified CLI trx compare first.trk second.trk # Using standalone command trx_simple_compare first.trk second.trk trx_validate ~~~~~~~~~~~~ Validate a TRX file for consistency and remove invalid streamlines. .. code-block:: bash # Using unified CLI trx validate data.trx --out cleaned.trx # Using standalone command trx_validate data.trx --out cleaned.trx trx_verify_header_compatibility ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Check whether tractogram headers are compatible for operations such as concatenation. .. code-block:: bash # Using unified CLI trx verify-header file1.trk file2.trk # Using standalone command trx_verify_header_compatibility file1.trk file2.trk trx_visualize_overlap ~~~~~~~~~~~~~~~~~~~~~ Visualize streamline overlap between tractograms (requires visualization dependencies). .. code-block:: bash # Using unified CLI trx visualize tractogram.trk reference.nii.gz # Using standalone command trx_visualize_overlap tractogram.trk reference.nii.gz Troubleshooting --------------- If the ``trx`` command is not working as expected, run ``trx --debug`` to print diagnostic information about the Python interpreter, package location, and whether all required and optional dependencies are installed: .. code-block:: bash trx --debug # Example output: # Environment diagnostics: # Python executable : /Users/you/myenv/bin/python # sys.prefix : /Users/you/myenv # trx-python version: 0.3.1 # trx package : /Users/you/myenv/lib/python3.11/site-packages/trx # # Required dependencies: # deepdiff found # nibabel found # numpy found # typer found # # Optional dependencies: # dipy found # fury not found # vtk not found Notes ----- - Test datasets for examples can be fetched with ``python -m trx.fetcher`` helpers: ``fetch_data(get_testing_files_dict())`` downloads to ``$TRX_HOME`` (default ``~/.tee_ar_ex``). - All commands print detailed usage with ``--help``. - The unified ``trx`` CLI uses `Typer `_ for beautiful terminal output with colors and rich formatting. tee-ar-ex-trx-python-a304ac2/docs/source/trx_specifications.rst000066400000000000000000000153711515240773700247360ustar00rootroot00000000000000:html_theme.sidebar_secondary.remove: TRX File Format Specifications =============================== This document contains the complete specifications for the TRX (Tractography File Format) as defined by the TRX specification. TRX is a community-oriented tractography file format designed to facilitate dataset exchange, interoperability, and state-of-the-art analyses. General Properties ------------------ **File Structure** - (Un)-Compressed Zip File or simple folder architecture - File architecture describes the data - Each file basename is the metadata's name - Each file extension is the metadata's dtype - Each file dimension is in the value between basename and metadata (1-dimension arrays do not have to follow this convention for readability) **Data Organization** - All arrays have a C-style memory layout (row-major) - All arrays have a little-endian byte order - Compression is optional: - Use ``ZIP_STORE`` for uncompressed storage - Use ``ZIP_DEFLATE`` if compression is desired - Compressed TRX files will have to be decompressed before being loaded Header ------ The header contains metadata for readability, run-time checks, and broader compatibility. It is stored as a dictionary in JSON format with the following fields: **Required Fields:** .. code-block:: text VOXEL_TO_RASMM : 4x4 transformation matrix (list of 4 lists, each containing 4 floats) DIMENSIONS : Image dimensions (list of 3 uint16) NB_STREAMLINES : Number of streamlines (uint32) NB_VERTICES : Total number of vertices (uint64) Arrays ------ positions.{N}.float{16,32,64} ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - Written in world space (RASMM), similar to TCK files - Should always be float16/32/64 (default recommended: float16) - Stored as contiguous 3D array with shape (NB_VERTICES, 3) - The {N} dimension specifier can be omitted for 1D arrays for readability offsets.uint{32,64} ~~~~~~~~~~~~~~~~~~~ - Should always be uint32 or uint64 - Indicates the starting vertex index for each streamline (starts at 0) - Streamline lengths can be calculated by: 1. Checking the header for total vertices count 2. Using positions array size: ``positions.shape[0] / 3`` 3. Calculating differences between consecutive elements: append total_vertices to offsets array and compute ediff1d dpv (data_per_vertex) ~~~~~~~~~~~~~~~~~~~~~ - Always of size (NB_VERTICES, 1) or (NB_VERTICES, N) - Contains data associated with each vertex/point along streamlines - Common uses: FA values, colors, curvature, local coordinate systems dps (data_per_streamline) ~~~~~~~~~~~~~~~~~~~~~~~~~ - Always of size (NB_STREAMLINES, 1) or (NB_STREAMLINES, N) - Contains data associated with entire streamlines - Common uses: bundle IDs, mean metrics, algorithm information Groups ------ Groups are tables of indices that allow sparse & overlapping representation (clusters, connectomics, bundles). **Properties:** - All indices must be ``0 <= id < NB_STREAMLINES`` - Datatype should be uint32 - Allows efficient retrieval of predefined streamline subsets from memmaps - Variables can have different sizes dpg (data_per_group) ~~~~~~~~~~~~~~~~~~~~ - Each folder corresponds to name of a group - Not all metadata have to be present in all groups - Always of size (1,) or (N,) per group - Contains group-specific metadata like volumes, mean values, color codes Supported Data Types -------------------- The TRX format supports the following data types: **Integer Types:** - ``int8``, ``int16``, ``int32``, ``int64`` - ``uint8``, ``uint16``, ``uint32``, ``uint64`` **Floating Point Types:** - ``float16``, ``float32``, ``float64`` **Boolean Type:** - ``bit`` (for boolean data) Example File Structure ---------------------- .. code-block:: text OHBM_demo.trx |-- dpg | |-- AF_L | | |-- mean_fa.float16 | | |-- shuffle_colors.3.uint8 | | +-- volume.uint32 | |-- AF_R | | |-- mean_fa.float16 | | |-- shuffle_colors.3.uint8 | | +-- volume.uint32 | |-- CC | | |-- mean_fa.float16 | | |-- shuffle_colors.3.uint8 | | +-- volume.uint32 | |-- CST_L | | +-- shuffle_colors.3.uint8 | |-- CST_R | | +-- shuffle_colors.3.uint8 | |-- SLF_L | | |-- mean_fa.float16 | | |-- shuffle_colors.3.uint8 | | +-- volume.uint32 | +-- SLF_R | |-- mean_fa.float16 | |-- shuffle_colors.3.uint8 | +-- volume.uint32 |-- dpv | |-- color_x.uint8 | |-- color_y.uint8 | |-- color_z.uint8 | +-- fa.float16 |-- dps | |-- algo.uint8 | |-- algo.json | |-- clusters_QB.uint16 | |-- commit_colors.3.uint8 | +-- commit_weights.float32 |-- groups | |-- AF_L.uint32 | |-- AF_R.uint32 | |-- CC.uint32 | |-- CST_L.uint32 | |-- CST_R.uint32 | |-- SLF_L.uint32 | +-- SLF_R.uint32 |-- header.json |-- offsets.uint64 +-- positions.3.float16 Naming Conventions ------------------ **Files:** - Basename = metadata name - Extension = data type - Dimension specifiers between basename and extension (optional for 1D) **Examples:** - ``positions.3.float16`` - 3D position data as float16 - ``fa.float16`` - 1D fractional anisotropy values as float16 - ``colors.3.uint8`` - RGB color values as 8-bit unsigned integers - ``bundle_id.uint8`` - Bundle identifiers as 8-bit unsigned integers Memory and Performance Considerations ------------------------------------- **Memory Efficiency:** - Use float16 for positional data when precision allows - Choose appropriate integer sizes for indices (uint32 for streamline indices) - Consider compression for disk storage but expect decompression overhead **Performance:** - C-style memory layout enables efficient numpy operations - Little-endian byte order ensures consistency across platforms - Memory-mapped access for large datasets without full loading **Scalability:** - Support for arbitrarily large numbers of streamlines and vertices - Group-based organization enables efficient subset operations - Flexible metadata structure accommodates various analysis workflows Compatibility and Integration ----------------------------- TRX is designed for integration with existing neuroimaging ecosystems: **Current Support:** - Native support in trx-python library - Conversion tools for common tractography formats (TCK, TRK, etc.) - Integration with DIPY for advanced processing **Future Goals:** - Integration with Brain Imaging Data Structure (BIDS) ecosystem - Support in major neuroimaging software packages - Standardization across the tractography community For latest updates and community discussions, see: - `TRX Specification Repository `_ - `TRX Python Implementation `_ tee-ar-ex-trx-python-a304ac2/examples/000077500000000000000000000000001515240773700176635ustar00rootroot00000000000000tee-ar-ex-trx-python-a304ac2/examples/README.txt000066400000000000000000000012441515240773700213620ustar00rootroot00000000000000Tutorials ========= These tutorials demonstrate how to use the trx-python library for working with TRX tractography files. Getting Started --------------- New to TRX? These tutorials will guide you through: 1. **Reading and Writing TRX Files** - Load, inspect, and save TRX files 2. **Working with Groups** - Organize streamlines into anatomical bundles 3. **Data Per Vertex and Streamline** - Work with metadata Prerequisites ------------- To run these tutorials, you need: - trx-python installed (``pip install trx-python[all]``) - An internet connection (for downloading test data on first run) Each tutorial can be run as a Python script or in a Jupyter notebook. tee-ar-ex-trx-python-a304ac2/examples/plot_dps_dpv.py000066400000000000000000000146551515240773700227450ustar00rootroot00000000000000# -*- coding: utf-8 -*- # sphinx_gallery_thumbnail_path = '../docs/_static/trx_logo.png' """ Data Per Vertex and Data Per Streamline ======================================== This tutorial demonstrates how to work with metadata in TRX files. TRX supports two types of metadata: - **Data Per Vertex (dpv)**: Information attached to each point along streamlines - **Data Per Streamline (dps)**: Information attached to entire streamlines By the end of this tutorial, you will know how to: - Access dpv and dps data in a TRX file - Understand the data shapes and organization - Use metadata for filtering and analysis """ # %% # Understanding DPV and DPS # ------------------------- # # **Data Per Vertex (dpv):** # # - Attached to each individual point (vertex) in all streamlines # - Shape: (NB_VERTICES, 1) for scalar data or (NB_VERTICES, N) for vector data # - Common uses: FA values at each point, RGB colors, local orientations # # **Data Per Streamline (dps):** # # - Attached to entire streamlines (one value per streamline) # - Shape: (NB_STREAMLINES, 1) for scalar data or (NB_STREAMLINES, N) for vector data # - Common uses: bundle ID, mean FA, streamline length, tracking algorithm ID # %% # Loading a TRX file with metadata # -------------------------------- # # Let's load a TRX file and explore its metadata. import os import numpy as np from trx.fetcher import fetch_data, get_home, get_testing_files_dict from trx.trx_file_memmap import load # Download test data fetch_data(get_testing_files_dict(), keys="gold_standard.zip") trx_home = get_home() trx_path = os.path.join(trx_home, "gold_standard", "gs.trx") # Load the TRX file trx = load(trx_path) print(f"Loaded TRX with {len(trx)} streamlines") print(f"Total vertices: {trx.header['NB_VERTICES']}") # %% # Exploring Data Per Vertex (dpv) # ------------------------------- # # Let's see what dpv data is available. print("Data Per Vertex keys:", list(trx.data_per_vertex.keys())) # Examine each dpv field for key in trx.data_per_vertex: data = trx.data_per_vertex[key] print(f"\n {key}:") print(f" Shape: {data._data.shape}") print(f" Dtype: {data._data.dtype}") print(f" Sample values: {data._data[:3].flatten()}") # %% # Accessing dpv for a specific streamline # --------------------------------------- # # The dpv data is organized to match the streamlines. You can access # the dpv values for a specific streamline using the same indices. if len(trx.data_per_vertex) > 0: first_dpv_key = list(trx.data_per_vertex.keys())[0] dpv_data = trx.data_per_vertex[first_dpv_key] # Get dpv values for the first streamline first_streamline_dpv = dpv_data[0] print(f"DPV '{first_dpv_key}' for first streamline:") print(f" Shape: {first_streamline_dpv.shape}") print(f" Values: {first_streamline_dpv.flatten()}") # %% # Exploring Data Per Streamline (dps) # ----------------------------------- # # Now let's examine the dps data. print("Data Per Streamline keys:", list(trx.data_per_streamline.keys())) # Examine each dps field for key in trx.data_per_streamline: data = trx.data_per_streamline[key] print(f"\n {key}:") print(f" Shape: {data.shape}") print(f" Dtype: {data.dtype}") print(f" First 5 values: {data[:5].flatten()}") # %% # DPS for filtering streamlines # ----------------------------- # # A common use case is filtering streamlines based on dps values. # For example, selecting streamlines with high FA values. if len(trx.data_per_streamline) > 0: # Use the first dps key for demonstration first_dps_key = list(trx.data_per_streamline.keys())[0] dps_data = trx.data_per_streamline[first_dps_key] # Calculate some statistics print(f"\nStatistics for '{first_dps_key}':") print(f" Min: {np.min(dps_data):.4f}") print(f" Max: {np.max(dps_data):.4f}") print(f" Mean: {np.mean(dps_data):.4f}") print(f" Std: {np.std(dps_data):.4f}") # %% # File structure for dpv and dps # ------------------------------ # # In the TRX format, dpv and dps are stored in separate directories: # # .. code-block:: text # # my_tractogram.trx/ # |-- dpv/ # | |-- fa.float16 # FA values per vertex # | |-- colors.3.uint8 # RGB colors (3 values per vertex) # | +-- curvature.float32 # Curvature per vertex # |-- dps/ # | |-- bundle_id.uint8 # Bundle assignment per streamline # | |-- length.uint16 # Length per streamline # | +-- mean_fa.float32 # Mean FA per streamline # +-- ... # # The filename format is: ``name.dtype`` or ``name.dimension.dtype`` # %% # Working with multi-dimensional data # ----------------------------------- # # Both dpv and dps can have multiple dimensions. For example, RGB colors # have 3 values per vertex. print("\nDemonstrating multi-dimensional data:") # Check for any multi-dimensional dpv for key in trx.data_per_vertex: data = trx.data_per_vertex[key] if len(data._data.shape) > 1 and data._data.shape[1] > 1: print(f" {key}: {data._data.shape[1]}D data per vertex") # Check for any multi-dimensional dps for key in trx.data_per_streamline: data = trx.data_per_streamline[key] if len(data.shape) > 1 and data.shape[1] > 1: print(f" {key}: {data.shape[1]}D data per streamline") # %% # Relationship between dpv and streamlines # ---------------------------------------- # # It's important to understand how dpv data maps to individual streamlines. # Each streamline's dpv values can be accessed using the streamline's # vertex indices. # Get vertex counts for first few streamlines print("\nVertex distribution for first 5 streamlines:") for i in range(min(5, len(trx))): streamline = trx.streamlines[i] print(f" Streamline {i}: {len(streamline)} vertices") # Total vertices should match total_from_streamlines = sum(len(trx.streamlines[i]) for i in range(len(trx))) print(f"\nTotal vertices from streamlines: {total_from_streamlines}") print(f"Total vertices in header: {trx.header['NB_VERTICES']}") # %% # Summary # ------- # # In this tutorial, you learned how to: # # - Access dpv data using ``trx.data_per_vertex[key]`` # - Access dps data using ``trx.data_per_streamline[key]`` # - Understand the shape conventions for scalar and vector data # - Use metadata for statistical analysis # - Understand the file structure for dpv and dps # # The TRX format's metadata system is designed for flexibility, allowing # you to attach any kind of information to vertices or streamlines. tee-ar-ex-trx-python-a304ac2/examples/plot_groups.py000066400000000000000000000137021515240773700226150ustar00rootroot00000000000000# -*- coding: utf-8 -*- # sphinx_gallery_thumbnail_path = '../docs/_static/trx_logo.png' """ Working with Groups ==================== This tutorial demonstrates how to work with groups in TRX files. Groups allow you to organize streamlines into meaningful subsets, such as anatomical bundles or clusters. By the end of this tutorial, you will know how to: - Access groups in a TRX file - Extract streamlines belonging to a specific group - Understand the relationship between groups and data_per_group (dpg) - Work with overlapping groups """ # %% # What are Groups? # ---------------- # # Groups in TRX files are collections of streamline indices. They enable: # # - **Sparse representation**: Only store indices instead of copying data # - **Overlapping membership**: A streamline can belong to multiple groups # - **Efficient access**: Quickly extract predefined subsets of streamlines # # Common use cases include anatomical bundles (e.g., Arcuate Fasciculus, # Corpus Callosum), clustering results, or connectivity-based groupings. # %% # Loading a TRX file with groups # ------------------------------ # # Let's load a TRX file that contains group information. import os import numpy as np from trx.fetcher import fetch_data, get_home, get_testing_files_dict from trx.trx_file_memmap import load # Download test data fetch_data(get_testing_files_dict(), keys="gold_standard.zip") trx_home = get_home() trx_path = os.path.join(trx_home, "gold_standard", "gs.trx") # Load the TRX file trx = load(trx_path) print(f"Loaded TRX with {len(trx)} streamlines") # %% # Accessing groups # ---------------- # # Groups are stored as a dictionary where keys are group names and values # are numpy arrays of streamline indices. print(f"Available groups: {list(trx.groups.keys())}") # Check the number of groups print(f"Number of groups: {len(trx.groups)}") # %% # Let's examine the groups in more detail: for group_name, indices in trx.groups.items(): print(f" {group_name}: {len(indices)} streamlines") # %% # Extracting a group # ------------------ # # You can extract all streamlines belonging to a specific group using # the ``get_group()`` method. if len(trx.groups) > 0: # Get the first group name first_group = list(trx.groups.keys())[0] # Extract the group as a new TrxFile group_trx = trx.get_group(first_group) print(f"Extracted group '{first_group}' with {len(group_trx)} streamlines") # You can also access the raw indices group_indices = trx.groups[first_group] print(f"Raw indices (first 10): {group_indices[:10]}") else: print("No groups available in this file") # %% # Using group indices directly # ---------------------------- # # You can use group indices to select streamlines directly with the # ``select()`` method. if len(trx.groups) > 0: first_group = list(trx.groups.keys())[0] indices = trx.groups[first_group] # Select streamlines using indices selected = trx.select(indices[:5]) # Select first 5 from the group print(f"Selected {len(selected)} streamlines from group '{first_group}'") # %% # Data per group (dpg) # -------------------- # # Groups can have associated metadata stored in ``data_per_group`` (dpg). # This is useful for storing group-level statistics like mean FA, volume, # or color codes. print(f"Data per group keys: {list(trx.data_per_group.keys())}") # Check what metadata is available for each group for group_name in trx.data_per_group: dpg_keys = list(trx.data_per_group[group_name].keys()) print(f" {group_name}: {dpg_keys}") # %% # Creating groups manually # ------------------------ # # You can create groups by assigning indices to the groups dictionary. # Here's an example of how groups work conceptually. # Example: Create conceptual groups for 10 streamlines example_groups = { "bundle_A": np.array([0, 1, 2, 3], dtype=np.uint32), "bundle_B": np.array([4, 5, 6, 7, 8, 9], dtype=np.uint32), "overlapping": np.array([3, 4, 5], dtype=np.uint32), # Overlaps with A and B } print("Example groups:") for name, indices in example_groups.items(): print(f" {name}: streamlines {indices}") # Note: Streamline 3 is in both bundle_A and overlapping # Note: Streamlines 4, 5 are in both bundle_B and overlapping print("\nOverlapping groups are allowed in TRX!") # %% # Group file structure # -------------------- # # In the TRX file format, groups are stored as binary files in a ``groups/`` # directory: # # .. code-block:: text # # my_tractogram.trx/ # |-- groups/ # | |-- AF_L.uint32 # Arcuate Fasciculus Left # | |-- AF_R.uint32 # Arcuate Fasciculus Right # | |-- CC.uint32 # Corpus Callosum # | +-- CST_L.uint32 # Corticospinal Tract Left # +-- ... # # Each file contains a flat array of streamline indices as uint32 values. # %% # Filtering streamlines by group # ------------------------------ # # A common workflow is to filter streamlines based on group membership # and then analyze or visualize specific bundles. if len(trx.groups) > 0: # Get all group names group_names = list(trx.groups.keys()) # Report statistics for each group print("Group statistics:") for group_name in group_names: group_trx = trx.get_group(group_name) total_points = len(group_trx.streamlines._data) avg_length = total_points / len(group_trx) if len(group_trx) > 0 else 0 print(f" {group_name}:") print(f" - Streamlines: {len(group_trx)}") print(f" - Total points: {total_points}") print(f" - Avg points per streamline: {avg_length:.1f}") # %% # Summary # ------- # # In this tutorial, you learned how to: # # - Access groups using ``trx.groups`` # - Extract group streamlines using ``get_group()`` # - Work with ``data_per_group`` (dpg) metadata # - Understand that groups can overlap # - Filter and analyze streamlines by group membership # # Groups are a powerful feature of the TRX format that enable efficient # organization and retrieval of streamline subsets without data duplication. tee-ar-ex-trx-python-a304ac2/examples/plot_read_write_trx.py000066400000000000000000000105131515240773700243150ustar00rootroot00000000000000# -*- coding: utf-8 -*- # sphinx_gallery_thumbnail_path = '../docs/_static/trx_logo.png' """ Reading and Writing TRX Files ============================== This tutorial demonstrates how to read and write TRX files using trx-python. TRX is a tractography file format designed for efficient storage and access of brain fiber tract streamline data. By the end of this tutorial, you will know how to: - Load a TRX file from disk - Inspect the contents of a TRX file - Access streamlines and metadata - Save a TRX file to disk - Create a TRX file from scratch """ # %% # Loading a TRX file # ------------------ # # Let's start by loading an existing TRX file. First, we need to download # some test data. import os import tempfile from trx.fetcher import fetch_data, get_home, get_testing_files_dict from trx.trx_file_memmap import load, save # Download test data fetch_data(get_testing_files_dict(), keys="gold_standard.zip") trx_home = get_home() trx_path = os.path.join(trx_home, "gold_standard", "gs.trx") # Load the TRX file trx = load(trx_path) print("TRX file loaded successfully!") # %% # Inspecting TRX file contents # ---------------------------- # # The TrxFile object has several key attributes that you can inspect. # Let's look at what's inside our loaded file. # Print a summary of the TRX file print(trx) # %% # The header contains essential metadata about the tractogram: print("Header information:") print(f" Number of streamlines: {trx.header['NB_STREAMLINES']}") print(f" Number of vertices: {trx.header['NB_VERTICES']}") print(f" Image dimensions: {trx.header['DIMENSIONS']}") print(f" Voxel to RASMM affine:\n{trx.header['VOXEL_TO_RASMM']}") # %% # Accessing streamlines # --------------------- # # Streamlines are the core data in a TRX file. Each streamline is a sequence # of 3D points representing a fiber tract in the brain. print(f"Number of streamlines: {len(trx)}") print(f"Total number of vertices: {len(trx.streamlines._data)}") # Access the first streamline first_streamline = trx.streamlines[0] print(f"\nFirst streamline has {len(first_streamline)} points") print(f"First 3 points of the first streamline:\n{first_streamline[:3]}") # %% # Accessing metadata # ------------------ # # TRX files can contain additional data per vertex (dpv) and per streamline (dps). print("Data per vertex (dpv) keys:", list(trx.data_per_vertex.keys())) print("Data per streamline (dps) keys:", list(trx.data_per_streamline.keys())) print("Groups:", list(trx.groups.keys())) # %% # Selecting a subset of streamlines # --------------------------------- # # You can easily select a subset of streamlines using indices or slicing. # Select first 5 streamlines subset = trx[:5] print(f"Subset has {len(subset)} streamlines") # Select specific streamlines by indices (ensure indices are valid) max_idx = len(trx) - 1 indices = [0, min(2, max_idx), min(5, max_idx)] selected = trx.select(indices) print(f"Selected {len(selected)} streamlines") # %% # Saving a TRX file # ----------------- # # You can save a TRX file back to disk. The file can be saved as a compressed # or uncompressed zip archive, or as a directory. with tempfile.TemporaryDirectory() as tmpdir: # Save as TRX file (zip archive) output_path = os.path.join(tmpdir, "output.trx") save(trx, output_path) print(f"Saved TRX file to: {output_path}") print(f"File size: {os.path.getsize(output_path)} bytes") # Reload to verify reloaded = load(output_path) print(f"Reloaded TRX has {len(reloaded)} streamlines") # %% # Creating a TRX file from an existing one # ---------------------------------------- # # A common workflow is to create a new TRX file based on an existing one, # preserving the spatial reference information. # Create a deepcopy of the loaded TRX file trx_copy = trx.deepcopy() print(f"Created copy with {len(trx_copy)} streamlines") print(f"Header preserved: DIMENSIONS = {trx_copy.header['DIMENSIONS']}") # %% # Summary # ------- # # In this tutorial, you learned how to: # # - Load TRX files using ``load()`` # - Inspect header information and streamline data # - Access data per vertex (dpv) and data per streamline (dps) # - Select subsets of streamlines # - Save TRX files using ``save()`` # - Create copies of TRX files using ``deepcopy()`` # # The TRX format is designed for memory efficiency through memory-mapping, # making it suitable for large tractography datasets. tee-ar-ex-trx-python-a304ac2/pyproject.toml000066400000000000000000000072721515240773700207710ustar00rootroot00000000000000[build-system] requires = ["setuptools >= 64", "wheel", "setuptools_scm[toml] >= 7"] build-backend = "setuptools.build_meta" [project] name = "trx-python" dynamic = ["version"] description = "A community-oriented file format for tractography" readme = "README.md" license = {text = "BSD License"} requires-python = ">=3.11" authors = [ {name = "The TRX developers"} ] classifiers = [ "Development Status :: 3 - Alpha", "Environment :: Console", "Intended Audience :: Science/Research", "License :: OSI Approved :: BSD License", "Operating System :: OS Independent", "Programming Language :: Python", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", "Topic :: Scientific/Engineering", ] dependencies = [ "deepdiff", "nibabel >= 5", "numpy >= 1.22", "typer >= 0.9", ] [project.optional-dependencies] dev = [ "spin >= 0.13", "setuptools_scm", ] doc = [ "matplotlib", "numpydoc", "pydata-sphinx-theme >= 0.16.1", "sphinx >= 8.2.0", "sphinx-autoapi >= 3.4.0", "sphinx-design", "sphinx-gallery", ] style = [ "codespell", "pre-commit", "ruff", ] test = [ "psutil", "pytest >= 7", "pytest-console-scripts >= 0", "pytest-cov", ] all = [ "trx-python[dev]", "trx-python[doc]", "trx-python[style]", "trx-python[test]", ] [project.urls] Homepage = "https://github.com/tee-ar-ex/trx-python" Documentation = "https://tee-ar-ex.github.io/trx-python/" Repository = "https://github.com/tee-ar-ex/trx-python" [project.scripts] trx = "trx.cli:main" trx_concatenate_tractograms = "trx.cli:concatenate_tractograms_cmd" trx_convert_dsi_studio = "trx.cli:convert_dsi_cmd" trx_convert_tractogram = "trx.cli:convert_cmd" trx_generate_from_scratch = "trx.cli:generate_cmd" trx_manipulate_datatype = "trx.cli:manipulate_dtype_cmd" trx_simple_compare = "trx.cli:compare_cmd" trx_validate = "trx.cli:validate_cmd" trx_verify_header_compatibility = "trx.cli:verify_header_cmd" trx_visualize_overlap = "trx.cli:visualize_cmd" trx_info = "trx.cli:info_cmd" [tool.setuptools] packages = ["trx"] include-package-data = true [tool.setuptools_scm] write_to = "trx/_version.py" fallback_version = "0.0" local_scheme = "no-local-version" [tool.codespell] ignore-words-list = "astroid" [tool.spin] package = "trx" [tool.spin.commands] "Setup" = [".spin/cmds.py:setup"] "Build" = ["spin.cmds.pip.install"] "Test" = [".spin/cmds.py:test", ".spin/cmds.py:lint"] "Docs" = [".spin/cmds.py:docs"] "Clean" = [".spin/cmds.py:clean"] [tool.numpydoc_validation] checks = [ "all", # report on all checks, except the below # These we we ignore: "GL01", # Docstring should start in the line immediately after the quotes "GL02", # Closing quotes on own line (doesn't work on Python 3.13 anyway) "EX01", "EX02", # examples failed (we test them separately) "ES01", # no extended summary "SA01", # no see also "YD01", # no yields section "SA04", # no description in See Also "PR04", # Parameter "shape (n_channels" has no type "RT02", # The first line of the Returns section should ] # remember to use single quotes for regex in TOML exclude = [ # don't report on objects that match any of these regex '\.undocumented_method$', '\.__repr__$', '\.__str__$', '\.__len__$', '\.__getitem__$', '\.__deepcopy__$', ] exclude_files = [ # don't process filepaths that match these regex '^trx/tests/.*', '^module/gui.*', '^examples/.*', ] override_SS05 = [ # override SS05 to allow docstrings starting with these words '^Process ', '^Assess ', '^Access ', ] tee-ar-ex-trx-python-a304ac2/ruff.toml000066400000000000000000000015331515240773700177060ustar00rootroot00000000000000target-version = "py312" line-length = 88 force-exclude = true extend-exclude = [ "__pycache__", "build", "_version.py", "docs/**", ] [lint] select = [ "F", # Pyflakes "E", # pycodestyle errors "W", # pycodestyle warnings "C", # mccabe complexity "B", # flake8-bugbear "I", # isort ] ignore = [ "B905", # zip without explicit strict parameter "C901", # too complex "E203", # whitespace before ':' ] [lint.extend-per-file-ignores] "trx/tests/**" = ["B011"] [lint.isort] case-sensitive = true combine-as-imports = true force-sort-within-sections = true known-first-party = ["trx"] no-sections = false order-by-type = true relative-imports-order = "closest-to-furthest" section-order = ["future", "standard-library", "third-party", "first-party", "local-folder"] [format] quote-style = "double" tee-ar-ex-trx-python-a304ac2/tools/000077500000000000000000000000001515240773700172055ustar00rootroot00000000000000tee-ar-ex-trx-python-a304ac2/tools/update_switcher.py000066400000000000000000000104151515240773700227520ustar00rootroot00000000000000#!/usr/bin/env python # -*- coding: utf-8 -*- """Update switcher.json for documentation version switching. This script maintains the version switcher JSON file used by pydata-sphinx-theme to enable users to switch between different documentation versions. """ import argparse import json from pathlib import Path import sys BASE_URL = "https://tee-ar-ex.github.io/trx-python" def load_switcher(path): """Load existing switcher.json or return empty list. Parameters ---------- path : str or Path Path to the switcher.json file. Returns ------- list List of version entries from the switcher file. """ try: with open(path, "r") as f: return json.load(f) except (FileNotFoundError, json.JSONDecodeError): return [] def save_switcher(path, versions): """Save switcher.json with proper formatting. Parameters ---------- path : str or Path Path to the switcher.json file. versions : list List of version entries to write. """ with open(path, "w") as f: json.dump(versions, f, indent=4) f.write("\n") def ensure_dev_entry(versions): """Ensure dev entry exists in versions list. Parameters ---------- versions : list List of version entries. Returns ------- list Updated versions list with dev entry. """ dev_exists = any(v.get("version") == "dev" for v in versions) if not dev_exists: versions.insert(0, {"name": "dev", "version": "dev", "url": f"{BASE_URL}/dev/"}) return versions def ensure_stable_entry(versions): """Ensure stable entry exists with preferred flag. Parameters ---------- versions : list List of version entries. Returns ------- list Updated versions list with stable entry. """ stable_idx = next( (i for i, v in enumerate(versions) if v.get("version") == "stable"), None ) if stable_idx is not None: versions[stable_idx]["preferred"] = True else: versions.append( { "name": "stable", "version": "stable", "url": f"{BASE_URL}/stable/", "preferred": True, } ) return versions def add_version(versions, version): """Add a new version entry to the versions list. Parameters ---------- versions : list List of version entries. version : str Version string to add (e.g., "0.5.0"). Returns ------- list Updated versions list. """ # Remove 'preferred' from all existing entries for v in versions: v.pop("preferred", None) # Check if this version already exists version_exists = any(v.get("version") == version for v in versions) if not version_exists: new_entry = { "name": version, "version": version, "url": f"{BASE_URL}/{version}/", } # Find dev entry index to insert after it dev_idx = next( (i for i, v in enumerate(versions) if v.get("version") == "dev"), -1 ) if dev_idx >= 0: versions.insert(dev_idx + 1, new_entry) else: versions.insert(0, new_entry) return versions def main(): """Run the switcher update workflow. Returns ------- int Exit code (0 for success). """ parser = argparse.ArgumentParser( description="Update switcher.json for documentation version switching" ) parser.add_argument("switcher_path", type=Path, help="Path to switcher.json file") parser.add_argument("--version", type=str, help="New version to add (e.g., 0.5.0)") args = parser.parse_args() # Load existing versions versions = load_switcher(args.switcher_path) # Add new version if specified if args.version: versions = add_version(versions, args.version) # Ensure required entries exist versions = ensure_dev_entry(versions) versions = ensure_stable_entry(versions) # Save updated switcher.json save_switcher(args.switcher_path, versions) # Print result for CI logs print(f"Updated {args.switcher_path}:") print(json.dumps(versions, indent=4)) return 0 if __name__ == "__main__": sys.exit(main()) tee-ar-ex-trx-python-a304ac2/trx/000077500000000000000000000000001515240773700166625ustar00rootroot00000000000000tee-ar-ex-trx-python-a304ac2/trx/__init__.py000066400000000000000000000005141515240773700207730ustar00rootroot00000000000000"""TRX file format for brain tractography data.""" try: from ._version import __version__ # noqa: F401 except ImportError: try: from importlib.metadata import PackageNotFoundError, version __version__ = version("trx-python") except (ImportError, PackageNotFoundError): __version__ = "unknown" tee-ar-ex-trx-python-a304ac2/trx/cli.py000066400000000000000000001004031515240773700200010ustar00rootroot00000000000000# -*- coding: utf-8 -*- """ TRX Command Line Interface. This module provides a unified CLI for all TRX file format operations using Typer. """ from pathlib import Path from typing import Annotated, List, Optional import numpy as np import typer from trx.io import load, save from trx.trx_file_memmap import TrxFile, concatenate, load as load_trx from trx.workflows import ( convert_dsi_studio, convert_tractogram, generate_trx_from_scratch, manipulate_trx_datatype, tractogram_simple_compare, tractogram_visualize_overlap, validate_tractogram, verify_header_compatibility, ) def _debug_callback(value: bool) -> None: """Print environment and dependency diagnostics, then exit. Parameters ---------- value : bool Whether the ``--debug`` flag was passed. """ if not value: return import importlib.metadata import importlib.util import sys from trx import __version__ typer.echo("Environment diagnostics:") typer.echo(f" Python executable : {sys.executable}") typer.echo(f" sys.prefix : {sys.prefix}") typer.echo(f" trx-python version: {__version__}") trx_spec = importlib.util.find_spec("trx") trx_location = ( trx_spec.submodule_search_locations[0] if trx_spec and trx_spec.submodule_search_locations else "unknown" ) typer.echo(f" trx package : {trx_location}") # Read required dependencies from package metadata required_deps = [] try: import re for req in importlib.metadata.requires("trx-python") or []: # Skip optional / extra deps (they contain "; extra ==") if "extra ==" in req: continue # Extract the package name (strip version specifiers like >=, <=, ~=) dep_name = re.split(r"[>= None: """TRX File Format Tools - CLI for brain tractography data manipulation. Parameters ---------- _debug : bool, optional If True, print environment and dependency diagnostics and exit. """ app = typer.Typer( name="trx", help="TRX File Format Tools - CLI for brain tractography data manipulation.", add_completion=False, rich_markup_mode="rich", callback=_main_callback, ) def _check_overwrite(filepath: Path, overwrite: bool) -> None: """Check if file exists and raise error if overwrite is not enabled. Parameters ---------- filepath : Path Path to the output file. overwrite : bool If True, allow overwriting existing files. Raises ------ typer.Exit If file exists and overwrite is False. """ if filepath.is_file() and not overwrite: typer.echo( typer.style( f"Error: {filepath} already exists. Use --force to overwrite.", fg=typer.colors.RED, ), err=True, ) raise typer.Exit(code=1) @app.command("concatenate") def concatenate_tractograms( in_tractograms: Annotated[ List[Path], typer.Argument( help="Input tractogram files. Format: trk, tck, vtk, fib, dpy, trx.", ), ], out_tractogram: Annotated[ Path, typer.Argument(help="Output filename for the concatenated tractogram."), ], delete_dpv: Annotated[ bool, typer.Option( "--delete-dpv", help="Delete data_per_vertex if not all inputs have the same metadata.", ), ] = False, delete_dps: Annotated[ bool, typer.Option( "--delete-dps", help="Delete data_per_streamline if not all inputs have the same metadata.", ), ] = False, delete_groups: Annotated[ bool, typer.Option( "--delete-groups", help="Delete groups if not all inputs have the same metadata.", ), ] = False, reference: Annotated[ Optional[Path], typer.Option( "--reference", "-r", help="Reference anatomy for tck/vtk/fib/dpy files (.nii or .nii.gz).", ), ] = None, force: Annotated[ bool, typer.Option("--force", "-f", help="Force overwriting of output files."), ] = False, ) -> None: """Concatenate multiple tractograms into one. If the data_per_point or data_per_streamline is not the same for all tractograms, the data must be deleted first using the appropriate flags. Parameters ---------- in_tractograms : list of Path Input tractogram files (.trk, .tck, .vtk, .fib, .dpy, .trx). out_tractogram : Path Output filename for the concatenated tractogram. delete_dpv : bool, optional Delete ``data_per_vertex`` if metadata differ across inputs. delete_dps : bool, optional Delete ``data_per_streamline`` if metadata differ across inputs. delete_groups : bool, optional Delete groups when metadata differ across inputs. reference : Path or None, optional Reference anatomy for tck/vtk/fib/dpy inputs. force : bool, optional Overwrite output if it already exists. Returns ------- None Writes the concatenated tractogram to ``out_tractogram``. """ _check_overwrite(out_tractogram, force) ref = str(reference) if reference else None trx_list = [] has_group = False for filename in in_tractograms: tractogram_obj = load(str(filename), ref) if not isinstance(tractogram_obj, TrxFile): tractogram_obj = TrxFile.from_sft(tractogram_obj) elif len(tractogram_obj.groups): has_group = True trx_list.append(tractogram_obj) trx = concatenate( trx_list, delete_dpv=delete_dpv, delete_dps=delete_dps, delete_groups=delete_groups or not has_group, check_space_attributes=True, preallocation=False, ) save(trx, str(out_tractogram)) typer.echo( typer.style( f"Successfully concatenated {len(in_tractograms)} tractograms " f"to {out_tractogram}", fg=typer.colors.GREEN, ) ) @app.command("convert") def convert( in_tractogram: Annotated[ Path, typer.Argument(help="Input tractogram. Format: trk, tck, vtk, fib, dpy, trx."), ], out_tractogram: Annotated[ Path, typer.Argument(help="Output tractogram. Format: trk, tck, vtk, fib, dpy, trx."), ], reference: Annotated[ Optional[Path], typer.Option( "--reference", "-r", help="Reference anatomy for tck/vtk/fib/dpy files (.nii or .nii.gz).", ), ] = None, positions_dtype: Annotated[ str, typer.Option( "--positions-dtype", help="Datatype for positions in TRX output.", ), ] = "float32", offsets_dtype: Annotated[ str, typer.Option( "--offsets-dtype", help="Datatype for offsets in TRX output.", ), ] = "uint64", force: Annotated[ bool, typer.Option("--force", "-f", help="Force overwriting of output files."), ] = False, ) -> None: """Convert tractograms between formats. Supports conversion of .tck, .trk, .fib, .vtk, .trx and .dpy files. TCK files always need a reference NIFTI file for conversion. Parameters ---------- in_tractogram : Path Input tractogram file. out_tractogram : Path Output tractogram path. reference : Path or None, optional Reference anatomy required for some input formats. positions_dtype : str, optional Datatype for positions in TRX output. offsets_dtype : str, optional Datatype for offsets in TRX output. force : bool, optional Overwrite output if it already exists. Returns ------- None Writes the converted tractogram to disk. """ _check_overwrite(out_tractogram, force) ref = str(reference) if reference else None convert_tractogram( str(in_tractogram), str(out_tractogram), ref, pos_dtype=positions_dtype, offsets_dtype=offsets_dtype, ) typer.echo( typer.style( f"Successfully converted {in_tractogram} to {out_tractogram}", fg=typer.colors.GREEN, ) ) @app.command("convert-dsi") def convert_dsi( in_dsi_tractogram: Annotated[ Path, typer.Argument(help="Input tractogram from DSI Studio (.trk)."), ], in_dsi_fa: Annotated[ Path, typer.Argument(help="Input FA from DSI Studio (.nii.gz)."), ], out_tractogram: Annotated[ Path, typer.Argument(help="Output tractogram file."), ], remove_invalid: Annotated[ bool, typer.Option( "--remove-invalid", help="Remove streamlines landing out of the bounding box.", ), ] = False, keep_invalid: Annotated[ bool, typer.Option( "--keep-invalid", help="Keep streamlines landing out of the bounding box.", ), ] = False, force: Annotated[ bool, typer.Option("--force", "-f", help="Force overwriting of output files."), ] = False, ) -> None: """Convert a DSI-Studio TRK file to TRX or TRK and fix space metadata. Parameters ---------- in_dsi_tractogram : Path Input DSI-Studio tractogram (.trk or .trk.gz). in_dsi_fa : Path FA volume used as reference (.nii.gz). out_tractogram : Path Output tractogram path (.trx or .trk). remove_invalid : bool, optional Remove streamlines outside the bounding box. Defaults to False. keep_invalid : bool, optional Keep streamlines outside the bounding box. Defaults to False. force : bool, optional Overwrite output if it already exists. Returns ------- None Writes the converted tractogram to disk. """ _check_overwrite(out_tractogram, force) if remove_invalid and keep_invalid: typer.echo( typer.style( "Error: Cannot use both --remove-invalid and --keep-invalid.", fg=typer.colors.RED, ), err=True, ) raise typer.Exit(code=1) convert_dsi_studio( str(in_dsi_tractogram), str(in_dsi_fa), str(out_tractogram), remove_invalid=remove_invalid, keep_invalid=keep_invalid, ) typer.echo( typer.style( f"Successfully converted DSI-Studio tractogram to {out_tractogram}", fg=typer.colors.GREEN, ) ) @app.command("generate") def generate( reference: Annotated[ Path, typer.Argument(help="Reference anatomy (.nii or .nii.gz)."), ], out_tractogram: Annotated[ Path, typer.Argument(help="Output tractogram. Format: trk, tck, vtk, fib, dpy, trx."), ], positions: Annotated[ Optional[Path], typer.Option( "--positions", help="Binary file with streamline coordinates (Nx3 .npy).", ), ] = None, offsets: Annotated[ Optional[Path], typer.Option( "--offsets", help="Binary file with streamline offsets (.npy).", ), ] = None, positions_csv: Annotated[ Optional[Path], typer.Option( "--positions-csv", help="CSV file with streamline coordinates (x1,y1,z1,x2,y2,z2,...).", ), ] = None, space: Annotated[ str, typer.Option( "--space", help="Coordinate space. Non-default requires Dipy.", ), ] = "RASMM", origin: Annotated[ str, typer.Option( "--origin", help="Coordinate origin. Non-default requires Dipy.", ), ] = "NIFTI", positions_dtype: Annotated[ str, typer.Option("--positions-dtype", help="Datatype for positions."), ] = "float32", offsets_dtype: Annotated[ str, typer.Option("--offsets-dtype", help="Datatype for offsets."), ] = "uint64", dpv: Annotated[ Optional[List[str]], typer.Option( "--dpv", help="Data per vertex: FILE,DTYPE (e.g., color.npy,uint8).", ), ] = None, dps: Annotated[ Optional[List[str]], typer.Option( "--dps", help="Data per streamline: FILE,DTYPE (e.g., algo.npy,uint8).", ), ] = None, groups: Annotated[ Optional[List[str]], typer.Option( "--groups", help="Groups: FILE,DTYPE (e.g., AF_L.npy,int32).", ), ] = None, dpg: Annotated[ Optional[List[str]], typer.Option( "--dpg", help="Data per group: GROUP,FILE,DTYPE (e.g., AF_L,mean_fa.npy,float32).", ), ] = None, verify_invalid: Annotated[ bool, typer.Option( "--verify-invalid", help="Verify positions are valid (within bounding box). Requires Dipy.", ), ] = False, force: Annotated[ bool, typer.Option("--force", "-f", help="Force overwriting of output files."), ] = False, ) -> None: """Generate a TRX file from raw data files. Create a TRX file from CSV, TXT, or NPY files by specifying positions, offsets, data_per_vertex, data_per_streamlines, groups, and data_per_group. Parameters ---------- reference : Path Reference anatomy (.nii or .nii.gz). out_tractogram : Path Output tractogram (.trk, .tck, .vtk, .fib, .dpy, .trx). positions : Path or None, optional Binary file with streamline coordinates (Nx3 .npy). offsets : Path or None, optional Binary file with streamline offsets (.npy). positions_csv : Path or None, optional CSV file with flattened streamline coordinates. space : str, optional Coordinate space. Non-default requires Dipy. origin : str, optional Coordinate origin. Non-default requires Dipy. positions_dtype : str, optional Datatype for positions. offsets_dtype : str, optional Datatype for offsets. dpv : list of str or None, optional Data per vertex entries as FILE,DTYPE pairs. dps : list of str or None, optional Data per streamline entries as FILE,DTYPE pairs. groups : list of str or None, optional Group entries as FILE,DTYPE pairs. dpg : list of str or None, optional Data per group entries as GROUP,FILE,DTYPE triplets. verify_invalid : bool, optional Verify positions are inside bounding box (requires Dipy). force : bool, optional Overwrite output if it already exists. Returns ------- None Writes the generated tractogram to disk. """ _check_overwrite(out_tractogram, force) # Validate input combinations if not positions and not positions_csv: typer.echo( typer.style( "Error: At least one positions option must be provided " "(--positions or --positions-csv).", fg=typer.colors.RED, ), err=True, ) raise typer.Exit(code=1) if positions_csv and positions: typer.echo( typer.style( "Error: Cannot use both --positions and --positions-csv.", fg=typer.colors.RED, ), err=True, ) raise typer.Exit(code=1) if positions and offsets is None: typer.echo( typer.style( "Error: --offsets must be provided if --positions is used.", fg=typer.colors.RED, ), err=True, ) raise typer.Exit(code=1) if offsets and positions is None: typer.echo( typer.style( "Error: --positions must be provided if --offsets is used.", fg=typer.colors.RED, ), err=True, ) raise typer.Exit(code=1) # Parse comma-separated arguments to tuples dpv_list = None if dpv: dpv_list = [tuple(item.split(",")) for item in dpv] dps_list = None if dps: dps_list = [tuple(item.split(",")) for item in dps] groups_list = None if groups: groups_list = [tuple(item.split(",")) for item in groups] dpg_list = None if dpg: dpg_list = [tuple(item.split(",")) for item in dpg] generate_trx_from_scratch( str(reference), str(out_tractogram), positions_csv=str(positions_csv) if positions_csv else None, positions=str(positions) if positions else None, offsets=str(offsets) if offsets else None, positions_dtype=positions_dtype, offsets_dtype=offsets_dtype, space_str=space, origin_str=origin, verify_invalid=verify_invalid, dpv=dpv_list, dps=dps_list, groups=groups_list, dpg=dpg_list, ) typer.echo( typer.style( f"Successfully generated {out_tractogram}", fg=typer.colors.GREEN, ) ) @app.command("manipulate-dtype") def manipulate_dtype( in_tractogram: Annotated[ Path, typer.Argument(help="Input TRX file."), ], out_tractogram: Annotated[ Path, typer.Argument(help="Output tractogram file."), ], positions_dtype: Annotated[ Optional[str], typer.Option( "--positions-dtype", help="Datatype for positions (float16, float32, float64).", ), ] = None, offsets_dtype: Annotated[ Optional[str], typer.Option( "--offsets-dtype", help="Datatype for offsets (uint32, uint64).", ), ] = None, dpv: Annotated[ Optional[List[str]], typer.Option( "--dpv", help="Data per vertex dtype: NAME,DTYPE (e.g., color_x,uint8).", ), ] = None, dps: Annotated[ Optional[List[str]], typer.Option( "--dps", help="Data per streamline dtype: NAME,DTYPE (e.g., algo,uint8).", ), ] = None, groups: Annotated[ Optional[List[str]], typer.Option( "--groups", help="Groups dtype: NAME,DTYPE (e.g., CC,uint64).", ), ] = None, dpg: Annotated[ Optional[List[str]], typer.Option( "--dpg", help="Data per group dtype: GROUP,NAME,DTYPE (e.g., CC,mean_fa,float64).", ), ] = None, force: Annotated[ bool, typer.Option("--force", "-f", help="Force overwriting of output files."), ] = False, ) -> None: # noqa: C901 """Manipulate TRX file internal array data types. Change the data types of positions, offsets, data_per_vertex, data_per_streamline, groups, and data_per_group arrays. Parameters ---------- in_tractogram : Path Input TRX file. out_tractogram : Path Output TRX file. positions_dtype : str or None, optional Target dtype for positions (float16, float32, float64). offsets_dtype : str or None, optional Target dtype for offsets (uint32, uint64). dpv : list of str or None, optional Data per vertex dtype overrides as NAME,DTYPE pairs. dps : list of str or None, optional Data per streamline dtype overrides as NAME,DTYPE pairs. groups : list of str or None, optional Group dtype overrides as NAME,DTYPE pairs. dpg : list of str or None, optional Data per group dtype overrides as GROUP,NAME,DTYPE triplets. force : bool, optional Overwrite output if it already exists. Returns ------- None Writes the dtype-converted TRX file. """ _check_overwrite(out_tractogram, force) dtype_dict = {} if positions_dtype: dtype_dict["positions"] = np.dtype(positions_dtype) if offsets_dtype: dtype_dict["offsets"] = np.dtype(offsets_dtype) if dpv: dtype_dict["dpv"] = {} for item in dpv: name, dtype = item.split(",") dtype_dict["dpv"][name] = np.dtype(dtype) if dps: dtype_dict["dps"] = {} for item in dps: name, dtype = item.split(",") dtype_dict["dps"][name] = np.dtype(dtype) if groups: dtype_dict["groups"] = {} for item in groups: name, dtype = item.split(",") dtype_dict["groups"][name] = np.dtype(dtype) if dpg: dtype_dict["dpg"] = {} for item in dpg: parts = item.split(",") group, name, dtype = parts[0], parts[1], parts[2] if group not in dtype_dict["dpg"]: dtype_dict["dpg"][group] = {} dtype_dict["dpg"][group][name] = np.dtype(dtype) manipulate_trx_datatype(str(in_tractogram), str(out_tractogram), dtype_dict) typer.echo( typer.style( f"Successfully manipulated datatypes and saved to {out_tractogram}", fg=typer.colors.GREEN, ) ) @app.command("compare") def compare( in_tractogram1: Annotated[ Path, typer.Argument(help="First tractogram file."), ], in_tractogram2: Annotated[ Path, typer.Argument(help="Second tractogram file."), ], reference: Annotated[ Optional[Path], typer.Option( "--reference", "-r", help="Reference anatomy for tck/vtk/fib/dpy files (.nii or .nii.gz).", ), ] = None, ) -> None: """Compare two tractograms and report basic differences. Parameters ---------- in_tractogram1 : Path First tractogram file. in_tractogram2 : Path Second tractogram file. reference : Path or None, optional Reference anatomy for formats requiring it. Returns ------- None Prints comparison summary to stdout. """ ref = str(reference) if reference else None tractogram_simple_compare([str(in_tractogram1), str(in_tractogram2)], ref) @app.command("validate") def validate( in_tractogram: Annotated[ Path, typer.Argument(help="Input tractogram. Format: trk, tck, vtk, fib, dpy, trx."), ], out_tractogram: Annotated[ Optional[Path], typer.Option( "--out", "-o", help="Output tractogram after removing invalid streamlines.", ), ] = None, remove_identical: Annotated[ bool, typer.Option( "--remove-identical", help="Remove identical streamlines from the set.", ), ] = False, precision: Annotated[ int, typer.Option( "--precision", "-p", help="Number of decimals when hashing streamline points.", ), ] = 1, reference: Annotated[ Optional[Path], typer.Option( "--reference", "-r", help="Reference anatomy for tck/vtk/fib/dpy files (.nii or .nii.gz).", ), ] = None, force: Annotated[ bool, typer.Option("--force", "-f", help="Force overwriting of output files."), ] = False, ) -> None: """Validate a tractogram and optionally clean invalid/duplicate streamlines. Parameters ---------- in_tractogram : Path Input tractogram (.trk, .tck, .vtk, .fib, .dpy, .trx). out_tractogram : Path or None, optional Optional output tractogram with invalid streamlines removed. remove_identical : bool, optional Remove duplicate streamlines based on hashing precision. precision : int, optional Number of decimals when hashing streamline points. reference : Path or None, optional Reference anatomy for formats requiring it. force : bool, optional Overwrite output if it already exists. Returns ------- None Prints validation summary and optionally writes cleaned output. """ if out_tractogram: _check_overwrite(out_tractogram, force) ref = str(reference) if reference else None out = str(out_tractogram) if out_tractogram else None validate_tractogram( str(in_tractogram), reference=ref, out_tractogram=out, remove_identical_streamlines=remove_identical, precision=precision, ) if out_tractogram: typer.echo( typer.style( f"Validation complete. Output saved to {out_tractogram}", fg=typer.colors.GREEN, ) ) else: typer.echo( typer.style( "Validation complete.", fg=typer.colors.GREEN, ) ) @app.command("verify-header") def verify_header( in_files: Annotated[ List[Path], typer.Argument(help="Files to compare (trk, trx, and nii)."), ], ) -> None: """Compare spatial attributes of input files. Parameters ---------- in_files : list of Path Files to compare (.trk, .trx, .nii, .nii.gz). Returns ------- None Prints compatibility results to stdout. """ verify_header_compatibility([str(f) for f in in_files]) @app.command("visualize") def visualize( in_tractogram: Annotated[ Path, typer.Argument(help="Input tractogram. Format: trk, tck, vtk, fib, dpy, trx."), ], reference: Annotated[ Path, typer.Argument(help="Reference anatomy (.nii or .nii.gz)."), ], remove_invalid: Annotated[ bool, typer.Option( "--remove-invalid", help="Remove invalid streamlines to avoid density_map crash.", ), ] = False, ) -> None: """Display tractogram and density map with bounding box. Parameters ---------- in_tractogram : Path Input tractogram (.trk, .tck, .vtk, .fib, .dpy, .trx). reference : Path Reference anatomy (.nii or .nii.gz). remove_invalid : bool, optional Remove invalid streamlines to avoid density map crashes. Returns ------- None Opens visualization windows when fury is available. """ tractogram_visualize_overlap( str(in_tractogram), str(reference), remove_invalid, ) def _format_size(size_bytes: int) -> str: """Format byte size to human readable string. Parameters ---------- size_bytes : int Size in bytes. Returns ------- str Human readable size string (e.g., "1.5 MB"). """ for unit in ["B", "KB", "MB", "GB"]: if size_bytes < 1024: return f"{size_bytes:.1f} {unit}" if unit != "B" else f"{size_bytes} {unit}" size_bytes /= 1024 return f"{size_bytes:.1f} TB" @app.command("info") def info( in_tractogram: Annotated[ Path, typer.Argument(help="Input TRX file."), ], ) -> None: """Display detailed information about a TRX file. Shows file size, compression status, header metadata (affine, dimensions, voxel sizes), streamline/vertex counts, data keys (dpv, dps, dpg), groups, and archive contents listing similar to ``unzip -l``. Parameters ---------- in_tractogram : Path Input TRX file (.trx extension required). Returns ------- None Prints TRX file information to stdout. Examples -------- $ trx info tractogram.trx $ trx_info tractogram.trx """ import zipfile if not in_tractogram.exists(): typer.echo( typer.style(f"Error: {in_tractogram} does not exist.", fg=typer.colors.RED), err=True, ) raise typer.Exit(code=1) if in_tractogram.suffix.lower() != ".trx": typer.echo( typer.style( f"Error: {in_tractogram.name} is not a TRX file. " "Only .trx files are supported.", fg=typer.colors.RED, ), err=True, ) raise typer.Exit(code=1) # Show archive info file_size = in_tractogram.stat().st_size typer.echo(f"File: {in_tractogram}") typer.echo(f"Size: {_format_size(file_size)}") with zipfile.ZipFile(str(in_tractogram), "r") as zf: total_uncompressed = sum(info.file_size for info in zf.infolist()) is_compressed = any(info.compress_type != 0 for info in zf.infolist()) typer.echo(f"Entries: {len(zf.infolist())}") typer.echo(f"Compressed: {'Yes' if is_compressed else 'No'}") typer.echo(f"Uncompressed size: {_format_size(total_uncompressed)}") typer.echo("") # Show TRX content info trx = load_trx(str(in_tractogram)) typer.echo(trx) # Show file listing (unzip -l style) typer.echo("\nArchive contents:") typer.echo(" Length Date Time Name") typer.echo("--------- ---------- ----- ----") with zipfile.ZipFile(str(in_tractogram), "r") as zf: for zinfo in zf.infolist(): dt = zinfo.date_time date_str = f"{dt[1]:02d}-{dt[2]:02d}-{dt[0]}" time_str = f"{dt[3]:02d}:{dt[4]:02d}" typer.echo( f"{zinfo.file_size:>9} {date_str} {time_str} {zinfo.filename}" ) num_files = len(zf.infolist()) typer.echo("--------- -------") typer.echo(f"{total_uncompressed:>9} {num_files} files") trx.close() def main(): """Entry point for the TRX CLI.""" app() # Standalone entry points for backward compatibility # These create individual Typer apps for each command def _create_standalone_app(command_func, name: str, help_text: str): """Create a standalone Typer app for a single command. Parameters ---------- command_func : callable The command function to wrap. name : str Name of the command. help_text : str Help text for the command. Returns ------- callable Entry point function. """ standalone = typer.Typer( name=name, help=help_text, add_completion=False, rich_markup_mode="rich", ) standalone.command()(command_func) return lambda: standalone() concatenate_tractograms_cmd = _create_standalone_app( concatenate_tractograms, "trx_concatenate_tractograms", "Concatenate multiple tractograms into one.", ) convert_dsi_cmd = _create_standalone_app( convert_dsi, "trx_convert_dsi_studio", "Fix DSI-Studio TRK files for compatibility.", ) convert_cmd = _create_standalone_app( convert, "trx_convert_tractogram", "Convert tractograms between formats.", ) generate_cmd = _create_standalone_app( generate, "trx_generate_from_scratch", "Generate TRX file from raw data files.", ) manipulate_dtype_cmd = _create_standalone_app( manipulate_dtype, "trx_manipulate_datatype", "Manipulate TRX file internal array data types.", ) compare_cmd = _create_standalone_app( compare, "trx_simple_compare", "Simple comparison of tractograms by subtracting coordinates.", ) validate_cmd = _create_standalone_app( validate, "trx_validate", "Validate TRX file and remove invalid streamlines.", ) verify_header_cmd = _create_standalone_app( verify_header, "trx_verify_header_compatibility", "Compare spatial attributes of input files.", ) visualize_cmd = _create_standalone_app( visualize, "trx_visualize_overlap", "Display tractogram and density map with bounding box.", ) info_cmd = _create_standalone_app( info, "trx_info", "Display information about a TRX file.", ) if __name__ == "__main__": main() tee-ar-ex-trx-python-a304ac2/trx/fetcher.py000066400000000000000000000121511515240773700206540ustar00rootroot00000000000000# -*- coding: utf-8 -*- """Test data management for downloading and verifying test assets.""" import hashlib import logging import os import shutil import urllib.request TEST_DATA_REPO = "tee-ar-ex/trx-test-data" TEST_DATA_TAG = "v0.1.0" # GitHub release API entrypoint for metadata (asset list, sizes, etc.). TEST_DATA_API_URL = ( f"https://api.github.com/repos/{TEST_DATA_REPO}/releases/tags/{TEST_DATA_TAG}" ) # Direct download base for release assets. TEST_DATA_BASE_URL = ( f"https://github.com/{TEST_DATA_REPO}/releases/download/{TEST_DATA_TAG}" ) def get_home(): """Return a user-writeable file-system location to put files. Returns ------- str Path to the TRX home directory. """ if "TRX_HOME" in os.environ: trx_home = os.environ["TRX_HOME"] else: trx_home = os.path.join(os.path.expanduser("~"), ".tee_ar_ex") return trx_home def get_testing_files_dict(): """Return dictionary linking zip file to their GitHub release URL and checksums. Assets are hosted under the v0.1.0 release of tee-ar-ex/trx-test-data. If URLs change, check TEST_DATA_API_URL to discover the latest asset locations. Returns ------- dict Mapping of filenames to (url, md5, sha256) tuples. """ return { "DSI.zip": ( f"{TEST_DATA_BASE_URL}/DSI.zip", "b847f053fc694d55d935c0be0e5268f7", # md5 "1b09ce8b4b47b2600336c558fdba7051218296e8440e737364f2c4b8ebae666c", ), "memmap_test_data.zip": ( f"{TEST_DATA_BASE_URL}/memmap_test_data.zip", "03f7651a0f9e3eeabee9aed0ad5f69e1", # md5 "98ba89d7a9a7baa2d37956a0a591dce9bb4581bd01296ad5a596706ee90a52ef", ), "trx_from_scratch.zip": ( f"{TEST_DATA_BASE_URL}/trx_from_scratch.zip", "d9f220a095ce7f027772fcd9451a2ee5", # md5 "f98ab6da6a6065527fde4b0b6aa40f07583e925d952182e9bbd0febd55c0f6b2", ), "gold_standard.zip": ( f"{TEST_DATA_BASE_URL}/gold_standard.zip", "57e3f9951fe77245684ede8688af3ae8", # md5 "35a0b633560cc2b0d8ecda885aa72d06385499e0cd1ca11a956b0904c3358f01", ), } def md5sum(filename): """Compute the MD5 checksum of a file. Parameters ---------- filename : str Path to file to hash. Returns ------- str Hexadecimal MD5 digest. """ h = hashlib.md5() with open(filename, "rb") as f: for chunk in iter(lambda: f.read(128 * h.block_size), b""): h.update(chunk) return h.hexdigest() def sha256sum(filename): """Compute the SHA256 checksum of a file. Parameters ---------- filename : str Path to file to hash. Returns ------- str Hexadecimal SHA256 digest. """ h = hashlib.sha256() with open(filename, "rb") as f: for chunk in iter(lambda: f.read(128 * h.block_size), b""): h.update(chunk) return h.hexdigest() def fetch_data(files_dict, keys=None): # noqa: C901 """Download files to folder and check their md5 checksums. Parameters ---------- files_dict : dict For each file in `files_dict` the value should be (url, md5). The file will be downloaded from url, if the file does not already exist or if the file exists but the md5 checksum does not match. Zip files are automatically unzipped and its contents are md5 checked. keys : list of str or str or None, optional Subset of keys from ``files_dict`` to download. When None, all keys are downloaded. Raises ------ ValueError Raises if the md5 checksum of the file does not match the expected value. The downloaded file is not deleted when this error is raised. """ trx_home = get_home() if not os.path.exists(trx_home): os.makedirs(trx_home) if keys is None: keys = files_dict.keys() elif isinstance(keys, str): keys = [keys] for f in keys: file_entry = files_dict[f] if len(file_entry) == 2: url, expected_md5 = file_entry expected_sha = None else: url, expected_md5, expected_sha = file_entry full_path = os.path.join(trx_home, f) logging.info("Downloading {} to {}".format(f, trx_home)) if not os.path.exists(full_path): urllib.request.urlretrieve(url, full_path) actual_md5 = md5sum(full_path) if expected_md5 != actual_md5: raise ValueError( f"Md5sum for {f} does not match. " "Please remove the file to download it again: " + full_path ) if expected_sha is not None: actual_sha = sha256sum(full_path) if expected_sha != actual_sha: raise ValueError( f"SHA256 for {f} does not match. " "Please remove the file to download it again: " + full_path ) if f.endswith(".zip"): dst_dir = os.path.join(trx_home, f[:-4]) shutil.unpack_archive(full_path, extract_dir=dst_dir, format="zip") tee-ar-ex-trx-python-a304ac2/trx/io.py000066400000000000000000000133271515240773700176510ustar00rootroot00000000000000# -*- coding: utf-8 -*- """Unified I/O interface for tractogram file formats.""" import logging import os import sys import tempfile try: import dipy # noqa: F401 dipy_available = True except ImportError: dipy_available = False from trx.utils import split_name_with_gz def get_trx_tmp_dir(): """Return a temporary directory honoring the ``TRX_TMPDIR`` setting. When the ``TRX_TMPDIR`` environment variable is set to ``"use_working_dir"`` the current working directory is used. Otherwise, the value of ``TRX_TMPDIR`` is used directly. If the variable is not set, the system temporary directory is used. Returns ------- tempfile.TemporaryDirectory Context-managed temporary directory placed according to the environment configuration. """ if os.getenv("TRX_TMPDIR") is not None: if os.getenv("TRX_TMPDIR") == "use_working_dir": trx_tmp_dir = os.getcwd() else: trx_tmp_dir = os.getenv("TRX_TMPDIR") else: trx_tmp_dir = tempfile.gettempdir() if sys.version_info[1] >= 10: return tempfile.TemporaryDirectory( dir=trx_tmp_dir, prefix="trx_", ignore_cleanup_errors=True ) else: return tempfile.TemporaryDirectory(dir=trx_tmp_dir, prefix="trx_") def load_sft_with_reference(filepath, reference=None, bbox_check=True): """Load a tractogram as a StatefulTractogram with an explicit reference. Parameters ---------- filepath : str Path to the tractogram file (.trk, .tck, .fib, .vtk, .dpy). reference : str or nibabel.Nifti1Image, optional Reference image used for formats without embedded affine information. Pass ``"same"`` to reuse the header embedded in .trk files. bbox_check : bool, optional If True, validate that streamlines lie within the reference bounding box. Defaults to True. Returns ------- StatefulTractogram or None Loaded tractogram. Returns ``None`` when ``dipy`` is unavailable. Raises ------ IOError If the file format is unsupported or a required reference is missing. """ if not dipy_available: logging.error( "Dipy library is missing, cannot use functions related " "to the StatefulTractogram." ) return None from dipy.io.streamline import load_tractogram # Force the usage of --reference for all file formats without an header _, ext = os.path.splitext(filepath) if ext == ".trk": if reference is not None and reference != "same": logging.warning( "Reference is discarded for this file format {}.".format(filepath) ) sft = load_tractogram(filepath, "same", bbox_valid_check=bbox_check) elif ext in [".tck", ".fib", ".vtk", ".dpy"]: if reference is None or reference == "same": raise IOError( "--reference is required for this file format {}.".format(filepath) ) else: sft = load_tractogram(filepath, reference, bbox_valid_check=bbox_check) else: raise IOError("{} is an unsupported file format".format(filepath)) return sft def load(tractogram_filename, reference): """Load a tractogram from disk and return a TRX or StatefulTractogram. Parameters ---------- tractogram_filename : str Path to the input tractogram. TRX directories are supported. reference : str or nibabel.Nifti1Image Reference image used for formats without embedded affine information. Returns ------- TrxFile or StatefulTractogram TRX file handle for ``.trx`` inputs, otherwise a StatefulTractogram. """ import trx.trx_file_memmap as tmm in_ext = split_name_with_gz(tractogram_filename)[1] if in_ext != ".trx" and not os.path.isdir(tractogram_filename): tractogram_obj = load_sft_with_reference( tractogram_filename, reference, bbox_check=False ) else: tractogram_obj = tmm.load(tractogram_filename) return tractogram_obj def save(tractogram_obj, tractogram_filename, bbox_valid_check=False): """Save a tractogram object to disk. Parameters ---------- tractogram_obj : TrxFile or StatefulTractogram Tractogram to persist. Non-TRX inputs are converted to StatefulTractogram before saving to non-TRX formats. tractogram_filename : str Destination file name. ``.trx`` will be saved using the TRX writer; all other extensions are handled by ``dipy.save_tractogram``. bbox_valid_check : bool, optional If True, validate that streamlines lie within the reference bounding box when saving non-TRX formats. Defaults to False. Returns ------- None The function writes to disk and returns ``None``. Returns ``None`` immediately when ``dipy`` is unavailable. """ if not dipy_available: logging.error( "Dipy library is missing, cannot use functions related " "to the StatefulTractogram." ) return None from dipy.io.stateful_tractogram import StatefulTractogram from dipy.io.streamline import save_tractogram import trx.trx_file_memmap as tmm out_ext = split_name_with_gz(tractogram_filename)[1] if out_ext != ".trx": if not isinstance(tractogram_obj, StatefulTractogram): tractogram_obj = tractogram_obj.to_sft() save_tractogram( tractogram_obj, tractogram_filename, bbox_valid_check=bbox_valid_check ) else: if not isinstance(tractogram_obj, tmm.TrxFile): tractogram_obj = tmm.TrxFile.from_sft(tractogram_obj) tmm.save(tractogram_obj, tractogram_filename) tractogram_obj.close() tee-ar-ex-trx-python-a304ac2/trx/streamlines_ops.py000066400000000000000000000124201515240773700224420ustar00rootroot00000000000000# -*- coding: utf-8 -*- """Set operations on streamlines with precision-based matching.""" from functools import reduce import itertools import numpy as np MIN_NB_POINTS = 5 KEY_INDEX = np.concatenate((range(5), range(-1, -6, -1))) def intersection(left, right): """Return the intersection of two streamline hash dictionaries. Parameters ---------- left : dict Hash dictionary returned by :func:`hash_streamlines`. right : dict Hash dictionary returned by :func:`hash_streamlines`. Returns ------- dict Dictionary containing only keys present in both inputs. """ return {k: v for k, v in left.items() if k in right} def difference(left, right): """Return the difference of two streamline hash dictionaries. Parameters ---------- left : dict Hash dictionary returned by :func:`hash_streamlines`. right : dict Hash dictionary returned by :func:`hash_streamlines`. Returns ------- dict Dictionary containing keys present in ``left`` but not in ``right``. """ return {k: v for k, v in left.items() if k not in right} def union(left, right): """Return the union of two streamline hash dictionaries. Parameters ---------- left : dict Hash dictionary returned by :func:`hash_streamlines`. right : dict Hash dictionary returned by :func:`hash_streamlines`. Returns ------- dict Dictionary containing all keys from both inputs. Values from ``left`` overwrite those from ``right`` when keys overlap. """ result = right.copy() result.update(left) return result def get_streamline_key(streamline, precision=None): """Produce a hash key from a streamline using a few points. Parameters ---------- streamline : ndarray A single streamline (N, 3). precision : int, optional The number of decimals to keep when hashing the points of the streamlines. Allows a soft comparison of streamlines. If None, no rounding is performed. Returns ------- bytes Hash of the first/last MIN_NB_POINTS points of the streamline. """ # Use just a few data points as hash key. I could use all the data of # the streamlines, but then the complexity grows with the number of # points. if len(streamline) < MIN_NB_POINTS: key = streamline.copy() else: key = streamline[KEY_INDEX].copy() if precision is not None: key = np.round(key, precision) key.flags.writeable = False return key.data.tobytes() def hash_streamlines(streamlines, start_index=0, precision=None): """Produce a dict from streamlines. Produce a dict from streamlines by using the points as keys and the indices of the streamlines as values. Parameters ---------- streamlines : list of ndarray The list of streamlines used to produce the dict. start_index : int, optional The index of the first streamline. 0 by default. precision : int, optional The number of decimals to keep when hashing the points of the streamlines. Allows a soft comparison of streamlines. If None, no rounding is performed. Returns ------- dict A dict where the keys are streamline points and the values are indices starting at start_index. """ keys = [get_streamline_key(s, precision) for s in streamlines] return {k: i for i, k in enumerate(keys, start_index)} def perform_streamlines_operation(operation, streamlines, precision=0): """Perform an operation on a list of list of streamlines. Given a list of list of streamlines, this function applies the operation to the first two lists of streamlines. The result in then used recursively with the third, fourth, etc. lists of streamlines. A valid operation is any function that takes two streamlines dict as input and produces a new streamlines dict (see hash_streamlines). Union, difference, and intersection are valid examples of operations. Parameters ---------- operation : callable A callable that takes two streamlines dicts as inputs and produces a new streamline dict. streamlines : list of list of streamlines The streamlines used in the operation. precision : int, optional The number of decimals to keep when hashing the points of the streamlines. Allows a soft comparison of streamlines. If None, no rounding is performed. Returns ------- streamlines : list of `nib.streamline.ArraySequence` The streamlines obtained after performing the operation on all the input streamlines. indices : np.ndarray The indices of the streamlines that are used in the output. """ # Hash the streamlines using the desired precision. indices = np.cumsum([0] + [len(s) for s in streamlines[:-1]]) hashes = [hash_streamlines(s, i, precision) for s, i in zip(streamlines, indices)] # Perform the operation on the hashes and get the output streamlines. to_keep = reduce(operation, hashes) all_streamlines = list(itertools.chain(*streamlines)) indices = np.array(sorted(to_keep.values())).astype(np.uint32) streamlines = [all_streamlines[i] for i in indices] return streamlines, indices tee-ar-ex-trx-python-a304ac2/trx/tests/000077500000000000000000000000001515240773700200245ustar00rootroot00000000000000tee-ar-ex-trx-python-a304ac2/trx/tests/test_cli.py000066400000000000000000000457021515240773700222140ustar00rootroot00000000000000# -*- coding: utf-8 -*- """Tests for CLI commands and workflow functions.""" import os import tempfile from deepdiff import DeepDiff import numpy as np from numpy.testing import assert_allclose, assert_array_equal, assert_equal import pytest try: from dipy.io.streamline import load_tractogram dipy_available = True except ImportError: dipy_available = False from trx.fetcher import fetch_data, get_home, get_testing_files_dict import trx.trx_file_memmap as tmm from trx.workflows import ( convert_dsi_studio, convert_tractogram, generate_trx_from_scratch, manipulate_trx_datatype, validate_tractogram, ) # If they already exist, this only takes 5 seconds (check md5sum) fetch_data( get_testing_files_dict(), keys=["DSI.zip", "trx_from_scratch.zip", "gold_standard.zip"], ) def _normalize_dtype_dict(dtype_dict): """Normalize dtype dict to use explicit little-endian byte order. On little-endian systems, numpy may use '=' (native) or '<' (explicit) interchangeably. This normalizes all dtypes to '<' for consistent comparison. """ normalized = {} for key, value in dtype_dict.items(): if isinstance(value, dict): normalized[key] = _normalize_dtype_dict(value) elif isinstance(value, np.dtype): # Normalize to little-endian for multi-byte types if value.byteorder == "=" and value.itemsize > 1: normalized[key] = value.newbyteorder("<") else: normalized[key] = value else: normalized[key] = value return normalized # Tests for standalone CLI commands (trx_* commands) class TestStandaloneCommands: """Tests for standalone CLI commands.""" def test_help_option_convert_dsi(self, script_runner): ret = script_runner.run(["trx_convert_dsi_studio", "--help"]) assert ret.success def test_help_option_convert(self, script_runner): ret = script_runner.run(["trx_convert_tractogram", "--help"]) assert ret.success def test_help_option_generate_from_scratch(self, script_runner): ret = script_runner.run(["trx_generate_from_scratch", "--help"]) assert ret.success def test_help_option_concatenate(self, script_runner): ret = script_runner.run(["trx_concatenate_tractograms", "--help"]) assert ret.success def test_help_option_manipulate(self, script_runner): ret = script_runner.run(["trx_manipulate_datatype", "--help"]) assert ret.success def test_help_option_compare(self, script_runner): ret = script_runner.run(["trx_simple_compare", "--help"]) assert ret.success def test_help_option_validate(self, script_runner): ret = script_runner.run(["trx_validate", "--help"]) assert ret.success def test_help_option_verify_header(self, script_runner): ret = script_runner.run(["trx_verify_header_compatibility", "--help"]) assert ret.success def test_help_option_visualize(self, script_runner): ret = script_runner.run(["trx_visualize_overlap", "--help"]) assert ret.success def test_help_option_info(self, script_runner): ret = script_runner.run(["trx_info", "--help"]) assert ret.success # Tests for unified trx CLI class TestUnifiedCLI: """Tests for the unified trx CLI.""" def test_trx_help(self, script_runner): ret = script_runner.run(["trx", "--help"]) assert ret.success def test_trx_concatenate_help(self, script_runner): ret = script_runner.run(["trx", "concatenate", "--help"]) assert ret.success def test_trx_convert_help(self, script_runner): ret = script_runner.run(["trx", "convert", "--help"]) assert ret.success def test_trx_convert_dsi_help(self, script_runner): ret = script_runner.run(["trx", "convert-dsi", "--help"]) assert ret.success def test_trx_generate_help(self, script_runner): ret = script_runner.run(["trx", "generate", "--help"]) assert ret.success def test_trx_manipulate_dtype_help(self, script_runner): ret = script_runner.run(["trx", "manipulate-dtype", "--help"]) assert ret.success def test_trx_compare_help(self, script_runner): ret = script_runner.run(["trx", "compare", "--help"]) assert ret.success def test_trx_validate_help(self, script_runner): ret = script_runner.run(["trx", "validate", "--help"]) assert ret.success def test_trx_verify_header_help(self, script_runner): ret = script_runner.run(["trx", "verify-header", "--help"]) assert ret.success def test_trx_visualize_help(self, script_runner): ret = script_runner.run(["trx", "visualize", "--help"]) assert ret.success def test_trx_info_help(self, script_runner): ret = script_runner.run(["trx", "info", "--help"]) assert ret.success def test_trx_info_execution(self, script_runner): """Test trx info command execution on a real TRX file.""" trx_path = os.path.join(get_home(), "gold_standard", "gs.trx") ret = script_runner.run(["trx", "info", trx_path]) assert ret.success # Check key output elements assert "VOXEL_TO_RASMM" in ret.stdout assert "DIMENSIONS" in ret.stdout assert "streamline_count" in ret.stdout assert "vertex_count" in ret.stdout assert "Archive contents:" in ret.stdout def test_trx_info_wrong_extension(self, script_runner): """Test trx info rejects non-TRX files.""" tck_path = os.path.join(get_home(), "gold_standard", "gs.tck") ret = script_runner.run(["trx", "info", tck_path]) assert not ret.success assert "not a TRX file" in ret.stderr def test_trx_info_file_not_found(self, script_runner): """Test trx info handles missing files.""" ret = script_runner.run(["trx", "info", "nonexistent.trx"]) assert not ret.success assert "does not exist" in ret.stderr # Tests for workflow functions class TestWorkflowFunctions: """Tests for workflow functions.""" @pytest.mark.skipif(not dipy_available, reason="Dipy is not installed.") def test_execution_convert_dsi(self): with tempfile.TemporaryDirectory() as tmp_dir: in_trk = os.path.join(get_home(), "DSI", "CC.trk.gz") in_nii = os.path.join(get_home(), "DSI", "CC.nii.gz") exp_data = os.path.join(get_home(), "DSI", "CC_fix_data.npy") exp_offsets = os.path.join(get_home(), "DSI", "CC_fix_offsets.npy") out_fix_path = os.path.join(tmp_dir, "fixed.trk") convert_dsi_studio( in_trk, in_nii, out_fix_path, remove_invalid=False, keep_invalid=True ) data_fix = np.load(exp_data) offsets_fix = np.load(exp_offsets) sft = load_tractogram(out_fix_path, "same") assert_equal(sft.streamlines._data, data_fix) assert_equal(sft.streamlines._offsets, offsets_fix) @pytest.mark.skipif(not dipy_available, reason="Dipy is not installed.") def test_execution_convert_to_trx(self): with tempfile.TemporaryDirectory() as tmp_dir: in_trk = os.path.join(get_home(), "DSI", "CC_fix.trk") exp_data = os.path.join(get_home(), "DSI", "CC_fix_data.npy") exp_offsets = os.path.join(get_home(), "DSI", "CC_fix_offsets.npy") out_trx_path = os.path.join(tmp_dir, "CC_fix.trx") convert_tractogram(in_trk, out_trx_path, None) data_fix = np.load(exp_data) offsets_fix = np.load(exp_offsets) trx = tmm.load(out_trx_path) assert_equal(trx.streamlines._data.dtype, np.float32) assert_equal(trx.streamlines._offsets.dtype, np.uint32) assert_array_equal(trx.streamlines._data, data_fix) assert_array_equal(trx.streamlines._offsets, offsets_fix) trx.close() @pytest.mark.skipif(not dipy_available, reason="Dipy is not installed.") def test_execution_convert_from_trx(self): with tempfile.TemporaryDirectory() as tmp_dir: in_trk = os.path.join(get_home(), "DSI", "CC_fix.trk") in_nii = os.path.join(get_home(), "DSI", "CC.nii.gz") exp_data = os.path.join(get_home(), "DSI", "CC_fix_data.npy") exp_offsets = os.path.join(get_home(), "DSI", "CC_fix_offsets.npy") # Sequential conversions out_trx_path = os.path.join(tmp_dir, "CC_fix.trx") out_trk_path = os.path.join(tmp_dir, "CC_fix.trk") out_tck_path = os.path.join(tmp_dir, "CC_fix.tck") convert_tractogram(in_trk, out_trx_path, None) convert_tractogram(out_trx_path, out_tck_path, None) convert_tractogram(out_trx_path, out_trk_path, None) data_fix = np.load(exp_data) offsets_fix = np.load(exp_offsets) sft = load_tractogram(out_trk_path, "same") assert_equal(sft.streamlines._data, data_fix) assert_equal(sft.streamlines._offsets, offsets_fix) sft = load_tractogram(out_tck_path, in_nii) assert_equal(sft.streamlines._data, data_fix) assert_equal(sft.streamlines._offsets, offsets_fix) @pytest.mark.skipif(not dipy_available, reason="Dipy is not installed.") def test_execution_convert_dtype_p16_o64(self): with tempfile.TemporaryDirectory() as tmp_dir: in_trk = os.path.join(get_home(), "DSI", "CC_fix.trk") out_convert_path = os.path.join(tmp_dir, "CC_fix_p16_o64.trx") convert_tractogram( in_trk, out_convert_path, None, pos_dtype="float16", offsets_dtype="uint64", ) trx = tmm.load(out_convert_path) assert_equal(trx.streamlines._data.dtype, np.float16) assert_equal(trx.streamlines._offsets.dtype, np.uint64) trx.close() @pytest.mark.skipif(not dipy_available, reason="Dipy is not installed.") def test_execution_convert_dtype_p64_o32(self): with tempfile.TemporaryDirectory() as tmp_dir: in_trk = os.path.join(get_home(), "DSI", "CC_fix.trk") out_convert_path = os.path.join(tmp_dir, "CC_fix_p16_o64.trx") convert_tractogram( in_trk, out_convert_path, None, pos_dtype="float64", offsets_dtype="uint32", ) trx = tmm.load(out_convert_path) assert_equal(trx.streamlines._data.dtype, np.float64) assert_equal(trx.streamlines._offsets.dtype, np.uint32) trx.close() def test_execution_generate_trx_from_scratch(self): with tempfile.TemporaryDirectory() as tmp_dir: reference_fa = os.path.join(get_home(), "trx_from_scratch", "fa.nii.gz") raw_arr_dir = os.path.join(get_home(), "trx_from_scratch", "test_npy") expected_trx = os.path.join(get_home(), "trx_from_scratch", "expected.trx") dpv = [ (os.path.join(raw_arr_dir, "dpv_cx.npy"), "uint8"), (os.path.join(raw_arr_dir, "dpv_cy.npy"), "uint8"), (os.path.join(raw_arr_dir, "dpv_cz.npy"), "uint8"), ] dps = [ (os.path.join(raw_arr_dir, "dps_algo.npy"), "uint8"), (os.path.join(raw_arr_dir, "dps_cw.npy"), "float64"), ] dpg = [ ( "g_AF_L", os.path.join(raw_arr_dir, "dpg_AF_L_mean_fa.npy"), "float32", ), ( "g_AF_R", os.path.join(raw_arr_dir, "dpg_AF_R_mean_fa.npy"), "float32", ), ("g_AF_L", os.path.join(raw_arr_dir, "dpg_AF_L_volume.npy"), "float32"), ] groups = [ (os.path.join(raw_arr_dir, "g_AF_L.npy"), "int32"), (os.path.join(raw_arr_dir, "g_AF_R.npy"), "int32"), (os.path.join(raw_arr_dir, "g_CST_L.npy"), "int32"), ] out_gen_path = os.path.join(tmp_dir, "generated.trx") generate_trx_from_scratch( reference_fa, out_gen_path, positions=os.path.join(raw_arr_dir, "positions.npy"), offsets=os.path.join(raw_arr_dir, "offsets.npy"), positions_dtype="float16", offsets_dtype="uint64", space_str="rasmm", origin_str="nifti", verify_invalid=False, dpv=dpv, dps=dps, groups=groups, dpg=dpg, ) exp_trx = tmm.load(expected_trx) gen_trx = tmm.load(out_gen_path) assert DeepDiff(exp_trx.get_dtype_dict(), gen_trx.get_dtype_dict()) == {} assert_allclose( exp_trx.streamlines._data, gen_trx.streamlines._data, atol=0.1, rtol=0.1 ) assert_equal(exp_trx.streamlines._offsets, gen_trx.streamlines._offsets) for key in exp_trx.data_per_vertex.keys(): assert_equal( exp_trx.data_per_vertex[key]._data, gen_trx.data_per_vertex[key]._data, ) assert_equal( exp_trx.data_per_vertex[key]._offsets, gen_trx.data_per_vertex[key]._offsets, ) for key in exp_trx.data_per_streamline.keys(): assert_equal( exp_trx.data_per_streamline[key], gen_trx.data_per_streamline[key] ) for key in exp_trx.groups.keys(): assert_equal(exp_trx.groups[key], gen_trx.groups[key]) for group in exp_trx.groups.keys(): if group in exp_trx.data_per_group: for key in exp_trx.data_per_group[group].keys(): assert_equal( exp_trx.data_per_group[group][key], gen_trx.data_per_group[group][key], ) exp_trx.close() gen_trx.close() @pytest.mark.skipif(not dipy_available, reason="Dipy is not installed.") def test_execution_concatenate_validate_trx(self): with tempfile.TemporaryDirectory() as tmp_dir: trx1 = tmm.load(os.path.join(get_home(), "gold_standard", "gs.trx")) trx2 = tmm.load(os.path.join(get_home(), "gold_standard", "gs.trx")) trx = tmm.concatenate([trx1, trx2], preallocation=False) # Right size assert_equal(len(trx.streamlines), 2 * len(trx1.streamlines)) # Right data end_idx = trx1.header["NB_VERTICES"] assert_allclose(trx.streamlines._data[:end_idx], trx1.streamlines._data) assert_allclose(trx.streamlines._data[end_idx:], trx2.streamlines._data) # Right data_per_* for key in trx.data_per_vertex.keys(): assert_equal( trx.data_per_vertex[key]._data[:end_idx], trx1.data_per_vertex[key]._data, ) assert_equal( trx.data_per_vertex[key]._data[end_idx:], trx2.data_per_vertex[key]._data, ) end_idx = trx1.header["NB_STREAMLINES"] for key in trx.data_per_streamline.keys(): assert_equal( trx.data_per_streamline[key][:end_idx], trx1.data_per_streamline[key], ) assert_equal( trx.data_per_streamline[key][end_idx:], trx2.data_per_streamline[key], ) # Validate out_concat_path = os.path.join(tmp_dir, "concat.trx") out_valid_path = os.path.join(tmp_dir, "valid.trx") tmm.save(trx, out_concat_path) validate_tractogram( out_concat_path, None, out_valid_path, remove_identical_streamlines=True, precision=0, ) trx_val = tmm.load(out_valid_path) # Right dtype and size assert DeepDiff(trx.get_dtype_dict(), trx_val.get_dtype_dict()) == {} assert_equal(len(trx1.streamlines), len(trx_val.streamlines)) trx.close() trx1.close() trx2.close() trx_val.close() @pytest.mark.skipif(not dipy_available, reason="Dipy is not installed.") def test_execution_manipulate_trx_datatype(self): with tempfile.TemporaryDirectory() as tmp_dir: expected_trx = os.path.join(get_home(), "trx_from_scratch", "expected.trx") trx = tmm.load(expected_trx) expected_dtype = { "positions": np.dtype("float16"), "offsets": np.dtype("uint64"), "dpv": { "dpv_cx": np.dtype("uint8"), "dpv_cy": np.dtype("uint8"), "dpv_cz": np.dtype("uint8"), }, "dps": {"dps_algo": np.dtype("uint8"), "dps_cw": np.dtype("float64")}, "dpg": { "g_AF_L": { "dpg_AF_L_mean_fa": np.dtype("float32"), "dpg_AF_L_volume": np.dtype("float32"), }, "g_AF_R": {"dpg_AF_R_mean_fa": np.dtype("float32")}, }, "groups": {"g_AF_L": np.dtype("int32"), "g_AF_R": np.dtype("int32")}, } assert ( DeepDiff( trx.get_dtype_dict(), _normalize_dtype_dict(expected_dtype), ) == {} ) trx.close() generated_dtype = { "positions": np.dtype("float32"), "offsets": np.dtype("uint32"), "dpv": { "dpv_cx": np.dtype("uint16"), "dpv_cy": np.dtype("uint16"), "dpv_cz": np.dtype("uint16"), }, "dps": {"dps_algo": np.dtype("uint8"), "dps_cw": np.dtype("float32")}, "dpg": { "g_AF_L": { "dpg_AF_L_mean_fa": np.dtype("float64"), "dpg_AF_L_volume": np.dtype("float32"), }, "g_AF_R": {"dpg_AF_R_mean_fa": np.dtype("float64")}, }, "groups": {"g_AF_L": np.dtype("uint16"), "g_AF_R": np.dtype("uint16")}, } out_gen_path = os.path.join(tmp_dir, "generated.trx") manipulate_trx_datatype(expected_trx, out_gen_path, generated_dtype) trx = tmm.load(out_gen_path) assert ( DeepDiff( trx.get_dtype_dict(), _normalize_dtype_dict(generated_dtype), ) == {} ) trx.close() tee-ar-ex-trx-python-a304ac2/trx/tests/test_io.py000066400000000000000000000201711515240773700220450ustar00rootroot00000000000000# -*- coding: utf-8 -*- from copy import deepcopy import os from tempfile import TemporaryDirectory import zipfile import numpy as np from numpy.testing import assert_allclose import psutil import pytest try: from dipy.io.streamline import load_tractogram, save_tractogram dipy_available = True except ImportError: dipy_available = False from trx.fetcher import fetch_data, get_home, get_testing_files_dict from trx.io import load, save import trx.trx_file_memmap as tmm from trx.trx_file_memmap import TrxFile fetch_data(get_testing_files_dict(), keys=["gold_standard.zip"]) @pytest.mark.parametrize("path", [("gs.trk"), ("gs.tck"), ("gs.vtk")]) @pytest.mark.skipif(not dipy_available, reason="Dipy is not installed.") def test_seq_ops_sft(path): with TemporaryDirectory() as tmp_dir: gs_dir = os.path.join(get_home(), "gold_standard") path = os.path.join(tmp_dir, path) obj = load(os.path.join(gs_dir, "gs.trx"), os.path.join(gs_dir, "gs.nii")) sft_1 = obj.to_sft() save_tractogram(sft_1, path) obj.close() save_tractogram(sft_1, os.path.join(tmp_dir, "tmp.trk")) _ = load_tractogram(os.path.join(tmp_dir, "tmp.trk"), "same") def test_seq_ops_trx(): with TemporaryDirectory() as tmp_dir: gs_dir = os.path.join(get_home(), "gold_standard") path = os.path.join(gs_dir, "gs.trx") trx_1 = tmm.load(path) tmm.save(trx_1, os.path.join(tmp_dir, "tmp.trx")) trx_1.close() trx_2 = tmm.load(os.path.join(tmp_dir, "tmp.trx")) trx_2.close() @pytest.mark.parametrize("path", [("gs.trx"), ("gs.trk"), ("gs.tck"), ("gs.vtk")]) @pytest.mark.skipif(not dipy_available, reason="Dipy is not installed.") def test_load_vox(path): gs_dir = os.path.join(get_home(), "gold_standard") path = os.path.join(gs_dir, path) coord = np.loadtxt(os.path.join(get_home(), "gold_standard", "gs_vox_space.txt")) obj = load(path, os.path.join(gs_dir, "gs.nii")) sft = obj.to_sft() if isinstance(obj, TrxFile) else obj sft.to_vox() assert_allclose(sft.streamlines._data, coord, rtol=1e-04, atol=1e-06) if isinstance(obj, TrxFile): obj.close() @pytest.mark.parametrize("path", [("gs.trx"), ("gs.trk"), ("gs.tck"), ("gs.vtk")]) @pytest.mark.skipif(not dipy_available, reason="Dipy is not installed.") def test_load_voxmm(path): gs_dir = os.path.join(get_home(), "gold_standard") path = os.path.join(gs_dir, path) coord = np.loadtxt(os.path.join(get_home(), "gold_standard", "gs_voxmm_space.txt")) obj = load(path, os.path.join(gs_dir, "gs.nii")) sft = obj.to_sft() if isinstance(obj, TrxFile) else obj sft.to_voxmm() assert_allclose(sft.streamlines._data, coord, rtol=1e-04, atol=1e-06) if isinstance(obj, TrxFile): obj.close() @pytest.mark.parametrize("path", [("gs.trk"), ("gs.trx"), ("gs_fldr.trx")]) @pytest.mark.skipif(not dipy_available, reason="Dipy is not installed.") def test_multi_load_save_rasmm(path): with TemporaryDirectory() as tmp_gs_dir: gs_dir = os.path.join(get_home(), "gold_standard") basename, ext = os.path.splitext(path) path = os.path.join(gs_dir, path) coord = np.loadtxt( os.path.join(get_home(), "gold_standard", "gs_rasmm_space.txt") ) obj = load(path, os.path.join(gs_dir, "gs.nii")) for i in range(3): out_path = os.path.join(tmp_gs_dir, "{}_tmp{}_{}".format(basename, i, ext)) save(obj, out_path) if isinstance(obj, TrxFile): obj.close() obj = load(out_path, os.path.join(gs_dir, "gs.nii")) assert_allclose(obj.streamlines._data, coord, rtol=1e-04, atol=1e-06) if isinstance(obj, TrxFile): obj.close() @pytest.mark.parametrize("path", [("gs.trx"), ("gs_fldr.trx")]) @pytest.mark.skipif(not dipy_available, reason="Dipy is not installed.") def test_delete_tmp_gs_dir(path): gs_dir = os.path.join(get_home(), "gold_standard") path = os.path.join(gs_dir, path) trx1 = tmm.load(path) if os.path.isfile(path): tmp_gs_dir = deepcopy(trx1._uncompressed_folder_handle.name) assert os.path.isdir(tmp_gs_dir) sft = trx1.to_sft() trx1.close() coord_rasmm = np.loadtxt( os.path.join(get_home(), "gold_standard", "gs_rasmm_space.txt") ) coord_vox = np.loadtxt( os.path.join(get_home(), "gold_standard", "gs_vox_space.txt") ) # The folder trx representation does not need tmp files if os.path.isfile(path): assert not os.path.isdir(tmp_gs_dir) assert_allclose(sft.streamlines._data, coord_rasmm, rtol=1e-04, atol=1e-06) # Reloading the TRX and checking its data, then closing trx2 = tmm.load(path) assert_allclose( trx2.streamlines._data, sft.streamlines._data, rtol=1e-04, atol=1e-06 ) trx2.close() sft.to_vox() assert_allclose(sft.streamlines._data, coord_vox, rtol=1e-04, atol=1e-06) trx3 = tmm.load(path) assert_allclose(trx3.streamlines._data, coord_rasmm, rtol=1e-04, atol=1e-06) trx3.close() @pytest.mark.parametrize("path", [("gs.trx")]) @pytest.mark.skipif(not dipy_available, reason="Dipy is not installed.") def test_close_tmp_files(path): gs_dir = os.path.join(get_home(), "gold_standard") path = os.path.join(gs_dir, path) trx = tmm.load(path) process = psutil.Process(os.getpid()) open_files = process.open_files() expected_content = [ "offsets.uint32", "positions.3.float32", "header.json", "random_coord.3.float32", "color_y.float32", "color_x.float32", "color_z.float32", ] count = 0 for open_file in open_files: basename = os.path.basename(open_file.path) if basename in expected_content: count += 1 assert count == 6 trx.close() open_files = process.open_files() count = 0 for open_file in open_files: basename = os.path.basename(open_file.path) if basename in expected_content: count += 1 assert not count @pytest.mark.parametrize("tmp_path", [("~"), ("use_working_dir")]) def test_change_tmp_dir(tmp_path): gs_dir = os.path.join(get_home(), "gold_standard") path = os.path.join(gs_dir, "gs.trx") if tmp_path == "use_working_dir": os.environ["TRX_TMPDIR"] = "use_working_dir" else: os.environ["TRX_TMPDIR"] = os.path.expanduser(tmp_path) trx = tmm.load(path) tmp_gs_dir = deepcopy(trx._uncompressed_folder_handle.name) if tmp_path == "use_working_dir": assert os.path.dirname(tmp_gs_dir) == os.getcwd() else: assert os.path.dirname(tmp_gs_dir) == os.path.expanduser(tmp_path) trx.close() assert not os.path.isdir(tmp_gs_dir) @pytest.mark.parametrize("path", [("gs.trx"), ("gs_fldr.trx")]) def test_complete_dir_from_trx(path): gs_dir = os.path.join(get_home(), "gold_standard") path = os.path.join(gs_dir, path) trx = tmm.load(path) if trx._uncompressed_folder_handle is None: dir_to_check = path else: dir_to_check = trx._uncompressed_folder_handle.name file_paths = [] for dirpath, _, filenames in os.walk(dir_to_check): for filename in filenames: full_path = os.path.join(dirpath, filename) cut_path = full_path.split(dir_to_check)[1][1:].replace("\\", "/") file_paths.append(cut_path) expected_content = [ "offsets.uint32", "positions.3.float32", "header.json", "dps/random_coord.3.float32", "dpv/color_y.float32", "dpv/color_x.float32", "dpv/color_z.float32", ] assert set(file_paths) == set(expected_content) def test_complete_zip_from_trx(): gs_dir = os.path.join(get_home(), "gold_standard") path = os.path.join(gs_dir, "gs.trx") with zipfile.ZipFile(path, mode="r") as zf: zip_file_list = zf.namelist() expected_content = [ "offsets.uint32", "positions.3.float32", "header.json", "dps/random_coord.3.float32", "dpv/color_y.float32", "dpv/color_x.float32", "dpv/color_z.float32", ] assert set(zip_file_list) == set(expected_content) tee-ar-ex-trx-python-a304ac2/trx/tests/test_memmap.py000066400000000000000000000464561515240773700227300ustar00rootroot00000000000000# -*- coding: utf-8 -*- import json import os import struct import tempfile import zipfile from nibabel.streamlines import LazyTractogram from nibabel.streamlines.tests.test_tractogram import make_dummy_streamline import numpy as np import pytest try: import dipy # noqa: F401 dipy_available = True except ImportError: dipy_available = False from trx.fetcher import fetch_data, get_home, get_testing_files_dict from trx.io import get_trx_tmp_dir import trx.trx_file_memmap as tmm fetch_data(get_testing_files_dict(), keys=["memmap_test_data.zip"]) tmp_dir = get_trx_tmp_dir() @pytest.mark.parametrize( "arr,expected,value_error", [ (np.ones((5, 5, 5), dtype=np.int16), None, True), (np.ones((5, 4), dtype=np.int16), "mean_fa.4.int16", False), (np.ones((5, 4), dtype=np.float64), "mean_fa.4.float64", False), (np.ones((5, 1), dtype=np.float64), "mean_fa.float64", False), (np.ones((1), dtype=np.float64), "mean_fa.float64", False), ], ) def test__generate_filename_from_data( arr, expected, value_error, filename="mean_fa.bit" ): if value_error: with pytest.raises(ValueError): new_fn = tmm._generate_filename_from_data(arr=arr, filename=filename) assert new_fn is None else: new_fn = tmm._generate_filename_from_data(arr=arr, filename=filename) assert new_fn == expected @pytest.mark.parametrize( "filename,expected,value_error", [ ("mean_fa.float64", ("mean_fa", 1, ".float64"), False), ("mean_fa.5.int32", ("mean_fa", 5, ".int32"), False), ("mean_fa", None, True), ("mean_fa.5.4.int32", None, True), pytest.param( "mean_fa.fa", None, True, marks=pytest.mark.xfail, id="invalid extension" ), ], ) def test__split_ext_with_dimensionality(filename, expected, value_error): if value_error: with pytest.raises(ValueError): assert tmm._split_ext_with_dimensionality(filename) == expected else: assert tmm._split_ext_with_dimensionality(filename) == expected @pytest.mark.parametrize( "offsets,nb_vertices,expected", [ (np.array(range(5), dtype=np.int16), 4, np.array([1, 1, 1, 1, 0])), (np.array([0, 1, 1, 3, 4], dtype=np.int32), 4, np.array([1, 0, 2, 1, 0])), (np.array(range(4), dtype=np.uint64), 4, np.array([1, 1, 1, 1])), pytest.param( np.array([0, 1, 0, 3, 4], dtype=np.int16), 4, np.array([1, 3, 0, 1, 0]), marks=pytest.mark.xfail, id="offsets not sorted", ), ], ) def test__compute_lengths(offsets, nb_vertices, expected): offsets = tmm._append_last_offsets(offsets, nb_vertices) lengths = tmm._compute_lengths(offsets=offsets) assert np.array_equal(lengths, expected) @pytest.mark.parametrize( "ext,expected", [ (".bit", True), (".int16", True), (".float32", True), (".ushort", True), (".txt", False), ], ) def test__is_dtype_valid(ext, expected): assert tmm._is_dtype_valid(ext) == expected @pytest.mark.parametrize( "arr,l_bound,r_bound,expected", [ (np.array(range(5), dtype=np.int16), None, None, 4), (np.array([0, 1, 0, 3, 4], dtype=np.int16), None, None, 1), (np.array([0, 1, 2, 0, 4], dtype=np.int16), None, None, 2), (np.array(range(5), dtype=np.int16), 1, 2, 2), (np.array(range(5), dtype=np.int16), 3, 3, 3), (np.zeros((5), dtype=np.int16), 3, 3, -1), ], ) def test__dichotomic_search(arr, l_bound, r_bound, expected): end_idx = tmm._dichotomic_search(arr, l_bound=l_bound, r_bound=r_bound) assert end_idx == expected @pytest.mark.parametrize( "basename, create, expected", [ ("offsets.int16", True, np.array(range(12), dtype=np.int16).reshape((3, 4))), ("offsets.float32", False, None), ], ) def test__create_memmap(basename, create, expected): if create: with get_trx_tmp_dir() as dirname: filename = os.path.join(dirname, basename) fp = tmm._create_memmap( filename=filename, mode="w+", shape=(3, 4), dtype=np.int16 ) fp[:] = expected[:] mmarr = tmm._create_memmap(filename=filename, shape=(3, 4), dtype=np.int16) assert np.array_equal(mmarr, expected) else: with get_trx_tmp_dir() as dirname: filename = os.path.join(dirname, basename) mmarr = tmm._create_memmap(filename=filename, shape=(0,), dtype=np.int16) assert os.path.isfile(filename) assert np.array_equal(mmarr, np.zeros(shape=(0,), dtype=np.float32)) # need dpg test with missing keys @pytest.mark.parametrize( "path,check_dpg,value_error", [ ("small_compressed.trx", False, False), ("small.trx", True, False), ("small_fldr.trx", False, False), ("dontexist.trx", False, True), ], ) def test_load(path, check_dpg, value_error): path = os.path.join(get_home(), "memmap_test_data", path) # Need to perhaps improve test if value_error: with pytest.raises(ValueError): assert not isinstance( tmm.load(input_obj=path, check_dpg=check_dpg), tmm.TrxFile ) else: assert isinstance(tmm.load(input_obj=path, check_dpg=check_dpg), tmm.TrxFile) @pytest.mark.parametrize("path", [("small.trx")]) def test_load_zip(path): path = os.path.join(get_home(), "memmap_test_data", path) assert isinstance(tmm.load_from_zip(path), tmm.TrxFile) @pytest.mark.parametrize("path", [("small_fldr.trx")]) def test_load_directory(path): path = os.path.join(get_home(), "memmap_test_data", path) assert isinstance(tmm.load_from_directory(path), tmm.TrxFile) @pytest.mark.parametrize("path", [("small.trx")]) def test_concatenate(path): path = os.path.join(get_home(), "memmap_test_data", path) trx1 = tmm.load(path) trx2 = tmm.load(path) concat = tmm.concatenate([trx1, trx2]) assert len(concat) == 2 * len(trx2) trx1.close() trx2.close() concat.close() @pytest.mark.parametrize("path", [("small.trx")]) def test_resize(path): path = os.path.join(get_home(), "memmap_test_data", path) trx1 = tmm.load(path) concat = tmm.TrxFile(nb_vertices=1000000, nb_streamlines=10000, init_as=trx1) tmm.concatenate([concat, trx1], preallocation=True, delete_groups=True) concat.resize() assert len(concat) == len(trx1) trx1.close() concat.close() @pytest.mark.parametrize("path, buffer", [("small.trx", 10000), ("small.trx", 0)]) def test_append(path, buffer): path = os.path.join(get_home(), "memmap_test_data", path) trx1 = tmm.load(path) concat = tmm.TrxFile(nb_vertices=1, nb_streamlines=1, init_as=trx1) concat.append(trx1, extra_buffer=buffer) if buffer > 0: concat.resize() assert len(concat) == len(trx1) trx1.close() concat.close() @pytest.mark.parametrize("path, buffer", [("small.trx", 10000)]) @pytest.mark.skipif(not dipy_available, reason="Dipy is not installed") def test_append_StatefulTractogram(path, buffer): path = os.path.join(get_home(), "memmap_test_data", path) trx = tmm.load(path) obj = trx.to_sft() concat = tmm.TrxFile(nb_vertices=1, nb_streamlines=1, init_as=trx) concat.append(obj, extra_buffer=buffer) if buffer > 0: concat.resize() assert len(concat) == len(obj) trx.close() concat.close() @pytest.mark.parametrize("path, buffer", [("small.trx", 10000)]) def test_append_Tractogram(path, buffer): path = os.path.join(get_home(), "memmap_test_data", path) trx = tmm.load(path) obj = trx.to_tractogram() concat = tmm.TrxFile(nb_vertices=1, nb_streamlines=1, init_as=trx) concat.append(obj, extra_buffer=buffer) if buffer > 0: concat.resize() assert len(concat) == len(obj) trx.close() concat.close() @pytest.mark.parametrize( "path, size, buffer", [ ("small.trx", 50, 10000), ("small.trx", 0, 10000), ("small.trx", 25000, 10000), ("small.trx", 50, 0), ("small.trx", 0, 0), ("small.trx", 25000, 10000), ], ) def test_from_lazy_tractogram(path, size, buffer): _ = np.random.RandomState(1776) streamlines = [] fa = [] commit_weights = [] clusters_QB = [] gen_range = [1, 2, 5, 2, 1] * (size // 5) for i in gen_range: data = make_dummy_streamline(i) streamline, data_per_point, data_for_streamline = data streamlines.append(streamline) fa.append(data_per_point["fa"].astype(np.float16)) commit_weights.append(data_for_streamline["mean_curvature"].astype(np.float32)) clusters_QB.append(data_for_streamline["mean_torsion"].astype(np.uint16)) def streamlines_func(): return (e for e in streamlines) data_per_point_func = {"fa": lambda: (e for e in fa)} data_per_streamline_func = { "commit_weights": lambda: (e for e in commit_weights), "clusters_QB": lambda: (e for e in clusters_QB), } obj = LazyTractogram( streamlines_func, data_per_streamline_func, data_per_point_func, affine_to_rasmm=np.eye(4), ) dtype_dict = { "positions": np.float32, "offsets": np.uint32, "dpv": {"fa": np.float16}, "dps": {"commit_weights": np.float32, "clusters_QB": np.uint16}, } path = os.path.join(get_home(), "memmap_test_data", path) trx = tmm.TrxFile.from_lazy_tractogram( obj, reference=path, extra_buffer=buffer, chunk_size=1000, dtype_dict=dtype_dict ) assert len(trx) == len(gen_range) def test_zip_from_folder(): pass def test_trxfile_init(): pass def test_trxfile_print(): pass def test_trxfile_len(): fake = tmm.TrxFile(nb_vertices=100, nb_streamlines=10) assert len(fake) == 10 def test_trxfile_getitem(): pass def test_trxfile_deepcopy(): pass def test_get_real_len(): fake = tmm.TrxFile(nb_vertices=100, nb_streamlines=10) assert fake._get_real_len() == (0, 0) def test_copy_fixed_arrays_from(): pass def test_initialize_empty_trx(): """Test creating, saving, and loading an empty TRX file.""" trx = tmm.TrxFile() assert trx.header["NB_STREAMLINES"] == 0 assert trx.header["NB_VERTICES"] == 0 assert len(trx.streamlines) == 0 with tempfile.TemporaryDirectory() as tmp_dir: out_path = os.path.join(tmp_dir, "empty.trx") tmm.save(trx, out_path) assert os.path.exists(out_path) file_size = os.path.getsize(out_path) assert file_size < 500 # Should be very small, just header.json in zip with zipfile.ZipFile(out_path, "r") as zf: filenames = [info.filename for info in zf.filelist] assert "header.json" in filenames positions_files = [f for f in filenames if f.startswith("positions")] offsets_files = [f for f in filenames if f.startswith("offsets")] assert len(positions_files) == 0 assert len(offsets_files) == 0 loaded_trx = tmm.load(out_path) assert loaded_trx.header["NB_STREAMLINES"] == 0 assert loaded_trx.header["NB_VERTICES"] == 0 assert len(loaded_trx.streamlines) == 0 assert len(loaded_trx.groups) == 0 assert len(loaded_trx.data_per_streamline) == 0 assert len(loaded_trx.data_per_vertex) == 0 assert len(loaded_trx.data_per_group) == 0 loaded_trx.close() def test_create_trx_from_pointer(): pass def test_trxfile_getgroup(): pass def test_trxfile_select(): pass def test_trxfile_to_memory(): pass def test_trxfile_close(): pass @pytest.mark.parametrize("path", [("small.trx")]) def test_close_releases_mmap_from_zip(path): """close() must release mmap handles even when loaded via load_from_zip().""" path = os.path.join(get_home(), "memmap_test_data", path) trx = tmm.load_from_zip(path) assert trx._uncompressed_folder_handle is None mmap_obj = trx.streamlines._data._mmap assert mmap_obj is not None, "expected a live mmap before close()" assert not mmap_obj.closed, "mmap should be open before close()" trx.close() assert mmap_obj.closed, ( "mmap is still open after close() — the mmap teardown was skipped " "because _uncompressed_folder_handle was None" ) # Endianness tests for cross-platform compatibility (Issue #83) @pytest.mark.parametrize( "dtype_input,expected_byteorder", [ # Native dtypes should be converted to little-endian (np.float32, "<"), (np.float64, "<"), (np.int32, "<"), (np.int64, "<"), (np.uint32, "<"), (np.uint64, "<"), ("float32", "<"), ("float64", "<"), # Big-endian dtypes should be converted to little-endian (">f4", "<"), (">f8", "<"), (">i4", "<"), (">u4", "<"), # Little-endian dtypes should remain little-endian ("u4") arr = np.array([0x12345678], dtype=big_endian_dtype) # Ensure little endian result = tmm._ensure_little_endian(arr) # Result should be little-endian assert result.dtype.byteorder == "<" # Value should be preserved assert result[0] == 0x12345678 def test_load_zip_with_local_header_extra_field(): """Test loading ZIP where local header has extra field not in central dir. Regression test for a bug where zip_info.FileHeader() was used to calculate data offset. The ZIP spec allows local headers to have different extra fields than central directory entries. The fix reads the actual local file header to get the correct offset. """ positions = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32) offsets = np.array([0, 2], dtype=np.uint64) header = { "DIMENSIONS": [10, 10, 10], "VOXEL_TO_RASMM": np.eye(4).tolist(), "NB_VERTICES": 2, "NB_STREAMLINES": 1, } with tempfile.TemporaryDirectory() as tmp_dir: trx_path = os.path.join(tmp_dir, "test.trx") # Build ZIP with extra bytes in local headers but not central directory with open(trx_path, "wb") as f: local_info = [] extra = b"\x00\x00\x04\x00TEST" # 8-byte extra field for name, data in [ ("header.json", json.dumps(header).encode()), ("positions.3.float32", positions.tobytes()), ("offsets.uint64", offsets.tobytes()), ]: offset = f.tell() fname = name.encode() crc = zipfile.crc32(data) # Local header WITH extra field f.write( struct.pack( "<4sHHHHHIIIHH", b"PK\x03\x04", 20, 0, 0, 0, 0, crc, len(data), len(data), len(fname), len(extra), ) ) f.write(fname) f.write(extra) f.write(data) local_info.append((name, offset, crc, len(data))) cd_start = f.tell() for name, offset, crc, size in local_info: fname = name.encode() # Central directory WITHOUT extra field (mismatch!) f.write( struct.pack( "<4sHHHHHHIIIHHHHHII", b"PK\x01\x02", 20, 20, 0, 0, 0, 0, crc, size, size, len(fname), 0, 0, 0, 0, 0, offset, ) ) f.write(fname) # End of central directory f.write( struct.pack( "<4sHHHHIIH", b"PK\x05\x06", 0, 0, 3, 3, f.tell() - cd_start, cd_start, 0, ) ) trx = tmm.load_from_zip(trx_path) np.testing.assert_array_almost_equal(trx.streamlines._data, positions) assert trx.header["NB_VERTICES"] == 2 assert trx.header["NB_STREAMLINES"] == 1 trx.close() def test_endianness_roundtrip(): """Test that data survives write/read cycle with correct endianness.""" with get_trx_tmp_dir() as dirname: # Test values that would be corrupted if endianness is wrong test_positions = np.array( [[1.5, 2.5, 3.5], [4.5, 5.5, 6.5], [7.5, 8.5, 9.5]], dtype=np.float32 ) test_offsets = np.array([0, 3], dtype=np.uint32) # Write as little-endian pos_file = os.path.join(dirname, "test_positions.3.float32") off_file = os.path.join(dirname, "test_offsets.uint32") tmm._ensure_little_endian(test_positions).tofile(pos_file) tmm._ensure_little_endian(test_offsets).tofile(off_file) # Read back using _create_memmap (which enforces little-endian) read_positions = tmm._create_memmap( pos_file, mode="r", shape=(3, 3), dtype="float32" ) read_offsets = tmm._create_memmap( off_file, mode="r", shape=(2,), dtype="uint32" ) # Values should match np.testing.assert_array_almost_equal(read_positions, test_positions) np.testing.assert_array_equal(read_offsets, test_offsets) tee-ar-ex-trx-python-a304ac2/trx/tests/test_streamlines_ops.py000066400000000000000000000051641515240773700246520ustar00rootroot00000000000000# -*- coding: utf-8 -*- import numpy as np import pytest from trx.streamlines_ops import ( difference, intersection, perform_streamlines_operation, union, ) streamlines_ori = [ np.ones(90).reshape((30, 3)), np.arange(90).reshape((30, 3)) + 0.3333, ] @pytest.mark.parametrize( "precision, noise, expected", [ (0, 0.0001, [4]), (1, 0.0001, [4]), (2, 0.0001, [4]), (4, 0.0001, []), (0, 0.01, [4]), (1, 0.01, [4]), (2, 0.01, []), (0, 1, []), ], ) def test_intersection(precision, noise, expected): streamlines_new = [] for i in range(5): if i < 4: streamlines_new.append(streamlines_ori[1] + np.random.random((30, 3))) else: streamlines_new.append( streamlines_ori[1] + noise * np.random.random((30, 3)) ) # print(streamlines_new) _, indices_uniq = perform_streamlines_operation( intersection, [streamlines_new, streamlines_ori], precision=precision ) indices_uniq = indices_uniq.tolist() assert indices_uniq == expected @pytest.mark.parametrize( "precision, noise, expected", [ (0, 0.0001, 6), (1, 0.0001, 6), (2, 0.0001, 6), (4, 0.0001, 7), (0, 0.01, 6), (1, 0.01, 6), (2, 0.01, 7), (0, 1, 7), ], ) def test_union(precision, noise, expected): streamlines_new = [] for i in range(5): if i < 4: streamlines_new.append(streamlines_ori[1] + np.random.random((30, 3))) else: streamlines_new.append( streamlines_ori[1] + noise * np.random.random((30, 3)) ) unique_streamlines, _ = perform_streamlines_operation( union, [streamlines_new, streamlines_ori], precision=precision ) assert len(unique_streamlines) == expected @pytest.mark.parametrize( "precision, noise, expected", [ (0, 0.0001, 4), (1, 0.0001, 4), (2, 0.0001, 4), (4, 0.0001, 5), (0, 0.01, 4), (1, 0.01, 4), (2, 0.01, 5), (0, 1, 5), ], ) def test_difference(precision, noise, expected): streamlines_new = [] for i in range(5): if i < 4: streamlines_new.append(streamlines_ori[1] + np.random.random((30, 3))) else: streamlines_new.append( streamlines_ori[1] + noise * np.random.random((30, 3)) ) unique_streamlines, _ = perform_streamlines_operation( difference, [streamlines_new, streamlines_ori], precision=precision ) assert len(unique_streamlines) == expected tee-ar-ex-trx-python-a304ac2/trx/trx_file_memmap.py000066400000000000000000002340371515240773700224150ustar00rootroot00000000000000# -*- coding: utf-8 -*- """Core TrxFile class with memory-mapped data access.""" from copy import deepcopy import json import logging import os import shutil import struct from typing import Any, List, Optional, Tuple, Type, Union import zipfile import nibabel as nib from nibabel.affines import voxel_sizes from nibabel.nifti1 import Nifti1Header, Nifti1Image from nibabel.orientations import aff2axcodes from nibabel.streamlines.array_sequence import ArraySequence from nibabel.streamlines.tractogram import LazyTractogram, Tractogram from nibabel.streamlines.trk import TrkFile import numpy as np from trx.io import get_trx_tmp_dir from trx.utils import ( append_generator_to_dict, close_or_delete_mmap, convert_data_dict_to_tractogram, get_reference_info_wrapper, ) try: import dipy # noqa: F401 dipy_available = True except ImportError: dipy_available = False def _get_dtype_little_endian(dtype: Union[np.dtype, str, type]) -> np.dtype: """Convert a dtype to its little-endian equivalent. The TRX file format uses little-endian byte order for cross-platform compatibility. This function ensures that dtypes are always interpreted as little-endian when reading/writing TRX files. Parameters ---------- dtype : np.dtype, str, or type Input dtype specification (e.g., np.float32, 'float32', '>f4'). Returns ------- np.dtype Little-endian dtype. For single-byte types (uint8, int8, bool), returns the original dtype as endianness is not applicable. """ dt = np.dtype(dtype) # Single-byte types don't have endianness if dt.byteorder == "|" or dt.itemsize == 1: return dt # Already little-endian if dt.byteorder == "<": return dt # Convert to little-endian return dt.newbyteorder("<") def _ensure_little_endian(arr: np.ndarray) -> np.ndarray: """Ensure array data is in little-endian byte order for writing. Parameters ---------- arr : np.ndarray Input array. Returns ------- np.ndarray Array with little-endian byte order. Returns a copy if conversion was needed, otherwise returns the original array. """ dt = arr.dtype # Single-byte types don't have endianness if dt.byteorder == "|" or dt.itemsize == 1: return arr # Already little-endian if dt.byteorder == "<": return arr # Native byte order on little-endian system if dt.byteorder == "=" and np.little_endian: return arr # Convert to little-endian return arr.astype(dt.newbyteorder("<")) def _append_last_offsets(nib_offsets: np.ndarray, nb_vertices: int) -> np.ndarray: """Append the last element of offsets from header information. Parameters ---------- nib_offsets : np.ndarray Array of offsets with the last element being the start of the last streamline (nibabel convention). nb_vertices : int Total number of vertices in the streamlines. Returns ------- np.ndarray Offsets array (VTK convention). """ def is_sorted(a): """Return True if array is sorted non-decreasing. Parameters ---------- a : np.ndarray 1D array of numeric offsets. Returns ------- bool True when ``a`` is monotonically non-decreasing. """ return np.all(a[:-1] <= a[1:]) if not is_sorted(nib_offsets): raise ValueError("Offsets must be sorted values.") return np.append(nib_offsets, nb_vertices).astype(nib_offsets.dtype) def _generate_filename_from_data(arr: np.ndarray, filename: str) -> str: """Determine the data type from array data and generate the appropriate filename. Parameters ---------- arr : np.ndarray A NumPy array (1-2D, otherwise ValueError raised). filename : str The original filename. Returns ------- str An updated filename with appropriate extension. """ base, ext = os.path.splitext(filename) if ext: logging.warning("Will overwrite provided extension if needed.") dtype = arr.dtype dtype = "bit" if dtype is np.dtype(bool) else dtype.name if arr.ndim == 1: new_filename = "{}.{}".format(base, dtype) elif arr.ndim == 2: dim = arr.shape[-1] if dim == 1: new_filename = "{}.{}".format(base, dtype) else: new_filename = "{}.{}.{}".format(base, arr.shape[-1], dtype) else: raise ValueError("Invalid dimensionality.") return new_filename def _split_ext_with_dimensionality(filename: str) -> Tuple[str, int, str]: """Take a filename and split it into its components. Parameters ---------- filename : str Input filename. Returns ------- tuple A tuple of (basename, dimension, extension). """ basename = os.path.basename(filename) split = basename.split(".") if len(split) != 2 and len(split) != 3: raise ValueError("Invalid filename.") basename = split[0] ext = ".{}".format(split[-1]) dim = 1 if len(split) == 2 else split[1] _is_dtype_valid(ext) return basename, int(dim), ext def _compute_lengths(offsets: np.ndarray) -> np.ndarray: """Compute lengths from offsets. Parameters ---------- offsets : np.ndarray An array of offsets. Returns ------- np.ndarray An array of lengths. """ if len(offsets) > 0: last_elem_pos = _dichotomic_search(offsets) lengths = np.ediff1d(offsets) if len(lengths) > last_elem_pos: lengths[last_elem_pos] = 0 else: lengths = np.array([0]) return lengths.astype(np.uint32) def _is_dtype_valid(ext: str) -> bool: """Verify that filename extension is a valid datatype. Parameters ---------- ext : str Filename extension. Returns ------- bool True if the provided datatype is valid, False otherwise. """ if ext.replace(".", "") == "bit": return True try: isinstance(np.dtype(ext.replace(".", "")), np.dtype) return True except TypeError: return False def _dichotomic_search( x: np.ndarray, l_bound: Optional[int] = None, r_bound: Optional[int] = None ) -> int: """Find where data of a contiguous array is actually ending. Parameters ---------- x : np.ndarray Array of values. l_bound : int, optional Lower bound index for search. r_bound : int, optional Upper bound index for search. Returns ------- int Index at which array value is 0 (if possible), otherwise returns -1. """ if l_bound is None and r_bound is None: l_bound = 0 r_bound = len(x) - 1 if l_bound == r_bound: val = l_bound if x[l_bound] != 0 else -1 return val mid_bound = (l_bound + r_bound + 1) // 2 if x[mid_bound] == 0: return _dichotomic_search(x, l_bound, mid_bound - 1) else: return _dichotomic_search(x, mid_bound, r_bound) def _create_memmap( filename: str, mode: str = "r", shape: Tuple = (1,), dtype: np.dtype = np.float32, offset: int = 0, order: str = "C", ) -> np.ndarray: """Wrap memmap creation to support empty arrays. Parameters ---------- filename : str Filename where the empty memmap should be created. mode : str, optional File open mode (see np.memmap for options). Default is 'r'. shape : tuple, optional Shape of memmapped array. Default is (1,). dtype : np.dtype, optional Datatype of memmapped array. Default is np.float32. offset : int, optional Offset of the data within the file. Default is 0. order : str, optional Data representation on disk ('C' or 'F'). Default is 'C'. Returns ------- np.ndarray Memory-mapped array or a zero-filled array if shape[0] is 0. """ if np.dtype(dtype) == bool: filename = filename.replace(".bool", ".bit") # TRX format uses little-endian byte order for cross-platform compatibility dtype = _get_dtype_little_endian(dtype) if shape[0]: return np.memmap( filename, mode=mode, offset=offset, shape=shape, dtype=dtype, order=order ) else: if not os.path.isfile(filename): f = open(filename, "wb") f.close() return np.zeros(shape, dtype=dtype) def load(input_obj: str, check_dpg: bool = True) -> Type["TrxFile"]: """Load a TrxFile (compressed or not). Parameters ---------- input_obj : str A directory name or filepath to the TRX data. check_dpg : bool, optional Whether to check group metadata. Default is True. Returns ------- TrxFile TrxFile object representing the read data. """ # TODO Check if 0 streamlines, then 0 vertices is expected (vice-versa) # TODO 4x4 affine matrices should contains values (no all-zeros) # TODO 3x1 dimensions array should contains values at each position (int) if os.path.isfile(input_obj): was_compressed = False with zipfile.ZipFile(input_obj, "r") as zf: for info in zf.infolist(): if info.compress_type != 0: was_compressed = True break if was_compressed: with zipfile.ZipFile(input_obj, "r") as zf: tmp_dir = get_trx_tmp_dir() zf.extractall(tmp_dir.name) trx = load_from_directory(tmp_dir.name) trx._uncompressed_folder_handle = tmp_dir logging.info( "File was compressed, call the close() function before exiting." ) else: trx = load_from_zip(input_obj) elif os.path.isdir(input_obj): trx = load_from_directory(input_obj) else: raise ValueError("File/Folder does not exist") # Example of robust check for metadata if check_dpg: for dpg in trx.data_per_group.keys(): if dpg not in trx.groups.keys(): raise ValueError( "An undeclared group ({}) has data_per_group.".format(dpg) ) return trx def load_from_zip(filename: str) -> Type["TrxFile"]: """Load a TrxFile from a single zipfile. Note: Does not work with compressed zipfiles. Parameters ---------- filename : str Path of the zipped TrxFile. Returns ------- TrxFile TrxFile representing the read data. """ with zipfile.ZipFile(filename, mode="r") as zf: with zf.open("header.json") as zf_header: header = json.load(zf_header) header["VOXEL_TO_RASMM"] = np.reshape( header["VOXEL_TO_RASMM"], (4, 4) ).astype(np.float32) header["DIMENSIONS"] = np.array(header["DIMENSIONS"], dtype=np.uint16) files_pointer_size = {} for zip_info in zf.filelist: elem_filename = zip_info.filename _, ext = os.path.splitext(elem_filename) if ext == ".json" or zip_info.is_dir(): continue if not _is_dtype_valid(ext): continue raise ValueError("The dtype {} is not supported".format(elem_filename)) if ext == ".bit": ext = ".bool" # Read actual local file header to get correct data offset. # We can't use zip_info.FileHeader() because ZIP spec allows local # headers to differ from central directory entries. # See: https://pkware.cachefly.net/webdocs/casestudies/APPNOTE.TXT _ZIP_LOCAL_HEADER_SIZE = 30 _ZIP_LOCAL_HEADER_SIGNATURE = b"PK\x03\x04" zf.fp.seek(zip_info.header_offset) local_header = zf.fp.read(_ZIP_LOCAL_HEADER_SIZE) if len(local_header) < _ZIP_LOCAL_HEADER_SIZE: raise ValueError(f"Truncated local file header for {elem_filename}") if local_header[:4] != _ZIP_LOCAL_HEADER_SIGNATURE: raise ValueError( f"Invalid local file header signature for {elem_filename}" ) fname_len, extra_len = struct.unpack(" Type["TrxFile"]: """Load a TrxFile from a folder containing memmaps. Parameters ---------- directory : str Path of the directory containing TRX data. Returns ------- TrxFile TrxFile representing the read data. """ directory = os.path.abspath(directory) with open(os.path.join(directory, "header.json")) as header: header = json.load(header) header["VOXEL_TO_RASMM"] = np.reshape(header["VOXEL_TO_RASMM"], (4, 4)).astype( np.float32 ) header["DIMENSIONS"] = np.array(header["DIMENSIONS"], dtype=np.uint16) files_pointer_size = {} for root, _dirs, files in os.walk(directory): for name in files: elem_filename = os.path.join(root, name) _, ext = os.path.splitext(elem_filename) if ext == ".json": continue if not _is_dtype_valid(ext): raise ValueError( "The dtype of {} is not supported".format(elem_filename) ) if ext == ".bit": ext = ".bool" dtype_size = np.dtype(ext[1:]).itemsize size = os.path.getsize(elem_filename) / dtype_size if size.is_integer(): files_pointer_size[elem_filename] = 0, int(size) elif os.path.getsize(elem_filename) == 1: files_pointer_size[elem_filename] = 0, 0 else: raise ValueError("Wrong size or datatype") return TrxFile._create_trx_from_pointer(header, files_pointer_size, root=directory) def _filter_empty_trx_files(trx_list: List["TrxFile"]) -> List["TrxFile"]: """Remove empty TrxFiles from the list. Parameters ---------- trx_list : list of TrxFile class instances Collection of tractograms to filter. Returns ------- list of TrxFile class instances Only entries containing at least one streamline. """ return [curr_trx for curr_trx in trx_list if curr_trx.header["NB_STREAMLINES"] > 0] def _get_all_data_keys(trx_list: List["TrxFile"]) -> Tuple[set, set]: """Get all dps and dpv keys from the TrxFile list. Parameters ---------- trx_list : list of TrxFile class instances Collection of tractograms. Returns ------- tuple of set Sets of `data_per_streamline` keys and `data_per_vertex` keys. """ all_dps = [] all_dpv = [] for curr_trx in trx_list: all_dps.extend(list(curr_trx.data_per_streamline.keys())) all_dpv.extend(list(curr_trx.data_per_vertex.keys())) return set(all_dps), set(all_dpv) def _check_space_attributes(trx_list: List["TrxFile"]) -> None: """Verify that space attributes are consistent across TrxFiles. Parameters ---------- trx_list : list of TrxFile Tractograms to compare for affine and dimension consistency. Raises ------ ValueError If voxel-to-RASMM matrices or dimensions differ. """ ref_trx = trx_list[0] for curr_trx in trx_list[1:]: if not np.allclose( ref_trx.header["VOXEL_TO_RASMM"], curr_trx.header["VOXEL_TO_RASMM"] ) or not np.array_equal( ref_trx.header["DIMENSIONS"], curr_trx.header["DIMENSIONS"] ): raise ValueError("Wrong space attributes.") def _verify_dpv_coherence( trx_list: List["TrxFile"], all_dpv: set, ref_trx: "TrxFile", delete_dpv: bool ) -> None: """Verify dpv coherence across TrxFiles. Parameters ---------- trx_list : list of TrxFile class instances Tractograms being concatenated. all_dpv : set Union of `data_per_vertex` keys across tractograms. ref_trx : TrxFile class instance Reference tractogram for dtype/key checks. delete_dpv : bool Drop mismatched dpv keys instead of raising when True. Raises ------ ValueError If dpv keys or dtypes differ and `delete_dpv` is False. """ for curr_trx in trx_list: for key in all_dpv: if ( key not in ref_trx.data_per_vertex.keys() or key not in curr_trx.data_per_vertex.keys() ): if not delete_dpv: logging.debug( "{} dpv key does not exist in all TrxFile.".format(key) ) raise ValueError("TrxFile must be sharing identical dpv keys.") elif ( ref_trx.data_per_vertex[key]._data.dtype != curr_trx.data_per_vertex[key]._data.dtype ): logging.debug( "{} dpv key is not declared with the same dtype " "in all TrxFile.".format(key) ) raise ValueError("Shared dpv key, has different dtype.") def _verify_dps_coherence( trx_list: List["TrxFile"], all_dps: set, ref_trx: "TrxFile", delete_dps: bool ) -> None: """Verify dps coherence across TrxFiles. Parameters ---------- trx_list : list of TrxFile class instances Tractograms being concatenated. all_dps : set Union of data_per_streamline keys across tractograms. ref_trx : TrxFile class instance Reference tractogram for dtype/key checks. delete_dps : bool Drop mismatched dps keys instead of raising when True. Raises ------ ValueError If dps keys or dtypes differ and `delete_dps` is False. """ for curr_trx in trx_list: for key in all_dps: if ( key not in ref_trx.data_per_streamline.keys() or key not in curr_trx.data_per_streamline.keys() ): if not delete_dps: logging.debug( "{} dps key does not exist in all TrxFile.".format(key) ) raise ValueError("TrxFile must be sharing identical dps keys.") elif ( ref_trx.data_per_streamline[key].dtype != curr_trx.data_per_streamline[key].dtype ): logging.debug( "{} dps key is not declared with the same dtype " "in all TrxFile.".format(key) ) raise ValueError("Shared dps key, has different dtype.") def _compute_groups_info(trx_list: List["TrxFile"]) -> Tuple[dict, dict]: """Compute group length and dtype information. Parameters ---------- trx_list : list of TrxFile class instances Tractograms being concatenated. Returns ------- tuple of dict (group lengths, group dtypes) keyed by group name. """ all_groups_len = {} all_groups_dtype = {} for trx_1 in trx_list: for group_key in trx_1.groups.keys(): if group_key in all_groups_len: all_groups_len[group_key] += len(trx_1.groups[group_key]) else: all_groups_len[group_key] = len(trx_1.groups[group_key]) if ( group_key in all_groups_dtype and trx_1.groups[group_key].dtype != all_groups_dtype[group_key] ): raise ValueError("Shared group key, has different dtype.") else: all_groups_dtype[group_key] = trx_1.groups[group_key].dtype return all_groups_len, all_groups_dtype def _create_new_trx_for_concatenation( trx_list: List["TrxFile"], ref_trx: "TrxFile", delete_dps: bool, delete_dpv: bool, delete_groups: bool, ) -> "TrxFile": """Create a new TrxFile for concatenation. Parameters ---------- trx_list : list of TrxFile class instances Input tractograms to concatenate. ref_trx : TrxFile class instance Reference tractogram for header/dtype template. delete_dps : bool Drop `data_per_streamline` keys not shared. delete_dpv : bool Drop `data_per_vertex` keys not shared. delete_groups : bool Drop groups when metadata differ. Returns ------- TrxFile Empty TRX ready to receive concatenated data. """ nb_vertices = 0 nb_streamlines = 0 for curr_trx in trx_list: curr_strs_len, curr_pts_len = curr_trx._get_real_len() nb_streamlines += curr_strs_len nb_vertices += curr_pts_len new_trx = TrxFile( nb_vertices=nb_vertices, nb_streamlines=nb_streamlines, init_as=ref_trx ) if delete_dps: new_trx.data_per_streamline = {} if delete_dpv: new_trx.data_per_vertex = {} if delete_groups: new_trx.groups = {} return new_trx def _setup_groups_for_concatenation( new_trx: "TrxFile", trx_list: List["TrxFile"], all_groups_len: dict, all_groups_dtype: dict, delete_groups: bool, ) -> None: """Setup groups in the new TrxFile for concatenation. Parameters ---------- new_trx : TrxFile class instance Destination tractogram. trx_list : list of TrxFile class instances Source tractograms. all_groups_len : dict Mapping of group name to total length. all_groups_dtype : dict Mapping of group name to dtype. delete_groups : bool If True, skip creating group arrays. """ if delete_groups: return tmp_dir = new_trx._uncompressed_folder_handle.name for group_key in all_groups_len.keys(): if not os.path.isdir(os.path.join(tmp_dir, "groups/")): os.mkdir(os.path.join(tmp_dir, "groups/")) dtype = all_groups_dtype[group_key] group_filename = os.path.join( tmp_dir, "groups/{}.{}".format(group_key, dtype.name) ) group_len = all_groups_len[group_key] new_trx.groups[group_key] = _create_memmap( group_filename, mode="w+", shape=(group_len,), dtype=dtype ) pos = 0 count = 0 for curr_trx in trx_list: curr_len = len(curr_trx.groups[group_key]) new_trx.groups[group_key][pos : pos + curr_len] = ( curr_trx.groups[group_key] + count ) pos += curr_len count += curr_trx.header["NB_STREAMLINES"] def concatenate( trx_list: List["TrxFile"], delete_dpv: bool = False, delete_dps: bool = False, delete_groups: bool = False, check_space_attributes: bool = True, preallocation: bool = False, ) -> "TrxFile": """Concatenate multiple TrxFile together, with support for preallocation. Parameters ---------- trx_list : list of TrxFile A list containing TrxFiles to concatenate. delete_dpv : bool, optional Delete dpv keys that do not exist in all the provided TrxFiles. Default is False. delete_dps : bool, optional Delete dps keys that do not exist in all the provided TrxFiles. Default is False. delete_groups : bool, optional Delete all the groups that currently exist in the TrxFiles. Default is False. check_space_attributes : bool, optional Verify that dimensions and size of data are similar between all the TrxFiles. Default is True. preallocation : bool, optional Preallocated TrxFile has already been generated and is the first element in trx_list. Note: delete_groups must be set to True as well. Default is False. Returns ------- TrxFile TrxFile representing the concatenated data. """ trx_list = _filter_empty_trx_files(trx_list) if len(trx_list) == 0: logging.warning("Inputs of concatenation were empty.") return TrxFile() ref_trx = trx_list[0] all_dps, all_dpv = _get_all_data_keys(trx_list) if check_space_attributes: _check_space_attributes(trx_list) if preallocation and not delete_groups: raise ValueError("Groups are variables, cannot be handled with preallocation") _verify_dpv_coherence(trx_list, all_dpv, ref_trx, delete_dpv) _verify_dps_coherence(trx_list, all_dps, ref_trx, delete_dps) all_groups_len, all_groups_dtype = _compute_groups_info(trx_list) to_concat_list = trx_list[1:] if preallocation else trx_list if not preallocation: new_trx = _create_new_trx_for_concatenation( to_concat_list, ref_trx, delete_dps, delete_dpv, delete_groups ) _setup_groups_for_concatenation( new_trx, trx_list, all_groups_len, all_groups_dtype, delete_groups ) strs_end, pts_end = 0, 0 else: new_trx = ref_trx strs_end, pts_end = new_trx._get_real_len() for curr_trx in to_concat_list: strs_end, pts_end = new_trx._copy_fixed_arrays_from( curr_trx, strs_start=strs_end, pts_start=pts_end ) return new_trx def save( trx: "TrxFile", filename: str, compression_standard: Any = zipfile.ZIP_STORED ) -> None: """Save a TrxFile (compressed or not). Parameters ---------- trx : TrxFile The TrxFile to save. filename : str The path to save the TrxFile to. compression_standard : int, optional The compression standard to use, as defined by the ZipFile library. Default is zipfile.ZIP_STORED. """ _, ext = os.path.splitext(filename) if ext not in [".zip", ".trx", ""]: raise ValueError("Unsupported extension.") copy_trx = trx.deepcopy() copy_trx.resize() tmp_dir_name = copy_trx._uncompressed_folder_handle.name if ext in [".zip", ".trx"]: zip_from_folder(tmp_dir_name, filename, compression_standard) else: if os.path.isdir(filename): shutil.rmtree(filename) shutil.copytree(tmp_dir_name, filename) copy_trx.close() def zip_from_folder( directory: str, filename: str, compression_standard: Any = zipfile.ZIP_STORED ) -> None: """Zip on-disk memmaps into a single file. Parameters ---------- directory : str The path to the on-disk memmap directory. filename : str The path where the zip file should be created. compression_standard : int, optional The compression standard to use, as defined by the ZipFile library. Default is zipfile.ZIP_STORED. """ with zipfile.ZipFile(filename, mode="w", compression=compression_standard) as zf: for root, _, files in os.walk(directory): for name in files: curr_filename = os.path.join(root, name) tmp_filename = curr_filename.replace(directory, "")[1:] zf.write(curr_filename, tmp_filename) class TrxFile: """Core class of the TrxFile. Parameters ---------- nb_vertices : int, optional The number of vertices to use in the new TrxFile. nb_streamlines : int, optional The number of streamlines in the new TrxFile. init_as : TrxFile class instance, optional A TrxFile to use as reference. reference : str, dict, Nifti1Image, TrkFile, or Nifti1Header, optional A Nifti or Trk file/obj to use as reference. """ header: dict streamlines: Type[ArraySequence] groups: dict data_per_streamline: dict data_per_vertex: dict data_per_group: dict def __init__( self, nb_vertices: Optional[int] = None, nb_streamlines: Optional[int] = None, init_as: Optional[Type["TrxFile"]] = None, reference: Union[ str, dict, Type[Nifti1Image], Type[TrkFile], Type[Nifti1Header], None, ] = None, ) -> None: """Initialize an empty TrxFile with support for preallocation. Parameters ---------- nb_vertices : int, optional The number of vertices to use in the new TrxFile. nb_streamlines : int, optional The number of streamlines in the new TrxFile. init_as : TrxFile, optional A TrxFile to use as reference. reference : str, dict, Nifti1Image, TrkFile, Nifti1Header, optional A Nifti or Trk file/obj to use as reference. """ if init_as is not None: affine = init_as.header["VOXEL_TO_RASMM"] dimensions = init_as.header["DIMENSIONS"] elif reference is not None: affine, dimensions, _, _ = get_reference_info_wrapper(reference) else: logging.debug( "No reference provided, using blank space " "attributes, please update them later." ) affine = np.eye(4).astype(np.float32) dimensions = np.array([1, 1, 1], dtype=np.uint16) if nb_vertices is None and nb_streamlines is None: if init_as is not None: raise ValueError( "Can't use init_as without declaring nb_vertices AND nb_streamlines" ) logging.debug("Initializing empty TrxFile.") self.header = {} # Using the new format default type tmp_strs = ArraySequence() tmp_strs._data = tmp_strs._data.astype(np.float32) tmp_strs._offsets = tmp_strs._offsets.astype(np.uint32) tmp_strs._lengths = tmp_strs._lengths.astype(np.uint32) self.streamlines = tmp_strs self.groups = {} self.data_per_streamline = {} self.data_per_vertex = {} self.data_per_group = {} self._uncompressed_folder_handle = None nb_vertices = 0 nb_streamlines = 0 elif nb_vertices is not None and nb_streamlines is not None: logging.debug( "Preallocating TrxFile with size {} streamlinesand {} vertices.".format( nb_streamlines, nb_vertices ) ) trx = self._initialize_empty_trx( nb_streamlines, nb_vertices, init_as=init_as ) self.__dict__ = trx.__dict__ else: raise ValueError("You must declare both nb_vertices AND NB_STREAMLINES") self.header["VOXEL_TO_RASMM"] = affine self.header["DIMENSIONS"] = dimensions self.header["NB_VERTICES"] = nb_vertices self.header["NB_STREAMLINES"] = nb_streamlines self._copy_safe = True def __str__(self) -> str: """Generate the string for printing""" affine = np.array(self.header["VOXEL_TO_RASMM"], dtype=np.float32) dimensions = np.array(self.header["DIMENSIONS"], dtype=np.uint16) vox_sizes = np.array(voxel_sizes(affine), dtype=np.float32) vox_order = "".join(aff2axcodes(affine)) text = "VOXEL_TO_RASMM: \n{}".format( np.array2string(affine, formatter={"float_kind": lambda x: "%.6f" % x}) ) text += "\nDIMENSIONS: {}".format(np.array2string(dimensions)) text += "\nVOX_SIZES: {}".format( np.array2string(vox_sizes, formatter={"float_kind": lambda x: "%.2f" % x}) ) text += "\nVOX_ORDER: {}".format(vox_order) strs_size = self.header["NB_STREAMLINES"] pts_size = self.header["NB_VERTICES"] strs_len, pts_len = self._get_real_len() if strs_size != strs_len or pts_size != pts_len: text += "\nstreamline_size: {}".format(strs_size) text += "\nvertex_size: {}".format(pts_size) text += "\nstreamline_count: {}".format(strs_len) text += "\nvertex_count: {}".format(pts_len) dpv_keys = list(self.data_per_vertex.keys()) if dpv_keys: text += "\ndata_per_vertex keys: {}".format(dpv_keys) else: text += "\nNo data per vertex (dpv) keys" dps_keys = list(self.data_per_streamline.keys()) if dps_keys: text += "\ndata_per_streamline keys: {}".format(dps_keys) else: text += "\nNo data per streamline (dps) keys" group_keys = list(self.groups.keys()) if group_keys: text += "\ngroups keys: {}".format(group_keys) else: text += "\nNo group keys" for group_key in self.groups.keys(): if group_key in self.data_per_group: text += "\ndata_per_groups ({}) keys: {}".format( group_key, list(self.data_per_group[group_key].keys()) ) text += "\ncopy_safe: {}".format(self._copy_safe) return text def __len__(self) -> int: """Define the length of the object""" return len(self.streamlines) def __getitem__(self, key) -> Any: """Slice all data in a consistent way""" if isinstance(key, int): if key < 0: key += len(self) key = [key] elif isinstance(key, slice): key = list(range(*key.indices(len(self)))) return self.select(key, keep_group=False) def __deepcopy__(self) -> Type["TrxFile"]: """Return a deep copy of the TrxFile. Parameters ---------- self TrxFile class instance. Returns ------- TrxFile class instance Deep-copied instance. """ return self.deepcopy() def deepcopy(self) -> Type["TrxFile"]: # noqa: C901 """Create a deepcopy of the TrxFile. Returns ------- TrxFile A deepcopied TrxFile of the current TrxFile. """ tmp_dir = get_trx_tmp_dir() out_json = open(os.path.join(tmp_dir.name, "header.json"), "w") tmp_header = deepcopy(self.header) if not isinstance(tmp_header["VOXEL_TO_RASMM"], list): tmp_header["VOXEL_TO_RASMM"] = tmp_header["VOXEL_TO_RASMM"].tolist() if not isinstance(tmp_header["DIMENSIONS"], list): tmp_header["DIMENSIONS"] = tmp_header["DIMENSIONS"].tolist() # tofile() always write in C-order # Ensure little-endian byte order for cross-platform compatibility if not self._copy_safe: to_dump = self.streamlines.copy()._data tmp_header["NB_STREAMLINES"] = len(self.streamlines) tmp_header["NB_VERTICES"] = len(to_dump) else: to_dump = self.streamlines._data json.dump(tmp_header, out_json) out_json.close() # Only write positions and offsets if TRX is not empty if tmp_header["NB_STREAMLINES"] > 0 and tmp_header["NB_VERTICES"] > 0: positions_filename = _generate_filename_from_data( to_dump, os.path.join(tmp_dir.name, "positions") ) _ensure_little_endian(to_dump).tofile(positions_filename) if not self._copy_safe: to_dump = _append_last_offsets( self.streamlines.copy()._offsets, self.header["NB_VERTICES"] ) else: to_dump = _append_last_offsets( self.streamlines._offsets, self.header["NB_VERTICES"] ) offsets_filename = _generate_filename_from_data( self.streamlines._offsets, os.path.join(tmp_dir.name, "offsets") ) _ensure_little_endian(to_dump).tofile(offsets_filename) if len(self.data_per_vertex.keys()) > 0: os.mkdir(os.path.join(tmp_dir.name, "dpv/")) for dpv_key in self.data_per_vertex.keys(): if not self._copy_safe: to_dump = self.data_per_vertex[dpv_key].copy()._data else: to_dump = self.data_per_vertex[dpv_key]._data dpv_filename = _generate_filename_from_data( to_dump, os.path.join(tmp_dir.name, "dpv/", dpv_key) ) _ensure_little_endian(to_dump).tofile(dpv_filename) if len(self.data_per_streamline.keys()) > 0: os.mkdir(os.path.join(tmp_dir.name, "dps/")) for dps_key in self.data_per_streamline.keys(): to_dump = self.data_per_streamline[dps_key] dps_filename = _generate_filename_from_data( to_dump, os.path.join(tmp_dir.name, "dps/", dps_key) ) _ensure_little_endian(to_dump).tofile(dps_filename) if len(self.groups.keys()) > 0: os.mkdir(os.path.join(tmp_dir.name, "groups/")) for group_key in self.groups.keys(): to_dump = self.groups[group_key] group_filename = _generate_filename_from_data( to_dump, os.path.join(tmp_dir.name, "groups/", group_key) ) _ensure_little_endian(to_dump).tofile(group_filename) if group_key not in self.data_per_group: continue for dpg_key in self.data_per_group[group_key].keys(): # Creates 'dpg/' only if required if not os.path.isdir(os.path.join(tmp_dir.name, "dpg/")): os.mkdir(os.path.join(tmp_dir.name, "dpg/")) if not os.path.isdir(os.path.join(tmp_dir.name, "dpg/", group_key)): os.mkdir(os.path.join(tmp_dir.name, "dpg/", group_key)) to_dump = self.data_per_group[group_key][dpg_key] dpg_filename = _generate_filename_from_data( to_dump, os.path.join(tmp_dir.name, "dpg/", group_key, dpg_key) ) _ensure_little_endian(to_dump).tofile(dpg_filename) copy_trx = load_from_directory(tmp_dir.name) copy_trx._uncompressed_folder_handle = tmp_dir return copy_trx def _get_real_len(self) -> Tuple[int, int]: """Get the real size of data (ignoring zeros of preallocation). Returns ------- tuple of int A tuple (strs_end, pts_end) representing the index of the last streamline and the total length of all the streamlines. """ if len(self.streamlines._lengths) == 0: return 0, 0 last_elem_pos = _dichotomic_search(self.streamlines._lengths) if last_elem_pos != -1: strs_end = int(last_elem_pos + 1) pts_end = int(np.sum(self.streamlines._lengths[0:strs_end])) return strs_end, pts_end return 0, 0 def _copy_fixed_arrays_from( self, trx: Type["TrxFile"], strs_start: int = 0, pts_start: int = 0, nb_strs_to_copy: Optional[int] = None, ) -> Tuple[int, int]: """Fill a TrxFile using another and start indexes (preallocation). Parameters ---------- trx : TrxFile TrxFile to copy data from. strs_start : int, optional The start index of the streamline. Default is 0. pts_start : int, optional The start index of the point. Default is 0. nb_strs_to_copy : int, optional The number of streamlines to copy. If not set, will copy all. Returns ------- tuple of int A tuple (strs_end, pts_end) representing the end of the copied streamlines and end of copied points. """ if nb_strs_to_copy is None: curr_strs_len, curr_pts_len = trx._get_real_len() else: curr_strs_len = int(nb_strs_to_copy) curr_pts_len = np.sum(trx.streamlines._lengths[0:curr_strs_len]) curr_pts_len = int(curr_pts_len) strs_end = strs_start + curr_strs_len pts_end = pts_start + curr_pts_len if curr_pts_len == 0: return strs_start, pts_start # Mandatory arrays self.streamlines._data[pts_start:pts_end] = trx.streamlines._data[ 0:curr_pts_len ] self.streamlines._offsets[strs_start:strs_end] = ( trx.streamlines._offsets[0:curr_strs_len] + pts_start ) self.streamlines._lengths[strs_start:strs_end] = trx.streamlines._lengths[ 0:curr_strs_len ] # Optional fixed-sized arrays for dpv_key in self.data_per_vertex.keys(): self.data_per_vertex[dpv_key]._data[pts_start:pts_end] = ( trx.data_per_vertex[dpv_key]._data[0:curr_pts_len] ) self.data_per_vertex[dpv_key]._offsets = self.streamlines._offsets self.data_per_vertex[dpv_key]._lengths = self.streamlines._lengths for dps_key in self.data_per_streamline.keys(): self.data_per_streamline[dps_key][strs_start:strs_end] = ( trx.data_per_streamline[dps_key][0:curr_strs_len] ) return strs_end, pts_end @staticmethod def _initialize_empty_trx( # noqa: C901 nb_streamlines: int, nb_vertices: int, init_as: Optional[Type["TrxFile"]] = None, ) -> Type["TrxFile"]: """Create on-disk memmaps of a certain size (preallocation). Parameters ---------- nb_streamlines : int The number of streamlines that the empty TrxFile will be initialized with. nb_vertices : int The number of vertices that the empty TrxFile will be initialized with. init_as : TrxFile, optional A TrxFile to initialize the empty TrxFile with. Returns ------- TrxFile An empty TrxFile preallocated with a certain size. """ trx = TrxFile() tmp_dir = get_trx_tmp_dir() logging.info("Temporary folder for memmaps: {}".format(tmp_dir.name)) trx.header["NB_VERTICES"] = nb_vertices trx.header["NB_STREAMLINES"] = nb_streamlines if init_as is not None: trx.header["VOXEL_TO_RASMM"] = init_as.header["VOXEL_TO_RASMM"] trx.header["DIMENSIONS"] = init_as.header["DIMENSIONS"] positions_dtype = init_as.streamlines._data.dtype offsets_dtype = init_as.streamlines._offsets.dtype lengths_dtype = init_as.streamlines._lengths.dtype else: positions_dtype = np.dtype(np.float16) offsets_dtype = np.dtype(np.uint32) lengths_dtype = np.dtype(np.uint32) logging.debug( "Initializing positions with dtype: {}".format(positions_dtype.name) ) logging.debug("Initializing offsets with dtype: {}".format(offsets_dtype.name)) logging.debug("Initializing lengths with dtype: {}".format(lengths_dtype.name)) # A TrxFile without init_as only contain the essential arrays positions_filename = os.path.join( tmp_dir.name, "positions.3.{}".format(positions_dtype.name) ) trx.streamlines._data = _create_memmap( positions_filename, mode="w+", shape=(nb_vertices, 3), dtype=positions_dtype ) offsets_filename = os.path.join( tmp_dir.name, "offsets.{}".format(offsets_dtype.name) ) trx.streamlines._offsets = _create_memmap( offsets_filename, mode="w+", shape=(nb_streamlines,), dtype=offsets_dtype ) trx.streamlines._lengths = np.zeros( shape=(nb_streamlines,), dtype=lengths_dtype ) # Only the structure of fixed-size arrays is copied if init_as is not None: if len(init_as.data_per_vertex.keys()) > 0: os.mkdir(os.path.join(tmp_dir.name, "dpv/")) if len(init_as.data_per_streamline.keys()) > 0: os.mkdir(os.path.join(tmp_dir.name, "dps/")) for dpv_key in init_as.data_per_vertex.keys(): dtype = init_as.data_per_vertex[dpv_key]._data.dtype tmp_as = init_as.data_per_vertex[dpv_key]._data if tmp_as.ndim == 1: dpv_filename = os.path.join( tmp_dir.name, "dpv/{}.{}".format(dpv_key, dtype.name) ) shape = (nb_vertices, 1) elif tmp_as.ndim == 2: dim = tmp_as.shape[-1] shape = (nb_vertices, dim) dpv_filename = os.path.join( tmp_dir.name, "dpv/{}.{}.{}".format(dpv_key, dim, dtype.name) ) else: raise ValueError("Invalid dimensionality.") logging.debug( "Initializing {} (dpv) with dtype: {}".format(dpv_key, dtype.name) ) trx.data_per_vertex[dpv_key] = ArraySequence() trx.data_per_vertex[dpv_key]._data = _create_memmap( dpv_filename, mode="w+", shape=shape, dtype=dtype ) trx.data_per_vertex[dpv_key]._offsets = trx.streamlines._offsets trx.data_per_vertex[dpv_key]._lengths = trx.streamlines._lengths for dps_key in init_as.data_per_streamline.keys(): dtype = init_as.data_per_streamline[dps_key].dtype tmp_as = init_as.data_per_streamline[dps_key] if tmp_as.ndim == 1: dps_filename = os.path.join( tmp_dir.name, "dps/{}.{}".format(dps_key, dtype.name) ) shape = (nb_streamlines,) elif tmp_as.ndim == 2: dim = tmp_as.shape[-1] shape = (nb_streamlines, dim) dps_filename = os.path.join( tmp_dir.name, "dps/{}.{}.{}".format(dps_key, dim, dtype.name) ) else: raise ValueError("Invalid dimensionality.") logging.debug( "Initializing {} (dps) with and dtype: {}".format( dps_key, dtype.name ) ) trx.data_per_streamline[dps_key] = _create_memmap( dps_filename, mode="w+", shape=shape, dtype=dtype ) trx._uncompressed_folder_handle = tmp_dir return trx def _create_trx_from_pointer( # noqa: C901 header: dict, dict_pointer_size: dict, root_zip: Optional[str] = None, root: Optional[str] = None, ) -> Type["TrxFile"]: """Create a TrxFile after reading the structure of a zip/folder. Parameters ---------- header : dict A TrxFile header dictionary which will be used for the new TrxFile. dict_pointer_size : dict A dictionary containing the filenames of all the files within the TrxFile disk file/folder. root_zip : str, optional The path of the ZipFile pointer. root : str, optional The dirname of the ZipFile pointer. Returns ------- TrxFile A TrxFile constructed from the pointer provided. """ trx = TrxFile() trx.header = header # Handle empty TRX files early - no positions/offsets to load if header["NB_STREAMLINES"] == 0 or header["NB_VERTICES"] == 0: return trx positions, offsets = None, None for elem_filename in dict_pointer_size.keys(): if root_zip: filename = root_zip else: filename = elem_filename folder = os.path.dirname(elem_filename) base, dim, ext = _split_ext_with_dimensionality(elem_filename) if ext == ".bit": ext = ".bool" mem_adress, size = dict_pointer_size[elem_filename] if root is not None: # This is for Unix if os.name != "nt" and folder.startswith(root.rstrip("/")): folder = folder.replace(root, "").lstrip("/") # These three are for Windows elif os.path.isdir(folder) and os.path.basename(folder) in [ "dpv", "dps", "groups", ]: folder = os.path.basename(folder) elif os.path.basename(os.path.dirname(folder)) == "dpg": folder = os.path.join("dpg", os.path.basename(folder)) else: folder = "" # Parse/walk the directory tree if base == "positions" and folder == "": if size != trx.header["NB_VERTICES"] * 3 or dim != 3: raise ValueError("Wrong data size/dimensionality.") positions = _create_memmap( filename, mode="r+", offset=mem_adress, shape=(trx.header["NB_VERTICES"], 3), dtype=ext[1:], ) elif base == "offsets" and folder == "": if size != trx.header["NB_STREAMLINES"] + 1 or dim != 1: raise ValueError("Wrong offsets size/dimensionality.") offsets = _create_memmap( filename, mode="r+", offset=mem_adress, shape=(trx.header["NB_STREAMLINES"] + 1,), dtype=ext[1:], ) if offsets[-1] != 0: lengths = _compute_lengths(offsets) else: lengths = [0] elif folder == "dps": nb_scalar = size / trx.header["NB_STREAMLINES"] if not nb_scalar.is_integer() or nb_scalar != dim: raise ValueError("Wrong dps size/dimensionality.") else: shape = (trx.header["NB_STREAMLINES"], int(nb_scalar)) trx.data_per_streamline[base] = _create_memmap( filename, mode="r+", offset=mem_adress, shape=shape, dtype=ext[1:] ) elif folder == "dpv": nb_scalar = size / trx.header["NB_VERTICES"] if not nb_scalar.is_integer() or nb_scalar != dim: raise ValueError("Wrong dpv size/dimensionality.") else: shape = (trx.header["NB_VERTICES"], int(nb_scalar)) trx.data_per_vertex[base] = _create_memmap( filename, mode="r+", offset=mem_adress, shape=shape, dtype=ext[1:] ) elif folder.startswith("dpg"): if int(size) != dim: raise ValueError("Wrong dpg size/dimensionality.") else: shape = (1, int(size)) # Handle the two-layers architecture data_name = os.path.basename(base) sub_folder = os.path.basename(folder) if sub_folder not in trx.data_per_group: trx.data_per_group[sub_folder] = {} trx.data_per_group[sub_folder][data_name] = _create_memmap( filename, mode="r+", offset=mem_adress, shape=shape, dtype=ext[1:] ) elif folder == "groups": # Groups are simply indices, nothing else # TODO Crash if not uint? if dim != 1: raise ValueError("Wrong group dimensionality.") else: shape = (int(size),) trx.groups[base] = _create_memmap( filename, mode="r+", offset=mem_adress, shape=shape, dtype=ext[1:] ) else: logging.error( "{} is not part of a valid structure.".format(elem_filename) ) # All essential array must be declared if positions is not None and offsets is not None: trx.streamlines._data = positions trx.streamlines._offsets = offsets[:-1] trx.streamlines._lengths = lengths else: raise ValueError("Missing essential data.") for dpv_key in trx.data_per_vertex: tmp = trx.data_per_vertex[dpv_key] trx.data_per_vertex[dpv_key] = ArraySequence() trx.data_per_vertex[dpv_key]._data = tmp trx.data_per_vertex[dpv_key]._offsets = offsets[:-1] trx.data_per_vertex[dpv_key]._lengths = lengths return trx def resize( # noqa: C901 self, nb_streamlines: Optional[int] = None, nb_vertices: Optional[int] = None, delete_dpg: bool = False, ) -> None: """Remove the unused portion of preallocated memmaps. Parameters ---------- nb_streamlines : int, optional The number of streamlines to keep. nb_vertices : int, optional The number of vertices to keep. delete_dpg : bool, optional Remove data_per_group when resizing. Default is False. """ if not self._copy_safe: raise ValueError("Cannot resize a sliced datasets.") strs_end, pts_end = self._get_real_len() if nb_streamlines is not None and nb_streamlines < strs_end: strs_end = nb_streamlines logging.info( "Resizing (down) memmaps, less streamlines than it actually contains." ) if nb_vertices is None: pts_end = int(np.sum(self.streamlines._lengths[0:nb_streamlines])) nb_vertices = pts_end elif nb_vertices < pts_end: # Resizing vertices only is too dangerous, not allowed logging.warning("Cannot resize (down) vertices for consistency.") return if nb_streamlines is None: nb_streamlines = strs_end if ( nb_streamlines == self.header["NB_STREAMLINES"] and nb_vertices == self.header["NB_VERTICES"] ): logging.debug("TrxFile of the right size, no resizing.") return trx = self._initialize_empty_trx(nb_streamlines, nb_vertices, init_as=self) logging.info( "Resizing streamlines from size {} to {}".format( len(self.streamlines), nb_streamlines ) ) logging.info( "Resizing vertices from size {} to {}".format( len(self.streamlines._data), nb_vertices ) ) # Copy the fixed-sized info from the original TrxFile to the new # (resized) one. if nb_streamlines < self.header["NB_STREAMLINES"]: trx._copy_fixed_arrays_from(self, nb_strs_to_copy=nb_streamlines) else: trx._copy_fixed_arrays_from(self) tmp_dir = trx._uncompressed_folder_handle.name if len(self.groups.keys()) > 0: os.mkdir(os.path.join(tmp_dir, "groups/")) for group_key in self.groups.keys(): group_dtype = self.groups[group_key].dtype group_name = os.path.join( tmp_dir, "groups/", "{}.{}".format(group_key, group_dtype.name) ) ori_len = len(self.groups[group_key]) # Remove groups indices if resizing down tmp = self.groups[group_key][self.groups[group_key] < strs_end] trx.groups[group_key] = _create_memmap( group_name, mode="w+", shape=(len(tmp),), dtype=group_dtype ) logging.debug( "{} group went from {} items to {}".format(group_key, ori_len, len(tmp)) ) trx.groups[group_key][:] = tmp if delete_dpg: self.close() self.__dict__ = trx.__dict__ return if len(self.data_per_group.keys()) > 0: os.mkdir(os.path.join(tmp_dir, "dpg/")) for group_key in self.data_per_group: if not os.path.isdir(os.path.join(tmp_dir, "dpg/", group_key)): os.mkdir(os.path.join(tmp_dir, "dpg/", group_key)) if group_key not in trx.data_per_group: trx.data_per_group[group_key] = {} for dpg_key in self.data_per_group[group_key].keys(): dpg_dtype = self.data_per_group[group_key][dpg_key].dtype dpg_filename = _generate_filename_from_data( self.data_per_group[group_key][dpg_key], os.path.join(tmp_dir, "dpg/", group_key, dpg_key), ) shape = self.data_per_group[group_key][dpg_key].shape if dpg_key not in trx.data_per_group[group_key]: trx.data_per_group[group_key][dpg_key] = {} trx.data_per_group[group_key][dpg_key] = _create_memmap( dpg_filename, mode="w+", shape=shape, dtype=dpg_dtype ) trx.data_per_group[group_key][dpg_key][:] = self.data_per_group[ group_key ][dpg_key] self.close() self.__dict__ = trx.__dict__ def get_dtype_dict(self): """Get the dtype dictionary for the TrxFile. Returns ------- dict A dictionary containing the dtype for each data element. """ dtype_dict = { "positions": self.streamlines._data.dtype, "offsets": self.streamlines._offsets.dtype, "dpv": {}, "dps": {}, "dpg": {}, "groups": {}, } for key in self.data_per_vertex.keys(): dtype_dict["dpv"][key] = self.data_per_vertex[key]._data.dtype for key in self.data_per_streamline.keys(): dtype_dict["dps"][key] = self.data_per_streamline[key].dtype for group_key in self.data_per_group.keys(): dtype_dict["groups"][group_key] = self.groups[group_key].dtype for group_key in self.data_per_group.keys(): dtype_dict["dpg"][group_key] = {} for dpg_key in self.data_per_group[group_key].keys(): dtype_dict["dpg"][group_key][dpg_key] = self.data_per_group[group_key][ dpg_key ].dtype return dtype_dict def append(self, obj, extra_buffer: int = 0) -> None: """Append another tractogram-like object to this TRX. Parameters ---------- obj : TrxFile or Tractogram or StatefulTractogram class instance Object whose streamlines and associated data will be appended. extra_buffer : int, optional Additional preallocation buffer for streamlines (in count). Returns ------- None Mutates the current TrxFile in-place. """ curr_dtype_dict = self.get_dtype_dict() if dipy_available: from dipy.io.stateful_tractogram import StatefulTractogram if not isinstance(obj, (TrxFile, Tractogram)) and ( dipy_available and not isinstance(obj, StatefulTractogram) ): raise TypeError( "{} is not a supported object type for appending.".format(type(obj)) ) elif isinstance(obj, Tractogram): obj = self.from_tractogram( obj, reference=self.header, dtype_dict=curr_dtype_dict ) elif dipy_available and isinstance(obj, StatefulTractogram): obj = self.from_sft(obj, dtype_dict=curr_dtype_dict) self._append_trx(obj, extra_buffer=extra_buffer) def _append_trx(self, trx: Type["TrxFile"], extra_buffer: int = 0) -> None: """Append a TrxFile to another (with buffer support). Parameters ---------- trx : TrxFile The TrxFile to append to the current TrxFile. extra_buffer : int, optional The additional buffer space required to append data. Default is 0. """ strs_end, pts_end = self._get_real_len() nb_streamlines = strs_end + trx.header["NB_STREAMLINES"] nb_vertices = pts_end + trx.header["NB_VERTICES"] if ( self.header["NB_STREAMLINES"] < nb_streamlines or self.header["NB_VERTICES"] < nb_vertices ): self.resize( nb_streamlines=nb_streamlines + extra_buffer, nb_vertices=nb_vertices + extra_buffer * 100, ) _ = concatenate([self, trx], preallocation=True, delete_groups=True) def get_group( self, key: str, keep_group: bool = True, copy_safe: bool = False ) -> Type["TrxFile"]: """Get a particular group from the TrxFile. Parameters ---------- key : str The group name to select. keep_group : bool, optional Make sure group exists in returned TrxFile. Default is True. copy_safe : bool, optional Perform a deepcopy. Default is False. Returns ------- TrxFile A TrxFile exclusively containing data from said group. """ return self.select(self.groups[key], keep_group=keep_group, copy_safe=copy_safe) def select( self, indices: np.ndarray, keep_group: bool = True, copy_safe: bool = False ) -> Type["TrxFile"]: """Get a subset of items, always pointing to the same memmaps. Parameters ---------- indices : np.ndarray The list of indices of elements to return. keep_group : bool, optional Ensure group is returned in output TrxFile. Default is True. copy_safe : bool, optional Perform a deep-copy. Default is False. Returns ------- TrxFile A TrxFile containing data originating from the selected indices. """ indices = np.array(indices, dtype=np.uint32) new_trx = TrxFile() new_trx._copy_safe = copy_safe new_trx.header = deepcopy(self.header) if isinstance(indices, np.ndarray) and len(indices) == 0: # Even while empty, basic dtype and header must be coherent positions_dtype = self.streamlines._data.dtype offsets_dtype = self.streamlines._offsets.dtype lengths_dtype = self.streamlines._lengths.dtype new_trx.streamlines._data = new_trx.streamlines._data.reshape( (0, 3) ).astype(positions_dtype) new_trx.streamlines._offsets = new_trx.streamlines._offsets.astype( offsets_dtype ) new_trx.streamlines._lengths = new_trx.streamlines._lengths.astype( lengths_dtype ) new_trx.header["NB_VERTICES"] = len(new_trx.streamlines._data) new_trx.header["NB_STREAMLINES"] = len(new_trx.streamlines._lengths) return new_trx.deepcopy() if copy_safe else new_trx new_trx.streamlines = ( self.streamlines[indices].copy() if copy_safe else self.streamlines[indices] ) for dpv_key in self.data_per_vertex.keys(): new_trx.data_per_vertex[dpv_key] = ( self.data_per_vertex[dpv_key][indices].copy() if copy_safe else self.data_per_vertex[dpv_key][indices] ) for dps_key in self.data_per_streamline.keys(): new_trx.data_per_streamline[dps_key] = ( self.data_per_streamline[dps_key][indices].copy() if copy_safe else self.data_per_streamline[dps_key][indices] ) # Not keeping group is equivalent to the [] operator if keep_group: logging.warning("Keeping dpg despite affecting the group items.") for group_key in self.groups.keys(): # Keep the group indices even when fancy slicing index = np.argsort(indices) sorted_x = indices[index] sorted_index = np.searchsorted(sorted_x, self.groups[group_key]) yindex = np.take(index, sorted_index, mode="clip") mask = indices[yindex] != self.groups[group_key] intersect = yindex[~mask] if len(intersect) == 0: continue new_trx.groups[group_key] = intersect if group_key in self.data_per_group: for dpg_key in self.data_per_group[group_key].keys(): if group_key not in new_trx.data_per_group: new_trx.data_per_group[group_key] = {} new_trx.data_per_group[group_key][dpg_key] = ( self.data_per_group[group_key][dpg_key] ) new_trx.header["NB_VERTICES"] = len(new_trx.streamlines._data) new_trx.header["NB_STREAMLINES"] = len(new_trx.streamlines._lengths) return new_trx.deepcopy() if copy_safe else new_trx @staticmethod def from_lazy_tractogram( obj: ["LazyTractogram"], reference, extra_buffer: int = 0, chunk_size: int = 10000, dtype_dict: dict = None, ) -> Type["TrxFile"]: """Create a TrxFile from a LazyTractogram with buffer support. Parameters ---------- obj : LazyTractogram The LazyTractogram to convert. reference : object Reference for spatial information. extra_buffer : int, optional The buffer space between reallocation. This number should be a number of streamlines. Use 0 for no buffer. Default is 0. chunk_size : int, optional The number of streamlines to save at a time. Default is 10000. dtype_dict : dict, optional Dictionary specifying dtypes for positions, offsets, dpv, and dps. Returns ------- TrxFile A TrxFile created from the LazyTractogram. """ if dtype_dict is None: dtype_dict = { "positions": np.float32, "offsets": np.uint32, "dpv": {}, "dps": {}, } data = {"strs": [], "dpv": {}, "dps": {}} concat = None count = 0 iterator = iter(obj) while True: if count < chunk_size: try: i = next(iterator) count += 1 except StopIteration: obj = convert_data_dict_to_tractogram(data) if concat is None: if len(obj.streamlines) == 0: concat = TrxFile() else: concat = TrxFile.from_tractogram( obj, reference=reference, dtype_dict=dtype_dict ) elif len(obj.streamlines) > 0: curr_obj = TrxFile.from_tractogram( obj, reference=reference, dtype_dict=dtype_dict ) concat.append(curr_obj) break append_generator_to_dict(i, data) else: obj = convert_data_dict_to_tractogram(data) if concat is None: concat = TrxFile.from_tractogram( obj, reference=reference, dtype_dict=dtype_dict ) else: curr_obj = TrxFile.from_tractogram( obj, reference=reference, dtype_dict=dtype_dict ) concat.append(curr_obj, extra_buffer=extra_buffer) data = {"strs": [], "dpv": {}, "dps": {}} count = 0 concat.resize() return concat @staticmethod def from_sft(sft, dtype_dict=None): """Generate a TrxFile from a StatefulTractogram. Parameters ---------- sft : StatefulTractogram class instance Input tractogram. dtype_dict : dict or None, optional Mapping of target dtypes for positions, offsets, dpv, and dps. When None, uses ``sft.dtype_dict`` or sensible defaults. Returns ------- TrxFile TRX representation of the StatefulTractogram. """ if dtype_dict is None: dtype_dict = {} if len(sft.dtype_dict) > 0: dtype_dict = sft.dtype_dict if "dpp" in dtype_dict: dtype_dict["dpv"] = dtype_dict.pop("dpp") elif len(dtype_dict) == 0: dtype_dict = { "positions": np.float32, "offsets": np.uint32, "dpv": {}, "dps": {}, } positions_dtype = dtype_dict["positions"] offsets_dtype = dtype_dict["offsets"] if not np.issubdtype(positions_dtype, np.floating): logging.warning( "Casting positions as {}, considering using a floating point " "dtype.".format(positions_dtype) ) if not np.issubdtype(offsets_dtype, np.integer): logging.warning( "Casting offsets as {}, considering using a integer dtype.".format( offsets_dtype ) ) trx = TrxFile( nb_vertices=len(sft.streamlines._data), nb_streamlines=len(sft.streamlines) ) trx.header = { "DIMENSIONS": sft.dimensions.tolist(), "VOXEL_TO_RASMM": sft.affine.tolist(), "NB_VERTICES": len(sft.streamlines._data), "NB_STREAMLINES": len(sft.streamlines), } old_space = deepcopy(sft.space) old_origin = deepcopy(sft.origin) # TrxFile are written on disk in RASMM/center convention sft.to_rasmm() sft.to_center() tmp_streamlines = deepcopy(sft.streamlines) # Cast the int64 of Nibabel to uint32 tmp_streamlines._offsets = tmp_streamlines._offsets.astype(offsets_dtype) tmp_streamlines._data = tmp_streamlines._data.astype(positions_dtype) trx.streamlines = tmp_streamlines for key in sft.data_per_point: dtype_to_use = ( dtype_dict["dpv"][key] if key in dtype_dict["dpv"] else np.float32 ) trx.data_per_vertex[key] = sft.data_per_point[key] trx.data_per_vertex[key]._data = sft.data_per_point[key]._data.astype( dtype_to_use ) for key in sft.data_per_streamline: dtype_to_use = ( dtype_dict["dps"][key] if key in dtype_dict["dps"] else np.float32 ) trx.data_per_streamline[key] = sft.data_per_streamline[key].astype( dtype_to_use ) # For safety and for RAM, convert the whole object to memmaps tmp_dir = get_trx_tmp_dir() save(trx, tmp_dir.name) trx.close() trx = load_from_directory(tmp_dir.name) trx._uncompressed_folder_handle = tmp_dir sft.to_space(old_space) sft.to_origin(old_origin) del tmp_streamlines return trx @staticmethod def from_tractogram( tractogram, reference, dtype_dict=None, ): """Generate a TrxFile from a nibabel Tractogram. Parameters ---------- tractogram : nibabel.streamlines.Tractogram class instance Input tractogram to convert. reference : object Reference anatomy used to populate header fields. dtype_dict : dict or None, optional Mapping of target dtypes for positions, offsets, dpv, and dps. Returns ------- TrxFile class instance TRX representation of the tractogram. """ if dtype_dict is None: dtype_dict = { "positions": np.float32, "offsets": np.uint32, "dpv": {}, "dps": {}, } positions_dtype = ( dtype_dict["positions"] if "positions" in dtype_dict else np.float32 ) offsets_dtype = dtype_dict["offsets"] if "offsets" in dtype_dict else np.uint32 if not np.issubdtype(positions_dtype, np.floating): logging.warning( "Casting positions as {}, considering using a floating point " "dtype.".format(positions_dtype) ) if not np.issubdtype(offsets_dtype, np.integer): logging.warning( "Casting offsets as {}, considering using a integer dtype.".format( offsets_dtype ) ) trx = TrxFile( nb_vertices=len(tractogram.streamlines._data), nb_streamlines=len(tractogram.streamlines), ) affine, dimensions, _, _ = get_reference_info_wrapper(reference) trx.header = { "DIMENSIONS": dimensions, "VOXEL_TO_RASMM": affine, "NB_VERTICES": len(tractogram.streamlines._data), "NB_STREAMLINES": len(tractogram.streamlines), } tmp_streamlines = deepcopy(tractogram.streamlines) # Cast the int64 of Nibabel to uint32 tmp_streamlines._offsets = tmp_streamlines._offsets.astype(offsets_dtype) tmp_streamlines._data = tmp_streamlines._data.astype(positions_dtype) trx.streamlines = tmp_streamlines for key in tractogram.data_per_point: dtype_to_use = ( dtype_dict["dpv"][key] if key in dtype_dict["dpv"] else np.float32 ) trx.data_per_vertex[key] = tractogram.data_per_point[key] trx.data_per_vertex[key]._data = tractogram.data_per_point[ key ]._data.astype(dtype_to_use) for key in tractogram.data_per_streamline: dtype_to_use = ( dtype_dict["dps"][key] if key in dtype_dict["dps"] else np.float32 ) trx.data_per_streamline[key] = tractogram.data_per_streamline[key].astype( dtype_to_use ) # For safety and for RAM, convert the whole object to memmaps tmp_dir = get_trx_tmp_dir() save(trx, tmp_dir.name) trx.close() trx = load_from_directory(tmp_dir.name) del tmp_streamlines return trx def to_tractogram(self, resize=False): """Convert this TrxFile to a nibabel Tractogram. Parameters ---------- resize : bool, optional If True, resize to actual data length before conversion. Returns ------- nibabel.streamlines.Tractogram class instance Tractogram containing streamlines and metadata. """ if resize: self.resize() trx_obj = self.to_memory() tractogram = nib.streamlines.Tractogram([], affine_to_rasmm=np.eye(4)) tractogram._set_streamlines(trx_obj.streamlines) tractogram._data_per_point = trx_obj.data_per_vertex tractogram._data_per_streamline = trx_obj.data_per_streamline return tractogram def to_memory(self, resize: bool = False) -> Type["TrxFile"]: """Convert a TrxFile to a RAM representation. Parameters ---------- resize : bool, optional Resize TrxFile when converting to RAM representation. Default is False. Returns ------- TrxFile A non memory-mapped TrxFile. """ if resize: self.resize() trx_obj = TrxFile() trx_obj.header = deepcopy(self.header) trx_obj.streamlines = deepcopy(self.streamlines) for key in self.data_per_vertex: trx_obj.data_per_vertex[key] = deepcopy(self.data_per_vertex[key]) for key in self.data_per_streamline: trx_obj.data_per_streamline[key] = deepcopy(self.data_per_streamline[key]) for key in self.groups: trx_obj.groups[key] = deepcopy(self.groups[key]) for key in self.data_per_group: trx_obj.data_per_group[key] = deepcopy(self.data_per_group[key]) return trx_obj def to_sft(self, resize=False): """Convert this TrxFile to a StatefulTractogram. Parameters ---------- resize : bool, optional If True, resize to actual data length before conversion. Returns ------- StatefulTractogram class instance or None StatefulTractogram object, or None if dipy is unavailable. """ try: from dipy.io.stateful_tractogram import Space, StatefulTractogram except ImportError: logging.error( "Dipy library is missing, cannot convert to StatefulTractogram." ) return None affine = np.array(self.header["VOXEL_TO_RASMM"], dtype=np.float32) dimensions = np.array(self.header["DIMENSIONS"], dtype=np.uint16) vox_sizes = np.array(voxel_sizes(affine), dtype=np.float32) vox_order = "".join(aff2axcodes(affine)) space_attributes = (affine, dimensions, vox_sizes, vox_order) if resize: self.resize() sft = StatefulTractogram( deepcopy(self.streamlines), space_attributes, Space.RASMM, data_per_point=deepcopy(self.data_per_vertex), data_per_streamline=deepcopy(self.data_per_streamline), ) tmp_dict = self.get_dtype_dict() if "dpv" in tmp_dict: tmp_dict["dpp"] = tmp_dict.pop("dpv") sft.dtype_dict = self.get_dtype_dict() return sft def close(self) -> None: """Cleanup on-disk temporary folder and memmaps. Returns ------- None Releases file handles and removes temporary storage. """ close_or_delete_mmap(self.streamlines) for key in self.data_per_vertex: close_or_delete_mmap(self.data_per_vertex[key]) for key in self.data_per_streamline: close_or_delete_mmap(self.data_per_streamline[key]) for key in self.groups: close_or_delete_mmap(self.groups[key]) for key in self.data_per_group: for dpg in self.data_per_group[key]: close_or_delete_mmap(self.data_per_group[key][dpg]) if self._uncompressed_folder_handle is not None: try: self._uncompressed_folder_handle.cleanup() except PermissionError: logging.error( "Windows PermissionError, temporary directory %s was not deleted!", self._uncompressed_folder_handle.name, ) self.__init__() logging.debug("Deleted memmaps and initialized empty TrxFile.") tee-ar-ex-trx-python-a304ac2/trx/utils.py000066400000000000000000000362031515240773700204000ustar00rootroot00000000000000# -*- coding: utf-8 -*- """Utility functions for reference handling, coordinate flips, and file operations.""" import logging import os import nibabel as nib from nibabel.streamlines.array_sequence import ArraySequence from nibabel.streamlines.tractogram import Tractogram, TractogramItem import numpy as np try: import dipy dipy_available = True except ImportError: dipy_available = False def close_or_delete_mmap(obj): """Close the memory-mapped file if it exists, otherwise set the object to None. Parameters ---------- obj : object The object that potentially has a memory-mapped file to be closed. """ if hasattr(obj, "_mmap") and obj._mmap is not None: obj._mmap.close() elif isinstance(obj, ArraySequence): close_or_delete_mmap(obj._data) close_or_delete_mmap(obj._offsets) close_or_delete_mmap(obj._lengths) elif isinstance(obj, np.memmap): del obj else: logging.debug("Object to be close or deleted must be np.memmap") def split_name_with_gz(filename): """Return the clean basename and extension of a file. Correctly manages the ".nii.gz" extensions. Parameters ---------- filename : str The filename to clean. Returns ------- base : str Clean basename. ext : str The full extension. """ base, ext = os.path.splitext(filename) if ext == ".gz": # Test if we have a .nii additional extension temp_base, add_ext = os.path.splitext(base) if add_ext == ".nii" or add_ext == ".trk": ext = add_ext + ext base = temp_base return base, ext def get_reference_info_wrapper(reference): # noqa: C901 """Extract spatial attributes from a reference object. Parameters ---------- reference : str or dict or Nifti1Image or TrkFile or Nifti1Header or TrxFile Reference that provides the spatial attribute. Returns ------- affine : ndarray (4, 4) Transformation of VOX to RASMM, np.float32. dimensions : ndarray (3,) Volume shape for each axis, int16. voxel_sizes : ndarray (3,) Size of voxel for each axis, float32. voxel_order : str Typically 'RAS' or 'LPS'. """ from trx import trx_file_memmap is_nifti = False is_trk = False is_sft = False is_trx = False if isinstance(reference, str): _, ext = split_name_with_gz(reference) if ext in [".nii", ".nii.gz"]: header = nib.load(reference).header is_nifti = True elif ext == ".trk": header = nib.streamlines.load(reference, lazy_load=True).header is_trk = True elif ext == ".trx": header = trx_file_memmap.load(reference).header is_trx = True elif isinstance(reference, trx_file_memmap.TrxFile): header = reference.header is_trx = True elif isinstance(reference, nib.nifti1.Nifti1Image): header = reference.header is_nifti = True elif isinstance(reference, nib.streamlines.trk.TrkFile): header = reference.header is_trk = True elif isinstance(reference, nib.nifti1.Nifti1Header): header = reference is_nifti = True elif isinstance(reference, dict) and "magic_number" in reference: header = reference is_trk = True elif isinstance(reference, dict) and "NB_VERTICES" in reference: header = reference is_trx = True elif dipy_available and isinstance( reference, dipy.io.stateful_tractogram.StatefulTractogram ): is_sft = True if is_nifti: affine = header.get_best_affine() dimensions = header["dim"][1:4] voxel_sizes = header["pixdim"][1:4] if not affine[0:3, 0:3].any(): raise ValueError( "Invalid affine, contains only zeros." "Cannot determine voxel order from transformation" ) voxel_order = "".join(nib.aff2axcodes(affine)) elif is_trk: affine = header["voxel_to_rasmm"] dimensions = header["dimensions"] voxel_sizes = header["voxel_sizes"] voxel_order = header["voxel_order"] elif is_sft: affine, dimensions, voxel_sizes, voxel_order = reference.space_attributes elif is_trx: affine = header["VOXEL_TO_RASMM"] dimensions = header["DIMENSIONS"] voxel_sizes = nib.affines.voxel_sizes(affine) voxel_order = "".join(nib.aff2axcodes(affine)) else: raise TypeError("Input reference is not one of the supported format") if isinstance(voxel_order, np.bytes_): voxel_order = voxel_order.decode("utf-8") if dipy_available: from dipy.io.utils import is_reference_info_valid is_reference_info_valid(affine, dimensions, voxel_sizes, voxel_order) return affine, dimensions, voxel_sizes, voxel_order def is_header_compatible(reference_1, reference_2): """Compare the spatial attributes of 2 references. Parameters ---------- reference_1 : Nifti or Trk filename, Nifti1Image or TrkFile, Nifti1Header or trk.header (dict) Reference that provides the spatial attribute. reference_2 : Nifti or Trk filename, Nifti1Image or TrkFile, Nifti1Header or trk.header (dict) Reference that provides the spatial attribute. Returns ------- bool Whether all the spatial attributes match. """ affine_1, dimensions_1, voxel_sizes_1, voxel_order_1 = get_reference_info_wrapper( reference_1 ) affine_2, dimensions_2, voxel_sizes_2, voxel_order_2 = get_reference_info_wrapper( reference_2 ) identical_header = True if not np.allclose(affine_1, affine_2, rtol=1e-03, atol=1e-03): logging.error("Affine not equal") identical_header = False if not np.array_equal(dimensions_1, dimensions_2): logging.error("Dimensions not equal") identical_header = False if not np.allclose(voxel_sizes_1, voxel_sizes_2, rtol=1e-03, atol=1e-03): logging.error("Voxel_size not equal") identical_header = False if voxel_order_1 != voxel_order_2: logging.error("Voxel_order not equal") identical_header = False return identical_header def get_axis_shift_vector(flip_axes): """Return a shift vector for the given axes. Parameters ---------- flip_axes : list of str String containing the axis to flip. Possible values are 'x', 'y', 'z'. Returns ------- shift_vector : np.ndarray (3,) Vector containing the axis to shift. Possible values are -1, 0. """ shift_vector = np.zeros(3) if "x" in flip_axes: shift_vector[0] = -1.0 if "y" in flip_axes: shift_vector[1] = -1.0 if "z" in flip_axes: shift_vector[2] = -1.0 return shift_vector def get_axis_flip_vector(flip_axes): """Return a flip vector for the given axes. Parameters ---------- flip_axes : list of str String containing the axis to flip. Possible values are 'x', 'y', 'z'. Returns ------- flip_vector : np.ndarray (3,) Vector containing the axis to flip. Possible values are -1, 1. """ flip_vector = np.ones(3) if "x" in flip_axes: flip_vector[0] = -1.0 if "y" in flip_axes: flip_vector[1] = -1.0 if "z" in flip_axes: flip_vector[2] = -1.0 return flip_vector def get_shift_vector(sft): """Return the shift vector for flipping a tractogram. When flipping a tractogram the shift vector is used to change the origin of the grid from the corner to the center of the grid. Parameters ---------- sft : StatefulTractogram StatefulTractogram object. Returns ------- shift_vector : ndarray Shift vector to apply to the streamlines. """ dims = sft.space_attributes[1] shift_vector = -1.0 * (np.array(dims) / 2.0) return shift_vector def flip_sft(sft, flip_axes): """Flip the streamlines in a StatefulTractogram. Use the spatial information to flip according to the center of the grid. Parameters ---------- sft : StatefulTractogram StatefulTractogram to flip. flip_axes : list of str Axes to flip. Possible values are 'x', 'y', 'z'. Returns ------- sft : StatefulTractogram StatefulTractogram with flipped axes. """ if not dipy_available: logging.error( "Dipy library is missing, cannot use functions related " "to the StatefulTractogram." ) return None flip_vector = get_axis_flip_vector(flip_axes) shift_vector = get_shift_vector(sft) flipped_streamlines = [] for streamline in sft.streamlines: mod_streamline = streamline + shift_vector mod_streamline *= flip_vector mod_streamline -= shift_vector flipped_streamlines.append(mod_streamline) from dipy.io.stateful_tractogram import StatefulTractogram new_sft = StatefulTractogram.from_sft( flipped_streamlines, sft, data_per_point=sft.data_per_point, data_per_streamline=sft.data_per_streamline, ) return new_sft def load_matrix_in_any_format(filepath): """Load a matrix from a txt file OR a npy file. Parameters ---------- filepath : str Path to the matrix file. Returns ------- matrix : numpy.ndarray The matrix. """ _, ext = os.path.splitext(filepath) if ext == ".txt": data = np.loadtxt(filepath) elif ext == ".npy": data = np.load(filepath) else: raise ValueError("Extension {} is not supported".format(ext)) return data def get_reverse_enum(space_str, origin_str): """Convert string representation to enums for the StatefulTractogram. Parameters ---------- space_str : str String representing the space. origin_str : str String representing the origin. Returns ------- space : Space Space enum value. origin : Origin Origin enum value. """ if not dipy_available: logging.error( "Dipy library is missing, cannot use functions related " "to the StatefulTractogram." ) return None from dipy.io.stateful_tractogram import Origin, Space origin = Origin.NIFTI if origin_str.lower() == "nifti" else Origin.TRACKVIS if space_str.lower() == "rasmm": space = Space.RASMM elif space_str.lower() == "voxmm": space = Space.VOXMM else: space = Space.VOX return space, origin def convert_data_dict_to_tractogram(data): """Convert data from a lazy tractogram to a tractogram. Parameters ---------- data : dict The data dictionary to convert into a nibabel tractogram. Returns ------- Tractogram A Tractogram object. """ streamlines = ArraySequence(data["strs"]) streamlines._data = streamlines._data for key in data["dps"]: shape = (len(streamlines), len(data["dps"][key]) // len(streamlines)) data["dps"][key] = np.array(data["dps"][key]).reshape(shape) for key in data["dpv"]: shape = ( len(streamlines._data), len(data["dpv"][key]) // len(streamlines._data), ) data["dpv"][key] = np.array(data["dpv"][key]).reshape(shape) tmp_arr = ArraySequence() tmp_arr._data = data["dpv"][key] tmp_arr._offsets = streamlines._offsets tmp_arr._lengths = streamlines._lengths data["dpv"][key] = tmp_arr obj = Tractogram( streamlines, data_per_point=data["dpv"], data_per_streamline=data["dps"] ) return obj def append_generator_to_dict(gen, data): """Append items yielded by a tractogram generator into data dict. Parameters ---------- gen : TractogramItem class instance or np.ndarray Item produced by a tractogram generator. Structured entries include per-point and per-streamline metadata. data : dict Accumulator containing ``strs`` (positions), ``dpv`` and ``dps`` dictionaries that will be extended in-place. Returns ------- None The function mutates ``data`` and returns ``None``. """ if isinstance(gen, TractogramItem): data["strs"].append(gen.streamline.tolist()) for key in gen.data_for_points: if key not in data["dpv"]: data["dpv"][key] = np.array([]) data["dpv"][key] = np.append(data["dpv"][key], gen.data_for_points[key]) for key in gen.data_for_streamline: if key not in data["dps"]: data["dps"][key] = np.array([]) data["dps"][key] = np.append(data["dps"][key], gen.data_for_streamline[key]) else: data["strs"].append(gen.tolist()) def verify_trx_dtype(trx, dict_dtype): # noqa: C901 """Verify that data dtypes in the trx match the given dict. Parameters ---------- trx : Tractogram Tractogram to verify. dict_dtype : dict Dictionary containing all elements dtype to verify. Returns ------- bool True if the dtype is the same, False otherwise. """ identical = True for key in dict_dtype: if key == "positions": if trx.streamlines._data.dtype != dict_dtype[key]: logging.warning("Positions dtype is different") identical = False elif key == "offsets": if trx.streamlines._offsets.dtype != dict_dtype[key]: logging.warning("Offsets dtype is different") identical = False elif key == "dpv": for key_dpv in dict_dtype[key]: if trx.data_per_vertex[key_dpv]._data.dtype != dict_dtype[key][key_dpv]: logging.warning( "Data per vertex ({}) dtype is different".format(key_dpv) ) identical = False elif key == "dps": for key_dps in dict_dtype[key]: if trx.data_per_streamline[key_dps].dtype != dict_dtype[key][key_dps]: logging.warning( "Data per streamline ({}) dtype is different".format(key_dps) ) identical = False elif key == "dpg": for key_group in dict_dtype[key]: for key_dpg in dict_dtype[key][key_group]: if ( trx.data_per_point[key_group][key_dpg].dtype != dict_dtype[key][key_group][key_dpg] ): logging.warning( "Data per group ({}) dtype is different".format(key_dpg) ) identical = False elif key == "groups": for key_group in dict_dtype[key]: if ( trx.data_per_point[key_group]._data.dtype != dict_dtype[key][key_group] ): logging.warning( "Data per group ({}) dtype is different".format(key_group) ) identical = False return identical tee-ar-ex-trx-python-a304ac2/trx/viz.py000066400000000000000000000076431515240773700200560ustar00rootroot00000000000000# -*- coding: utf-8 -*- """Optional 3D visualization using FURY/VTK.""" import itertools import logging import numpy as np try: from dipy.viz import actor, colormap, window import fury.utils as ut_vtk from fury.utils import get_bounds import vtk fury_available = True except ImportError: fury_available = False def display( volume, volume_affine=None, streamlines=None, title="FURY", display_bounds=True ): """Display a volume with optional streamlines using fury. Parameters ---------- volume : np.ndarray 3D volume to display. volume_affine : np.ndarray or None, optional Affine matrix for the volume; None assumes identity. streamlines : sequence or None, optional Streamlines to render as lines. title : str, optional Window title. display_bounds : bool, optional If True, draw bounding box and coordinate annotations. Returns ------- None Opens an interactive visualization window when fury is available. """ if not fury_available: logging.error( "Fury library is missing, visualization functions are not available." ) return None volume = volume.astype(float) scene = window.Scene() scene.background((1.0, 0.5, 0.0)) # Show the X/Y/Z plane intersecting, mid-slices slicer_actor_1 = actor.slicer( volume, affine=volume_affine, value_range=(volume.min(), volume.max()), interpolation="nearest", opacity=0.8, ) slicer_actor_2 = actor.slicer( volume, affine=volume_affine, value_range=(volume.min(), volume.max()), interpolation="nearest", opacity=0.8, ) slicer_actor_3 = actor.slicer( volume, affine=volume_affine, value_range=(volume.min(), volume.max()), interpolation="nearest", opacity=0.8, ) slicer_actor_1.display(y=volume.shape[1] // 2) slicer_actor_2.display(x=volume.shape[0] // 2) slicer_actor_3.display(z=volume.shape[2] // 2) scene.add(slicer_actor_1) scene.add(slicer_actor_2) scene.add(slicer_actor_3) # Bounding box to facilitate error detections if display_bounds: src = vtk.vtkCubeSource() bounds = np.round(get_bounds(slicer_actor_1), 6) src.SetBounds(bounds) src.Update() cube_actor = ut_vtk.get_actor_from_polydata(src.GetOutput()) cube_actor.GetProperty().SetRepresentationToWireframe() scene.add(cube_actor) # Show each corner's coordinates corners = itertools.product(bounds[0:2], bounds[2:4], bounds[4:6]) for corner in corners: text_actor = actor.text_3d( "{}, {}, {}".format(*corner), corner, font_size=6, justification="center", ) scene.add(text_actor) # Show the X/Y/Z dimensions text_actor_x = actor.text_3d( "{}".format(np.abs(bounds[0] - bounds[1])), ((bounds[0] + bounds[1]) / 2, bounds[2], bounds[4]), font_size=10, justification="center", ) text_actor_y = actor.text_3d( "{}".format(np.abs(bounds[2] - bounds[3])), (bounds[0], (bounds[2] + bounds[3]) / 2, bounds[4]), font_size=10, justification="center", ) text_actor_z = actor.text_3d( "{}".format(np.abs(bounds[4] - bounds[5])), (bounds[0], bounds[2], (bounds[4] + bounds[5]) / 2), font_size=10, justification="center", ) scene.add(text_actor_x) scene.add(text_actor_y) scene.add(text_actor_z) if streamlines is not None: streamlines_actor = actor.line( streamlines, colormap.line_colors(streamlines), opacity=0.25 ) scene.add(streamlines_actor) window.show(scene, title=title, size=(800, 800)) tee-ar-ex-trx-python-a304ac2/trx/workflows.py000066400000000000000000000651301515240773700212760ustar00rootroot00000000000000# -*- coding: utf-8 -*- """High-level processing workflows for tractogram operations.""" from copy import deepcopy import csv import gzip import json import logging import os import tempfile import nibabel as nib from nibabel.streamlines.array_sequence import ArraySequence import numpy as np try: import dipy # noqa: F401 dipy_available = True except ImportError: dipy_available = False from trx.io import get_trx_tmp_dir, load, load_sft_with_reference, save from trx.streamlines_ops import intersection, perform_streamlines_operation import trx.trx_file_memmap as tmm from trx.utils import ( flip_sft, get_axis_shift_vector, get_reference_info_wrapper, get_reverse_enum, is_header_compatible, load_matrix_in_any_format, split_name_with_gz, ) from trx.viz import display def convert_dsi_studio( in_dsi_tractogram, in_dsi_fa, out_tractogram, remove_invalid=True, keep_invalid=False, ): """Convert a DSI-Studio TRK file to TRX, fixing space metadata. Parameters ---------- in_dsi_tractogram : str Input DSI-Studio TRK path (optionally .trk.gz). in_dsi_fa : str FA image (.nii.gz) used as reference anatomy. out_tractogram : str Destination tractogram path; ``.trx`` will be written using TRX writer. remove_invalid : bool, optional Remove streamlines falling outside the bounding box. Defaults to True. keep_invalid : bool, optional Keep invalid streamlines even if outside bounding box. Defaults to False. Returns ------- None Writes the converted tractogram to disk. """ if not dipy_available: logging.error("Dipy library is missing, scripts are not available.") return None from dipy.io.stateful_tractogram import Space, StatefulTractogram from dipy.io.streamline import load_tractogram, save_tractogram in_ext = split_name_with_gz(in_dsi_tractogram)[1] out_ext = split_name_with_gz(out_tractogram)[1] if in_ext == ".trk.gz": with gzip.open(in_dsi_tractogram, "rb") as f_in: with open("tmp.trk", "wb") as f_out: f_out.writelines(f_in) sft = load_tractogram("tmp.trk", "same", bbox_valid_check=False) os.remove("tmp.trk") elif in_ext == ".trk": sft = load_tractogram(in_dsi_tractogram, "same", bbox_valid_check=False) else: raise IOError("{} is not currently supported.".format(in_ext)) sft.to_vox() sft_fix = StatefulTractogram( sft.streamlines, in_dsi_fa, Space.VOXMM, data_per_point=sft.data_per_point, data_per_streamline=sft.data_per_streamline, ) sft_fix.to_vox() flip_axis = ["x", "y"] sft_fix.streamlines._data -= get_axis_shift_vector(flip_axis) sft_flip = flip_sft(sft_fix, flip_axis) sft_flip.to_rasmm() sft_flip.streamlines._data -= [0.5, 0.5, -0.5] if remove_invalid: sft_flip.remove_invalid_streamlines() if out_ext != ".trx": save_tractogram(sft_flip, out_tractogram, bbox_valid_check=not keep_invalid) else: trx = tmm.TrxFile.from_sft(sft_flip) tmm.save(trx, out_tractogram) def convert_tractogram( # noqa: C901 in_tractogram, out_tractogram, reference, pos_dtype="float32", offsets_dtype="uint32", ): """Convert tractograms between formats with dtype control. Parameters ---------- in_tractogram : str Input tractogram path. out_tractogram : str Output tractogram path. reference : str Reference anatomy required for formats without header affine. pos_dtype : str, optional Datatype for positions in TRX output. offsets_dtype : str, optional Datatype for offsets in TRX output. Returns ------- None Writes the converted tractogram to disk. """ if not dipy_available: logging.error("Dipy library is missing, scripts are not available.") return None from dipy.io.streamline import save_tractogram in_ext = split_name_with_gz(in_tractogram)[1] out_ext = split_name_with_gz(out_tractogram)[1] if in_ext == out_ext: raise IOError("Input and output cannot be of the same file format.") if in_ext != ".trx": sft = load_sft_with_reference(in_tractogram, reference, bbox_check=False) else: trx = tmm.load(in_tractogram) sft = trx.to_sft() trx.close() if out_ext != ".trx": if out_ext == ".vtk": if sft.streamlines._data.dtype.name != pos_dtype: sft.streamlines._data = sft.streamlines._data.astype(pos_dtype) if offsets_dtype == "uint64" or offsets_dtype == "uint32": offsets_dtype = offsets_dtype[1:] if sft.streamlines._offsets.dtype.name != offsets_dtype: sft.streamlines._offsets = sft.streamlines._offsets.astype( offsets_dtype ) save_tractogram(sft, out_tractogram, bbox_valid_check=False) else: trx = tmm.TrxFile.from_sft(sft) if trx.streamlines._data.dtype.name != pos_dtype: trx.streamlines._data = trx.streamlines._data.astype(pos_dtype) if trx.streamlines._offsets.dtype.name != offsets_dtype: trx.streamlines._offsets = trx.streamlines._offsets.astype(offsets_dtype) tmm.save(trx, out_tractogram) trx.close() def tractogram_simple_compare(in_tractograms, reference): """Compare tractograms against a reference and return a summary diff. Parameters ---------- in_tractograms : list of str Paths to tractograms to compare. reference : str Reference tractogram path. Returns ------- dict Dictionary capturing differences across tractograms. """ if not dipy_available: logging.error("Dipy library is missing, scripts are not available.") return from dipy.io.stateful_tractogram import StatefulTractogram tractogram_obj = load(in_tractograms[0], reference) if not isinstance(tractogram_obj, StatefulTractogram): sft_1 = tractogram_obj.to_sft() tractogram_obj.close() else: sft_1 = tractogram_obj tractogram_obj = load(in_tractograms[1], reference) if not isinstance(tractogram_obj, StatefulTractogram): sft_2 = tractogram_obj.to_sft() tractogram_obj.close() else: sft_2 = tractogram_obj if np.allclose(sft_1.streamlines._data, sft_2.streamlines._data, atol=0.001): print("Matching tractograms in rasmm!") else: print( "Average difference in rasmm of {}".format( np.average(sft_1.streamlines._data - sft_2.streamlines._data, axis=0) ) ) sft_1.to_voxmm() sft_2.to_voxmm() if np.allclose(sft_1.streamlines._data, sft_2.streamlines._data, atol=0.001): print("Matching tractograms in voxmm!") else: print( "Average difference in voxmm of {}".format( np.average(sft_1.streamlines._data - sft_2.streamlines._data, axis=0) ) ) sft_1.to_vox() sft_2.to_vox() if np.allclose(sft_1.streamlines._data, sft_2.streamlines._data, atol=0.001): print("Matching tractograms in vox!") else: print( "Average difference in vox of {}".format( np.average(sft_1.streamlines._data - sft_2.streamlines._data, axis=0) ) ) def verify_header_compatibility(in_files): """Verify that multiple tractogram headers are mutually compatible. Parameters ---------- in_files : list of str Paths to tractogram or NIfTI files to compare. Returns ------- None Prints compatibility results to stdout. """ if not dipy_available: logging.error("Dipy library is missing, scripts are not available.") return all_valid = True for filepath in in_files: if not os.path.isfile(filepath): print("{} does not exist".format(filepath)) _, in_extension = split_name_with_gz(filepath) if in_extension not in [".trk", ".nii", ".nii.gz", ".trx"]: raise IOError("{} does not have a supported extension".format(filepath)) if not is_header_compatible(in_files[0], filepath): print( "{} and {} do not have compatible header.".format(in_files[0], filepath) ) all_valid = False if all_valid: print("All input files have compatible headers.") def tractogram_visualize_overlap(in_tractogram, reference, remove_invalid=True): """Visualize overlap between tractogram density maps in different spaces. Parameters ---------- in_tractogram : str Input tractogram path. reference : str Reference anatomy (.nii or .nii.gz). remove_invalid : bool, optional Remove streamlines outside bounding box before visualization. Returns ------- None Opens interactive windows when fury is available. """ if not dipy_available: logging.error("Dipy library is missing, scripts are not available.") return None from dipy.io.stateful_tractogram import StatefulTractogram from dipy.tracking.streamline import set_number_of_points from dipy.tracking.utils import density_map tractogram_obj = load(in_tractogram, reference) if not isinstance(tractogram_obj, StatefulTractogram): sft = tractogram_obj.to_sft() tractogram_obj.close() else: sft = tractogram_obj sft.streamlines._data = sft.streamlines._data.astype(float) sft.data_per_point = None sft.streamlines = set_number_of_points(sft.streamlines, 200) if remove_invalid: sft.remove_invalid_streamlines() # Approach (1) density_1 = density_map(sft.streamlines, sft.affine, sft.dimensions) img = nib.load(reference) display( img.get_fdata(), volume_affine=img.affine, streamlines=sft.streamlines, title="RASMM", ) # Approach (2) sft.to_vox() density_2 = density_map(sft.streamlines, np.eye(4), sft.dimensions) # Small difference due to casting of the affine as float32 or float64 diff = density_1 - density_2 print( "Total difference of {} voxels with total value of {}".format( np.count_nonzero(diff), np.sum(np.abs(diff)) ) ) display(img.get_fdata(), streamlines=sft.streamlines, title="VOX") # Try VOXMM sft.to_voxmm() affine = np.eye(4) affine[0:3, 0:3] *= sft.voxel_sizes display( img.get_fdata(), volume_affine=affine, streamlines=sft.streamlines, title="VOXMM", ) def validate_tractogram( in_tractogram, reference, out_tractogram, remove_identical_streamlines=True, precision=1, ): """Validate a tractogram and optionally remove invalid/duplicate streamlines. Parameters ---------- in_tractogram : str Input tractogram path. reference : str Reference anatomy for formats requiring it. out_tractogram : str or None Optional output path to save the cleaned tractogram. remove_identical_streamlines : bool, optional Remove duplicate streamlines based on hashing precision. precision : int, optional Number of decimals when hashing streamline points. Returns ------- None Prints warnings and optionally writes a cleaned tractogram. """ if not dipy_available: logging.error("Dipy library is missing, scripts are not available.") return None from dipy.io.stateful_tractogram import StatefulTractogram tractogram_obj = load(in_tractogram, reference) if not isinstance(tractogram_obj, StatefulTractogram): sft = tractogram_obj.to_sft() # tractogram_obj.close() else: sft = tractogram_obj ori_dtype = sft.dtype_dict ori_len = len(sft) tot_remove = 0 invalid_coord_ind, _ = sft.remove_invalid_streamlines() tot_remove += len(invalid_coord_ind) logging.warning( "Removed {} streamlines with invalid coordinates.".format( len(invalid_coord_ind) ) ) indices = [i for i in range(len(sft)) if len(sft.streamlines[i]) <= 1] tot_remove = +len(indices) logging.warning( "Removed {} invalid streamlines (1 or 0 points).".format(len(indices)) ) for i in np.setdiff1d(range(len(sft)), indices): norm = np.linalg.norm(np.diff(sft.streamlines[i], axis=0), axis=1) if (norm < 0.001).any(): indices.append(i) indices_val = np.setdiff1d(range(len(sft)), indices).astype(np.uint32) logging.warning( "Removed {} invalid streamlines (overlapping points).".format( ori_len - len(indices_val) ) ) tot_remove += ori_len - len(indices_val) if remove_identical_streamlines: _, indices_uniq = perform_streamlines_operation( intersection, [sft.streamlines], precision=precision ) indices_final = np.intersect1d(indices_val, indices_uniq).astype(np.uint32) logging.warning( "Removed {} overlapping streamlines.".format( ori_len - len(indices_final) - tot_remove ) ) indices_final = np.intersect1d(indices_val, indices_uniq) else: indices_final = indices_val if out_tractogram: streamlines = sft.streamlines[indices_final].copy() dpp = {} for key in sft.data_per_point.keys(): dpp[key] = sft.data_per_point[key][indices_final].copy() dps = {} for key in sft.data_per_streamline.keys(): dps[key] = sft.data_per_streamline[key][indices_final] new_sft = StatefulTractogram.from_sft( streamlines, sft, data_per_point=dpp, data_per_streamline=dps ) new_sft.dtype_dict = ori_dtype save(new_sft, out_tractogram) def _load_streamlines_from_csv(positions_csv): """Load streamlines from a CSV file. Parameters ---------- positions_csv : str Path to CSV containing flattened coordinates. Returns ------- nibabel.streamlines.ArraySequence class instance Streamlines reconstructed from the CSV rows. """ with open(positions_csv, newline="") as f: reader = csv.reader(f) data = list(reader) data = [np.reshape(i, (len(i) // 3, 3)).astype(float) for i in data] return ArraySequence(data) def _load_streamlines_from_arrays(positions, offsets): """Load streamlines from position and offset arrays. Parameters ---------- positions : str Path to positions array (.npy or text) shaped (N, 3). offsets : str Path to offsets array marking streamline boundaries. Returns ------- tuple (ArraySequence, np.ndarray) of streamlines and offsets. """ positions = load_matrix_in_any_format(positions) offsets = load_matrix_in_any_format(offsets) lengths = tmm._compute_lengths(offsets) streamlines = ArraySequence() streamlines._data = positions streamlines._offsets = deepcopy(offsets) streamlines._lengths = lengths return streamlines, offsets def _apply_spatial_transforms( streamlines, reference, space_str, origin_str, verify_invalid, offsets ): """Apply spatial transforms and optionally remove invalid streamlines. Parameters ---------- streamlines : ArraySequence class instance Streamlines to transform. reference : str Reference anatomy used for space/origin. space_str : str Desired space (e.g., \"rasmm\"). origin_str : str Desired origin (e.g., \"nifti\"). verify_invalid : bool Remove streamlines outside bounding box when True. offsets : np.ndarray Offsets array to preserve after transforms. Returns ------- ArraySequence class instance or None Transformed streamlines, or None if dipy is unavailable. """ if not dipy_available: logging.error( "Dipy library is missing, advanced options " "related to spatial transforms and invalid " "streamlines are not available." ) return None from dipy.io.stateful_tractogram import StatefulTractogram space, origin = get_reverse_enum(space_str, origin_str) sft = StatefulTractogram(streamlines, reference, space, origin) if verify_invalid: rem, _ = sft.remove_invalid_streamlines() print( "{} streamlines were removed becaused they were invalid.".format(len(rem)) ) sft.to_rasmm() sft.to_center() streamlines = sft.streamlines streamlines._offsets = offsets return streamlines def _write_header(tmp_dir_name, reference, streamlines): """Write TRX header file to a temporary directory. Parameters ---------- tmp_dir_name : str Temporary directory where header.json is written. reference : str Reference anatomy used to derive affine and dimensions. streamlines : ArraySequence class instance Streamlines whose counts populate the header. """ affine, dimensions, _, _ = get_reference_info_wrapper(reference) header = { "DIMENSIONS": dimensions.tolist(), "VOXEL_TO_RASMM": affine.tolist(), "NB_VERTICES": len(streamlines._data), "NB_STREAMLINES": len(streamlines) - 1, } if header["NB_STREAMLINES"] <= 1: raise IOError("To use this script, you need at least 2streamlines.") with open(os.path.join(tmp_dir_name, "header.json"), "w") as out_json: json.dump(header, out_json) def _write_streamline_data(tmp_dir_name, streamlines, positions_dtype, offsets_dtype): """Write streamline position and offset data. Parameters ---------- tmp_dir_name : str Temporary directory to store binary arrays. streamlines : ArraySequence class instance Streamlines to serialize. positions_dtype : str Datatype for positions array. offsets_dtype : str Datatype for offsets array. """ curr_filename = os.path.join(tmp_dir_name, "positions.3.{}".format(positions_dtype)) positions = streamlines._data.astype(positions_dtype) tmm._ensure_little_endian(positions).tofile(curr_filename) curr_filename = os.path.join(tmp_dir_name, "offsets.{}".format(offsets_dtype)) offsets = streamlines._offsets.astype(offsets_dtype) tmm._ensure_little_endian(offsets).tofile(curr_filename) def _normalize_dtype(dtype_str): """Normalize dtype string format for file naming. Parameters ---------- dtype_str : str Input dtype string (e.g., \"bool\", \"float32\"). Returns ------- str Normalized dtype string where ``bool`` is mapped to ``bit``. """ return "bit" if dtype_str == "bool" else dtype_str def _write_data_array(tmp_dir_name, subdir_name, args, is_dpg=False): """Write a data array (dpv/dps/group/dpg) to disk. Parameters ---------- tmp_dir_name : str Base temporary directory. subdir_name : str Subdirectory name (dpv, dps, groups, dpg). args : tuple Tuple describing the array path and dtype (and group when dpg). is_dpg : bool, optional True when writing data_per_group arrays. Returns ------- None Writes the array to disk. """ if is_dpg: os.makedirs(os.path.join(tmp_dir_name, "dpg", args[0]), exist_ok=True) curr_arr = load_matrix_in_any_format(args[1]).astype(args[2]) basename = os.path.basename(os.path.splitext(args[1])[0]) dtype_str = _normalize_dtype(args[1]) if args[1] != "bool" else "bit" dtype = args[2] else: os.makedirs(os.path.join(tmp_dir_name, subdir_name), exist_ok=True) curr_arr = np.squeeze(load_matrix_in_any_format(args[0]).astype(args[1])) basename = os.path.basename(os.path.splitext(args[0])[0]) dtype_str = _normalize_dtype(args[1]) dtype = dtype_str if curr_arr.ndim > 2: raise IOError("Maximum of 2 dimensions for dpv/dps/dpg.") if curr_arr.shape == (1, 1): curr_arr = curr_arr.reshape((1,)) dim = "" if curr_arr.ndim == 1 else "{}.".format(curr_arr.shape[-1]) if is_dpg: curr_filename = os.path.join( tmp_dir_name, "dpg", args[0], "{}.{}{}".format(basename, dim, dtype) ) else: curr_filename = os.path.join( tmp_dir_name, subdir_name, "{}.{}{}".format(basename, dim, dtype) ) tmm._ensure_little_endian(curr_arr).tofile(curr_filename) def generate_trx_from_scratch( # noqa: C901 reference, out_tractogram, positions_csv=False, positions=False, offsets=False, positions_dtype="float32", offsets_dtype="uint64", space_str="rasmm", origin_str="nifti", verify_invalid=True, dpv=None, dps=None, groups=None, dpg=None, ): """Generate TRX file from scratch using various input formats. Parameters ---------- reference : str Reference anatomy used to set affine and dimensions. out_tractogram : str Output TRX filename. positions_csv : str or bool, optional CSV file containing streamline coordinates; False to disable. positions : str or bool, optional Binary positions array file; False to disable. offsets : str or bool, optional Offsets array file; False to disable. positions_dtype : str, optional Datatype for positions. offsets_dtype : str, optional Datatype for offsets. space_str : str, optional Desired space for generated streamlines. origin_str : str, optional Desired origin for generated streamlines. verify_invalid : bool, optional Remove invalid streamlines when True. dpv : list or None, optional Data per vertex definitions. dps : list or None, optional Data per streamline definitions. groups : list or None, optional Group definitions. dpg : list or None, optional Data per group definitions. Returns ------- None Writes the generated TRX file to disk. """ if dpv is None: dpv = [] if dps is None: dps = [] if groups is None: groups = [] if dpg is None: dpg = [] with get_trx_tmp_dir() as tmp_dir_name: if positions_csv: streamlines = _load_streamlines_from_csv(positions_csv) offsets = None else: streamlines, offsets = _load_streamlines_from_arrays(positions, offsets) if ( space_str.lower() != "rasmm" or origin_str.lower() != "nifti" or verify_invalid ): streamlines = _apply_spatial_transforms( streamlines, reference, space_str, origin_str, verify_invalid, offsets ) if streamlines is None: return _write_header(tmp_dir_name, reference, streamlines) _write_streamline_data( tmp_dir_name, streamlines, positions_dtype, offsets_dtype ) if dpv: for arg in dpv: _write_data_array(tmp_dir_name, "dpv", arg) if dps: for arg in dps: _write_data_array(tmp_dir_name, "dps", arg) if groups: for arg in groups: _write_data_array(tmp_dir_name, "groups", arg) if dpg: for arg in dpg: _write_data_array(tmp_dir_name, "dpg", arg, is_dpg=True) trx = tmm.load(tmp_dir_name) tmm.save(trx, out_tractogram) trx.close() def manipulate_trx_datatype(in_filename, out_filename, dict_dtype): # noqa: C901 """Change dtype of positions, offsets, dpv, dps, dpg, and groups in a TRX. Parameters ---------- in_filename : str Input TRX file path. out_filename : str Output TRX file path. dict_dtype : dict Mapping describing target dtypes for each data category. Returns ------- None Writes the converted TRX to ``out_filename``. """ trx = tmm.load(in_filename) # For each key in dict_dtype, we create a new memmap with the new dtype # and we copy the data from the old memmap to the new one. for key in dict_dtype: if key == "positions": tmp_mm = np.memmap( tempfile.NamedTemporaryFile(), dtype=dict_dtype[key], mode="w+", shape=trx.streamlines._data.shape, ) tmp_mm[:] = trx.streamlines._data[:] trx.streamlines._data = tmp_mm elif key == "offsets": tmp_mm = np.memmap( tempfile.NamedTemporaryFile(), dtype=dict_dtype[key], mode="w+", shape=trx.streamlines._offsets.shape, ) tmp_mm[:] = trx.streamlines._offsets[:] trx.streamlines._offsets = tmp_mm elif key == "dpv": for key_dpv in dict_dtype[key]: tmp_mm = np.memmap( tempfile.NamedTemporaryFile(), dtype=dict_dtype[key][key_dpv], mode="w+", shape=trx.data_per_vertex[key_dpv]._data.shape, ) tmp_mm[:] = trx.data_per_vertex[key_dpv]._data[:] trx.data_per_vertex[key_dpv]._data = tmp_mm elif key == "dps": for key_dps in dict_dtype[key]: tmp_mm = np.memmap( tempfile.NamedTemporaryFile(), dtype=dict_dtype[key][key_dps], mode="w+", shape=trx.data_per_streamline[key_dps].shape, ) tmp_mm[:] = trx.data_per_streamline[key_dps][:] trx.data_per_streamline[key_dps] = tmp_mm elif key == "dpg": for key_group in dict_dtype[key]: for key_dpg in dict_dtype[key][key_group]: tmp_mm = np.memmap( tempfile.NamedTemporaryFile(), dtype=dict_dtype[key][key_group][key_dpg], mode="w+", shape=trx.data_per_group[key_group][key_dpg].shape, ) tmp_mm[:] = trx.data_per_group[key_group][key_dpg][:] trx.data_per_group[key_group][key_dpg] = tmp_mm elif key == "groups": for key_group in dict_dtype[key]: tmp_mm = np.memmap( tempfile.NamedTemporaryFile(), dtype=dict_dtype[key][key_group], mode="w+", shape=trx.groups[key_group].shape, ) tmp_mm[:] = trx.groups[key_group][:] trx.groups[key_group] = tmp_mm tmm.save(trx, out_filename) trx.close()