././@PaxHeader0000000000000000000000000000003300000000000011451 xustar000000000000000027 mtime=1698416097.773776 ufo-tofu-0.13.0/0000775000175000017500000000000000000000000013632 5ustar00tomastomas00000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1683790733.0 ufo-tofu-0.13.0/LICENSE0000664000175000017500000001674300000000000014652 0ustar00tomastomas00000000000000 GNU LESSER GENERAL PUBLIC LICENSE Version 3, 29 June 2007 Copyright (C) 2007 Free Software Foundation, Inc. Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed. This version of the GNU Lesser General Public License incorporates the terms and conditions of version 3 of the GNU General Public License, supplemented by the additional permissions listed below. 0. Additional Definitions. As used herein, "this License" refers to version 3 of the GNU Lesser General Public License, and the "GNU GPL" refers to version 3 of the GNU General Public License. "The Library" refers to a covered work governed by this License, other than an Application or a Combined Work as defined below. An "Application" is any work that makes use of an interface provided by the Library, but which is not otherwise based on the Library. Defining a subclass of a class defined by the Library is deemed a mode of using an interface provided by the Library. A "Combined Work" is a work produced by combining or linking an Application with the Library. The particular version of the Library with which the Combined Work was made is also called the "Linked Version". The "Minimal Corresponding Source" for a Combined Work means the Corresponding Source for the Combined Work, excluding any source code for portions of the Combined Work that, considered in isolation, are based on the Application, and not on the Linked Version. The "Corresponding Application Code" for a Combined Work means the object code and/or source code for the Application, including any data and utility programs needed for reproducing the Combined Work from the Application, but excluding the System Libraries of the Combined Work. 1. Exception to Section 3 of the GNU GPL. You may convey a covered work under sections 3 and 4 of this License without being bound by section 3 of the GNU GPL. 2. Conveying Modified Versions. If you modify a copy of the Library, and, in your modifications, a facility refers to a function or data to be supplied by an Application that uses the facility (other than as an argument passed when the facility is invoked), then you may convey a copy of the modified version: a) under this License, provided that you make a good faith effort to ensure that, in the event an Application does not supply the function or data, the facility still operates, and performs whatever part of its purpose remains meaningful, or b) under the GNU GPL, with none of the additional permissions of this License applicable to that copy. 3. Object Code Incorporating Material from Library Header Files. The object code form of an Application may incorporate material from a header file that is part of the Library. You may convey such object code under terms of your choice, provided that, if the incorporated material is not limited to numerical parameters, data structure layouts and accessors, or small macros, inline functions and templates (ten or fewer lines in length), you do both of the following: a) Give prominent notice with each copy of the object code that the Library is used in it and that the Library and its use are covered by this License. b) Accompany the object code with a copy of the GNU GPL and this license document. 4. Combined Works. You may convey a Combined Work under terms of your choice that, taken together, effectively do not restrict modification of the portions of the Library contained in the Combined Work and reverse engineering for debugging such modifications, if you also do each of the following: a) Give prominent notice with each copy of the Combined Work that the Library is used in it and that the Library and its use are covered by this License. b) Accompany the Combined Work with a copy of the GNU GPL and this license document. c) For a Combined Work that displays copyright notices during execution, include the copyright notice for the Library among these notices, as well as a reference directing the user to the copies of the GNU GPL and this license document. d) Do one of the following: 0) Convey the Minimal Corresponding Source under the terms of this License, and the Corresponding Application Code in a form suitable for, and under terms that permit, the user to recombine or relink the Application with a modified version of the Linked Version to produce a modified Combined Work, in the manner specified by section 6 of the GNU GPL for conveying Corresponding Source. 1) Use a suitable shared library mechanism for linking with the Library. A suitable mechanism is one that (a) uses at run time a copy of the Library already present on the user's computer system, and (b) will operate properly with a modified version of the Library that is interface-compatible with the Linked Version. e) Provide Installation Information, but only if you would otherwise be required to provide such information under section 6 of the GNU GPL, and only to the extent that such information is necessary to install and execute a modified version of the Combined Work produced by recombining or relinking the Application with a modified version of the Linked Version. (If you use option 4d0, the Installation Information must accompany the Minimal Corresponding Source and Corresponding Application Code. If you use option 4d1, you must provide the Installation Information in the manner specified by section 6 of the GNU GPL for conveying Corresponding Source.) 5. Combined Libraries. You may place library facilities that are a work based on the Library side by side in a single library together with other library facilities that are not Applications and are not covered by this License, and convey such a combined library under terms of your choice, if you do both of the following: a) Accompany the combined library with a copy of the same work based on the Library, uncombined with any other library facilities, conveyed under the terms of this License. b) Give prominent notice with the combined library that part of it is a work based on the Library, and explaining where to find the accompanying uncombined form of the same work. 6. Revised Versions of the GNU Lesser General Public License. The Free Software Foundation may publish revised and/or new versions of the GNU Lesser General Public License from time to time. Such new versions will be similar in spirit to the present version, but may differ in detail to address new problems or concerns. Each version is given a distinguishing version number. If the Library as you received it specifies that a certain numbered version of the GNU Lesser General Public License "or any later version" applies to it, you have the option of following the terms and conditions either of that published version or of any later version published by the Free Software Foundation. If the Library as you received it does not specify a version number of the GNU Lesser General Public License, you may choose any version of the GNU Lesser General Public License ever published by the Free Software Foundation. If the Library as you received it specifies that a proxy can decide whether future versions of the GNU Lesser General Public License shall apply, that proxy's public statement of acceptance of any version is permanent authorization for you to choose that version for the Library. ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1655997240.0 ufo-tofu-0.13.0/MANIFEST.in0000664000175000017500000000004700000000000015371 0ustar00tomastomas00000000000000include pkgconfig.py include README.md ././@PaxHeader0000000000000000000000000000003300000000000011451 xustar000000000000000027 mtime=1698416097.773776 ufo-tofu-0.13.0/PKG-INFO0000664000175000017500000001123500000000000014731 0ustar00tomastomas00000000000000Metadata-Version: 2.1 Name: ufo-tofu Version: 0.13.0 Summary: A fast, versatile and user-friendly image processing toolkit for computed tomography Home-page: http://github.com/ufo-kit/tofu Author: Matthias Vogelgesang Author-email: matthias.vogelgesang@kit.edu License: LGPL Requires-Python: >=3 Description-Content-Type: text/markdown License-File: LICENSE ## About [![PyPI version](https://badge.fury.io/py/ufo-tofu.png)](http://badge.fury.io/py/ufo-tofu) [![Documentation status](https://readthedocs.org/projects/tofu/badge/?version=latest)](http://tofu.readthedocs.io/en/latest/?badge=latest) This repository contains Python data processing scripts to be used with the UFO framework. At the moment they are targeted at high-performance reconstruction of tomographic data sets. If you use this software for publishing your data, we kindly ask you to cite the article **Faragó, T., Gasilov, S., Emslie, I., Zuber, M., Helfen, L., Vogelgesang, M. & Baumbach, T. (2022). J. Synchrotron Rad. 29, https://doi.org/10.1107/S160057752200282X** If you want to stay updated, subscribe to our [newsletter](mailto:sympa@lists.kit.edu?subject=subscribe%20ufo%20YourFirstName%20YourLastName). Simply leave the body of the e-mail empty and in the subject change ``YourFirstName YourLastName`` accordingly. ## Installation First make sure you have [ufo-core](https://github.com/ufo-kit/ufo-core) and [ufo-filters](https://github.com/ufo-kit/ufo-filters) installed. For that, please follow the [installation instructions](https://ufo-core.readthedocs.io/en/latest/install/index.html). You can either install the prerequisites yourself on [Linux](https://ufo-core.readthedocs.io/en/latest/install/linux.html), or use one of our [Docker containers](https://ufo-core.readthedocs.io/en/latest/install/docker.html). Then, for the newest version run the following in *tofu*'s top directory: pip install . or to install via PyPI: pip install ufo-tofu in a prepared virtualenv or as root for system-wide installation. Note, that if you do plan to use the graphical user interface you need PyQt5, pyqtgraph and PyOpenGL. You are strongly advised to install PyQt through your system package manager, you can install pyqtgraph and PyOpenGL using the pip package manager though: pip install pyqtgraph PyOpenGL ## Usage ### Flow `tofu flow` is a visual flow programming tool. You can create a flow by using any task from [ufo-filters](https://github.com/ufo-kit/ufo-filters) and execute it. In includes visualization of 2D and 3D results, so you can quickly check the output of your flow, which is useful for finding algorithm parameters. ![flow](https://user-images.githubusercontent.com/2648829/150096902-fdbf1b7e-b34e-4368-98ac-c924cad8a6cd.jpg) ### Reconstruction To do a tomographic reconstruction you simply call $ tofu tomo --sinograms $PATH_TO_SINOGRAMS from the command line. To get get correct results, you may need to append options such as `--axis-pos/-a` and `--angle-step/-a` (which are given in radians!). Input paths are either directories or glob patterns. Output paths are either directories or a format that contains one `%i` [specifier](http://www.pixelbeat.org/programming/gcc/format_specs.html): $ tofu tomo --axis-pos=123.4 --angle-step=0.000123 \ --sinograms="/foo/bar/*.tif" --output="/output/slices-%05i.tif" You can get a help for all options by running $ tofu tomo --help and more verbose output by running with the `-v/--verbose` flag. You can also load reconstruction parameters from a configuration file called `reco.conf`. You may create a template with $ tofu init Note, that options passed via the command line always override configuration parameters! Besides scripted reconstructions, one can also run a standalone GUI for both reconstruction and quick assessment of the reconstructed data via $ tofu gui ![GUI](https://cloud.githubusercontent.com/assets/115270/6442540/db0b55fe-c0f0-11e4-9577-0048fddae8b7.png) ### Performance measurement If you are running at least ufo-core/filters 0.6, you can evaluate the performance of the filtered backprojection (without sinogram transposition!), with $ tofu perf You can customize parameter scans, pretty easily via $ tofu perf --width 256:8192:256 --height 512 which will reconstruct all combinations of width between 256 and 8192 with a step of 256 and a fixed height of 512 pixels. ### Estimating the center of rotation If you do not know the correct center of rotation from your experimental setup, you can estimate it with: $ tofu estimate -i $PATH_TO_SINOGRAMS Currently, a modified algorithm based on the work of [Donath et al.](http://dx.doi.org/10.1364/JOSAA.23.001048) is used to determine the center. ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414563.0 ufo-tofu-0.13.0/README.md0000664000175000017500000001046600000000000015120 0ustar00tomastomas00000000000000## About [![PyPI version](https://badge.fury.io/py/ufo-tofu.png)](http://badge.fury.io/py/ufo-tofu) [![Documentation status](https://readthedocs.org/projects/tofu/badge/?version=latest)](http://tofu.readthedocs.io/en/latest/?badge=latest) This repository contains Python data processing scripts to be used with the UFO framework. At the moment they are targeted at high-performance reconstruction of tomographic data sets. If you use this software for publishing your data, we kindly ask you to cite the article **Faragó, T., Gasilov, S., Emslie, I., Zuber, M., Helfen, L., Vogelgesang, M. & Baumbach, T. (2022). J. Synchrotron Rad. 29, https://doi.org/10.1107/S160057752200282X** If you want to stay updated, subscribe to our [newsletter](mailto:sympa@lists.kit.edu?subject=subscribe%20ufo%20YourFirstName%20YourLastName). Simply leave the body of the e-mail empty and in the subject change ``YourFirstName YourLastName`` accordingly. ## Installation First make sure you have [ufo-core](https://github.com/ufo-kit/ufo-core) and [ufo-filters](https://github.com/ufo-kit/ufo-filters) installed. For that, please follow the [installation instructions](https://ufo-core.readthedocs.io/en/latest/install/index.html). You can either install the prerequisites yourself on [Linux](https://ufo-core.readthedocs.io/en/latest/install/linux.html), or use one of our [Docker containers](https://ufo-core.readthedocs.io/en/latest/install/docker.html). Then, for the newest version run the following in *tofu*'s top directory: pip install . or to install via PyPI: pip install ufo-tofu in a prepared virtualenv or as root for system-wide installation. Note, that if you do plan to use the graphical user interface you need PyQt5, pyqtgraph and PyOpenGL. You are strongly advised to install PyQt through your system package manager, you can install pyqtgraph and PyOpenGL using the pip package manager though: pip install pyqtgraph PyOpenGL ## Usage ### Flow `tofu flow` is a visual flow programming tool. You can create a flow by using any task from [ufo-filters](https://github.com/ufo-kit/ufo-filters) and execute it. In includes visualization of 2D and 3D results, so you can quickly check the output of your flow, which is useful for finding algorithm parameters. ![flow](https://user-images.githubusercontent.com/2648829/150096902-fdbf1b7e-b34e-4368-98ac-c924cad8a6cd.jpg) ### Reconstruction To do a tomographic reconstruction you simply call $ tofu tomo --sinograms $PATH_TO_SINOGRAMS from the command line. To get get correct results, you may need to append options such as `--axis-pos/-a` and `--angle-step/-a` (which are given in radians!). Input paths are either directories or glob patterns. Output paths are either directories or a format that contains one `%i` [specifier](http://www.pixelbeat.org/programming/gcc/format_specs.html): $ tofu tomo --axis-pos=123.4 --angle-step=0.000123 \ --sinograms="/foo/bar/*.tif" --output="/output/slices-%05i.tif" You can get a help for all options by running $ tofu tomo --help and more verbose output by running with the `-v/--verbose` flag. You can also load reconstruction parameters from a configuration file called `reco.conf`. You may create a template with $ tofu init Note, that options passed via the command line always override configuration parameters! Besides scripted reconstructions, one can also run a standalone GUI for both reconstruction and quick assessment of the reconstructed data via $ tofu gui ![GUI](https://cloud.githubusercontent.com/assets/115270/6442540/db0b55fe-c0f0-11e4-9577-0048fddae8b7.png) ### Performance measurement If you are running at least ufo-core/filters 0.6, you can evaluate the performance of the filtered backprojection (without sinogram transposition!), with $ tofu perf You can customize parameter scans, pretty easily via $ tofu perf --width 256:8192:256 --height 512 which will reconstruct all combinations of width between 256 and 8192 with a step of 256 and a fixed height of 512 pixels. ### Estimating the center of rotation If you do not know the correct center of rotation from your experimental setup, you can estimate it with: $ tofu estimate -i $PATH_TO_SINOGRAMS Currently, a modified algorithm based on the work of [Donath et al.](http://dx.doi.org/10.1364/JOSAA.23.001048) is used to determine the center. ././@PaxHeader0000000000000000000000000000003300000000000011451 xustar000000000000000027 mtime=1698416097.765776 ufo-tofu-0.13.0/bin/0000775000175000017500000000000000000000000014402 5ustar00tomastomas00000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414563.0 ufo-tofu-0.13.0/bin/tofu0000775000175000017500000001770600000000000015320 0ustar00tomastomas00000000000000#!/usr/bin/env python3 import os import sys import argparse import logging import time import re import gi from tofu import config, __version__ gi.require_version('Ufo', '0.0') LOG = logging.getLogger('tofu') def init(args): if not os.path.exists(args.config): config.write(args.config) else: raise RuntimeError("{0} already exists".format(args.config)) def run_tomo(args): from tofu import reco reco.tomo(args) def run_lamino(args): from tofu import lamino lamino.lamino(args) def run_genreco(args): from tofu import genreco genreco.genreco(args) def run_flat_correct(args): from tofu import preprocess preprocess.run_flat_correct(args) def run_preprocessing(args): from tofu import preprocess preprocess.run_preprocessing(args) def run_sinos(args): from tofu import preprocess preprocess.run_sinogram_generation(args) def run_ez(args): from tofu.ez.GUI.ezufo_launcher import main_qt main_qt(args) def get_ipython_shell(config=None): import IPython version = IPython.__version__ shell = None def cmp_versions(v1, v2): """Compare two version numbers and return cmp compatible result""" def normalize(v): return [int(x) for x in re.sub(r'(\.0+)*$', '', v).split(".")] n1 = normalize(v1) n2 = normalize(v2) return (n1 > n2) - (n1 < n2) if cmp_versions(version, '0.11') < 0: from IPython.Shell import IPShellEmbed shell = IPShellEmbed() elif cmp_versions(version, '1.0') < 0: from IPython.frontend.terminal.embed import \ InteractiveShellEmbed shell = InteractiveShellEmbed(config=config, banner1='') else: from IPython.terminal.embed import InteractiveShellEmbed shell = InteractiveShellEmbed(config=config, banner1='') return shell def run_shell(args): from tofu import reco shell = get_ipython_shell() shell() def run_find_large_spots(args): from tofu.find_large_spots import find_large_spots, find_large_spots_median if args.method == 'grow': find_large_spots(args) else: find_large_spots_median(args) def run_inpaint(args): from tofu import inpaint inpaint.run(args) def gui(args): try: from tofu import gui gui.main(args) except ImportError as e: LOG.error(str(e)) def run_flow(args): from tofu.flow.main import main as flow_main flow_main() def estimate(params): from tofu import reco center = reco.estimate_center(params) if params.verbose: out = '>>> Best axis of rotation: {}'.format(center) else: out = center print(out) def perf(args): from tofu import reco def measure(args): exec_times = [] total_times = [] for i in range(args.num_runs): start = time.time() exec_times.append(reco.tomo(args)) total_times.append(time.time() - start) exec_time = sum(exec_times) / len(exec_times) total_time = sum(total_times) / len(total_times) overhead = (total_time / exec_time - 1.0) * 100 input_bandwidth = args.width * args.height * num_projections * 4 / exec_time / 1024. / 1024. output_bandwidth = args.width * args.width * height * 4 / exec_time / 1024. / 1024. slice_bandwidth = args.height / exec_time # Four bytes of our output bandwidth constitute one slice pixel, for each # pixel we have to do roughly n * 6 floating point ops (2 mad, 1 add, 1 # interpolation) flops = output_bandwidth / 4 * 6 * num_projections / 1024 msg = ("width={:<6d} height={:<6d} n_proj={:<6d} " "exec={:.4f}s total={:.4f}s overhead={:.2f}% " "bandwidth_i={:.2f}MB/s bandwidth_o={:.2f}MB/s slices={:.2f}/s " "flops={:.2f}GFLOPs\n") sys.stdout.write(msg.format(args.width, args.height, args.number, exec_time, total_time, overhead, input_bandwidth, output_bandwidth, slice_bandwidth, flops)) sys.stdout.flush() args.projections = None args.sinograms = None args.dry_run = True for width in range(*args.width_range): for height in range(*args.height_range): for num_projections in range(*args.num_projection_range): args.width = width args.height = height args.number = num_projections measure(args) def main(): parser = argparse.ArgumentParser() parser.add_argument('--config', **config.SECTIONS['general']['config']) parser.add_argument('--version', action='version', version='%(prog)s {}'.format(__version__)) sino_params = ('flat-correction', 'sinos') reco_params = ('flat-correction', 'reconstruction') tomo_params = config.TOMO_PARAMS lamino_params = config.LAMINO_PARAMS gui_params = tomo_params + ('gui', ) cmd_parsers = [ ('init', init, (), "Create configuration file"), ('preprocess', run_preprocessing, config.PREPROC_PARAMS, "Run preprocessing"), ('flatcorrect', run_flat_correct, ('flat-correction',), "Run flat field correction"), ('sinos', run_sinos, sino_params, "Generate sinograms from projections"), ('tomo', run_tomo, tomo_params, "Run tomographic reconstruction"), ('lamino', run_lamino, lamino_params, "Run laminographic reconstruction"), ('reco', run_genreco, config.GEN_RECO_PARAMS, "Run general projection-based " "reconstruction for tomographic/" "laminographic cone/parallel beam"), ('gui', gui, tomo_params + ('gui',), "GUI for tomographic reconstruction"), ('flow', run_flow, (), "Visual flow creation"), ('ez', run_ez, (), "GUI for making ufo-kit data processing pipelines"), ('estimate', estimate, tomo_params + ('estimate',), "Estimate center of rotation"), ('perf', perf, tomo_params + ('perf',), "Check reconstruction performance"), ('interactive', run_shell, tomo_params, "Run interactive mode"), ('find-large-spots', run_find_large_spots, ('find-large-spots',), "Find large spots on images"), ('inpaint', run_inpaint, ('inpaint',), "Inpaint images"), ] if sys.version < '3.7': subparsers = parser.add_subparsers(title="Commands", dest='commands') else: subparsers = parser.add_subparsers(title="Commands", dest='commands', required=True) for cmd, func, sections, text in cmd_parsers: cmd_params = config.Params(sections=sections) cmd_parser = subparsers.add_parser(cmd, help=text, formatter_class=argparse.ArgumentDefaultsHelpFormatter) cmd_parser = cmd_params.add_arguments(cmd_parser) cmd_parser.set_defaults(_func=func) args = config.parse_known_args(parser, subparser=True) log_level = logging.DEBUG if args.verbose else logging.INFO LOG.setLevel(log_level) stream_handler = logging.StreamHandler(sys.stdout) stream_handler.setFormatter(logging.Formatter('%(levelname)s: %(message)s')) LOG.addHandler(stream_handler) if args.log: file_handler = logging.FileHandler(args.log) file_handler.setFormatter(logging.Formatter('[%(asctime)s] %(name)s:%(levelname)s: %(message)s')) LOG.addHandler(file_handler) try: config.log_values(args) args._func(args) except RuntimeError as e: LOG.error(str(e)) sys.exit(1) if __name__ == '__main__': main() # vim: ft=python ././@PaxHeader0000000000000000000000000000003300000000000011451 xustar000000000000000027 mtime=1698416097.773776 ufo-tofu-0.13.0/setup.cfg0000664000175000017500000000004600000000000015453 0ustar00tomastomas00000000000000[egg_info] tag_build = tag_date = 0 ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414563.0 ufo-tofu-0.13.0/setup.py0000664000175000017500000000163300000000000015347 0ustar00tomastomas00000000000000from setuptools import setup, find_packages from tofu import __version__ setup( name='ufo-tofu', python_requires='>=3', version=__version__, author='Matthias Vogelgesang', author_email='matthias.vogelgesang@kit.edu', url='http://github.com/ufo-kit/tofu', license='LGPL', packages=find_packages(), package_data={'tofu': ['gui.ui'], 'tofu.flow': ['composites/*.cm', 'config.json']}, scripts=['bin/tofu'], exclude_package_data={'': ['README.rst']}, install_requires= [ 'PyGObject', 'imageio', 'numpy', 'networkx', 'PyQt5', 'pyqtconsole', 'pyxdg', 'qtpynodeeditor' ], description="A fast, versatile and user-friendly image "\ "processing toolkit for computed tomography", long_description=open('README.md').read(), long_description_content_type='text/markdown', ) ././@PaxHeader0000000000000000000000000000003300000000000011451 xustar000000000000000027 mtime=1698416097.765776 ufo-tofu-0.13.0/tofu/0000775000175000017500000000000000000000000014607 5ustar00tomastomas00000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698416001.0 ufo-tofu-0.13.0/tofu/__init__.py0000664000175000017500000000002700000000000016717 0ustar00tomastomas00000000000000__version__ = '0.13.0' ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414563.0 ufo-tofu-0.13.0/tofu/config.py0000664000175000017500000010170400000000000016431 0ustar00tomastomas00000000000000import argparse import sys import logging import configparser as configparser from collections import OrderedDict from tofu.util import convert_filesize, restrict_value, tupleize, range_list LOG = logging.getLogger(__name__) NAME = "reco.conf" SECTIONS = OrderedDict() SECTIONS['general'] = { 'config': { 'default': NAME, 'type': str, 'help': "File name of configuration", 'metavar': 'FILE'}, 'verbose': { 'default': False, 'ezdefault': False, 'help': 'Verbose output', 'action': 'store_true'}, 'output': { 'default': 'result-%05i.tif', 'type': str, 'help': "Path to location or format-specified file path " "for storing reconstructed slices", 'metavar': 'PATH'}, 'output-bitdepth': { 'default': 32, 'ezdefault': 8, 'type': restrict_value((0, None), dtype=int), 'help': "Bit depth of output, either 8, 16 or 32", 'metavar': 'BITDEPTH'}, 'output-minimum': { 'default': None, 'ezdefault': 0.0, 'type': float, 'help': "Minimum value that maps to zero (turns on --output-rescale)", 'metavar': 'MIN'}, 'output-maximum': { 'default': None, 'ezdefault': 0.0, 'type': float, 'help': "Maximum input value that maps to largest output value (turns on --output-rescale)", 'metavar': 'MAX'}, 'output-rescale': { 'default': False, 'action': 'store_true', 'help': "If true rescale grey values either automatically or according to set " "--output-minimum and --output-maximum"}, 'output-bytes-per-file': { 'default': '128g', 'type': convert_filesize, 'help': "Maximum bytes per file (0=single-image output, otherwise multi-image output)\ , 'k', 'm', 'g', 't' suffixes can be used", 'metavar': 'BYTESPERFILE'}, 'output-append': { 'default': False, 'action': 'store_true', 'help': 'Append images instead of overwriting existing files'}, 'log': { 'default': None, 'type': str, 'help': "File name of optional log", 'metavar': 'FILE'}, 'width': { 'default': None, 'type': restrict_value((0, None), dtype=int), 'help': "Input width"}} SECTIONS['reading'] = { 'y': { 'default': 0, 'ezdefault': 100, 'type': restrict_value((0, None), dtype=int), 'help': 'Vertical coordinate from where to start reading the input image'}, 'height': { 'default': None, 'ezdefault': 200, 'type': restrict_value((0, None), dtype=int), 'help': "Number of rows which will be read"}, 'bitdepth': { 'default': 32, 'type': restrict_value((0, None), dtype=int), 'help': "Bit depth of raw files"}, 'y-step': { 'default': 1, 'ezdefault': 20, 'type': restrict_value((0, None), dtype=int), 'help': "Read every \"step\" row from the input"}, 'start': { 'default': 0, 'type': restrict_value((0, None), dtype=int), 'help': 'Offset to the first read file'}, 'number': { 'default': None, 'type': restrict_value((0, None), dtype=int), 'help': 'Number of files to read'}, 'step': { 'default': 1, 'ezdefault': 1, 'type': restrict_value((0, None), dtype=int), 'help': 'Read every \"step\" file'}, 'resize': { 'default': None, 'type': restrict_value((0, None), dtype=int), 'help': 'Bin pixels before processing'}, 'retries': { 'default': 0, 'type': restrict_value((0, None), dtype=int), 'metavar': 'NUMBER', 'help': 'How many times to wait for new files'}, 'retry-timeout': { 'default': 0, 'type': restrict_value((0, None), dtype=int), 'metavar': 'TIME', 'help': 'How long to wait for new files per trial'}} SECTIONS['flat-correction'] = { 'projections': { 'default': None, 'type': str, 'help': "Location with projections", 'metavar': 'PATH'}, 'darks': { 'default': None, 'type': str, 'help': "Location with darks", 'metavar': 'PATH'}, 'dark-scale': { 'default': 1, 'type': float, 'help': "Scaling dark"}, 'reduction-mode': { 'default': "Average", 'type': str, 'help': "Flat-field correction options: Average (darks) or median (flats)"}, 'fix-nan-and-inf': { 'default': False, 'help': "Fix nan and inf", 'action': 'store_true'}, 'flats': { 'default': None, 'type': str, 'help': "Location with flats", 'metavar': 'PATH'}, 'flats2': { 'default': None, 'type': str, 'help': "Location with flats 2 for interpolation correction", 'metavar': 'PATH'}, 'flat-scale': { 'default': 1, 'type': float, 'help': "Scaling flat"}, 'absorptivity': { 'default': False, 'action': 'store_true', 'help': 'Do absorption correction'}} SECTIONS['retrieve-phase'] = { 'retrieval-method': { 'choices': ['tie', 'ctf', 'qp', 'qp2'], 'default': 'tie', 'help': "Phase retrieval method"}, 'energy': { 'default': None, 'ezdefault': 20, 'type': float, 'help': "X-ray energy [keV]"}, 'propagation-distance': { 'default': None, 'ezdefault': "0.1", 'type': tupleize(), 'help': ("Sample <-> detector distance (if one value, then use the same for x and y " "direction, otherwise first specifies x and second y direction) [m]")}, 'pixel-size': { 'default': 1e-6, 'ezdefault': 3.6e-6, 'type': float, 'help': "Pixel size [m]"}, 'regularization-rate': { 'default': 2, 'ezdefault': 2.3, 'type': float, 'help': "Regularization rate (typical values between [2, 3])"}, 'delta': { 'default': None, 'type': float, 'help': "Real part of the complex refractive index of the material. " "If specified, phase retrieval returns projected thickness, " "if not, it returns phase"}, 'tie-approximate-logarithm': { 'default': False, 'help': ("Approximate the logarithm of the tie method by the first order Taylor series " "expansion [ln(x) ~ ln(a) + (x - a) / a at a, a specified with " "--tie-approximate-point]. This way we may do the filtering for FBP already " "by the phase retrieval and save one forward and one backward 1D FFT needed " "if the filtering occurse separately. This is mostly useful for online reconstruction " "when one reconstruct only a few slices."), 'action': 'store_true'}, 'tie-approximate-point': { 'default': 0.75, 'type': float, 'help': ("Taylor series point of expansion used by --tie-approximate-logarithm. " "The error of the approximation will be smallest around this point, " "so you can tune this for the desired grey level of interest " "(given by the sample based on e^(-mju * projected_thickness)).")}, 'retrieval-padded-width': { 'default': 0, 'type': restrict_value((0, None), dtype=int), 'help': "Padded width used for phase retrieval"}, 'retrieval-padded-height': { 'default': 0, 'type': restrict_value((0, None), dtype=int), 'help': "Padded height used for phase retrieval"}, 'retrieval-padding-mode': { 'choices': ['none', 'clamp', 'clamp_to_edge', 'repeat', 'mirrored_repeat'], 'default': 'clamp_to_edge', 'help': "Padded values assignment"}, 'thresholding-rate': { 'default': 0.01, 'type': float, 'help': "Thresholding rate (typical values between [0.01, 0.1])"}, 'frequency-cutoff': { 'default': 1e30, 'type': float, 'help': "Phase retrieval frequency cutoff [rad]"}} SECTIONS['sinos'] = { 'pass-size': { 'type': restrict_value((0, None), dtype=int), 'default': 0, 'help': 'Number of sinograms to process per pass'}} SECTIONS['reconstruction'] = { 'sinograms': { 'default': None, 'type': str, 'help': "Location with sinograms", 'metavar': 'PATH'}, 'angle': { 'default': None, 'type': float, 'help': "Angle step between projections in radians"}, 'enable-tracing': { 'default': False, 'help': "Enable tracing and store result in .PID.json", 'action': 'store_true'}, 'remotes': { 'default': None, 'type': str, 'help': "Addresses to remote ufo-nodes", 'nargs': '+'}, 'projection-filter': { 'default': 'ramp-fromreal', 'type': str, 'help': "Projection filter", 'choices': ['none', 'ramp', 'ramp-fromreal', 'butterworth', 'faris-byer', 'bh3', 'hamming']}, 'projection-filter-cutoff': { 'default': 0.5, 'type': float, 'help': "Relative cutoff frequency"}, 'projection-padding-mode': { 'choices': ['none', 'clamp', 'clamp_to_edge', 'repeat', 'mirrored_repeat'], 'default': 'clamp_to_edge', 'help': "Padded values assignment"}} SECTIONS['tomographic-reconstruction'] = { 'axis': { 'default': None, 'type': float, 'help': "Axis position"}, 'dry-run': { 'default': False, 'help': "Reconstruct without writing data", 'action': 'store_true'}, 'offset': { 'default': 0.0, 'type': float, 'help': "Angle offset of first projection in radians"}, 'method': { 'default': 'fbp', 'type': str, 'help': "Reconstruction method", 'choices': ['fbp', 'dfi', 'sart', 'sirt', 'sbtv', 'asdpocs']}} SECTIONS['laminographic-reconstruction'] = { 'angle': { 'default': None, 'type': float, 'help': "Angle step between projections in radians"}, 'dry-run': { 'default': False, 'help': "Reconstruct without writing data", 'action': 'store_true'}, 'axis': { 'default': None, 'required': True, 'type': tupleize(num_items=2), 'help': "Axis position"}, 'x-region': { 'default': "0,-1,1", 'type': tupleize(num_items=3, conv=int), 'help': "X region as from,to,step"}, 'y-region': { 'default': "0,-1,1", 'type': tupleize(num_items=3, conv=int), 'help': "Y region as from,to,step"}, 'z': { 'default': 0, 'type': int, 'help': "Z coordinate of the reconstructed slice"}, 'z-parameter': { 'default': 'z', 'type': str, 'choices': ['z', 'x-center', 'lamino-angle', 'roll-angle'], 'help': "Parameter to vary along the reconstructed z-axis"}, 'region': { 'default': "0,-1,1", 'type': tupleize(num_items=3), 'help': "Z-axis parameter region as from,to,step"}, 'overall-angle': { 'default': None, 'type': float, 'help': "The total angle over which projections were taken in degrees"}, 'lamino-angle': { 'default': None, 'required': True, 'type': float, 'help': "The laminographic angle in degrees"}, 'roll-angle': { 'default': 0.0, 'type': float, 'help': "Sample angular misalignment to the side (roll) in degrees, positive angles mean\ clockwise misalignment"}, 'slices-per-device': { 'default': None, 'type': restrict_value((0, None), dtype=int), 'help': "Number of slices computed by one computing device"}, 'only-bp': { 'default': False, 'action': 'store_true', 'help': "Do only backprojection with no other processing steps"}, 'lamino-padding-mode': { 'choices': ['none', 'clamp', 'clamp_to_edge', 'repeat', 'mirrored_repeat'], 'default': 'clamp', 'help': "Padded values assignment for the filtered projection"}} SECTIONS['fbp'] = { 'crop-width': { 'default': None, 'type': restrict_value((0, None), dtype=int), 'help': "Width of final slice"}, 'projection-crop-after': { 'choices': ['filter', 'backprojection'], 'default': 'backprojection', 'help': "Whether to crop projections after filtering (can cause truncation " "artifacts) or after backprojection"}} SECTIONS['dfi'] = { 'oversampling': { 'default': None, 'type': restrict_value((0, None), dtype=int), 'help': "Oversample factor"}} SECTIONS['ir'] = { 'num-iterations': { 'default': 10, 'type': restrict_value((0, None), dtype=int), 'help': "Maximum number of iterations"}} SECTIONS['sart'] = { 'relaxation-factor': { 'default': 0.25, 'type': float, 'help': "Relaxation factor"}} SECTIONS['sbtv'] = { 'lambda': { 'default': 0.1, 'type': float, 'help': "Lambda"}, 'mu': { 'default': 0.5, 'type': float, 'help': "mu"}} SECTIONS['gui'] = { 'enable-cropping': { 'default': False, 'help': "Enable cropping width", 'action': 'store_true'}, 'show-2d': { 'default': False, 'help': "Show 2D slices with pyqtgraph", 'action': 'store_true'}, 'show-3d': { 'default': False, 'help': "Show 3D slices with pyqtgraph", 'action': 'store_true'}, 'last-dir': { 'default': '.', 'type': str, 'help': "Path of the last used directory", 'metavar': 'PATH'}, 'deg0': { 'default': '.', 'type': str, 'help': "Location with 0 deg projection", 'metavar': 'PATH'}, 'deg180': { 'default': '.', 'type': str, 'help': "Location with 180 deg projection", 'metavar': 'PATH'}, 'ffc-correction': { 'default': False, 'help': "Enable darks or flats correction", 'action': 'store_true'}, 'num-flats': { 'default': 0, 'type': int, 'help': "Number of flats for ffc correction."}} SECTIONS['estimate'] = { 'estimate-method': { 'type': str, 'default': 'correlation', 'help': 'Rotation axis estimation algorithm', 'choices': ['reconstruction', 'correlation']}} SECTIONS['perf'] = { 'num-runs': { 'default': 3, 'type': restrict_value((0, None), dtype=int), 'help': "Number of runs"}, 'width-range': { 'default': '1024', 'type': range_list, 'help': "Width or range of widths of generated projections"}, 'height-range': { 'default': '1024', 'type': range_list, 'help': "Height or range of heights of generated projections"}, 'num-projection-range': { 'default': '512', 'type': range_list, 'help': "Number or range of number of projections"}} SECTIONS['preprocess'] = { 'transpose-input': { 'default': False, 'action': 'store_true', 'help': "Transpose projections before they are backprojected (after phase retrieval)"}, 'projection-filter': { 'default': 'ramp-fromreal', 'type': str, 'help': "Projection filter", 'choices': ['none', 'ramp', 'ramp-fromreal', 'butterworth', 'faris-byer', 'bh3', 'hamming']}, 'projection-filter-cutoff': { 'default': 0.5, 'type': float, 'help': "Relative cutoff frequency"}, 'projection-filter-scale': { 'default': 1., 'type': float, 'help': "Multiplicative factor of the projection filter"}, 'projection-padding-mode': { 'choices': ['none', 'clamp', 'clamp_to_edge', 'repeat', 'mirrored_repeat'], 'default': 'clamp_to_edge', 'help': "Padded values assignment"}, 'projection-crop-after': { 'choices': ['filter', 'backprojection'], 'default': 'backprojection', 'help': "Whether to crop projections after filtering (can cause truncation " "artifacts) or after backprojection"}} SECTIONS['cone-beam-weight'] = { 'source-position-y': { 'default': "-Inf", 'type': tupleize(dtype=list), 'help': "Y source position (along beam direction) in global coordinates [pixels] " "(multiple of detector pixel size)"}, 'detector-position-y': { 'default': "0", 'type': tupleize(dtype=list), 'help': "Y detector position (along beam direction) in global coordinates [pixels] " "(multiple of detector pixel size)"}, 'center-position-x': { 'default': None, 'type': tupleize(), 'help': "X rotation axis position on a projection [pixels]"}, 'center-position-z': { 'default': None, 'ezdefault': "0", 'type': tupleize(), 'help': "Z rotation axis position on a projection [pixels]"}, 'axis-angle-x': { 'default': "0", 'ezdefault': "30", 'type': tupleize(dtype=list), 'help': "Rotation axis rotation around the x axis" "(laminographic angle, 0 = tomography) [deg]"}} SECTIONS['general-reconstruction'] = { 'enable-tracing': { 'default': False, 'help': "Enable tracing and store result in .PID.json", 'action': 'store_true'}, 'disable-cone-beam-weight': { 'default': False, 'action': 'store_true', 'help': "Disable cone beam weighting"}, 'slice-memory-coeff': { 'default': 0.8, 'ezdefault': 0.7, 'type': restrict_value((0.01, 0.95)), 'help': "Portion of the GPU memory used for slices (from 0.01 to 0.9) [fraction]. " "The total amount of consumed memory will be larger depending on the " "complexity of the graph. In case of OpenCL memory allocation errors, " "try reducing this value."}, 'num-gpu-threads': { 'default': 1, 'ezdefault': None, 'type': restrict_value((1, None), dtype=int), 'help': "Number of parallel reconstruction threads on one GPU"}, 'disable-projection-crop': { 'default': False, 'action': 'store_true', 'help': "Disable automatic cropping of projections computed from volume region"}, 'dry-run': { 'default': False, 'help': "Reconstruct without reading or writing data", 'action': 'store_true'}, 'data-splitting-policy': { 'default': 'one', 'ezdefault': 'one', 'type': str, 'help': "'one': one GPU should process as many slices as possible, " "'many': slices should be spread across as many GPUs as possible", 'choices': ['one', 'many']}, 'projection-margin': { 'default': 0, 'type': restrict_value((0, None), dtype=int), 'help': "By optimization of the read projection region, the read region will be " "[y - margin, y + height + margin]"}, 'slices-per-device': { 'default': None, 'ezdefault': None, 'type': restrict_value((0, None), dtype=int), 'help': "Number of slices computed by one computing device"}, 'gpus': { 'default': None, 'nargs': '+', 'type': int, 'help': "GPUs with these indices will be used (0-based)"}, 'burst': { 'default': None, 'type': restrict_value((0, None), dtype=int), 'help': "Number of projections processed per kernel invocation"}, 'x-region': { 'default': "0,-1,1", 'type': tupleize(num_items=3), 'help': "x region as from,to,step"}, 'y-region': { 'default': "0,-1,1", 'type': tupleize(num_items=3), 'help': "y region as from,to,step"}, 'z': { 'default': 0, 'type': int, 'help': "z coordinate of the reconstructed slice"}, 'z-parameter': { 'default': 'z', 'type': str, 'choices': ['axis-angle-x', 'axis-angle-y', 'axis-angle-z', 'volume-angle-x', 'volume-angle-y', 'volume-angle-z', 'detector-angle-x', 'detector-angle-y', 'detector-angle-z', 'detector-position-x', 'detector-position-y', 'detector-position-z', 'source-position-x', 'source-position-y', 'source-position-z', 'center-position-x', 'center-position-z', 'z'], 'help': "Parameter to vary along the reconstructed z-axis"}, 'region': { 'default': "0,1,1", 'type': tupleize(num_items=3), 'help': "z axis parameter region as from,to,step"}, 'source-position-x': { 'default': "0", 'type': tupleize(dtype=list), 'help': "X source position (horizontal) in global coordinates [pixels]"}, 'source-position-z': { 'default': "0", 'type': tupleize(dtype=list), 'help': "Z source position (vertical) in global coordinates [pixels]"}, 'detector-position-x': { 'default': "0", 'type': tupleize(dtype=list), 'help': "X detector position (horizontal) in global coordinates [pixels]"}, 'detector-position-z': { 'default': "0", 'type': tupleize(dtype=list), 'help': "Z detector position (vertical) in global coordinates [pixels]"}, 'detector-angle-x': { 'default': "0", 'type': tupleize(dtype=list), 'help': "Detector rotation around the x axis (horizontal) [deg]"}, 'detector-angle-y': { 'default': "0", 'type': tupleize(dtype=list), 'help': "Detector rotation around the y axis (along beam direction) [deg]"}, 'detector-angle-z': { 'default': "0", 'type': tupleize(dtype=list), 'help': "Detector rotation around the z axis (vertical) [deg]"}, 'axis-angle-y': { 'default': "0", 'ezdefault': "0", 'type': tupleize(dtype=list), 'help': "Rotation axis rotation around the y axis (along beam direction) [deg]"}, 'axis-angle-z': { 'default': "0", 'type': tupleize(dtype=list), 'help': "Rotation axis rotation around the z axis (vertical) [deg]"}, 'volume-angle-x': { 'default': "0", 'type': tupleize(dtype=list), 'help': "Volume rotation around the x axis (horizontal) [deg]"}, 'volume-angle-y': { 'default': "0", 'type': tupleize(dtype=list), 'help': "Volume rotation around the y axis (along beam direction) [deg]"}, 'volume-angle-z': { 'default': "0", 'ezdefault': "0.0", 'type': tupleize(dtype=list), 'help': "Volume rotation around the z axis (vertical) [deg]"}, 'compute-type': { 'default': 'float', 'type': str, 'help': "Data type for performing kernel math operations", 'choices': ['half', 'float', 'double']}, 'result-type': { 'default': 'float', 'type': str, 'help': "Data type for storing the intermediate gray value for a voxel " "from various rotation angles", 'choices': ['half', 'float', 'double']}, 'store-type': { 'default': 'float', 'type': str, 'help': "Data type of the output volume", 'choices': ['half', 'float', 'double', 'uchar', 'ushort', 'uint']}, 'overall-angle': { 'default': None, 'ezdefault': 360, 'type': float, 'help': "The total angle over which projections were taken in degrees"}, 'genreco-padding-mode': { 'choices': ['none', 'clamp', 'clamp_to_edge', 'repeat', 'mirrored_repeat'], 'default': 'clamp', 'help': "Padded values assignment for the filtered projection"}, 'slice-gray-map': { 'default': "0,0", 'type': tupleize(num_items=2, conv=float), 'help': "Minimum and maximum gray value mapping if store-type is integer-based"} } SECTIONS['find-large-spots'] = { 'method': { 'default': 'grow', 'type': str, 'help': "Data type of the output volume", 'choices': ['grow', 'median']}, # median arguments 'median-width': { 'default': 10, 'type': int, 'help': "Width of the median filter (operates only horizontally)"}, 'dilation-disk-radius': { 'default': 2, 'type': int, 'help': "Dilation disk radius used for enlarging the found mask"}, # grow arguments 'images': { 'default': None, 'type': str, 'help': "Location with input images", 'metavar': 'PATH'}, 'transpose-input': { 'default': False, 'action': 'store_true', 'help': "Transpose image when *vertical_sigma* is True, i.e. filter horizontal stripes " "instead of vertical"}, 'gauss-sigma': { 'default': 0.0, 'ezdefault': 2.0, 'type': float, 'help': "Gaussian sigma for removing low frequencies (filter will be 1 - gauss window)"}, 'vertical-sigma': { 'default': False, 'action': 'store_true', 'help': "*gauss-sigma* will be used for removing low frequencies in a horizontal stripe " "(vertical Gaussian profile applied around frequency ky=0 for all kx in a 1 - " "gauss window fashion)"}, 'blurred-output': { 'default': None, 'type': str, 'help': "Path where to store the blurred input"}, 'spot-threshold': { 'default': 0.0, 'ezdefault': 1000, 'type': float, 'help': "Pixels with grey value larger than this are considered as spots"}, 'spot-threshold-mode': { 'default': 'absolute', 'type': str, 'help': "Pixels must be either \"below\", \"above\" the spot threshold, or \ their \"absolute\" value can be compared", 'choices': ['below', 'above', 'absolute']}, 'grow-threshold': { 'default': 0.0, 'type': float, 'help': "Spot growing threshold, if 0 it will be set to FWTM times noise standard deviation"}, 'find-large-spots-padding-mode': { 'choices': ['none', 'clamp', 'clamp_to_edge', 'repeat', 'mirrored_repeat'], 'default': 'repeat', 'help': "Padded values assignment for the filtered input image"}, } SECTIONS['inpaint'] = { 'projections': { 'default': None, 'type': str, 'help': "Location with projections", 'metavar': 'PATH'}, 'guidance-image': { 'default': None, 'type': str, 'help': "Guidance image, structure which will be inpainted into input images"}, 'mask-image': { 'default': None, 'type': str, 'help': "Mask image, pixels with ones will use the guidance image, pixels with zeros \ the original image"}, 'inpaint-padded-width': { 'default': 0, 'type': restrict_value((0, None), dtype=int), 'help': "Padded width used for inpainting"}, 'inpaint-padded-height': { 'default': 0, 'type': restrict_value((0, None), dtype=int), 'help': "Padded height used for inpainting"}, 'inpaint-padding-mode': { 'choices': ['none', 'clamp', 'clamp_to_edge', 'repeat', 'mirrored_repeat'], 'default': 'clamp_to_edge', 'help': "Padded values assignment for inpainting"}, 'preserve-mean': { 'default': False, 'action': 'store_true', 'help': "Mean value of the inpainted result will be the same as the one of the input"}, 'harmonize-borders': { 'default': False, 'action': 'store_true', 'help': "Harmonize transitions between image borders useful for the removal of the " "cross in the power spectrum"}, } TOMO_PARAMS = ('flat-correction', 'reconstruction', 'tomographic-reconstruction', 'fbp', 'dfi', 'ir', 'sart', 'sbtv') PREPROC_PARAMS = ('preprocess', 'cone-beam-weight', 'flat-correction', 'retrieve-phase') LAMINO_PARAMS = PREPROC_PARAMS + ('laminographic-reconstruction',) GEN_RECO_PARAMS = PREPROC_PARAMS + ('general-reconstruction',) NICE_NAMES = ('General', 'Input', 'Flat field correction', 'Phase retrieval', 'Sinogram generation', 'General reconstruction', 'Tomographic reconstruction', 'Laminographic reconstruction', 'Filtered backprojection', 'Direct Fourier Inversion', 'Iterative reconstruction', 'SART', 'SBTV', 'GUI settings', 'Estimation', 'Performance', 'Preprocess', 'Cone beam weight', 'General reconstruction', 'Find large spots', 'Inpaint') def get_config_name(): """Get the command line --config option.""" name = '' for i, arg in enumerate(sys.argv): if arg.startswith('--config'): if arg == '--config': return sys.argv[i + 1] else: name = sys.argv[i].split('--config')[1] if name[0] == '=': name = name[1:] return name return name def parse_known_args(parser, subparser=False): """ Parse arguments from file and then override by the ones specified on the command line. Use *parser* for parsing and is *subparser* is True take into account that there is a value on the command line specifying the subparser. """ if len(sys.argv) > 1: subparser_value = [sys.argv[1]] if subparser else [] config_values = config_to_list(config_name=get_config_name()) values = subparser_value + config_values + sys.argv[1:] args = None if config_values: args = parser.parse_known_args(args=subparser_value + config_values)[0] parser.parse_args(args=sys.argv[1:], namespace=args) else: values = "" return parser.parse_known_args(values)[0] def config_to_list(config_name=''): """ Read arguments from config file and convert them to a list of keys and values as sys.argv does when they are specified on the command line. *config_name* is the file name of the config file. """ result = [] config = configparser.ConfigParser() if not config.read([config_name]): return [] for section in SECTIONS: for name, opts in ((n, o) for n, o in list(SECTIONS[section].items()) if config.has_option(section, n)): value = config.get(section, name) if value != '' and value != 'None': action = opts.get('action', None) if action == 'store_true' and value == 'True': # Only the key is on the command line for this action result.append('--{}'.format(name)) if not action == 'store_true': if opts.get('nargs', None) == '+': result.append('--{}'.format(name)) result.extend((v.strip() for v in value.split(','))) else: result.append('--{}={}'.format(name, value)) return result def without_keys(d, keys): return {k: v for k, v in d.items() if k not in keys} class Params(object): def __init__(self, sections=()): self.sections = sections + ('general', 'reading') def add_parser_args(self, parser): for section in self.sections: for name in sorted(SECTIONS[section]): opts = without_keys(SECTIONS[section][name], {'ezdefault'}) parser.add_argument('--{}'.format(name), **opts) def add_arguments(self, parser): self.add_parser_args(parser) return parser def get_defaults(self): parser = argparse.ArgumentParser() self.add_arguments(parser) return parser.parse_args('') def write(config_file, args=None, sections=None): """ Write *config_file* with values from *args* if they are specified, otherwise use the defaults. If *sections* are specified, write values from *args* only to those sections, use the defaults on the remaining ones. """ config = configparser.ConfigParser() for section in SECTIONS: config.add_section(section) for name, opts in list(SECTIONS[section].items()): if args and sections and section in sections and hasattr(args, name.replace('-', '_')): value = getattr(args, name.replace('-', '_')) if isinstance(value, list): value = ', '.join(value) else: value = opts['default'] if opts['default'] != None else '' prefix = '# ' if value == '' else '' if name != 'config': config.set(section, prefix + name, value) with open(config_file, 'wb') as f: config.write(f) def log_values(args): """Log all values set in the args namespace. Arguments are grouped according to their section and logged alphabetically using the DEBUG log level thus --verbose is required. """ args = args.__dict__ for section, name in zip(SECTIONS, NICE_NAMES): entries = sorted((k for k in list(args.keys()) if k.replace('_', '-') in SECTIONS[section])) if entries: LOG.debug(name) for entry in entries: value = args[entry] if args[entry] is not None else "-" LOG.debug(" {:<16} {}".format(entry, value)) ././@PaxHeader0000000000000000000000000000003300000000000011451 xustar000000000000000027 mtime=1698416097.765776 ufo-tofu-0.13.0/tofu/ez/0000775000175000017500000000000000000000000015225 5ustar00tomastomas00000000000000././@PaxHeader0000000000000000000000000000003300000000000011451 xustar000000000000000027 mtime=1698416097.765776 ufo-tofu-0.13.0/tofu/ez/GUI/0000775000175000017500000000000000000000000015651 5ustar00tomastomas00000000000000././@PaxHeader0000000000000000000000000000003300000000000011451 xustar000000000000000027 mtime=1698416097.765776 ufo-tofu-0.13.0/tofu/ez/GUI/Advanced/0000775000175000017500000000000000000000000017356 5ustar00tomastomas00000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1683790582.0 ufo-tofu-0.13.0/tofu/ez/GUI/Advanced/__init__.py0000664000175000017500000000000000000000000021455 0ustar00tomastomas00000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414563.0 ufo-tofu-0.13.0/tofu/ez/GUI/Advanced/advanced.py0000664000175000017500000001503600000000000021502 0ustar00tomastomas00000000000000import logging from PyQt5.QtWidgets import QGridLayout, QLabel, QGroupBox, QLineEdit from tofu.ez.params import EZVARS from tofu.config import SECTIONS from tofu.ez.util import add_value_to_dict_entry, get_double_validator, reverse_tupleize LOG = logging.getLogger(__name__) class AdvancedGroup(QGroupBox): """ Advanced Tofu Reco settings """ def __init__(self): super().__init__() self.setTitle("Advanced TOFU Reconstruction Settings") self.setStyleSheet("QGroupBox {color: green;}") # LAMINO self.lamino_group = QGroupBox("Extended Settings of Reconstruction Algorithms") self.lamino_group.clicked.connect(self.set_lamino_group) self.lamino_angle_label = QLabel("Laminographic angle ") self.lamino_angle_entry = QLineEdit() self.lamino_angle_entry.setValidator(get_double_validator()) self.lamino_angle_entry.editingFinished.connect(self.set_lamino_angle) self.overall_rotation_label = QLabel("Overall rotation range about CT Z-axis") self.overall_rotation_entry = QLineEdit() self.overall_rotation_entry.setValidator(get_double_validator()) self.overall_rotation_entry.editingFinished.connect(self.set_overall_rotation) self.center_position_z_label = QLabel("Center Position Z ") self.center_position_z_entry = QLineEdit() self.center_position_z_entry.setValidator(get_double_validator()) self.center_position_z_entry.editingFinished.connect(self.set_center_position_z) self.axis_rotation_y_label = QLabel( "Sample rotation about the beam Y-axis " ) self.axis_rotation_y_entry = QLineEdit() self.axis_rotation_y_entry.editingFinished.connect(self.set_rotation_about_beam) # AUXILIARY FFC self.dark_scale_label = QLabel("Dark scale ") self.dark_scale_entry = QLineEdit() self.dark_scale_entry.setValidator(get_double_validator()) self.dark_scale_entry.editingFinished.connect(self.set_dark_scale) self.flat_scale_label = QLabel("Flat scale ") self.flat_scale_entry = QLineEdit() self.flat_scale_entry.setValidator(get_double_validator()) self.flat_scale_entry.editingFinished.connect(self.set_flat_scale) self.set_layout() def set_layout(self): layout = QGridLayout() self.lamino_group.setCheckable(True) self.lamino_group.setChecked(False) lamino_layout = QGridLayout() lamino_layout.addWidget(self.lamino_angle_label, 0, 0) lamino_layout.addWidget(self.lamino_angle_entry, 0, 1) lamino_layout.addWidget(self.overall_rotation_label, 1, 0) lamino_layout.addWidget(self.overall_rotation_entry, 1, 1) lamino_layout.addWidget(self.center_position_z_label, 2, 0) lamino_layout.addWidget(self.center_position_z_entry, 2, 1) lamino_layout.addWidget(self.axis_rotation_y_label, 3, 0) lamino_layout.addWidget(self.axis_rotation_y_entry, 3, 1) self.lamino_group.setLayout(lamino_layout) aux_group = QGroupBox("Auxiliary FFC Settings") aux_group.setCheckable(True) aux_group.setChecked(False) aux_layout = QGridLayout() aux_layout.addWidget(self.dark_scale_label, 0, 0) aux_layout.addWidget(self.dark_scale_entry, 0, 1) aux_layout.addWidget(self.flat_scale_label, 1, 0) aux_layout.addWidget(self.flat_scale_entry, 1, 1) aux_group.setLayout(aux_layout) layout.addWidget(self.lamino_group) layout.addWidget(aux_group) self.setLayout(layout) def load_values(self): self.lamino_group.setChecked(EZVARS['advanced']['more-reco-params']['value']) self.lamino_angle_entry.setText(str(reverse_tupleize()(SECTIONS['cone-beam-weight']['axis-angle-x']['value']))) self.overall_rotation_entry.setText(str(SECTIONS['general-reconstruction']['overall-angle']['value'])) self.center_position_z_entry.setText(str(reverse_tupleize()(SECTIONS['cone-beam-weight']['center-position-z']['value']))) self.axis_rotation_y_entry.setText(str(reverse_tupleize()(SECTIONS['general-reconstruction']['axis-angle-y']['value']))) self.dark_scale_entry.setText(str(EZVARS['flat-correction']['dark-scale']['value'])) self.flat_scale_entry.setText(str(EZVARS['flat-correction']['flat-scale']['value'])) def set_lamino_group(self): LOG.debug("Lamino: " + str(self.lamino_group.isChecked())) dict_entry = EZVARS['advanced']['more-reco-params'] add_value_to_dict_entry(dict_entry, self.lamino_group.isChecked()) def set_lamino_angle(self): LOG.debug(self.lamino_angle_entry.text()) dict_entry = SECTIONS['cone-beam-weight']['axis-angle-x'] add_value_to_dict_entry(dict_entry, str(self.lamino_angle_entry.text())) self.lamino_angle_entry.setText(str(reverse_tupleize()(dict_entry['value']))) def set_overall_rotation(self): LOG.debug(self.overall_rotation_entry.text()) dict_entry = SECTIONS['general-reconstruction']['overall-angle'] add_value_to_dict_entry(dict_entry, str(self.overall_rotation_entry.text())) self.overall_rotation_entry.setText(str(dict_entry['value'])) def set_center_position_z(self): LOG.debug(self.center_position_z_entry.text()) dict_entry = SECTIONS['cone-beam-weight']['center-position-z'] add_value_to_dict_entry(dict_entry, str(self.center_position_z_entry.text())) self.center_position_z_entry.setText(str(reverse_tupleize()(dict_entry['value']))) def set_rotation_about_beam(self): LOG.debug(self.axis_rotation_y_entry.text()) dict_entry = SECTIONS['general-reconstruction']['axis-angle-y'] add_value_to_dict_entry(dict_entry, str(self.axis_rotation_y_entry.text())) self.axis_rotation_y_entry.setText(str(reverse_tupleize()(dict_entry['value']))) def set_dark_scale(self): LOG.debug(self.dark_scale_entry.text()) dict_entry = EZVARS['flat-correction']['dark-scale'] add_value_to_dict_entry(dict_entry, str(self.dark_scale_entry.text())) self.dark_scale_entry.setText(str(dict_entry['value'])) def set_flat_scale(self): LOG.debug(self.flat_scale_entry.text()) dict_entry = EZVARS['flat-correction']['flat-scale'] add_value_to_dict_entry(dict_entry, str(self.flat_scale_entry.text())) self.flat_scale_entry.setText(str(dict_entry['value'])) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414563.0 ufo-tofu-0.13.0/tofu/ez/GUI/Advanced/ffc.py0000664000175000017500000001313200000000000020466 0ustar00tomastomas00000000000000import logging from PyQt5.QtWidgets import ( QGridLayout, QLabel, QGroupBox, QLineEdit, QCheckBox, QRadioButton, QHBoxLayout, ) from tofu.ez.params import EZVARS from tofu.ez.util import add_value_to_dict_entry, get_int_validator LOG = logging.getLogger(__name__) class FFCGroup(QGroupBox): """ Flat Field Correction Settings """ def __init__(self): super().__init__() self.setTitle("Flat Field Correction") self.setStyleSheet("QGroupBox {color: indigo;}") self.method_label = QLabel("Method:") self.average_rButton = QRadioButton("Average") self.average_rButton.clicked.connect(self.set_method) self.ssim_rButton = QRadioButton("SSIM") self.ssim_rButton.clicked.connect(self.set_method) self.eigen_rButton = QRadioButton("Eigen") self.eigen_rButton.clicked.connect(self.set_method) self.enable_sinFFC_checkbox = QCheckBox( "Use Smart Intensity Normalization Flat Field Correction" ) self.enable_sinFFC_checkbox.stateChanged.connect(self.set_sinFFC) self.eigen_pco_repetitions_label = QLabel("Eigen PCO Repetitions") self.eigen_pco_repetitions_entry = QLineEdit() self.eigen_pco_repetitions_entry.setValidator(get_int_validator()) self.eigen_pco_repetitions_entry.editingFinished.connect(self.set_pcoReps) self.eigen_pco_downsample_label = QLabel("Eigen PCO Downsample") self.eigen_pco_downsample_entry = QLineEdit() self.eigen_pco_downsample_entry.setValidator(get_int_validator()) self.eigen_pco_downsample_entry.editingFinished.connect(self.set_pcoDowns) self.downsample_label = QLabel("Downsample") self.downsample_entry = QLineEdit() self.downsample_entry.setValidator(get_int_validator()) self.downsample_entry.editingFinished.connect(self.set_downsample) self.set_layout() def set_layout(self): layout = QGridLayout() rbutton_layout = QHBoxLayout() rbutton_layout.addWidget(self.method_label) rbutton_layout.addWidget(self.eigen_rButton) rbutton_layout.addWidget(self.average_rButton) rbutton_layout.addWidget(self.ssim_rButton) layout.addWidget(self.enable_sinFFC_checkbox, 0, 0) layout.addItem(rbutton_layout, 1, 0, 1, 2) layout.addWidget(self.eigen_pco_repetitions_label, 2, 0) layout.addWidget(self.eigen_pco_repetitions_entry, 2, 1) layout.addWidget(self.eigen_pco_downsample_label, 3, 0) layout.addWidget(self.eigen_pco_downsample_entry, 3, 1) layout.addWidget(self.downsample_label, 4, 0) layout.addWidget(self.downsample_entry, 4, 1) self.setLayout(layout) def load_values(self): self.enable_sinFFC_checkbox.setChecked(EZVARS['flat-correction']['smart-ffc']['value']) self.set_method_from_params() self.eigen_pco_repetitions_entry.setText(str(EZVARS['flat-correction']['eigen-pco-reps']['value'])) self.eigen_pco_downsample_entry.setText(str(EZVARS['flat-correction']['eigen-pco-downsample']['value'])) self.downsample_entry.setText(str(EZVARS['flat-correction']['downsample']['value'])) def set_sinFFC(self): LOG.debug("sinFFC: " + str(self.enable_sinFFC_checkbox.isChecked())) dict_entry = EZVARS['flat-correction']['smart-ffc'] add_value_to_dict_entry(dict_entry, self.enable_sinFFC_checkbox.isChecked()) def set_pcoReps(self): LOG.debug("PCO Reps: " + str(self.eigen_pco_repetitions_entry.text())) dict_entry = EZVARS['flat-correction']['eigen-pco-reps'] add_value_to_dict_entry(dict_entry, str(self.eigen_pco_repetitions_entry.text())) self.eigen_pco_repetitions_entry.setText(str(dict_entry['value'])) def set_pcoDowns(self): LOG.debug("PCO Downsample: " + str(self.eigen_pco_downsample_entry.text())) dict_entry = EZVARS['flat-correction']['eigen-pco-downsample'] add_value_to_dict_entry(dict_entry, str(self.eigen_pco_downsample_entry.text())) self.eigen_pco_downsample_entry.setText(str(dict_entry['value'])) def set_downsample(self): LOG.debug("Downsample: " + str(self.downsample_entry.text())) dict_entry = EZVARS['flat-correction']['downsample'] add_value_to_dict_entry(dict_entry, str(self.downsample_entry.text())) self.downsample_entry.setText(str(dict_entry['value'])) def set_method(self): if self.eigen_rButton.isChecked(): LOG.debug("Method: Eigen") EZVARS['flat-correction']['smart-ffc-method']['value'] = "eigen" elif self.average_rButton.isChecked(): LOG.debug("Method: Average") EZVARS['flat-correction']['smart-ffc-method']['value'] = "average" elif self.ssim_rButton.isChecked(): LOG.debug("Method: SSIM") EZVARS['flat-correction']['smart-ffc-method']['value'] = "ssim" def set_method_from_params(self): if EZVARS['flat-correction']['smart-ffc-method']['value'] == "eigen": self.eigen_rButton.setChecked(True) self.average_rButton.setChecked(False) self.ssim_rButton.setChecked(False) elif EZVARS['flat-correction']['smart-ffc-method']['value'] == "average": self.eigen_rButton.setChecked(False) self.average_rButton.setChecked(True) self.ssim_rButton.setChecked(False) elif EZVARS['flat-correction']['smart-ffc-method']['value'] == "ssim": self.eigen_rButton.setChecked(False) self.average_rButton.setChecked(False) self.ssim_rButton.setChecked(True) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414563.0 ufo-tofu-0.13.0/tofu/ez/GUI/Advanced/nlmdn.py0000664000175000017500000003437200000000000021051 0ustar00tomastomas00000000000000import logging import os from shutil import rmtree from PyQt5.QtWidgets import ( QGridLayout, QLabel, QGroupBox, QLineEdit, QCheckBox, QPushButton, QFileDialog, QMessageBox, ) from PyQt5.QtCore import Qt from tofu.ez.ufo_cmd_gen import fmt_nlmdn_ufo_cmd from tofu.ez.params import EZVARS from tofu.ez.util import add_value_to_dict_entry, get_int_validator, get_double_validator LOG = logging.getLogger(__name__) class NLMDNGroup(QGroupBox): """ Non-local means de-noising settings """ def __init__(self): super().__init__() self.setTitle("Non-local-means Denoising") self.setStyleSheet("QGroupBox {color: royalblue;}") self.apply_to_reco_checkbox = QCheckBox("Automatically apply NLMDN to reconstructed slices") self.apply_to_reco_checkbox.stateChanged.connect(self.set_apply_to_reco) self.input_dir_button = QPushButton("Select input directory") self.input_dir_button.clicked.connect(self.set_indir_button) self.select_img_button = QPushButton("Select one image") self.select_img_button.clicked.connect(self.select_image) self.input_dir_entry = QLineEdit() self.input_dir_entry.editingFinished.connect(self.set_indir_entry) self.output_dir_button = QPushButton("Select output directory or filename pattern") self.output_dir_button.clicked.connect(self.set_outdir_button) self.save_bigtif_checkbox = QCheckBox("Save in bigtif container") self.save_bigtif_checkbox.clicked.connect(self.set_save_bigtif) self.output_dir_entry = QLineEdit() self.output_dir_entry.editingFinished.connect(self.set_outdir_entry) self.similarity_radius_label = QLabel("Radius for similarity search") self.similarity_radius_entry = QLineEdit() self.similarity_radius_entry.setValidator(get_int_validator()) self.similarity_radius_entry.editingFinished.connect(self.set_rad_sim_entry) self.patch_radius_label = QLabel("Radius of patches") self.patch_radius_entry = QLineEdit() self.patch_radius_entry.setValidator(get_int_validator()) self.patch_radius_entry.editingFinished.connect(self.set_rad_patch_entry) self.smoothing_label = QLabel("Smoothing control parameter") self.smoothing_entry = QLineEdit() self.smoothing_entry.setValidator(get_double_validator()) self.smoothing_entry.editingFinished.connect(self.set_smoothing_entry) self.noise_std_label = QLabel("Noise standard deviation") self.noise_std_entry = QLineEdit() self.noise_std_entry.setValidator(get_double_validator()) self.noise_std_entry.editingFinished.connect(self.set_noise_entry) self.window_label = QLabel("Window (optional)") self.window_entry = QLineEdit() self.window_entry.setValidator(get_double_validator()) self.window_entry.editingFinished.connect(self.set_window_entry) self.fast_checkbox = QCheckBox("Fast") self.fast_checkbox.clicked.connect(self.set_fast_checkbox) self.sigma_checkbox = QCheckBox("Estimate sigma") self.sigma_checkbox.clicked.connect(self.set_sigma_checkbox) self.help_button = QPushButton("Help") self.help_button.clicked.connect(self.help_button_pressed) self.delete_button = QPushButton("Delete reco dir") self.delete_button.clicked.connect(self.delete_button_pressed) self.dry_button = QPushButton("Dry run") self.dry_button.clicked.connect(self.dry_button_pressed) self.apply_button = QPushButton("Apply filter") self.apply_button.clicked.connect(self.apply_button_pressed) # self.apply_button.setStyleSheet("color:royalblue; font-weight: bold;") self.set_layout() def set_layout(self): layout = QGridLayout() layout.addWidget(self.apply_to_reco_checkbox, 0, 0, 1, 1) layout.addWidget(self.input_dir_button, 1, 0, 1, 2) layout.addWidget(self.select_img_button, 1, 2, 1, 2) layout.addWidget(self.input_dir_entry, 2, 0, 1, 4) layout.addWidget(self.output_dir_button, 3, 0, 1, 2) layout.addWidget(self.save_bigtif_checkbox, 3, 2, 1, 2, Qt.AlignCenter) layout.addWidget(self.output_dir_entry, 4, 0, 1, 4) layout.addWidget(self.similarity_radius_label, 5, 0, 1, 2) layout.addWidget(self.similarity_radius_entry, 5, 2, 1, 2) layout.addWidget(self.patch_radius_label, 6, 0, 1, 2) layout.addWidget(self.patch_radius_entry, 6, 2, 1, 2) layout.addWidget(self.smoothing_label, 7, 0, 1, 2) layout.addWidget(self.smoothing_entry, 7, 2, 1, 2) layout.addWidget(self.noise_std_label, 8, 0, 1, 2) layout.addWidget(self.noise_std_entry, 8, 2, 1, 2) layout.addWidget(self.window_label, 9, 0, 1, 2) layout.addWidget(self.window_entry, 9, 2, 1, 2) layout.addWidget(self.fast_checkbox, 10, 0, 1, 2, Qt.AlignCenter) layout.addWidget(self.sigma_checkbox, 10, 2, 1, 2, Qt.AlignCenter) layout.addWidget(self.help_button, 11, 0, 1, 1) layout.addWidget(self.delete_button, 11, 1) layout.addWidget(self.dry_button, 11, 2) layout.addWidget(self.apply_button, 11, 3) self.setLayout(layout) def load_values(self): self.apply_to_reco_checkbox.setChecked(bool(EZVARS['nlmdn']['do-after-reco']['value'])) self.input_dir_entry.setText(str(EZVARS['nlmdn']['input-dir']['value'])) self.output_dir_entry.setText(str(EZVARS['nlmdn']['output_pattern']['value'])) self.save_bigtif_checkbox.setChecked(bool(EZVARS['nlmdn']['bigtiff_output']['value'])) self.similarity_radius_entry.setText(str(EZVARS['nlmdn']['search-radius']['value'])) self.patch_radius_entry.setText(str(EZVARS['nlmdn']['patch-radius']['value'])) self.smoothing_entry.setText(str(EZVARS['nlmdn']['h']['value'])) self.noise_std_entry.setText(str(EZVARS['nlmdn']['sigma']['value'])) self.window_entry.setText(str(EZVARS['nlmdn']['window']['value'])) self.fast_checkbox.setChecked(bool(EZVARS['nlmdn']['fast']['value'])) self.sigma_checkbox.setChecked(bool(EZVARS['nlmdn']['estimate-sigma']['value'])) def set_apply_to_reco(self): LOG.debug( "Apply NLMDN to reconstructed slices checkbox: " + str(self.apply_to_reco_checkbox.isChecked()) ) dict_entry = EZVARS['nlmdn']['do-after-reco'] add_value_to_dict_entry(dict_entry, self.apply_to_reco_checkbox.isChecked()) if self.apply_to_reco_checkbox.isChecked(): self.input_dir_button.setDisabled(True) self.select_img_button.setDisabled(True) self.input_dir_entry.setDisabled(True) self.dry_button.setDisabled(True) self.apply_button.setDisabled(True) self.output_dir_button.setDisabled(True) self.output_dir_entry.setDisabled(True) elif not self.apply_to_reco_checkbox.isChecked(): self.input_dir_button.setDisabled(False) self.select_img_button.setDisabled(False) self.input_dir_entry.setDisabled(False) self.dry_button.setDisabled(False) self.apply_button.setDisabled(False) self.output_dir_button.setDisabled(False) self.output_dir_entry.setDisabled(False) def set_indir_button(self): """ Saves directory specified by user in file-dialog for input tomographic data """ LOG.debug("Select input directory pressed") dir_explore = QFileDialog(self) directory = dir_explore.getExistingDirectory() if directory: self.input_dir_entry.setText(str(directory)) self.set_indir_entry() self.output_dir_entry.setText(str(os.path.join(directory+'-nlmdn', 'im-%05i.tif'))) self.set_outdir_entry() dict_entry = EZVARS['nlmdn']['input-is-1file'] add_value_to_dict_entry(dict_entry, False) def set_indir_entry(self): LOG.debug("Indir entry: " + str(self.input_dir_entry.text())) dict_entry = EZVARS['nlmdn']['input-dir'] dir = self.input_dir_entry.text().strip() add_value_to_dict_entry(dict_entry, str(dir)) self.input_dir_entry.setText(str(dict_entry['value'])) def select_image(self): LOG.debug("Select one image button pressed") options = QFileDialog.Options() file_path, _ = QFileDialog.getOpenFileName( self, "Open .tif Image File", "", "Tiff Files (*.tif *.tiff)", options=options ) if file_path: img_name, img_ext = os.path.splitext(file_path) tmp = img_name + "-nlmfilt" + img_ext self.input_dir_entry.setText(str(file_path)) self.set_indir_entry() self.output_dir_entry.setText(str(tmp)) self.set_outdir_entry() dict_entry = EZVARS['nlmdn']['input-is-1file'] add_value_to_dict_entry(dict_entry, True) def set_outdir_button(self): LOG.debug("Select output directory pressed") dir_explore = QFileDialog(self) directory = dir_explore.getExistingDirectory() if directory: self.output_dir_entry.setText(str(os.path.join(directory,'im-nlmdn-%05i.tif'))) self.set_outdir_entry() def set_save_bigtif(self): LOG.debug("Save bigtif checkbox: " + str(self.save_bigtif_checkbox.isChecked())) dict_entry = EZVARS['nlmdn']['bigtiff_output'] add_value_to_dict_entry(dict_entry, self.save_bigtif_checkbox.isChecked()) def set_outdir_entry(self): LOG.debug("Outdir entry: " + str(self.output_dir_entry.text())) dict_entry = EZVARS['nlmdn']['output_pattern'] dir = self.output_dir_entry.text().strip() add_value_to_dict_entry(dict_entry, str(dir)) self.output_dir_entry.setText(str(dict_entry['value'])) def set_rad_sim_entry(self): LOG.debug("Radius for similarity: " + str(self.similarity_radius_entry.text())) dict_entry = EZVARS['nlmdn']['search-radius'] add_value_to_dict_entry(dict_entry, str(self.similarity_radius_entry.text())) self.similarity_radius_entry.setText(str(dict_entry['value'])) def set_rad_patch_entry(self): LOG.debug("Radius of patches: " + str(self.patch_radius_entry.text())) dict_entry = EZVARS['nlmdn']['patch-radius'] add_value_to_dict_entry(dict_entry, str(self.patch_radius_entry.text())) self.patch_radius_entry.setText(str(dict_entry['value'])) def set_smoothing_entry(self): LOG.debug("Smoothing control: " + str(self.smoothing_entry.text())) dict_entry = EZVARS['nlmdn']['h'] add_value_to_dict_entry(dict_entry, str(self.smoothing_entry.text())) self.smoothing_entry.setText(str(dict_entry['value'])) def set_noise_entry(self): LOG.debug("Noise std: " + str(self.noise_std_entry.text())) dict_entry = EZVARS['nlmdn']['sigma'] add_value_to_dict_entry(dict_entry, str(self.noise_std_entry.text())) self.noise_std_entry.setText(str(dict_entry['value'])) def set_window_entry(self): LOG.debug("Window: " + str(self.window_entry.text())) dict_entry = EZVARS['nlmdn']['window'] add_value_to_dict_entry(dict_entry, str(self.window_entry.text())) self.window_entry.setText(str(dict_entry['value'])) def set_fast_checkbox(self): LOG.debug("Fast: " + str(self.fast_checkbox.isChecked())) dict_entry = EZVARS['nlmdn']['fast'] add_value_to_dict_entry(dict_entry, self.fast_checkbox.isChecked()) def set_sigma_checkbox(self): LOG.debug("Estimate sigma: " + str(self.sigma_checkbox.isChecked())) dict_entry = EZVARS['nlmdn']['estimate-sigma'] add_value_to_dict_entry(dict_entry, self.sigma_checkbox.isChecked()) def help_button_pressed(self): LOG.debug("Help Button Pressed") h = "" h += 'Note4: set to "flats" if "flats2" exist but you need to ignore them; \n' h += "SerG, BMIT CLS, Dec. 2020." QMessageBox.information(self, "Help", h) def delete_button_pressed(self): LOG.debug("Delete Reco Button Pressed") """ Deletes the directory that contains reconstructed data """ LOG.debug("DELETE") msg = "Delete directory with reconstructed data?" dialog = QMessageBox.warning(self, "Warning: data can be lost", msg, QMessageBox.Yes | QMessageBox.No) if dialog == QMessageBox.Yes: if os.path.exists(str(EZVARS['nlmdn']['output_pattern']['value'])): LOG.debug("YES") if EZVARS['nlmdn']['output_pattern']['value'] == EZVARS['nlmdn']['input-dir']['value']: LOG.debug("Cannot delete: output directory is the same as input") else: rmtree(EZVARS['nlmdn']['output_pattern']['value']) LOG.debug("Directory with denoised images was removed") else: LOG.debug("Directory does not exist") else: LOG.debug("NO") def dry_button_pressed(self): LOG.debug("Dry Run Button Pressed") dict_entry = EZVARS['nlmdn']['dryrun'] add_value_to_dict_entry(dict_entry, True) self.apply_button_pressed() add_value_to_dict_entry(dict_entry, False) def apply_button_pressed(self): LOG.debug("Apply Filter Button Pressed") if os.path.exists(EZVARS['nlmdn']['output_pattern']['value']) and not \ EZVARS['nlmdn']['dryrun']['value']: title_text = "Warning: files can be overwritten" text1 = "Output directory exists. Files can be overwritten. Proceed?" dialog = QMessageBox.warning(self, title_text, text1, QMessageBox.Yes | QMessageBox.No) if dialog == QMessageBox.Yes: cmd = fmt_nlmdn_ufo_cmd(EZVARS['nlmdn']['input-dir']['value'], EZVARS['nlmdn']['output_pattern']['value']) else: cmd = fmt_nlmdn_ufo_cmd(EZVARS['nlmdn']['input-dir']['value'], EZVARS['nlmdn']['output_pattern']['value']) if EZVARS['nlmdn']['dryrun']['value']: print(cmd) else: os.system(cmd) QMessageBox.information(self, "Finished", "Finished") ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414563.0 ufo-tofu-0.13.0/tofu/ez/GUI/Advanced/optimization.py0000664000175000017500000001012600000000000022456 0ustar00tomastomas00000000000000import logging from PyQt5.QtWidgets import QGridLayout, QLabel, QGroupBox, QLineEdit, QCheckBox, QComboBox from tofu.ez.params import EZVARS from tofu.config import SECTIONS from tofu.ez.util import add_value_to_dict_entry, get_double_validator LOG = logging.getLogger(__name__) class OptimizationGroup(QGroupBox): """ Optimization settings """ def __init__(self): super().__init__() self.setTitle("Optimization Settings") self.setStyleSheet("QGroupBox {color: orange;}") self.verbose_switch = QCheckBox("Enable verbose console output") self.verbose_switch.stateChanged.connect(self.set_verbose_switch) self.slice_memory_label = QLabel("Slice memory coefficient") self.slice_memory_entry = QLineEdit() self.slice_memory_entry.setValidator(get_double_validator()) tmpstr="Fraction of VRAM which will be used to store images \n" \ "Reserve ~2 GB of VRAM for computation \n" \ "Decrease the coefficient if you have very large data and start getting errors" self.slice_memory_entry.setToolTip(tmpstr) self.slice_memory_label.setToolTip(tmpstr) self.slice_memory_entry.editingFinished.connect(self.set_slice) self.data_spllitting_policy_label = QLabel("Data Splitting Policy") self.data_spllitting_policy_combobox = QComboBox() self.data_spllitting_policy_label.setToolTip(SECTIONS['general-reconstruction']['data-splitting-policy']['help']) self.data_spllitting_policy_combobox.setToolTip(SECTIONS['general-reconstruction']['data-splitting-policy']['help']) self.data_spllitting_policy_combobox.addItems(["one","many"]) self.data_spllitting_policy_combobox.currentIndexChanged.connect(self.set_data_splitting_policy) self.set_layout() def set_layout(self): layout = QGridLayout() layout.addWidget(self.verbose_switch, 0, 0) gpu_group = QGroupBox("GPU optimization") gpu_group.setCheckable(True) gpu_group.setChecked(bool(EZVARS['advanced']['enable-optimization']['value'])) gpu_group.clicked.connect(self.set_enable_optimization) gpu_layout = QGridLayout() gpu_layout.addWidget(self.slice_memory_label, 0, 0) gpu_layout.addWidget(self.slice_memory_entry, 0, 1) gpu_layout.addWidget(self.data_spllitting_policy_label, 1, 0) gpu_layout.addWidget(self.data_spllitting_policy_combobox, 1, 1) gpu_group.setLayout(gpu_layout) layout.addWidget(gpu_group, 1, 0) self.setLayout(layout) def load_values(self): self.verbose_switch.setChecked(bool(SECTIONS['general']['verbose']['value'])) self.slice_memory_entry.setText(str(SECTIONS['general-reconstruction']['slice-memory-coeff']['value'])) idx = self.data_spllitting_policy_combobox.findText(SECTIONS['general-reconstruction']['data-splitting-policy']['value']) if idx >= 0: self.data_spllitting_policy_combobox.setCurrentIndex(idx) def set_verbose_switch(self): LOG.debug("Verbose: " + str(self.verbose_switch.isChecked())) dict_entry = SECTIONS['general']['verbose'] add_value_to_dict_entry(dict_entry, self.verbose_switch.isChecked()) def set_enable_optimization(self): checkbox = self.sender() LOG.debug("GPU Optimization: " + str(checkbox.isChecked())) dict_entry = EZVARS['advanced']['enable-optimization'] add_value_to_dict_entry(dict_entry, checkbox.isChecked()) def set_slice(self): LOG.debug(self.slice_memory_entry.text()) dict_entry = SECTIONS['general-reconstruction']['slice-memory-coeff'] add_value_to_dict_entry(dict_entry, str(self.slice_memory_entry.text())) self.slice_memory_entry.setText(str(dict_entry['value'])) def set_data_splitting_policy(self): LOG.debug(self.data_spllitting_policy_combobox.currentText()) dict_entry = SECTIONS['general-reconstruction']['data-splitting-policy'] add_value_to_dict_entry(dict_entry, str(self.data_spllitting_policy_combobox.currentText())) ././@PaxHeader0000000000000000000000000000003400000000000011452 xustar000000000000000028 mtime=1698416097.7697759 ufo-tofu-0.13.0/tofu/ez/GUI/Main/0000775000175000017500000000000000000000000016535 5ustar00tomastomas00000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1683790582.0 ufo-tofu-0.13.0/tofu/ez/GUI/Main/__init__.py0000664000175000017500000000000000000000000020634 0ustar00tomastomas00000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414563.0 ufo-tofu-0.13.0/tofu/ez/GUI/Main/batch_process.py0000664000175000017500000001157100000000000021733 0ustar00tomastomas00000000000000import yaml import logging import glob import os from PyQt5.QtWidgets import QGroupBox, QLabel, QGridLayout, QPushButton, QFileDialog, QLineEdit from tofu.ez.GUI.Main.config import ConfigGroup from tofu.ez.GUI.Stitch_tools_tab.auto_horizontal_stitch_funcs import AutoHorizontalStitchFunctions class BatchProcessGroup(QGroupBox): def __init__(self): super().__init__() self.parameters = {} self.config_group = None self.auto_stitch_funcs = None self.info_label = QLabel() self.set_info_label() self.input_dir_button = QPushButton("Select input directory") self.input_dir_button.setFixedWidth(500) self.input_dir_button.clicked.connect(self.input_dir_button_pressed) self.input_dir_entry = QLineEdit("...Enter the path to the input directory") self.input_dir_entry.setFixedWidth(450) self.input_dir_entry.textChanged.connect(self.set_input_entry) self.batch_proc_button = QPushButton("Begin Batch Process") self.batch_proc_button.clicked.connect(self.batch_proc_button_pressed) self.batch_proc_button.setStyleSheet("background-color:orangered; font-size:26px") self.batch_proc_button.setFixedHeight(100) self.set_layout() def set_layout(self): self.setMaximumSize(1000, 400) layout = QGridLayout() layout.addWidget(self.input_dir_button, 0, 0) layout.addWidget(self.input_dir_entry, 0, 1) layout.addWidget(self.info_label, 1, 0) layout.addWidget(self.batch_proc_button, 2, 0, 1, 2) self.setLayout(layout) self.show() def set_info_label(self): info_str = "EZ Batch Process allows for batch reconstruction and processing of images.\n\n" info_str += "The program reads a list of .yaml parameter files from the input directory and executes\n" \ "them sequentially in alpha-numeric order.\n" info_str += "It is the user's responsibility to name files so that they are executed in the desired order.\n" info_str += "It is suggested to prepend descriptive filenames with numbers to indicate the order.\n" \ "For example: \n\n" info_str += "00_horizontal_stitch_params.yaml\n" info_str += "01_ezufo_params.yaml\n" info_str += "02_vertical_stitch_params.yaml\n" self.info_label.setText(info_str) def input_dir_button_pressed(self): logging.debug("Input Button Pressed") dir_explore = QFileDialog(self) input_dir = dir_explore.getExistingDirectory() self.input_dir_entry.setText(input_dir) self.parameters['input_dir'] = input_dir def set_input_entry(self): logging.debug("Input Entry: " + str(self.input_dir_entry.text())) self.parameters['input_dir'] = str(self.input_dir_entry.text()) def batch_proc_button_pressed(self): logging.debug("Batch Process Button Pressed") try: param_files_list = sorted(glob.glob(os.path.join(self.parameters['input_dir'], "*.yaml"))) if len(param_files_list) == 0: print("=> Error: Did not find any .yaml files in the input directory. Please try again.") else: print("*************************************************************************") print("************************** Begin Batch Process **************************") print("*************************************************************************\n") print("=> Found the following .yaml files:") for file in param_files_list: print("--> " + file) # Open .yaml file and store the parameters try: file_in = open(file, 'r') params = yaml.load(file_in, Loader=yaml.FullLoader) except FileNotFoundError: print("Something went wrong") params_type = params['parameters_type'] print(" type: " + params_type) if params_type == "auto_horizontal_stitch": # Call functions to begin auto horizontal stitch and pass params self.auto_stitch_funcs = AutoHorizontalStitchFunctions(params) self.auto_stitch_funcs.run_horizontal_auto_stitch() elif params_type == "ez_ufo_reco": # Call functions to begin ezufo reco and pass params self.config_group = ConfigGroup() self.config_group.run_reconstruction(params, batch_run=True) elif params_type == "auto_vertical_stitch": pass # Call functions to begin auto horizontal stitch and pass params except KeyError: print("Please select an input directory") ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414563.0 ufo-tofu-0.13.0/tofu/ez/GUI/Main/centre_of_rotation.py0000664000175000017500000002062200000000000022774 0ustar00tomastomas00000000000000import logging from PyQt5.QtWidgets import QGridLayout, QLabel, QRadioButton, QGroupBox, QLineEdit, QCheckBox from tofu.ez.params import EZVARS from tofu.ez.util import add_value_to_dict_entry, get_int_validator, get_tuple_validator, get_double_validator LOG = logging.getLogger(__name__) class CentreOfRotationGroup(QGroupBox): """ Centre of Rotation settings """ def __init__(self): super().__init__() self.setTitle("Centre of Rotation") self.setStyleSheet("QGroupBox {color: green;}") self.auto_correlate_rButton = QRadioButton() self.auto_correlate_rButton.setText("Auto: Correlate first/last projections") self.auto_correlate_rButton.clicked.connect(self.set_rButton) self.auto_minimize_rButton = QRadioButton() self.auto_minimize_rButton.setText("Auto: Minimize STD of a slice") self.auto_minimize_rButton.setToolTip( "Reconstructed patches are saved \nin your-temporary-data-folder\\axis-search" ) self.auto_minimize_rButton.clicked.connect(self.set_rButton) self.auto_minimize_apply_pr = QCheckBox() self.auto_minimize_apply_pr.setText("Apply PR while searching") self.auto_minimize_apply_pr.stateChanged.connect(self.set_minimize_apply_pr) self.define_axis_rButton = QRadioButton() self.define_axis_rButton.setText("Define rotation axis manually") self.define_axis_rButton.clicked.connect(self.set_rButton) self.search_rotation_label = QLabel() self.search_rotation_label.setText("Search rotation axis in [start, stop, step] interval") self.search_rotation_entry = QLineEdit() self.search_rotation_entry.setValidator(get_tuple_validator()) self.search_rotation_entry.editingFinished.connect(self.set_search_rotation) self.search_in_slice_label = QLabel() self.search_in_slice_label.setText("Search in slice from row number") self.search_in_slice_entry = QLineEdit() self.search_in_slice_entry.setValidator(get_int_validator()) self.search_in_slice_entry.editingFinished.connect(self.set_search_slice) self.size_of_recon_label = QLabel() self.size_of_recon_label.setText("Size of reconstructed patch [pixel]") self.size_of_recon_entry = QLineEdit() self.size_of_recon_entry.setValidator(get_int_validator()) self.size_of_recon_entry.editingFinished.connect(self.set_size_of_reco) self.axis_col_label = QLabel() self.axis_col_label.setText("Axis is in column No [pixel]") self.axis_col_entry = QLineEdit() self.axis_col_entry.setValidator(get_double_validator()) self.axis_col_entry.editingFinished.connect(self.set_axis_col) self.inc_axis_label = QLabel() self.inc_axis_label.setText("Increment axis every reconstruction") self.inc_axis_entry = QLineEdit() self.inc_axis_entry.setValidator(get_double_validator()) self.inc_axis_entry.editingFinished.connect(self.set_axis_inc) self.image_midpoint_rButton = QRadioButton() self.image_midpoint_rButton.setText("Use image midpoint (for half-acquisition)") self.image_midpoint_rButton.clicked.connect(self.set_rButton) # TODO Used for proper spacing - should be a better way self.blank_label = QLabel(" ") self.blank_label2 = QLabel(" ") self.set_layout() def set_layout(self): layout = QGridLayout() layout.addWidget(self.auto_correlate_rButton, 0, 0) layout.addWidget(self.blank_label, 0, 1) layout.addWidget(self.blank_label2, 0, 2) layout.addWidget(self.auto_minimize_rButton, 1, 0) layout.addWidget(self.auto_minimize_apply_pr, 1, 1) layout.addWidget(self.search_rotation_label, 2, 0) layout.addWidget(self.search_rotation_entry, 2, 1, 1, 2) layout.addWidget(self.search_in_slice_label, 3, 0) layout.addWidget(self.search_in_slice_entry, 3, 1, 1, 2) layout.addWidget(self.size_of_recon_label, 4, 0) layout.addWidget(self.size_of_recon_entry, 4, 1, 1, 2) layout.addWidget(self.define_axis_rButton, 5, 0) layout.addWidget(self.axis_col_label, 6, 0) layout.addWidget(self.axis_col_entry, 6, 1, 1, 2) layout.addWidget(self.inc_axis_label, 7, 0) layout.addWidget(self.inc_axis_entry, 7, 1, 1, 2) layout.addWidget(self.image_midpoint_rButton, 8, 0) self.setLayout(layout) def load_values(self): self.set_rButton_from_params() self.search_rotation_entry.setText(str(EZVARS['COR']['search-interval']['value'])) self.search_in_slice_entry.setText(str(EZVARS['COR']['search-row']['value'])) self.size_of_recon_entry.setText(str(EZVARS['COR']['patch-size']['value'])) self.axis_col_entry.setText(str(EZVARS['COR']['user-defined-ax']['value'])) self.inc_axis_entry.setText(str(EZVARS['COR']['user-defined-dax']['value'])) def set_rButton(self): dict_entry = EZVARS['COR']['search-method'] if self.auto_correlate_rButton.isChecked(): LOG.debug("Auto Correlate") add_value_to_dict_entry(dict_entry, 1) elif self.auto_minimize_rButton.isChecked(): LOG.debug("Auto Minimize") add_value_to_dict_entry(dict_entry, 2) elif self.define_axis_rButton.isChecked(): LOG.debug("Define axis") add_value_to_dict_entry(dict_entry, 3) elif self.image_midpoint_rButton.isChecked(): LOG.debug("Use image midpoint") add_value_to_dict_entry(dict_entry, 4) def set_rButton_from_params(self): if EZVARS['COR']['search-method']['value'] == 1: self.auto_correlate_rButton.setChecked(True) self.auto_minimize_rButton.setChecked(False) self.define_axis_rButton.setChecked(False) self.image_midpoint_rButton.setChecked(False) elif EZVARS['COR']['search-method']['value'] == 2: self.auto_correlate_rButton.setChecked(False) self.auto_minimize_rButton.setChecked(True) self.define_axis_rButton.setChecked(False) self.image_midpoint_rButton.setChecked(False) elif EZVARS['COR']['search-method']['value'] == 3: self.auto_correlate_rButton.setChecked(False) self.auto_minimize_rButton.setChecked(False) self.define_axis_rButton.setChecked(True) self.image_midpoint_rButton.setChecked(False) elif EZVARS['COR']['search-method']['value'] == 4: self.auto_correlate_rButton.setChecked(False) self.auto_minimize_rButton.setChecked(False) self.define_axis_rButton.setChecked(False) self.image_midpoint_rButton.setChecked(True) def set_search_rotation(self): LOG.debug(self.search_rotation_entry.text()) dict_entry = EZVARS['COR']['search-interval'] add_value_to_dict_entry(dict_entry, str(self.search_rotation_entry.text())) self.search_rotation_entry.setText(str(dict_entry['value'])) def set_search_slice(self): LOG.debug(self.search_in_slice_entry.text()) dict_entry = EZVARS['COR']['search-row'] add_value_to_dict_entry(dict_entry, str(self.search_in_slice_entry.text())) self.search_in_slice_entry.setText(str(dict_entry['value'])) def set_size_of_reco(self): LOG.debug(self.size_of_recon_entry.text()) dict_entry = EZVARS['COR']['patch-size'] add_value_to_dict_entry(dict_entry, str(self.size_of_recon_entry.text())) self.size_of_recon_entry.setText(str(dict_entry['value'])) def set_minimize_apply_pr(self): LOG.debug("PR while min std ax search: " + str(self.auto_minimize_apply_pr.isChecked())) dict_entry = EZVARS['COR']['min-std-apply-pr'] add_value_to_dict_entry(dict_entry, self.auto_minimize_apply_pr.isChecked()) def set_axis_col(self): LOG.debug(self.axis_col_entry.text()) dict_entry = EZVARS['COR']['user-defined-ax'] add_value_to_dict_entry(dict_entry, str(self.axis_col_entry.text())) self.axis_col_entry.setText(str(dict_entry['value'])) def set_axis_inc(self): LOG.debug(self.inc_axis_entry.text()) dict_entry = EZVARS['COR']['user-defined-dax'] add_value_to_dict_entry(dict_entry, str(self.inc_axis_entry.text())) self.inc_axis_entry.setText(str(dict_entry['value'])) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414563.0 ufo-tofu-0.13.0/tofu/ez/GUI/Main/config.py0000664000175000017500000006631700000000000020371 0ustar00tomastomas00000000000000import os import logging from functools import partial from shutil import rmtree from PyQt5.QtWidgets import ( QMessageBox, QFileDialog, QCheckBox, QPushButton, QGridLayout, QLabel, QGroupBox, QLineEdit, ) from PyQt5.QtCore import QCoreApplication, QTimer, pyqtSignal, Qt from tofu.ez.main import execute_reconstruction, clean_tmp_dirs from tofu.ez.util import import_values, export_values from tofu.ez.GUI.message_dialog import warning_message from tofu.ez.params import EZVARS from tofu.ez.util import add_value_to_dict_entry LOG = logging.getLogger(__name__) class ConfigGroup(QGroupBox): """ Setup and configuration settings """ # Used to send signal to ezufo_launcher when settings are imported # https://stackoverflow.com/questions/2970312/pyqt4-qtcore-pyqtsignal-object-has-no-attribute-connect signal_update_vals_from_params = pyqtSignal() # Used to send signal when reconstruction is done signal_reco_done = pyqtSignal() def __init__(self): super().__init__() self.setTitle("Input/output and misc settings") self.setStyleSheet("QGroupBox {color: purple;}") # Select input directory self.input_dir_select = QPushButton("Select input directory (or paste abs. path)") self.input_dir_select.setStyleSheet("background-color:lightgrey; font: 12pt;") self.input_dir_entry = QLineEdit() self.input_dir_entry.editingFinished.connect(self.set_input_dir) self.input_dir_select.pressed.connect(self.select_input_dir) # Save .params checkbox self.save_params_checkbox = QCheckBox("Save args in .params file") self.save_params_checkbox.stateChanged.connect(self.set_save_args) # Select output directory self.output_dir_select = QPushButton() self.output_dir_select.setText("Select output directory (or paste abs. path)") self.output_dir_select.setStyleSheet("background-color:lightgrey; font: 12pt;") self.output_dir_entry = QLineEdit() self.output_dir_entry.editingFinished.connect(self.set_output_dir) self.output_dir_select.pressed.connect(self.select_output_dir) # Save in separate files or in one huge tiff file self.bigtiff_checkbox = QCheckBox() self.bigtiff_checkbox.setText("Save slices in multipage tiffs") self.bigtiff_checkbox.setToolTip( "Will save images in bigtiff containers. \n" "Note that some temporary data is always saved in bigtiffs.\n" "Use bio-formats importer plugin for imagej or fiji to open the bigtiffs." ) self.bigtiff_checkbox.stateChanged.connect(self.set_big_tiff) # Crop in the reconstruction plane self.preproc_checkbox = QCheckBox() self.preproc_checkbox.setText("Preprocess with a generic ufo-launch pipeline, f.i.") self.preproc_checkbox.setToolTip( "Selected ufo filters will be applied to each " "image before reconstruction begins. \n" 'To print the list of filters use "ufo-query -l" command. \n' 'Parameters of each filter can be seen with "ufo-query -p filtername".' ) self.preproc_checkbox.stateChanged.connect(self.set_preproc) self.preproc_entry = QLineEdit() self.preproc_entry.editingFinished.connect(self.set_preproc_entry) # Names of directories with flats/darks/projections frames self.dir_name_label = QLabel() self.dir_name_label.setText("Name of flats/darks/tomo subdirectories in each CT data set") self.darks_entry = QLineEdit() self.darks_entry.editingFinished.connect(self.set_darks) self.flats_entry = QLineEdit() self.flats_entry.editingFinished.connect(self.set_flats) self.tomo_entry = QLineEdit() self.tomo_entry.editingFinished.connect(self.set_tomo) self.flats2_entry = QLineEdit() self.flats2_entry.editingFinished.connect(self.set_flats2) # Select flats/darks/flats2 for use in multiple reconstructions self.use_common_flats_darks_checkbox = QCheckBox() self.use_common_flats_darks_checkbox.setText( "Use common flats/darks across multiple experiments" ) self.use_common_flats_darks_checkbox.stateChanged.connect(self.set_flats_darks_checkbox) self.select_darks_button = QPushButton("Select path to darks (or paste abs. path)") self.select_darks_button.setToolTip("Background detector noise") self.select_darks_button.clicked.connect(self.select_darks_button_pressed) self.select_flats_button = QPushButton("Select path to flats (or paste abs. path)") self.select_flats_button.setToolTip("Images without sample in the beam") self.select_flats_button.clicked.connect(self.select_flats_button_pressed) self.select_flats2_button = QPushButton("Select path to flats2 (or paste abs. path)") self.select_flats2_button.setToolTip( "If selected, it will be assumed that flats were \n" "acquired before projections while flats2 after \n" "and interpolation will be used to compute intensity of flat image \n" "for each projection between flats and flats2" ) self.select_flats2_button.clicked.connect(self.select_flats2_button_pressed) self.darks_absolute_entry = QLineEdit() self.darks_absolute_entry.setText("Absolute path to darks") self.darks_absolute_entry.editingFinished.connect(self.set_common_darks) self.flats_absolute_entry = QLineEdit() self.flats_absolute_entry.setText("Absolute path to flats") self.flats_absolute_entry.editingFinished.connect(self.set_common_flats) self.use_flats2_checkbox = QCheckBox("Use common flats2") self.use_flats2_checkbox.clicked.connect(self.set_use_flats2) self.flats2_absolute_entry = QLineEdit() self.flats2_absolute_entry.editingFinished.connect(self.set_common_flats2) self.flats2_absolute_entry.setText("Absolute path to flats2") # Select temporary directory self.temp_dir_select = QPushButton() self.temp_dir_select.setText("Select temporary directory (or paste abs. path)") self.temp_dir_select.setToolTip( "Temporary data will be saved there.\n" "note that the size of temporary data can exceed 300 GB in some cases." ) self.temp_dir_select.pressed.connect(self.select_temp_dir) self.temp_dir_select.setStyleSheet("background-color:lightgrey; font: 12pt;") self.temp_dir_entry = QLineEdit() self.temp_dir_entry.editingFinished.connect(self.set_temp_dir) # Keep temp data selection self.keep_tmp_data_checkbox = QCheckBox() self.keep_tmp_data_checkbox.setText("Keep all temp data till the end of reconstruction") self.keep_tmp_data_checkbox.setToolTip( "Useful option to inspect how images change at each step" ) self.keep_tmp_data_checkbox.stateChanged.connect(self.set_keep_tmp_data) # IMPORT SETTINGS FROM FILE self.open_settings_file = QPushButton() self.open_settings_file.setText("Import parameters from file") self.open_settings_file.setStyleSheet("background-color:lightgrey; font: 12pt;") self.open_settings_file.pressed.connect(self.import_settings_button_pressed) # EXPORT SETTINGS TO FILE self.save_settings_file = QPushButton() self.save_settings_file.setText("Export parameters to file") self.save_settings_file.setStyleSheet("background-color:lightgrey; font: 12pt;") self.save_settings_file.pressed.connect(self.export_settings_button_pressed) # QUIT self.quit_button = QPushButton() self.quit_button.setText("Quit") self.quit_button.setStyleSheet("background-color:lightgrey; font: 13pt; font-weight: bold;") self.quit_button.clicked.connect(self.quit_button_pressed) # HELP self.help_button = QPushButton() self.help_button.setText("Help") self.help_button.setStyleSheet("background-color:lightgrey; font: 13pt; font-weight: bold") self.help_button.clicked.connect(self.help_button_pressed) # DELETE self.delete_reco_dir_button = QPushButton() self.delete_reco_dir_button.setText("Delete reco dir") self.delete_reco_dir_button.setStyleSheet( "background-color:lightgrey; font: 13pt; font-weight: bold" ) self.delete_reco_dir_button.clicked.connect(self.delete_button_pressed) # DRY RUN self.dry_run_button = QPushButton() self.dry_run_button.setText("Dry run") self.dry_run_button.setStyleSheet( "background-color:lightgrey; font: 13pt; font-weight: bold" ) self.dry_run_button.clicked.connect(self.dryrun_button_pressed) # RECONSTRUCT self.reco_button = QPushButton() self.reco_button.setText("Reconstruct") self.reco_button.setStyleSheet( "background-color:lightgrey;color:royalblue; font: 14pt; font-weight: bold;" ) self.reco_button.clicked.connect(self.reco_button_pressed) # OPEN IMAGE AFTER RECONSTRUCT self.open_image_after_reco_checkbox = QCheckBox() self.open_image_after_reco_checkbox.setText( "Load images and open viewer after reconstruction" ) self.open_image_after_reco_checkbox.clicked.connect(self.set_open_image_after_reco) self.set_layout() def set_layout(self): """ Sets the layout of buttons, labels, etc. for config group """ layout = QGridLayout() checkbox_groupbox = QGroupBox() checkbox_layout = QGridLayout() checkbox_layout.addWidget(self.save_params_checkbox, 0, 0) checkbox_layout.addWidget(self.bigtiff_checkbox, 1, 0) checkbox_layout.addWidget(self.open_image_after_reco_checkbox, 2, 0) checkbox_layout.addWidget(self.keep_tmp_data_checkbox, 3, 0) checkbox_groupbox.setLayout(checkbox_layout) layout.addWidget(checkbox_groupbox, 0, 4, 4, 1) layout.addWidget(self.input_dir_select, 0, 0) layout.addWidget(self.input_dir_entry, 0, 1, 1, 3) layout.addWidget(self.output_dir_select, 1, 0) layout.addWidget(self.output_dir_entry, 1, 1, 1, 3) layout.addWidget(self.temp_dir_select, 2, 0) layout.addWidget(self.temp_dir_entry, 2, 1, 1, 3) layout.addWidget(self.preproc_checkbox, 3, 0) layout.addWidget(self.preproc_entry, 3, 1, 1, 3) fdt_groupbox = QGroupBox() fdt_layout = QGridLayout() fdt_layout.addWidget(self.dir_name_label, 0, 0) fdt_layout.addWidget(self.darks_entry, 0, 1) fdt_layout.addWidget(self.flats_entry, 0, 2) fdt_layout.addWidget(self.tomo_entry, 0, 3) fdt_layout.addWidget(self.flats2_entry, 0, 4) fdt_layout.addWidget(self.use_common_flats_darks_checkbox, 1, 0) fdt_layout.addWidget(self.select_darks_button, 1, 1) fdt_layout.addWidget(self.select_flats_button, 1, 2) fdt_layout.addWidget(self.select_flats2_button, 1, 4) fdt_layout.addWidget(self.darks_absolute_entry, 2, 1) fdt_layout.addWidget(self.flats_absolute_entry, 2, 2) fdt_layout.addWidget(self.use_flats2_checkbox, 2, 3, Qt.AlignRight) fdt_layout.addWidget(self.flats2_absolute_entry, 2, 4) fdt_groupbox.setLayout(fdt_layout) layout.addWidget(fdt_groupbox, 4, 0, 1, 5) layout.addWidget(self.open_settings_file, 5, 0, 1, 3) layout.addWidget(self.save_settings_file, 5, 3, 1, 2) layout.addWidget(self.quit_button, 6, 0) layout.addWidget(self.help_button, 6, 1) layout.addWidget(self.delete_reco_dir_button, 6, 2) layout.addWidget(self.dry_run_button, 6, 3) layout.addWidget(self.reco_button, 6, 4) self.setLayout(layout) def load_values(self): """ Updates displayed values for config group """ self.input_dir_entry.setText(EZVARS['inout']['input-dir']['value']) self.save_params_checkbox.setChecked(EZVARS['inout']['save-params']['value']) self.output_dir_entry.setText(EZVARS['inout']['output-dir']['value']) self.bigtiff_checkbox.setChecked(EZVARS['inout']['bigtiff-output']['value']) self.preproc_checkbox.setChecked(EZVARS['inout']['preprocess']['value']) self.preproc_entry.setText(EZVARS['inout']['preprocess-command']['value']) self.darks_entry.setText(EZVARS['inout']['darks-dir']['value']) self.flats_entry.setText(EZVARS['inout']['flats-dir']['value']) self.tomo_entry.setText(EZVARS['inout']['tomo-dir']['value']) self.flats2_entry.setText(EZVARS['inout']['flats2-dir']['value']) self.temp_dir_entry.setText(EZVARS['inout']['tmp-dir']['value']) self.keep_tmp_data_checkbox.setChecked(EZVARS['inout']['keep-tmp']['value']) self.dry_run_button.setChecked(EZVARS['inout']['dryrun']['value']) self.open_image_after_reco_checkbox.setChecked(EZVARS['inout']['open-viewer']['value']) self.use_common_flats_darks_checkbox.setChecked(EZVARS['inout']['shared-flatsdarks']['value']) self.darks_absolute_entry.setText(EZVARS['inout']['path2-shared-darks']['value']) self.flats_absolute_entry.setText(EZVARS['inout']['path2-shared-flats']['value']) self.use_flats2_checkbox.setChecked(EZVARS['inout']['shared-flats-after']['value']) self.flats2_absolute_entry.setText(EZVARS['inout']['path2-shared-flats2']['value']) def select_input_dir(self): """ Saves directory specified by user in file-dialog for input tomographic data """ dir_explore = QFileDialog(self) dir = dir_explore.getExistingDirectory(directory=self.input_dir_entry.text()) if dir: self.input_dir_entry.setText(dir) self.set_input_dir() def set_input_dir(self): LOG.debug(str(self.input_dir_entry.text())) dict_entry = EZVARS['inout']['input-dir'] dir = self.input_dir_entry.text().strip() add_value_to_dict_entry(dict_entry, dir) self.input_dir_entry.setText(dir) def select_output_dir(self): dir_explore = QFileDialog(self) dir = dir_explore.getExistingDirectory(directory=self.output_dir_entry.text()) if dir: self.output_dir_entry.setText(dir) self.set_output_dir() def set_output_dir(self): LOG.debug(str(self.output_dir_entry.text())) dict_entry = EZVARS['inout']['output-dir'] dir = self.output_dir_entry.text().strip() add_value_to_dict_entry(dict_entry, dir) self.output_dir_entry.setText(dir) def set_big_tiff(self): LOG.debug("Bigtiff: " + str(self.bigtiff_checkbox.isChecked())) dict_entry = EZVARS['inout']['bigtiff-output'] add_value_to_dict_entry(dict_entry, self.bigtiff_checkbox.isChecked()) def set_preproc(self): LOG.debug("Preproc: " + str(self.preproc_checkbox.isChecked())) dict_entry = EZVARS['inout']['preprocess'] add_value_to_dict_entry(dict_entry, self.preproc_checkbox.isChecked()) def set_preproc_entry(self): LOG.debug(self.preproc_entry.text()) dict_entry = EZVARS['inout']['preprocess-command'] text = self.preproc_entry.text().strip() add_value_to_dict_entry(dict_entry, text) self.preproc_entry.setText(text) def set_open_image_after_reco(self): LOG.debug( "Switch to Image Viewer After Reco: " + str(self.open_image_after_reco_checkbox.isChecked()) ) dict_entry = EZVARS['inout']['open-viewer'] add_value_to_dict_entry(dict_entry, self.open_image_after_reco_checkbox.isChecked()) def set_darks(self): LOG.debug(self.darks_entry.text()) dict_entry = EZVARS['inout']['darks-dir'] dir = self.darks_entry.text().strip() add_value_to_dict_entry(dict_entry, dir) self.darks_entry.setText(dir) def set_flats(self): LOG.debug(self.flats_entry.text()) dict_entry = EZVARS['inout']['flats-dir'] dir = self.flats_entry.text().strip() add_value_to_dict_entry(dict_entry, dir) self.flats_entry.setText(dir) def set_tomo(self): LOG.debug(self.tomo_entry.text()) dict_entry = EZVARS['inout']['tomo-dir'] dir = self.tomo_entry.text().strip() add_value_to_dict_entry(dict_entry, dir) self.tomo_entry.setText(dir) def set_flats2(self): LOG.debug(self.flats2_entry.text()) dict_entry = EZVARS['inout']['flats2-dir'] dir = self.flats2_entry.text().strip() add_value_to_dict_entry(dict_entry, dir) self.flats2_entry.setText(dir) def set_fdt_names(self): self.set_darks() self.set_flats() self.set_flats2() self.set_tomo() def set_flats_darks_checkbox(self): LOG.debug( "Use same flats/darks across multiple experiments: " + str(self.use_common_flats_darks_checkbox.isChecked()) ) dict_entry = EZVARS['inout']['shared-flatsdarks'] add_value_to_dict_entry(dict_entry, self.use_common_flats_darks_checkbox.isChecked()) def select_darks_button_pressed(self): LOG.debug("Select path to darks pressed") dir_explore = QFileDialog(self) directory = dir_explore.getExistingDirectory(directory=EZVARS['inout']['input-dir']['value']) if directory: self.darks_absolute_entry.setText(directory) self.set_common_darks() def select_flats_button_pressed(self): LOG.debug("Select path to flats pressed") dir_explore = QFileDialog(self) directory = dir_explore.getExistingDirectory(directory=EZVARS['inout']['input-dir']['value']) if directory: self.flats_absolute_entry.setText(directory) self.set_common_flats() def select_flats2_button_pressed(self): LOG.debug("Select path to flats2 pressed") dir_explore = QFileDialog(self) directory = dir_explore.getExistingDirectory(directory=EZVARS['inout']['input-dir']['value']) if directory: self.flats2_absolute_entry.setText(directory) self.set_common_flats2() def set_common_darks(self): LOG.debug("Common darks path: " + str(self.darks_absolute_entry.text())) dict_entry = EZVARS['inout']['path2-shared-darks'] text = self.darks_absolute_entry.text().strip() add_value_to_dict_entry(dict_entry, text) self.darks_absolute_entry.setText(text) def set_common_flats(self): LOG.debug("Common flats path: " + str(self.flats_absolute_entry.text())) dict_entry = EZVARS['inout']['path2-shared-flats'] text = self.flats_absolute_entry.text().strip() add_value_to_dict_entry(dict_entry, text) self.flats_absolute_entry.setText(text) def set_use_flats2(self): LOG.debug("Use common flats2 checkbox: " + str(self.use_flats2_checkbox.isChecked())) dict_entry = EZVARS['inout']['shared-flats-after'] text = self.use_flats2_checkbox.text().strip() add_value_to_dict_entry(dict_entry, text) self.use_flats2_checkbox.setText(text) def set_common_flats2(self): LOG.debug("Common flats2 path: " + str(self.flats2_absolute_entry.text())) dict_entry = EZVARS['inout']['path2-shared-flats2'] text = self.flats2_absolute_entry.text().strip() add_value_to_dict_entry(dict_entry, text) self.flats2_absolute_entry.setText(text) def select_temp_dir(self): dir_explore = QFileDialog(self) tmp_dir = dir_explore.getExistingDirectory(directory=self.temp_dir_entry.text()) if tmp_dir: self.temp_dir_entry.setText(tmp_dir) self.set_temp_dir() def set_temp_dir(self): LOG.debug(str(self.temp_dir_entry.text())) dict_entry = EZVARS['inout']['tmp-dir'] text = self.temp_dir_entry.text().strip() add_value_to_dict_entry(dict_entry, text) self.temp_dir_entry.setText(text) def set_keep_tmp_data(self): LOG.debug("Keep tmp: " + str(self.keep_tmp_data_checkbox.isChecked())) dict_entry = EZVARS['inout']['keep-tmp'] add_value_to_dict_entry(dict_entry, self.keep_tmp_data_checkbox.isChecked()) def quit_button_pressed(self): """ Displays confirmation dialog and cleans temporary directories """ LOG.debug("QUIT") reply = QMessageBox.question( self, "Quit", "Are you sure you want to quit?", QMessageBox.Yes | QMessageBox.No, QMessageBox.No, ) if reply == QMessageBox.Yes: # remove all directories with projections clean_tmp_dirs(EZVARS['inout']['tmp-dir']['value'], self.get_fdt_names()) # remove axis-search dir too tmp = os.path.join(EZVARS['inout']['tmp-dir']['value'], 'axis-search') QCoreApplication.instance().quit() else: pass def help_button_pressed(self): """ Displays pop-up help information """ LOG.debug("HELP") h = "This utility provides an interface to the ufo-kit software package.\n" h += "Use it for batch processing and optimization of reconstruction parameters.\n" h += "It creates a list of paths to all CT directories in the _input_ directory.\n" h += "A CT directory is defined as directory with at least \n" h += "_flats_, _darks_, _tomo_, and, optionally, _flats2_ subdirectories, \n" h += "which are not empty and contain only *.tif files. Names of CT\n" h += "directories are compared with the directory tree in the _output_ directory.\n" h += ( "(Note: relative directory tree in _input_ is preserved when writing results to the" " _output_.)\n" ) h += ( "Those CT sets will be reconstructed, whose names are not yet in the _output_" " directory." ) h += "Program will create an array of ufo/tofu commands according to defined parameters \n" h += ( "and then execute them sequentially. These commands can be also printed on the" " screen.\n" ) h += "Note2: if you bin in preprocess the center of rotation will change a lot; \n" h += 'Note4: set to "flats" if "flats2" exist but you need to ignore them; \n' h += ( "Created by Sergei Gasilov, BMIT CLS, Dec. 2018.\n Extended by Iain Emslie, Summer" " 2021." ) QMessageBox.information(self, "Help", h) def delete_button_pressed(self): """ Deletes the directory that contains reconstructed data """ LOG.debug("DELETE") msg = "Delete directory with reconstructed data?" dialog = QMessageBox.warning( self, "Warning: data can be lost", msg, QMessageBox.Yes | QMessageBox.No ) if dialog == QMessageBox.Yes: if os.path.exists(str(EZVARS['inout']['output-dir']['value'])): LOG.debug("YES") if EZVARS['inout']['output-dir']['value'] == EZVARS['inout']['input-dir']['value']: LOG.debug("Cannot delete: output directory is the same as input") else: try: rmtree(EZVARS['inout']['output-dir']['value']) except: warning_message('Error while deleting directory') LOG.debug("Directory with reconstructed data was removed") else: LOG.debug("Directory does not exist") else: LOG.debug("NO") def dryrun_button_pressed(self): """ Sets the dry-run parameter for Tofu to True and calls reconstruction """ LOG.debug("DRY") EZVARS['inout']['dryrun']['value'] = str(True) self.reco_button_pressed() def set_save_args(self): LOG.debug("Save args: " + str(self.save_params_checkbox.isChecked())) EZVARS['inout']['save-params']['value'] = bool(self.save_params_checkbox.isChecked()) def export_settings_button_pressed(self): """ Saves currently displayed GUI settings to an external .yaml file specified by user """ LOG.debug("Save settings pressed") options = QFileDialog.Options() fileName, _ = QFileDialog.getSaveFileName( self, "QFileDialog.getSaveFileName()", "", "YAML Files (*.yaml);; All Files (*)", options=options, ) if fileName: LOG.debug("Export YAML Path: " + fileName) file_extension = os.path.splitext(fileName) if file_extension[-1] == "": fileName = fileName + ".yaml" # Create and write to YAML file based on given fileName # self.yaml_io.write_yaml(fileName, parameters.params) export_values(fileName) def import_settings_button_pressed(self): """ Loads external settings from .yaml file specified by user Signal is sent to enable updating of displayed GUI values """ LOG.debug("Import settings pressed") options = QFileDialog.Options() filePath, _ = QFileDialog.getOpenFileName( self, "QFileDialog.getOpenFileName()", "", "YAML Files (*.yaml);; All Files (*)", options=options, ) if filePath: LOG.debug("Import YAML Path: " + filePath) import_values(filePath) self.signal_update_vals_from_params.emit() def reco_button_pressed(self): """ Gets the settings set by the user in the GUI These are then passed to execute_reconstruction """ #LOG.debug("RECO") self.set_fdt_names() self.set_common_darks() self.set_common_flats() self.set_common_flats2() self.set_big_tiff() self.set_input_dir() self.set_output_dir() self.set_temp_dir() self.set_preproc() self.set_preproc_entry() run_reco = partial(self.run_reconstruction, batch_run=False) #I had to add a little sleep as on some Linux ditributions params won't fully set before the main() begins QTimer.singleShot(100, run_reco) def run_reconstruction(self, batch_run): try: execute_reconstruction(self.get_fdt_names()) if batch_run is False: msg = "Done. See output in terminal for details." QMessageBox.information(self, "Finished", msg) if not EZVARS['inout']['dryrun']['value']: self.signal_reco_done.emit() EZVARS['inout']['dryrun']['value'] = bool(False) except InvalidInputError as err: msg = "" err_arg = err.args msg += err.args[0] QMessageBox.information(self, "Invalid Input Error", msg) def get_fdt_names(self): return [EZVARS['inout']['darks-dir']['value'], EZVARS['inout']['flats-dir']['value'], EZVARS['inout']['tomo-dir']['value'], EZVARS['inout']['flats2-dir']['value']] class InvalidInputError(Exception): """ Error to be raised when input values from GUI are out of range or invalid """ ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414563.0 ufo-tofu-0.13.0/tofu/ez/GUI/Main/filters.py0000664000175000017500000003123600000000000020564 0ustar00tomastomas00000000000000import logging from PyQt5.QtWidgets import ( QButtonGroup, QGridLayout, QLabel, QRadioButton, QCheckBox, QGroupBox, QLineEdit, ) from PyQt5.QtCore import Qt from tofu.ez.params import EZVARS from tofu.config import SECTIONS from tofu.ez.util import add_value_to_dict_entry, get_int_validator, get_double_validator LOG = logging.getLogger(__name__) class FiltersGroup(QGroupBox): """ Filter settings """ def __init__(self): super().__init__() self.setTitle("Filters") self.setStyleSheet("QGroupBox {color: orange;}") self.remove_spots_checkBox = QCheckBox() self.remove_spots_checkBox.setText("Remove large spots from projections") self.remove_spots_checkBox.setToolTip( "Efficiently suppresses very intense rings \n stemming from defects in scintillator" ) self.remove_spots_checkBox.stateChanged.connect(self.set_remove_spots) self.threshold_label = QLabel() self.threshold_label.setText("Threshold (prominence of the spot) [counts]") self.threshold_label.setToolTip( "Outliers which will be considered as the part of the large spot" ) self.threshold_entry = QLineEdit() self.threshold_entry.setValidator(get_double_validator()) self.threshold_entry.editingFinished.connect(self.set_threshold) self.spot_blur_label = QLabel() self.spot_blur_label.setText("Low-pass filter sigma [pixels]") # self.spot_blur_label.setToolTip( # "Regulates extent of the masked region around the detected outlier" # ) self.spot_blur_label.setToolTip('Low pass filter will be applied before spots are identified' 'to remove very low-frequency changes in the flat field') self.spot_blur_entry = QLineEdit() self.spot_blur_entry.setValidator(get_double_validator()) self.spot_blur_entry.editingFinished.connect(self.set_spot_blur) self.enable_RR_checkbox = QCheckBox() self.enable_RR_checkbox.setText("Enable ring removal") self.remove_spots_checkBox.setToolTip( "To suppress ring artifacts" " stemming from intensity fluctuations and detector non-linearities" ) self.enable_RR_checkbox.stateChanged.connect(self.set_ring_removal) self.use_LPF_rButton = QRadioButton() self.use_LPF_rButton.setText("Use ufo Fourier-transform based filter") self.use_LPF_rButton.clicked.connect(self.select_rButton) self.use_LPF_rButton.setToolTip( "To suppress ring artifacts" " stemming from intensity fluctuations and detector non-linearities" ) self.sarepy_rButton = QRadioButton() self.sarepy_rButton.setText("Use sarepy sorting: ") self.sarepy_rButton.clicked.connect(self.select_rButton) self.sarepy_rButton.setToolTip( "Non-FFT based algorithms from \n /Nghia T. Vo et al, Opt. Express 26, 28396 (2018)" ) self.filter_rButton_group = QButtonGroup(self) self.filter_rButton_group.addButton(self.use_LPF_rButton) self.filter_rButton_group.addButton(self.sarepy_rButton) self.one_dimens_rButton = QRadioButton() self.one_dimens_rButton.setText("1D") self.one_dimens_rButton.clicked.connect(self.select_dimens_rButton) self.one_dimens_rButton.setToolTip("Only low-pass filter along the lines of sinogram") self.two_dimens_rButton = QRadioButton() self.two_dimens_rButton.setText("2D") self.two_dimens_rButton.clicked.connect(self.select_dimens_rButton) self.two_dimens_rButton.setToolTip( "Low-pass filter along the lines and high-pass filter along the columns" ) self.dimens_rButton_group = QButtonGroup(self) self.dimens_rButton_group.addButton(self.one_dimens_rButton) self.dimens_rButton_group.addButton(self.two_dimens_rButton) self.sigma_horizontal_label = QLabel() self.sigma_horizontal_label.setText("sigma horizontal") self.sigma_horizontal_label.setToolTip( "Width [pixels] of Gaussian-shaped low-pass filter in frequency domain" ) self.sigma_horizontal_entry = QLineEdit() self.sigma_horizontal_entry.setValidator(get_int_validator()) self.sigma_horizontal_entry.editingFinished.connect(self.set_sigma_horizontal) self.sigma_vertical_label = QLabel() self.sigma_vertical_label.setText("sigma vertical") self.sigma_vertical_label.setToolTip( "Width [pixels] of Gaussian-shaped high-pass filter in frequency domain" ) self.sigma_vertical_entry = QLineEdit() self.sigma_vertical_entry.setValidator(get_int_validator()) self.sigma_vertical_entry.editingFinished.connect(self.set_sigma_vertical) self.wind_size_label = QLabel() self.wind_size_label.setText("window size") self.wind_size_label.setToolTip("Window size in remove_stripe_based_sorting algorithm") self.wind_size_entry = QLineEdit() self.wind_size_entry.setValidator(get_int_validator()) self.wind_size_entry.editingFinished.connect(self.set_window_size) self.wind_size_entry.setToolTip("Typically in the range 31..51 ") self.remove_wide_checkbox = QCheckBox() self.remove_wide_checkbox.setText("Remove wide") self.remove_wide_checkbox.setToolTip("Window size in remove_large_stripe algorithm") self.remove_wide_checkbox.stateChanged.connect(self.set_remove_wide) self.remove_wide_label = QLabel() self.remove_wide_label.setText("window") self.remove_wide_label.setToolTip("Typically in the range 51..131 ") self.remove_wide_entry = QLineEdit() self.remove_wide_entry.setValidator(get_int_validator()) self.remove_wide_entry.editingFinished.connect(self.set_wind) self.SNR_label = QLabel() self.SNR_label.setText("SNR") self.SNR_label.setToolTip("SNR param in remove_large_stripe algorithm") self.SNR_entry = QLineEdit() self.SNR_entry.setValidator(get_int_validator()) self.SNR_entry.editingFinished.connect(self.set_SNR) self.set_layout() def set_layout(self): layout = QGridLayout() remove_spots_groupbox = QGroupBox() remove_spots_layout = QGridLayout() remove_spots_layout.addWidget(self.remove_spots_checkBox, 0, 0) remove_spots_layout.addWidget(self.threshold_label, 1, 0) remove_spots_layout.addWidget(self.threshold_entry, 1, 1, 1, 7) remove_spots_layout.addWidget(self.spot_blur_label, 2, 0) remove_spots_layout.addWidget(self.spot_blur_entry, 2, 1, 1, 7) remove_spots_groupbox.setLayout(remove_spots_layout) layout.addWidget(remove_spots_groupbox) rr_groupbox = QGroupBox() rr_layout = QGridLayout() rr_layout.addWidget(self.enable_RR_checkbox, 3, 0) rr_layout.addWidget(self.use_LPF_rButton, 4, 0) rr_layout.addWidget(self.one_dimens_rButton, 4, 1) rr_layout.addWidget(self.two_dimens_rButton, 4, 2) rr_layout.addWidget(self.sigma_horizontal_label, 4, 3, Qt.AlignRight) rr_layout.addWidget(self.sigma_horizontal_entry, 4, 4) rr_layout.addWidget(self.sigma_vertical_label, 4, 5, Qt.AlignRight) rr_layout.addWidget(self.sigma_vertical_entry, 4, 6) rr_layout.addWidget(self.sarepy_rButton, 5, 0) rr_layout.addWidget(self.wind_size_label, 5, 1) rr_layout.addWidget(self.wind_size_entry, 5, 2) rr_layout.addWidget(self.remove_wide_checkbox, 5, 3) rr_layout.addWidget(self.remove_wide_label, 5, 4, Qt.AlignRight) rr_layout.addWidget(self.remove_wide_entry, 5, 5) rr_layout.addWidget(self.SNR_label, 5, 6) rr_layout.addWidget(self.SNR_entry, 5, 7) rr_groupbox.setLayout(rr_layout) layout.addWidget(rr_groupbox, 3, 0) self.setLayout(layout) def load_values(self): self.remove_spots_checkBox.setChecked(EZVARS['filters']['rm_spots']['value']) self.threshold_entry.setText(str(SECTIONS['find-large-spots']['spot-threshold']['value'])) self.spot_blur_entry.setText(str(SECTIONS['find-large-spots']['gauss-sigma']['value'])) self.enable_RR_checkbox.setChecked(EZVARS['RR']['enable-RR']['value']) if EZVARS['RR']['use-ufo']['value'] == True: self.use_LPF_rButton.setChecked(True) elif EZVARS['RR']['use-ufo']['value'] == False: self.use_LPF_rButton.setChecked(False) if EZVARS['RR']['ufo-2d']['value'] == True: self.one_dimens_rButton.setChecked(True) self.two_dimens_rButton.setChecked(False) elif EZVARS['RR']['ufo-2d']['value'] == False: self.one_dimens_rButton.setChecked(False) self.two_dimens_rButton.setChecked(True) self.sigma_horizontal_entry.setText(str(EZVARS['RR']['sx']['value'])) self.sigma_vertical_entry.setText(str(EZVARS['RR']['sy']['value'])) self.wind_size_entry.setText(str(EZVARS['RR']['spy-narrow-window']['value'])) self.remove_wide_checkbox.setChecked(EZVARS['RR']['spy-rm-wide']['value']) self.remove_wide_entry.setText(str(EZVARS['RR']['spy-wide-window']['value'])) self.SNR_entry.setText(str(EZVARS['RR']['spy-wide-SNR']['value'])) def set_remove_spots(self): LOG.debug("Remove large spots:" + str(self.remove_spots_checkBox.isChecked())) dict_entry = EZVARS['filters']['rm_spots'] add_value_to_dict_entry(dict_entry, self.remove_spots_checkBox.isChecked()) def set_threshold(self): LOG.debug(self.threshold_entry.text()) dict_entry = SECTIONS['find-large-spots']['spot-threshold'] add_value_to_dict_entry(dict_entry, self.threshold_entry.text()) self.threshold_entry.setText(str(dict_entry['value'])) def set_spot_blur(self): LOG.debug(self.spot_blur_entry.text()) dict_entry = SECTIONS['find-large-spots']['gauss-sigma'] add_value_to_dict_entry(dict_entry, self.spot_blur_entry.text()) self.spot_blur_entry.setText(str(dict_entry['value'])) def set_ring_removal(self): LOG.debug("RR: " + str(self.enable_RR_checkbox.isChecked())) dict_entry = EZVARS['RR']['enable-RR'] add_value_to_dict_entry(dict_entry, self.enable_RR_checkbox.isChecked()) def select_rButton(self): dict_entry = EZVARS['RR']['use-ufo'] if self.use_LPF_rButton.isChecked(): LOG.debug("Use LPF") add_value_to_dict_entry(dict_entry, True) elif self.sarepy_rButton.isChecked(): LOG.debug("Use Sarepy") add_value_to_dict_entry(dict_entry, False) def select_dimens_rButton(self): dict_entry = EZVARS['RR']['ufo-2d'] if self.one_dimens_rButton.isChecked(): LOG.debug("One dimension") add_value_to_dict_entry(dict_entry, True) elif self.two_dimens_rButton.isChecked(): LOG.debug("Two dimensions") add_value_to_dict_entry(dict_entry, False) def set_sigma_horizontal(self): LOG.debug(self.sigma_horizontal_entry.text()) dict_entry = EZVARS['RR']['sx'] add_value_to_dict_entry(dict_entry, self.sigma_horizontal_entry.text()) self.sigma_horizontal_entry.setText(str(dict_entry['value'])) def set_sigma_vertical(self): LOG.debug(self.sigma_vertical_entry.text()) dict_entry = EZVARS['RR']['sy'] add_value_to_dict_entry(dict_entry, self.sigma_vertical_entry.text()) self.sigma_vertical_entry.setText(str(dict_entry['value'])) def set_ufoRR_params_for_360_axis_search(self): self.set_sigma_vertical() self.set_sigma_horizontal() def set_window_size(self): LOG.debug(self.wind_size_entry.text()) dict_entry = EZVARS['RR']['spy-narrow-window'] add_value_to_dict_entry(dict_entry, self.wind_size_entry.text()) self.wind_size_entry.setText(str(dict_entry['value'])) def set_remove_wide(self): LOG.debug("Wide: " + str(self.remove_wide_checkbox.isChecked())) dict_entry = EZVARS['RR']['spy-rm-wide'] add_value_to_dict_entry(dict_entry, self.remove_wide_checkbox.text()) def set_wind(self): LOG.debug(self.remove_wide_entry.text()) dict_entry = EZVARS['RR']['spy-wide-window'] add_value_to_dict_entry(dict_entry, self.remove_wide_entry.text()) self.remove_wide_entry.setText(str(dict_entry['value'])) def set_SNR(self): LOG.debug(self.SNR_entry.text()) dict_entry = EZVARS['RR']['spy-wide-SNR'] add_value_to_dict_entry(dict_entry, self.SNR_entry.text()) self.SNR_entry.setText(str(dict_entry['value'])) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414563.0 ufo-tofu-0.13.0/tofu/ez/GUI/Main/phase_retrieval.py0000664000175000017500000001203600000000000022266 0ustar00tomastomas00000000000000import logging import math from PyQt5.QtWidgets import QGridLayout, QLabel, QGroupBox, QLineEdit, QCheckBox from tofu.config import SECTIONS from tofu.ez.params import EZVARS from tofu.ez.util import add_value_to_dict_entry, reverse_tupleize, get_double_validator, get_tuple_validator LOG = logging.getLogger(__name__) class PhaseRetrievalGroup(QGroupBox): """ Phase Retrieval settings """ def __init__(self): super().__init__() self.setTitle("Phase Retrieval") self.setStyleSheet("QGroupBox {color: blue;}") self.enable_PR_checkBox = QCheckBox() self.enable_PR_checkBox.setText("Enable Paganin/TIE phase retrieval") self.enable_PR_checkBox.stateChanged.connect(self.set_PR) self.photon_energy_label = QLabel() self.photon_energy_label.setText("Photon energy [keV]") self.photon_energy_entry = QLineEdit() self.photon_energy_entry.setValidator(get_double_validator()) self.photon_energy_entry.editingFinished.connect(self.set_photon_energy) self.pixel_size_label = QLabel() self.pixel_size_label.setText("Pixel size [micron]") self.pixel_size_entry = QLineEdit() self.pixel_size_entry.setValidator(get_double_validator()) self.pixel_size_entry.editingFinished.connect(self.set_pixel_size) self.detector_distance_label = QLabel() self.detector_distance_label.setText("Sample-detector distance [m]") self.detector_distance_entry = QLineEdit() self.detector_distance_entry.setValidator(get_tuple_validator()) self.detector_distance_entry.editingFinished.connect(self.set_detector_distance) self.delta_beta_ratio_label = QLabel() self.delta_beta_ratio_label.setText("Delta/beta ratio: (try default if unsure)") self.delta_beta_ratio_entry = QLineEdit() self.delta_beta_ratio_entry.setValidator(get_double_validator()) self.delta_beta_ratio_entry.editingFinished.connect(self.set_delta_beta) self.set_layout() def set_layout(self): layout = QGridLayout() layout.addWidget(self.enable_PR_checkBox, 0, 0) layout.addWidget(self.photon_energy_label, 1, 0) layout.addWidget(self.photon_energy_entry, 1, 1) layout.addWidget(self.pixel_size_label, 2, 0) layout.addWidget(self.pixel_size_entry, 2, 1) layout.addWidget(self.detector_distance_label, 3, 0) layout.addWidget(self.detector_distance_entry, 3, 1) layout.addWidget(self.delta_beta_ratio_label, 4, 0) layout.addWidget(self.delta_beta_ratio_entry, 4, 1) self.setLayout(layout) def load_values(self): self.enable_PR_checkBox.setChecked(EZVARS['retrieve-phase']['apply-pr']['value']) self.photon_energy_entry.setText(str(SECTIONS['retrieve-phase']['energy']['value'])) self.pixel_size_entry.setText(str( round(self.meters_to_microns(SECTIONS['retrieve-phase']['pixel-size']['value']),6))) self.detector_distance_entry.setText(str(reverse_tupleize()(SECTIONS['retrieve-phase']['propagation-distance']['value']))) self.delta_beta_ratio_entry.setText(str( round(self.regularization_rate_to_delta_beta_ratio(SECTIONS['retrieve-phase']['regularization-rate']['value']),6))) def set_PR(self): LOG.debug("PR: " + str(self.enable_PR_checkBox.isChecked())) dict_entry = EZVARS['retrieve-phase']['apply-pr'] add_value_to_dict_entry(dict_entry, self.enable_PR_checkBox.isChecked()) def set_photon_energy(self): LOG.debug(self.photon_energy_entry.text()) dict_entry = SECTIONS['retrieve-phase']['energy'] add_value_to_dict_entry(dict_entry, str(self.photon_energy_entry.text())) self.photon_energy_entry.setText(str(dict_entry['value'])) def set_pixel_size(self): LOG.debug(self.pixel_size_entry.text()) dict_entry = SECTIONS['retrieve-phase']['pixel-size'] add_value_to_dict_entry(dict_entry, self.microns_to_meters(float(self.pixel_size_entry.text()))) def set_detector_distance(self): LOG.debug(self.detector_distance_entry.text()) dict_entry = SECTIONS['retrieve-phase']['propagation-distance'] add_value_to_dict_entry(dict_entry, str(self.detector_distance_entry.text())) self.detector_distance_entry.setText(str(reverse_tupleize()(dict_entry['value']))) def set_delta_beta(self): LOG.debug(self.delta_beta_ratio_entry.text()) dict_entry = SECTIONS['retrieve-phase']['regularization-rate'] add_value_to_dict_entry(dict_entry, self.delta_beta_ratio_to_regularization_rate(float(self.delta_beta_ratio_entry.text()))) def meters_to_microns(self,value)->float: return value * 1e6 def microns_to_meters(self,value)->float: return value * 1e-6 def delta_beta_ratio_to_regularization_rate(self,value:float)->float: return math.log10(value) def regularization_rate_to_delta_beta_ratio(self,value)->float: return 10**value././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414563.0 ufo-tofu-0.13.0/tofu/ez/GUI/Main/region_and_histogram.py0000664000175000017500000002674200000000000023304 0ustar00tomastomas00000000000000import logging from PyQt5.QtWidgets import QGridLayout, QRadioButton, QLabel, QGroupBox, QLineEdit, QCheckBox from PyQt5.QtCore import Qt from tofu.ez.params import EZVARS from tofu.config import SECTIONS from tofu.ez.util import add_value_to_dict_entry, get_int_validator, get_double_validator, reverse_tupleize LOG = logging.getLogger(__name__) class ROIandHistGroup(QGroupBox): """ Binning settings """ def __init__(self): super().__init__() self.setTitle("Region of Interest and Histogram Settings") self.setStyleSheet("QGroupBox {color: red;}") self.select_rows_checkbox = QCheckBox() self.select_rows_checkbox.setText("Select rows which will be reconstructed") self.select_rows_checkbox.stateChanged.connect(self.set_select_rows) self.first_row_label = QLabel() self.first_row_label.setText("First row in projections") self.first_row_label.setToolTip("Counting from the top") self.first_row_entry = QLineEdit() self.first_row_entry.setValidator(get_int_validator()) self.first_row_entry.editingFinished.connect(self.set_first_row) self.num_rows_label = QLabel() self.num_rows_label.setText("Number of rows (ROI height)") self.num_rows_entry = QLineEdit() self.num_rows_entry.setValidator(get_int_validator()) self.num_rows_entry.editingFinished.connect(self.set_num_rows) self.nth_row_label = QLabel() self.nth_row_label.setText("Step (reconstruct every Nth row)") self.nth_row_entry = QLineEdit() self.nth_row_entry.setValidator(get_int_validator()) self.nth_row_entry.editingFinished.connect(self.set_reco_nth_rows) self.clip_histo_checkbox = QCheckBox() self.clip_histo_checkbox.setText("Clip histogram and save slices in") self.clip_histo_checkbox.stateChanged.connect(self.set_clip_histo) self.eight_bit_rButton = QRadioButton() self.eight_bit_rButton.setText("8-bit") self.eight_bit_rButton.setChecked(True) self.eight_bit_rButton.clicked.connect(self.set_bitdepth) self.sixteen_bit_rButton = QRadioButton() self.sixteen_bit_rButton.setText("16-bit") self.sixteen_bit_rButton.clicked.connect(self.set_bitdepth) self.min_val_label = QLabel() self.min_val_label.setText("Min value in 32-bit histogram") self.min_val_entry = QLineEdit() #self.min_val_entry.setValidator(get_double_validator()) self.min_val_entry.editingFinished.connect(self.set_min_val) self.max_val_label = QLabel() self.max_val_label.setText("Max value in 32-bit histogram") self.max_val_entry = QLineEdit() #self.max_val_entry.setValidator(get_double_validator()) self.max_val_entry.editingFinished.connect(self.set_max_val) self.crop_slices_checkbox = QCheckBox() self.crop_slices_checkbox.setText("Crop slices") self.crop_slices_checkbox.setToolTip("Crop slices in the reconstruction plane \n" "(x,y) - top left corner of selection \n" "(width, height) - size of selection") self.crop_slices_checkbox.stateChanged.connect(self.set_crop_slices) self.x_val_label = QLabel() self.x_val_label.setText("x") self.x_val_label.setToolTip("First column (counting from left)") self.x_val_entry = QLineEdit() self.x_val_entry.setValidator(get_int_validator()) self.x_val_entry.editingFinished.connect(self.set_x) self.width_val_label = QLabel() self.width_val_label.setText("width") self.width_val_entry = QLineEdit() self.width_val_entry.setValidator(get_int_validator()) self.width_val_entry.editingFinished.connect(self.set_width) self.y_val_label = QLabel() self.y_val_label.setText("y") self.y_val_label.setToolTip("First row (counting from top)") self.y_val_entry = QLineEdit() self.y_val_entry.setValidator(get_int_validator()) self.y_val_entry.editingFinished.connect(self.set_y) self.height_val_label = QLabel() self.height_val_label.setText("height") self.height_val_entry = QLineEdit() self.height_val_entry.setValidator(get_int_validator()) self.height_val_entry.editingFinished.connect(self.set_height) self.rotate_vol_label = QLabel() self.rotate_vol_label.setText("Rotate volume counterclockwise by [deg]") self.rotate_vol_entry = QLineEdit() self.rotate_vol_entry.setValidator(get_double_validator()) self.rotate_vol_entry.editingFinished.connect(self.set_rotate_volume) # self.setStyleSheet('background-color:Azure') self.set_layout() def set_layout(self): """ Sets the layout of buttons, labels, etc. for binning group """ layout = QGridLayout() layout.addWidget(self.select_rows_checkbox, 0, 0) layout.addWidget(self.first_row_label, 1, 0) layout.addWidget(self.first_row_entry, 1, 1, 1, 8) layout.addWidget(self.num_rows_label, 2, 0) layout.addWidget(self.num_rows_entry, 2, 1, 1, 8) layout.addWidget(self.nth_row_label, 3, 0) layout.addWidget(self.nth_row_entry, 3, 1, 1, 8) layout.addWidget(self.clip_histo_checkbox, 4, 0) layout.addWidget(self.eight_bit_rButton, 4, 1) layout.addWidget(self.sixteen_bit_rButton, 4, 2) layout.addWidget(self.min_val_label, 5, 0) layout.addWidget(self.min_val_entry, 5, 1, 1, 8) layout.addWidget(self.max_val_label, 6, 0) layout.addWidget(self.max_val_entry, 6, 1, 1, 8) layout.addWidget(self.crop_slices_checkbox, 7, 0) layout.addWidget(self.x_val_label, 7, 1)#, Qt.AlignRight) layout.addWidget(self.x_val_entry, 7, 2) layout.addWidget(self.width_val_label, 7, 3)#, Qt.AlignRight) layout.addWidget(self.width_val_entry, 7, 4) layout.addWidget(self.y_val_label, 7, 5) layout.addWidget(self.y_val_entry, 7, 6) layout.addWidget(self.height_val_label, 7, 7) layout.addWidget(self.height_val_entry, 7, 8) layout.addWidget(self.rotate_vol_label, 8, 0) layout.addWidget(self.rotate_vol_entry, 8, 1, 1, 8) self.setLayout(layout) def load_values(self): self.select_rows_checkbox.setChecked(EZVARS['inout']['input_ROI']['value']) self.first_row_entry.setText(str(SECTIONS['reading']['y']['value'])) self.num_rows_entry.setText(str(SECTIONS['reading']['height']['value'])) self.nth_row_entry.setText(str(SECTIONS['reading']['y-step']['value'])) self.clip_histo_checkbox.setChecked(EZVARS['inout']['clip_hist']['value']) if int(SECTIONS['general']['output-bitdepth']['value']) == 8: self.eight_bit_rButton.setChecked(True) self.sixteen_bit_rButton.setChecked(False) elif int(SECTIONS['general']['output-bitdepth']['value']) == 16: self.eight_bit_rButton.setChecked(False) self.sixteen_bit_rButton.setChecked(True) self.min_val_entry.setText(str(SECTIONS['general']['output-minimum']['value'])) self.max_val_entry.setText(str(SECTIONS['general']['output-maximum']['value'])) self.crop_slices_checkbox.setChecked(EZVARS['inout']['output-ROI']['value']) self.x_val_entry.setText(str(EZVARS['inout']['output-x']['value'])) self.width_val_entry.setText(str(EZVARS['inout']['output-width']['value'])) self.y_val_entry.setText(str(EZVARS['inout']['output-y']['value'])) self.height_val_entry.setText(str(EZVARS['inout']['output-height']['value'])) self.rotate_vol_entry.setText(str(reverse_tupleize()(SECTIONS['general-reconstruction']['volume-angle-z']['value']))) def set_select_rows(self): LOG.debug("Select rows: " + str(self.select_rows_checkbox.isChecked())) dict_entry = EZVARS['inout']['input_ROI'] add_value_to_dict_entry(dict_entry, self.select_rows_checkbox.isChecked()) def set_first_row(self): LOG.debug(self.first_row_entry.text()) dict_entry = SECTIONS['reading']['y'] add_value_to_dict_entry(dict_entry, str(self.first_row_entry.text())) self.first_row_entry.setText(str(dict_entry['value'])) def set_num_rows(self): LOG.debug(self.num_rows_entry.text()) dict_entry = SECTIONS['reading']['height'] add_value_to_dict_entry(dict_entry, str(self.num_rows_entry.text())) self.num_rows_entry.setText(str(dict_entry['value'])) def set_reco_nth_rows(self): LOG.debug(self.nth_row_entry.text()) dict_entry = SECTIONS['reading']['y-step'] add_value_to_dict_entry(dict_entry, str(self.nth_row_entry.text())) self.nth_row_entry.setText(str(dict_entry['value'])) def set_clip_histo(self): LOG.debug("Clip histo: " + str(self.clip_histo_checkbox.isChecked())) dict_entry = EZVARS['inout']['clip_hist'] add_value_to_dict_entry(dict_entry, self.clip_histo_checkbox.isChecked()) if EZVARS['inout']['clip_hist']['value']: return self.set_bitdepth() else: return '32' def set_bitdepth(self): dict_entry = SECTIONS['general']['output-bitdepth'] if self.eight_bit_rButton.isChecked(): LOG.debug("8 bit") add_value_to_dict_entry(dict_entry, str(8)) return '8' elif self.sixteen_bit_rButton.isChecked(): LOG.debug("16 bit") add_value_to_dict_entry(dict_entry, str(16)) return '16' def set_min_val(self): LOG.debug(self.min_val_entry.text()) dict_entry = SECTIONS['general']['output-minimum'] add_value_to_dict_entry(dict_entry, self.min_val_entry.text()) def set_max_val(self): LOG.debug(self.max_val_entry.text()) dict_entry = SECTIONS['general']['output-maximum'] add_value_to_dict_entry(dict_entry, self.max_val_entry.text()) def set_crop_slices(self): LOG.debug("Crop slices: " + str(self.crop_slices_checkbox.isChecked())) dict_entry = EZVARS['inout']['output-ROI'] add_value_to_dict_entry(dict_entry, self.crop_slices_checkbox.isChecked()) def set_x(self): LOG.debug(self.x_val_entry.text()) dict_entry = EZVARS['inout']['output-x'] add_value_to_dict_entry(dict_entry, str(self.x_val_entry.text())) self.x_val_entry.setText(str(dict_entry['value'])) def set_width(self): LOG.debug(self.width_val_entry.text()) dict_entry = EZVARS['inout']['output-width'] add_value_to_dict_entry(dict_entry, str(self.width_val_entry.text())) self.width_val_entry.setText(str(dict_entry['value'])) def set_y(self): LOG.debug(self.y_val_entry.text()) dict_entry = EZVARS['inout']['output-y'] add_value_to_dict_entry(dict_entry, str(self.y_val_entry.text())) self.y_val_entry.setText(str(dict_entry['value'])) def set_height(self): LOG.debug(self.height_val_entry.text()) dict_entry = EZVARS['inout']['output-height'] add_value_to_dict_entry(dict_entry, str(self.height_val_entry.text())) self.height_val_entry.setText(str(dict_entry['value'])) def set_rotate_volume(self): LOG.debug(self.rotate_vol_entry.text()) dict_entry = SECTIONS['general-reconstruction']['volume-angle-z'] add_value_to_dict_entry(dict_entry, str(self.rotate_vol_entry.text())) self.rotate_vol_entry.setText(str(reverse_tupleize()(dict_entry['value']))) ././@PaxHeader0000000000000000000000000000003400000000000011452 xustar000000000000000028 mtime=1698416097.7697759 ufo-tofu-0.13.0/tofu/ez/GUI/Stitch_tools_tab/0000775000175000017500000000000000000000000021155 5ustar00tomastomas00000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1683790733.0 ufo-tofu-0.13.0/tofu/ez/GUI/Stitch_tools_tab/__init__.py0000664000175000017500000000000000000000000023254 0ustar00tomastomas00000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414563.0 ufo-tofu-0.13.0/tofu/ez/GUI/Stitch_tools_tab/auto_horizontal_stitch_funcs.py0000664000175000017500000006411100000000000027527 0ustar00tomastomas00000000000000import os import tifffile from collections import defaultdict import numpy as np import multiprocessing as mp from functools import partial from scipy.stats import gmean import math import yaml class AutoHorizontalStitchFunctions: def __init__(self, parameters): self.lvl0 = os.path.abspath(parameters["input_dir"]) self.ct_dirs = [] self.ct_axis_dict = {} self.parameters = parameters self.greatest_axis_value = 0 def run_horizontal_auto_stitch(self): """ Main function that calls all other functions """ # Write parameters to .yaml file - quit if something goes wrong if self.write_yaml_params() == -1: return -1 self.print_parameters() # Check input directory and find structure print("--> Finding CT Directories") self.find_ct_dirs() if len(self.ct_dirs) == 0: print("Error: Could not find any input CT directories") print("-> Ensure that the directory you selected contains subdirectories named 'tomo'") return -1 # For each zview we compute the axis of rotation print("--> Finding Axis of Rotation for each Z-View") self.find_images_and_compute_centre() print("\n ==> Found the following z-views and their corresponding axis of rotation <==") # Check the axis values and adjust for any outliers # If difference between two subsequent zdirs is > 3 then just change it to be 1 greater than previous self.correct_outliers() print("--> ct_axis_dict after correction: ") print(self.ct_axis_dict) # Find the greatest axis value for use in determining overall cropping amount when stitching self.find_greatest_axis_value() print("Greatest axis value: " + str(self.greatest_axis_value)) # Output the input parameters and axis values to the log file self.write_to_log_file() # For each ct-dir and z-view we want to stitch all the images using the values in ct_axis_dict if not self.parameters['dry_run']: print("\n--> Stitching Images...") self.find_and_stitch_images() print("--> Finished Stitching") def write_yaml_params(self): try: # Create the output directory root and save the parameters.yaml file os.makedirs(self.parameters['output_dir'], mode=0o777) file_path = os.path.join(self.parameters['output_dir'], 'auto_vertical_stitch_parameters.yaml') file_out = open(file_path, 'w') yaml.dump(self.parameters, file_out) print("Parameters file saved at: " + str(file_path)) return 0 except FileExistsError: print("--> Output Directory Exists - Delete Before Proceeding") return -1 def find_ct_dirs(self): """ Walks directories rooted at "Input Directory" location Appends their absolute path to ct-dir if they contain a directory with same name as "tomo" entry in GUI """ for root, dirs, files in os.walk(self.lvl0): for name in dirs: if name == "tomo": self.ct_dirs.append(root) self.ct_dirs = sorted(list(set(self.ct_dirs))) def find_images_and_compute_centre(self): """ We use multiprocessing across all CPU cores to determine the axis values for each zview in parallel We get a dictionary of z-directory and axis of rotation key-value pairs in self.ct_axis_dict at the end """ index = range(len(self.ct_dirs)) pool = mp.Pool(processes=mp.cpu_count()) exec_func = partial(self.find_center_parallel_proc) temp_axis_list = pool.map(exec_func, index) # Flatten list of dicts to just be a dictionary of key:value pairs for item in temp_axis_list: self.ct_axis_dict.update(item) def find_center_parallel_proc(self, index): """ Finds the images corresponding to the 0-180, 90-270, 180-360 degree pairs These are used to compute the average axis of rotation for each zview in a ct directory :return: A key value pair corresponding to the z-view path and its axis of rotation """ zview_path = self.ct_dirs[index] # Get list of image names in the directory try: tomo_path = os.path.join(zview_path, "tomo") image_list = sorted(os.listdir(tomo_path)) num_images = len(image_list) # If the number of images is divisible by eight we do eight 180 degree pairs in 45 degree increments if num_images % 8 == 0: # Get the names of the images in 45 degree increments starting from 0 zero_degree_image_name = image_list[0] one_eighty_degree_image_name = image_list[int(num_images / 2) - 1] forty_five_degree_image_name = image_list[int(num_images / 8) - 1] two_twenty_five_degree_image_name = image_list[int(num_images * 5 / 8) - 1] ninety_degree_image_name = image_list[int(num_images / 4) - 1] two_seventy_degree_image_name = image_list[int(num_images * 3 / 4) - 1] one_thirty_five_degree_image_name = image_list[int(num_images * 3 / 8) - 1] three_fifteen_degree_image_name = image_list[int(num_images * 7 / 8) - 1] three_sixty_degree_image_name = image_list[-1] # Get the paths for the images zero_degree_image_path = os.path.join(tomo_path, zero_degree_image_name) forty_five_degree_image_path = os.path.join(tomo_path, forty_five_degree_image_name) ninety_degree_image_path = os.path.join(tomo_path, ninety_degree_image_name) one_thirty_five_degree_image_path = os.path.join(tomo_path, one_thirty_five_degree_image_name) one_eighty_degree_image_path = os.path.join(tomo_path, one_eighty_degree_image_name) two_twenty_five_degree_image_path = os.path.join(tomo_path, two_twenty_five_degree_image_name) two_seventy_degree_image_path = os.path.join(tomo_path, two_seventy_degree_image_name) three_fifteen_degree_image_path = os.path.join(tomo_path, three_fifteen_degree_image_name) three_sixty_degree_image_path = os.path.join(tomo_path, three_sixty_degree_image_name) axis_list = [self.compute_center(zero_degree_image_path, one_eighty_degree_image_path), self.compute_center(forty_five_degree_image_path, two_twenty_five_degree_image_path), self.compute_center(ninety_degree_image_path, two_seventy_degree_image_path), self.compute_center(one_thirty_five_degree_image_path, three_fifteen_degree_image_path), self.compute_center(one_eighty_degree_image_path, three_sixty_degree_image_path), self.compute_center(two_twenty_five_degree_image_path, forty_five_degree_image_path), self.compute_center(two_seventy_degree_image_path, ninety_degree_image_path), self.compute_center(three_fifteen_degree_image_path, one_thirty_five_degree_image_path)] # If the number of images is not divisible by eight we do four 180 degree pairs in 90 degree increments elif num_images % 4 == 0: # Get the images corresponding to 0, 90, 180, and 270 degree rotations in half-acquisition mode - zero_degree_image_name = image_list[0] one_eighty_degree_image_name = image_list[int(num_images / 2) - 1] ninety_degree_image_name = image_list[int(num_images / 4) - 1] two_seventy_degree_image_name = image_list[int(num_images * 3 / 4) - 1] three_sixty_degree_image_name = image_list[-1] # Get the paths for the images zero_degree_image_path = os.path.join(tomo_path, zero_degree_image_name) one_eighty_degree_image_path = os.path.join(tomo_path, one_eighty_degree_image_name) ninety_degree_image_path = os.path.join(tomo_path, ninety_degree_image_name) two_seventy_degree_image_path = os.path.join(tomo_path, two_seventy_degree_image_name) three_sixty_degree_image_path = os.path.join(tomo_path, three_sixty_degree_image_name) # Determine the axis of rotation for pairs at 0-180, 90-270, 180-360 and 270-90 degrees axis_list = [self.compute_center(zero_degree_image_path, one_eighty_degree_image_path), self.compute_center(ninety_degree_image_path, two_seventy_degree_image_path), self.compute_center(one_eighty_degree_image_path, three_sixty_degree_image_path), self.compute_center(two_seventy_degree_image_path, ninety_degree_image_path)] # Otherwise, we compute the centre based on 0-180 and 180-360 pairs else: # Get the images corresponding to 0, 180 and 360 degree rotations in half-acquisition mode - zero_degree_image_name = image_list[0] one_eighty_degree_image_name = image_list[int(num_images / 2) - 1] three_sixty_degree_image_name = image_list[-1] # Get the paths for the images zero_degree_image_path = os.path.join(tomo_path, zero_degree_image_name) one_eighty_degree_image_path = os.path.join(tomo_path, one_eighty_degree_image_name) three_sixty_degree_image_path = os.path.join(tomo_path, three_sixty_degree_image_name) # Determine the axis of rotation for pairs at 0-180, 90-270, 180-360 and 270-90 degrees axis_list = [self.compute_center(zero_degree_image_path, one_eighty_degree_image_path), self.compute_center(one_eighty_degree_image_path, three_sixty_degree_image_path)] # Find the average of 180 degree rotation pairs print("--> " + str(zview_path)) print(axis_list) # If mode occurs more than 4 times then pick it as axis value, otherwise use geometric mean most_common_value = max(set(axis_list), key=axis_list.count) if axis_list.count(most_common_value) > 4: axis_value = self.col_round(most_common_value) else: axis_value = self.col_round(gmean(axis_list)) print("Axis value: " + str(axis_value)) # Return each zview and its axis of rotation value as key-value pair return {zview_path: axis_value} except NotADirectoryError: print("Skipped - Not a Directory: " + tomo_path) def compute_center(self, zero_degree_image_path, one_eighty_degree_image_path): """ Takes two pairs of images in half-acquisition mode separated by a full 180 degree rotation of the sample The images are then flat-corrected and cropped to the overlap region They are then correlated using fft to determine the axis of rotation :param zero_degree_image_path: First sample scan :param one_eighty_degree_image_path: Second sample scan rotated 180 degree from first sample scan :return: The axis of rotation based on the correlation of two 180 degree image pairs """ if self.parameters['sample_on_right'] is False: # Read each image into a numpy array first = self.read_image(zero_degree_image_path, False) second = self.read_image(one_eighty_degree_image_path, False) elif self.parameters['sample_on_right'] is True: # Read each image into a numpy array - flip both images first = self.read_image(zero_degree_image_path, True) second = self.read_image(one_eighty_degree_image_path, True) # Do flat field correction on the images # Case 1: Using darks/flats/flats2 in each CTdir alongside tomo if self.parameters['common_flats_darks'] is False: tomo_path, filename = os.path.split(zero_degree_image_path) zdir_path, tomo_name = os.path.split(tomo_path) flats_path = os.path.join(zdir_path, "flats") darks_path = os.path.join(zdir_path, "darks") flat_files = self.get_filtered_filenames(flats_path) dark_files = self.get_filtered_filenames(darks_path) # Case 2: Using common set of flats and darks elif self.parameters['common_flats_darks'] is True: flat_files = self.get_filtered_filenames(self.parameters['flats_dir']) dark_files = self.get_filtered_filenames(self.parameters['darks_dir']) flats = np.array([tifffile.TiffFile(x).asarray().astype(np.float) for x in flat_files]) darks = np.array([tifffile.TiffFile(x).asarray().astype(np.float) for x in dark_files]) dark = np.mean(darks, axis=0) flat = np.mean(flats, axis=0) - dark first = (first - dark) / flat second = (second - dark) / flat # We must crop the first image from first pixel column up until overlap first_cropped = first[:, :int(self.parameters['overlap_region'])] # We must crop the 180 degree rotation (which has been flipped 180) from width-overlap until last pixel column second_cropped = second[:, :int(self.parameters['overlap_region'])] axis = self.compute_rotation_axis(first_cropped, second_cropped) return axis def get_filtered_filenames(self, path, exts=['.tif', '.edf']): result = [] try: for ext in exts: result += [os.path.join(path, f) for f in os.listdir(path) if f.endswith(ext)] except OSError: return [] return sorted(result) def compute_rotation_axis(self, first_projection, last_projection): """ Compute the tomographic rotation axis based on cross-correlation technique. *first_projection* is the projection at 0 deg, *last_projection* is the projection at 180 deg. """ from scipy.signal import fftconvolve width = first_projection.shape[1] first_projection = first_projection - first_projection.mean() last_projection = last_projection - last_projection.mean() # The rotation by 180 deg flips the image horizontally, in order # to do cross-correlation by convolution we must also flip it # vertically, so the image is transposed and we can apply convolution # which will act as cross-correlation convolved = fftconvolve(first_projection, last_projection[::-1, :], mode='same') center = np.unravel_index(convolved.argmax(), convolved.shape)[1] return (width / 2.0 + center) / 2 def write_to_log_file(self): ''' Creates a log file with extension .info at the root of the output_dir tree structure Log file contains directory path and axis value ''' if not os.path.isdir(self.parameters['output_dir']): os.makedirs(self.parameters['output_dir'], mode=0o777) file_path = os.path.join(self.parameters['output_dir'], 'axis_values.info') print("Axis values log file stored at: " + file_path) try: file_handle = open(file_path, 'w') # Print input parameters file_handle.write("======================== Parameters ========================" + "\n") file_handle.write("Input Directory: " + self.parameters['input_dir'] + "\n") file_handle.write("Output Directory: " + self.parameters['output_dir'] + "\n") file_handle.write("Using common set of flats and darks: " + str(self.parameters['common_flats_darks']) + "\n") file_handle.write("Flats Directory: " + self.parameters['flats_dir'] + "\n") file_handle.write("Darks Directory: " + self.parameters['darks_dir'] + "\n") file_handle.write("Overlap Region Size: " + self.parameters['overlap_region'] + "\n") file_handle.write("Sample on right: " + str(self.parameters['sample_on_right']) + "\n") # Print z-directory and corresponding axis value file_handle.write("\n======================== Axis Values ========================\n") for key in self.ct_axis_dict: key_value_str = str(key) + " : " + str(self.ct_axis_dict[key]) print(key_value_str) file_handle.write(key_value_str + '\n') file_handle.write("\nGreatest axis value: " + str(self.greatest_axis_value)) except FileNotFoundError: print("Error: Could not write log file") def correct_outliers(self): """ This function looks at each CTDir containing Z00-Z0N If the axis values for successive zviews are greater than 3 (an outlier) Then we correct this by tying the outlier to the previous Z-View axis plus one self.ct_axis_dict is updated with corrected axis values """ sorted_by_ctdir_dict = defaultdict(dict) for key in self.ct_axis_dict: path_key, zdir = os.path.split(str(key)) axis_value = self.ct_axis_dict[key] sorted_by_ctdir_dict[path_key][zdir] = axis_value for dir_key in sorted_by_ctdir_dict: z_dir_list = list(sorted_by_ctdir_dict[dir_key].values()) # Need to account for the case where the first z-view is an outlier min_value = min(z_dir_list) if z_dir_list[0] > min_value + 2: z_dir_list[0] = min_value # Compare the difference of successive pairwise axis values # If the difference is greater than 3 then set the second pair value to be 1 more than the first pair value for index in range(len(z_dir_list) - 1): first_value = z_dir_list[index] second_value = z_dir_list[index + 1] difference = abs(second_value - first_value) if difference > 3: # Set second value to be one more than first z_dir_list[index + 1] = z_dir_list[index] + 1 # Assigns the values in z_dir_list back to the ct_dir_dict index = 0 for zdir in sorted_by_ctdir_dict[dir_key]: corrected_axis_value = z_dir_list[index] sorted_by_ctdir_dict[dir_key][zdir] = corrected_axis_value index += 1 # Assigns the corrected values back to self.ct_axis_dict for path_key in sorted_by_ctdir_dict: for z_key in sorted_by_ctdir_dict[path_key]: path_string = os.path.join(str(path_key), str(z_key)) self.ct_axis_dict[path_string] = sorted_by_ctdir_dict[path_key][z_key] def find_greatest_axis_value(self): """ Looks through all axis values and determines the greatest value """ axis_list = list(self.ct_axis_dict.values()) self.greatest_axis_value = max(axis_list) def find_and_stitch_images(self): index = range(len(self.ct_dirs)) pool = mp.Pool(processes=mp.cpu_count()) exec_func = partial(self.find_and_stitch_parallel_proc) # TODO : Try using pool.map or pool.imap_unordered and compare times # Try imap_unordered() as see if it is faster - with chunksize len(self.ct_dir) / mp.cpu_count() # pool.imap_unordered(exec_func, index, int(len(self.ct_dirs) / mp.cpu_count())) pool.map(exec_func, index) def find_and_stitch_parallel_proc(self, index): z_dir_path = self.ct_dirs[index] # Get list of image names in the directory try: # Want to maintain directory structure for output so we subtract the output-path from z_dir_path # Then we append this to the output_dir path diff_path = os.path.relpath(z_dir_path, self.parameters['input_dir']) out_path = os.path.join(self.parameters['output_dir'], diff_path) rotation_axis = self.ct_axis_dict[z_dir_path] # If using common flats/darks across all zdirs # then use common flats/darks directories as source of images to stitch and save to output zdirs if self.parameters['common_flats_darks'] is True: self.stitch_180_pairs(rotation_axis, z_dir_path, out_path, "tomo") flats_parent_path, garbage = os.path.split(self.parameters['flats_dir']) self.stitch_180_pairs(rotation_axis, flats_parent_path, out_path, "flats") darks_parent_path, garbage = os.path.split(self.parameters['darks_dir']) self.stitch_180_pairs(rotation_axis, darks_parent_path, out_path, "darks") # If using local flats/darks to each zdir then use those as source for stitching elif self.parameters['common_flats_darks'] is False: self.stitch_180_pairs(rotation_axis, z_dir_path, out_path, "tomo") # Need to account for case where flats, darks, flats2 don't exist if os.path.isdir(os.path.join(z_dir_path, "flats")): self.stitch_180_pairs(rotation_axis, z_dir_path, out_path, "flats") if os.path.isdir(os.path.join(z_dir_path, "darks")): self.stitch_180_pairs(rotation_axis, z_dir_path, out_path, "darks") if os.path.isdir(os.path.join(z_dir_path, "flats2")): self.stitch_180_pairs(rotation_axis, z_dir_path, out_path, "flats2") print("--> " + str(z_dir_path)) print("Axis of rotation: " + str(rotation_axis)) except NotADirectoryError as e: print("Skipped - Not a Directory: " + e.filename) def stitch_180_pairs(self, rotation_axis, in_path, out_path, type_str): """ Finds images in tomo, flats, darks, flats2 directories corresponding to 180 degree pairs The first image is stitched with the middle image and so on by using the index and midpoint :param rotation_axis: axis of rotation for z-directory :param in_path: absolute path to z-directory :param out_path: absolute path to output directory :param type_str: Type of subdirectory - e.g. "tomo", "flats", "darks", "flats2" """ os.makedirs(os.path.join(out_path, type_str), mode=0o777) image_list = sorted(os.listdir(os.path.join(in_path, type_str))) midpoint = int(len(image_list) / 2) for index in range(midpoint): first_path = os.path.join(in_path, type_str, image_list[index]) second_path = os.path.join(in_path, type_str, image_list[midpoint + index]) output_image_path = os.path.join(out_path, type_str, type_str + "_stitched_{:>04}.tif".format(index)) crop_amount = abs(self.greatest_axis_value - round(rotation_axis)) self.open_images_stitch_write(rotation_axis, crop_amount, first_path, second_path, output_image_path) def print_parameters(self): """ Prints parameter values with line formatting """ print() print("**************************** Running Auto Horizontal Stitch ****************************") print("======================== Parameters ========================") print("Input Directory: " + self.parameters['input_dir']) print("Output Directory: " + self.parameters['output_dir']) print("Using common set of flats and darks: " + str(self.parameters['common_flats_darks'])) print("Flats Directory: " + self.parameters['flats_dir']) print("Darks Directory: " + self.parameters['darks_dir']) print("Overlap Region Size: " + self.parameters['overlap_region']) print("Sample on right: " + str(self.parameters['sample_on_right'])) print("============================================================") """****** BORROWED FUNCTIONS ******""" def read_image(self, file_name, flip_image): """ Reads in a tiff image from disk at location specified by file_name, returns a numpy array :param file_name: Str - path to file :param flip_image: Bool - Whether image is to be flipped horizontally or not :return: A numpy array of type float """ with tifffile.TiffFile(file_name) as tif: image = tif.pages[0].asarray(out='memmap') if flip_image is True: image = np.fliplr(image) return image def open_images_stitch_write(self, ax, crop, first_image_path, second_image_path, out_fmt): if self.parameters['sample_on_right'] is False: # Read each image into a numpy array - We flip the second image first = self.read_image(first_image_path, flip_image=False) second = self.read_image(second_image_path, flip_image=True) if self.parameters['sample_on_right'] is True: # We pass index and formats as argument - We flip the first image before stitching first = self.read_image(first_image_path, flip_image=True) second = self.read_image(second_image_path, flip_image=False) stitched = self.stitch(first, second, ax, crop) tifffile.imwrite(out_fmt, stitched) def stitch(self, first, second, axis, crop): h, w = first.shape if axis > w / 2: dx = int(2 * (w - axis) + 0.5) else: dx = int(2 * axis + 0.5) tmp = np.copy(first) first = second second = tmp result = np.empty((h, 2 * w - dx), dtype=first.dtype) ramp = np.linspace(0, 1, dx) # Mean values of the overlapping regions must match, which corrects flat-field inconsistency # between the two projections # We clip the values in second so that there are no saturated pixel overflow problems k = np.mean(first[:, w - dx:]) / np.mean(second[:, :dx]) second = np.clip(second * k, np.iinfo(np.uint16).min, np.iinfo(np.uint16).max).astype(np.uint16) result[:, :w - dx] = first[:, :w - dx] result[:, w - dx:w] = first[:, w - dx:] * (1 - ramp) + second[:, :dx] * ramp result[:, w:] = second[:, dx:] return result[:, slice(int(crop), int(2 * (w - axis) - crop), 1)] def col_round(self, x): frac = x - math.floor(x) if frac < 0.5: return math.floor(x) return math.ceil(x) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414563.0 ufo-tofu-0.13.0/tofu/ez/GUI/Stitch_tools_tab/auto_horizontal_stitch_gui.py0000664000175000017500000003107500000000000027200 0ustar00tomastomas00000000000000import os import logging import shutil import yaml from PyQt5.QtWidgets import QPushButton, QLabel, QLineEdit, QGridLayout, QFileDialog, QCheckBox,\ QMessageBox, QGroupBox from tofu.ez.GUI.Stitch_tools_tab.auto_horizontal_stitch_funcs import AutoHorizontalStitchFunctions class AutoHorizontalStitchGUI(QGroupBox): def __init__(self): super().__init__() self.setTitle('Auto Horizontal Stitch') #logger = logging.getLogger() #logger.setLevel(logging.DEBUG) self.parameters = {'parameters_type': 'auto_horizontal_stitch'} self.auto_horizontal_stitch_funcs = None self.input_button = QPushButton("Select Input Path") self.input_button.clicked.connect(self.input_button_pressed) self.input_entry = QLineEdit() self.input_entry.textChanged.connect(self.set_input_entry) self.output_button = QPushButton("Select Output Path") self.output_button.clicked.connect(self.output_button_pressed) self.output_entry = QLineEdit() self.output_entry.textChanged.connect(self.set_output_entry) self.flats_darks_group = QGroupBox("Use Common Set of Flats and Darks") self.flats_darks_group.clicked.connect(self.set_flats_darks_group) self.flats_button = QPushButton("Select Flats Path") self.flats_button.clicked.connect(self.flats_button_pressed) self.flats_entry = QLineEdit() self.flats_entry.textChanged.connect(self.set_flats_entry) self.darks_button = QPushButton("Select Darks Path") self.darks_button.clicked.connect(self.darks_button_pressed) self.darks_entry = QLineEdit() self.darks_entry.textChanged.connect(self.set_darks_entry) self.overlap_region_label = QLabel("Overlapping Pixels") self.overlap_region_entry = QLineEdit() self.overlap_region_entry.textChanged.connect(self.set_overlap_region_entry) self.sample_on_right_checkbox = QCheckBox("Is the sample on the right side of the image?") self.sample_on_right_checkbox.stateChanged.connect(self.set_sample_on_right_checkbox) self.save_params_button = QPushButton("Save parameters") self.save_params_button.clicked.connect(self.save_params_button_clicked) self.import_params_button = QPushButton("Import parameters") self.import_params_button.clicked.connect(self.import_params_button_clicked) self.help_button = QPushButton("Help") self.help_button.clicked.connect(self.help_button_pressed) self.delete_temp_button = QPushButton("Delete Output Directory") self.delete_temp_button.clicked.connect(self.delete_button_pressed) self.stitch_button = QPushButton("Stitch") self.stitch_button.clicked.connect(self.stitch_button_pressed) self.dry_run_checkbox = QCheckBox("Dry Run") self.dry_run_checkbox.stateChanged.connect(self.set_dry_run_checkbox) self.set_layout() self.init_values() self.show() def set_layout(self): self.setMaximumSize(800, 300) layout = QGridLayout() layout.addWidget(self.input_button, 0, 0, 1, 2) layout.addWidget(self.input_entry, 0, 2, 1, 4) layout.addWidget(self.output_button, 1, 0, 1, 2) layout.addWidget(self.output_entry, 1, 2, 1, 4) self.flats_darks_group.setCheckable(True) self.flats_darks_group.setChecked(False) flats_darks_layout = QGridLayout() flats_darks_layout.addWidget(self.flats_button, 0, 0, 1, 2) flats_darks_layout.addWidget(self.flats_entry, 0, 2, 1, 2) flats_darks_layout.addWidget(self.darks_button, 1, 0, 1, 2) flats_darks_layout.addWidget(self.darks_entry, 1, 2, 1, 2) self.flats_darks_group.setLayout(flats_darks_layout) layout.addWidget(self.flats_darks_group, 2, 0, 1, 4) layout.addWidget(self.overlap_region_label, 3, 2) layout.addWidget(self.overlap_region_entry, 3, 3) layout.addWidget(self.sample_on_right_checkbox, 3, 0, 1, 2) layout.addWidget(self.save_params_button, 4, 0, 1, 2) layout.addWidget(self.import_params_button, 4, 3, 1, 1) layout.addWidget(self.help_button, 4, 2, 1, 1) layout.addWidget(self.stitch_button, 5, 0, 1, 2) layout.addWidget(self.dry_run_checkbox, 5, 2, 1, 1) layout.addWidget(self.delete_temp_button, 5, 3, 1, 1) self.setLayout(layout) def init_values(self): self.input_entry.setText("...enter input directory") self.output_entry.setText("...enter output directory") self.flats_entry.setText("...enter flats directory") self.parameters['common_flats_darks'] = False self.parameters['flats_dir'] = "" self.darks_entry.setText("...enter darks directory") self.parameters['darks_dir'] = "" self.overlap_region_entry.setText("1540") self.parameters['overlap_region'] = "1540" self.sample_on_right_checkbox.setChecked(False) self.parameters['sample_on_right'] = False self.dry_run_checkbox.setChecked(False) self.parameters['dry_run'] = False def update_parameters(self, new_parameters): logging.debug("Update parameters") # Update parameters dictionary (which is passed to auto_stitch_funcs) self.parameters = new_parameters # Update displayed parameters for GUI self.input_entry.setText(self.parameters['input_dir']) self.output_entry.setText(self.parameters['output_dir']) self.flats_darks_group.setChecked(bool(self.parameters['common_flats_darks'])) self.flats_entry.setText(self.parameters['flats_dir']) self.darks_entry.setText(self.parameters['darks_dir']) self.sample_on_right_checkbox.setChecked(bool(self.parameters['sample_on_right'])) self.overlap_region_entry.setText(self.parameters['overlap_region']) self.dry_run_checkbox.setChecked(bool(self.parameters['dry_run'])) def input_button_pressed(self): logging.debug("Input Button Pressed") dir_explore = QFileDialog(self) input_dir = dir_explore.getExistingDirectory() self.input_entry.setText(input_dir) self.parameters['input_dir'] = input_dir def set_input_entry(self): logging.debug("Input Entry: " + str(self.input_entry.text())) self.parameters['input_dir'] = str(self.input_entry.text()) def output_button_pressed(self): logging.debug("Output Button Pressed") dir_explore = QFileDialog(self) output_dir = dir_explore.getExistingDirectory() self.output_entry.setText(output_dir) self.parameters['output_dir'] = output_dir def set_output_entry(self): logging.debug("Output Entry: " + str(self.output_entry.text())) self.parameters['output_dir'] = str(self.output_entry.text()) def set_flats_darks_group(self): logging.debug("Use Common Flats/Darks: " + str(self.flats_darks_group.isChecked())) if self.parameters['common_flats_darks'] is True: self.parameters['common_flats_darks'] = False else: self.parameters['common_flats_darks'] = True def flats_button_pressed(self): logging.debug("Flats Button Pressed") dir_explore = QFileDialog(self) flats_dir = dir_explore.getExistingDirectory() self.flats_entry.setText(flats_dir) self.parameters['flats_dir'] = flats_dir def set_flats_entry(self): logging.debug("Flats Entry: " + str(self.flats_entry.text())) self.parameters['flats_dir'] = str(self.flats_entry.text()) def darks_button_pressed(self): logging.debug("Darks Button Pressed") dir_explore = QFileDialog(self) darks_dir = dir_explore.getExistingDirectory() self.darks_entry.setText(darks_dir) self.parameters['darks_dir'] = darks_dir def set_darks_entry(self): logging.debug("Darks Entry: " + str(self.darks_entry.text())) self.parameters['darks_dir'] = str(self.darks_entry.text()) def set_overlap_region_entry(self): logging.debug("Overlap Region: " + str(self.overlap_region_entry.text())) self.parameters['overlap_region'] = str(self.overlap_region_entry.text()) def set_sample_on_right_checkbox(self): logging.debug("Sample is on right side of the image: " + str(self.sample_on_right_checkbox.isChecked())) self.parameters['sample_on_right'] = self.sample_on_right_checkbox.isChecked() def save_params_button_clicked(self): logging.debug("Save params button clicked") dir_explore = QFileDialog(self) params_file_path = dir_explore.getSaveFileName(filter="*.yaml") garbage, file_name = os.path.split(params_file_path[0]) file_extension = os.path.splitext(file_name) # If the user doesn't enter the .yaml extension then append it to filepath if file_extension[-1] == "": file_path = params_file_path[0] + ".yaml" else: file_path = params_file_path[0] try: file_out = open(file_path, 'w') yaml.dump(self.parameters, file_out) print("Parameters file saved at: " + str(file_path)) except FileNotFoundError: print("You need to select a directory and use a valid file name") def import_params_button_clicked(self): logging.debug("Import params button clicked") dir_explore = QFileDialog(self) params_file_path = dir_explore.getOpenFileName(filter="*.yaml") try: file_in = open(params_file_path[0], 'r') new_parameters = yaml.load(file_in, Loader=yaml.FullLoader) self.update_parameters(new_parameters) print("Parameters file loaded from: " + str(params_file_path[0])) except FileNotFoundError: print("You need to select a valid input file") def help_button_pressed(self): logging.debug("Help Button Pressed") h = "Auto-Stitch is used to automatically find the axis of rotation" \ " in order to stitch pairs of images gathered in half-acquisition mode.\n\n" h += "The input directory must contain at least one directory named 'tomo' containing .tiff image files.\n\n" h += "The output directory, containing the stitched images," \ " maintains the structure of the directory tree rooted at the input directory.\n\n" h += "If the experiment uses one set of flats/darks the" \ " 'Use Common Set of Flats and Darks' checkbox must be selected." \ " These will then be copied and stitched according to the axis of rotation of each z-view.\n\n" h += "If each z-view contains its own set of flats/darks then" \ " auto_stitch will automatically use these for flat-field correction and stitching.\n\n" h += "In most cases of half-acquisition the sample is positioned on the left-side of the image." \ " Select the 'Is sample on the right side of the image?' checkbox if it is on the right.\n\n" h += "The 'Overlapping Pixels' entry determines the region of" \ " the images which will be used to determine the axis of rotation.\n\n" h += "Parameters can be saved to and loaded from .yaml files of the user's choice.\n\n" h += "If the dry run button is selected the program will find the axis values without stitching the images.\n\n" QMessageBox.information(self, "Help", h) def delete_button_pressed(self): logging.debug("Delete Output Directory Button Pressed") delete_dialog = QMessageBox.question(self, 'Quit', 'Are you sure you want to delete the output directory?', QMessageBox.Yes | QMessageBox.No) if delete_dialog == QMessageBox.Yes: try: print("Deleting: " + self.parameters['output_dir'] + " ...") shutil.rmtree(self.parameters['output_dir']) print("Deleted directory: " + self.parameters['output_dir']) except FileNotFoundError: print("Directory does not exist: " + self.parameters['output_dir']) def stitch_button_pressed(self): logging.debug("Stitch Button Pressed") self.auto_horizontal_stitch_funcs = AutoHorizontalStitchFunctions(self.parameters) self.auto_horizontal_stitch_funcs.run_horizontal_auto_stitch() def set_dry_run_checkbox(self): logging.debug("Dry Run Checkbox: " + str(self.dry_run_checkbox.isChecked())) self.parameters['dry_run'] = self.dry_run_checkbox.isChecked() ''' if __name__ == '__main__': app = QApplication(sys.argv) window = AutoHorizontalStitchGUI() sys.exit(app.exec_()) ''' ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414563.0 ufo-tofu-0.13.0/tofu/ez/GUI/Stitch_tools_tab/ez_360_multi_stitch_qt.py0000664000175000017500000005206500000000000026041 0ustar00tomastomas00000000000000from PyQt5.QtWidgets import ( QGroupBox, QPushButton, QCheckBox, QLabel, QLineEdit, QGridLayout, QFileDialog, QMessageBox, ) from PyQt5.QtCore import pyqtSignal import logging from shutil import rmtree import os import yaml from tofu.ez.Helpers.stitch_funcs import main_360_mp_depth2 from tofu.ez.GUI.message_dialog import warning_message # Params import tofu.ez.params as params LOG = logging.getLogger(__name__) class MultiStitch360Group(QGroupBox): get_fdt_names_on_stitch_pressed = pyqtSignal() def __init__(self): super().__init__() self.setTitle("360-MULTI-STITCH") self.setToolTip("Converts half-acquistion data sets to ordinary projections \n" "and crops all images to the same size.") self.setStyleSheet('QGroupBox {color: red;}') self.input_dir_button = QPushButton("Select input directory") self.input_dir_button.setToolTip("Contains multiple CT directories with flats/darks/tomo subdirectories. \n" "Images in each will be stitched pair-wise [x and x+180 deg]") self.input_dir_button.clicked.connect(self.input_button_pressed) self.input_dir_entry = QLineEdit() self.input_dir_entry.editingFinished.connect(self.set_input_entry) self.temp_dir_button = QPushButton("Select temporary directory - default value recommended") self.temp_dir_button.clicked.connect(self.temp_button_pressed) self.temp_dir_entry = QLineEdit() self.temp_dir_entry.editingFinished.connect(self.set_temp_entry) self.output_dir_button = QPushButton("Directory to save stitched images") self.output_dir_button.clicked.connect(self.output_button_pressed) self.output_dir_entry = QLineEdit() self.output_dir_entry.editingFinished.connect(self.set_output_entry) self.crop_checkbox = QCheckBox("Crop all projections to match the width of smallest stitched projection") self.crop_checkbox.clicked.connect(self.set_crop_projections_checkbox) self.axis_bottom_label = QLabel() self.axis_bottom_label.setText("Axis of Rotation (Dir 00):") self.axis_bottom_entry = QLineEdit() self.axis_bottom_entry.editingFinished.connect(self.set_axis_bottom) self.axis_top_label = QLabel("Axis of Rotation (Dir 0N):") self.axis_group = QGroupBox("Enter axis of rotation manually") self.axis_group.clicked.connect(self.set_axis_group) self.axis_top_entry = QLineEdit() self.axis_top_entry.editingFinished.connect(self.set_axis_top) self.axis_z000_label = QLabel("Axis of Rotation (Dir 00):") self.axis_z000_entry = QLineEdit() self.axis_z000_entry.editingFinished.connect(self.set_z000) self.axis_z001_label = QLabel("Axis of Rotation (Dir 01):") self.axis_z001_entry = QLineEdit() self.axis_z001_entry.editingFinished.connect(self.set_z001) self.axis_z002_label = QLabel("Axis of Rotation (Dir 02):") self.axis_z002_entry = QLineEdit() self.axis_z002_entry.editingFinished.connect(self.set_z002) self.axis_z003_label = QLabel("Axis of Rotation (Dir 03):") self.axis_z003_entry = QLineEdit() self.axis_z003_entry.editingFinished.connect(self.set_z003) self.axis_z004_label = QLabel("Axis of Rotation (Dir 04):") self.axis_z004_entry = QLineEdit() self.axis_z004_entry.editingFinished.connect(self.set_z004) self.axis_z005_label = QLabel("Axis of Rotation (Dir 05):") self.axis_z005_entry = QLineEdit() self.axis_z005_entry.editingFinished.connect(self.set_z005) self.axis_z006_label = QLabel("Axis of Rotation (Dir 06):") self.axis_z006_entry = QLineEdit() self.axis_z006_entry.editingFinished.connect(self.set_z006) self.axis_z007_label = QLabel("Axis of Rotation (Dir 07):") self.axis_z007_entry = QLineEdit() self.axis_z007_entry.editingFinished.connect(self.set_z007) self.axis_z008_label = QLabel("Axis of Rotation (Dir 08):") self.axis_z008_entry = QLineEdit() self.axis_z008_entry.editingFinished.connect(self.set_z008) self.axis_z009_label = QLabel("Axis of Rotation (Dir 09):") self.axis_z009_entry = QLineEdit() self.axis_z009_entry.editingFinished.connect(self.set_z009) self.axis_z010_label = QLabel("Axis of Rotation (Dir 10):") self.axis_z010_entry = QLineEdit() self.axis_z010_entry.editingFinished.connect(self.set_z010) self.axis_z011_label = QLabel("Axis of Rotation (Dir 11):") self.axis_z011_entry = QLineEdit() self.axis_z011_entry.editingFinished.connect(self.set_z011) self.stitch_button = QPushButton("Stitch") self.stitch_button.clicked.connect(self.stitch_button_pressed) self.stitch_button.setStyleSheet("color:royalblue;font-weight:bold") self.delete_button = QPushButton("Delete output dir") self.delete_button.clicked.connect(self.delete_button_pressed) self.help_button = QPushButton("Help") self.help_button.clicked.connect(self.help_button_pressed) self.import_parameters_button = QPushButton("Import Parameters from File") self.import_parameters_button.clicked.connect(self.import_parameters_button_pressed) self.save_parameters_button = QPushButton("Save Parameters to File") self.save_parameters_button.clicked.connect(self.save_parameters_button_pressed) self.set_layout() def set_layout(self): layout = QGridLayout() layout.addWidget(self.input_dir_button, 0, 0, 1, 4) layout.addWidget(self.input_dir_entry, 1, 0, 1, 4) layout.addWidget(self.temp_dir_button, 2, 0, 1, 4) layout.addWidget(self.temp_dir_entry, 3, 0, 1, 4) layout.addWidget(self.output_dir_button, 4, 0, 1, 4) layout.addWidget(self.output_dir_entry, 5, 0, 1, 4) layout.addWidget(self.crop_checkbox, 6, 0, 1, 4) layout.addWidget(self.axis_bottom_label, 7, 0) layout.addWidget(self.axis_bottom_entry, 7, 1) layout.addWidget(self.axis_top_label, 7, 2) layout.addWidget(self.axis_top_entry, 7, 3) self.axis_group.setCheckable(True) self.axis_group.setChecked(False) axis_layout = QGridLayout() axis_layout.addWidget(self.axis_z000_label, 0, 0) axis_layout.addWidget(self.axis_z000_entry, 0, 1) axis_layout.addWidget(self.axis_z006_label, 0, 2) axis_layout.addWidget(self.axis_z006_entry, 0, 3) axis_layout.addWidget(self.axis_z001_label, 1, 0) axis_layout.addWidget(self.axis_z001_entry, 1, 1) axis_layout.addWidget(self.axis_z007_label, 1, 2) axis_layout.addWidget(self.axis_z007_entry, 1, 3) axis_layout.addWidget(self.axis_z002_label, 2, 0) axis_layout.addWidget(self.axis_z002_entry, 2, 1) axis_layout.addWidget(self.axis_z008_label, 2, 2) axis_layout.addWidget(self.axis_z008_entry, 2, 3) axis_layout.addWidget(self.axis_z003_label, 3, 0) axis_layout.addWidget(self.axis_z003_entry, 3, 1) axis_layout.addWidget(self.axis_z009_label, 3, 2) axis_layout.addWidget(self.axis_z009_entry, 3, 3) axis_layout.addWidget(self.axis_z004_label, 4, 0) axis_layout.addWidget(self.axis_z004_entry, 4, 1) axis_layout.addWidget(self.axis_z010_label, 4, 2) axis_layout.addWidget(self.axis_z010_entry, 4, 3) axis_layout.addWidget(self.axis_z005_label, 5, 0) axis_layout.addWidget(self.axis_z005_entry, 5, 1) axis_layout.addWidget(self.axis_z011_label, 5, 2) axis_layout.addWidget(self.axis_z011_entry, 5, 3) self.axis_group.setLayout(axis_layout) self.axis_group.setTabOrder(self.axis_z000_entry, self.axis_z001_entry) self.axis_group.setTabOrder(self.axis_z001_entry, self.axis_z002_entry) self.axis_group.setTabOrder(self.axis_z002_entry, self.axis_z003_entry) self.axis_group.setTabOrder(self.axis_z003_entry, self.axis_z004_entry) self.axis_group.setTabOrder(self.axis_z004_entry, self.axis_z005_entry) self.axis_group.setTabOrder(self.axis_z005_entry, self.axis_z006_entry) self.axis_group.setTabOrder(self.axis_z006_entry, self.axis_z007_entry) self.axis_group.setTabOrder(self.axis_z007_entry, self.axis_z008_entry) self.axis_group.setTabOrder(self.axis_z008_entry, self.axis_z009_entry) self.axis_group.setTabOrder(self.axis_z009_entry, self.axis_z010_entry) self.axis_group.setTabOrder(self.axis_z010_entry, self.axis_z011_entry) layout.addWidget(self.axis_group, 8, 0, 1, 4) layout.addWidget(self.help_button, 9, 0) layout.addWidget(self.delete_button, 9, 1) layout.addWidget(self.stitch_button, 9, 2, 1, 2) layout.addWidget(self.import_parameters_button, 10, 0, 1, 2) layout.addWidget(self.save_parameters_button, 10, 2, 1, 2) self.setLayout(layout) def init_values(self): self.parameters = {'parameters_type': '360_multi_stitch'} self.parameters['360multi_input_dir'] = os.path.expanduser('~')# #EZVARS['360-batch-stitch']['indir'] self.input_dir_entry.setText(self.parameters['360multi_input_dir']) self.parameters['360multi_temp_dir'] = os.path.join( #EZVARS['360-batch-stitch']['tmpdir'] os.path.expanduser('~'), "tmp-batch360stitch") self.temp_dir_entry.setText(self.parameters['360multi_temp_dir']) self.parameters['360multi_output_dir'] = os.path.join(os.path.expanduser('~'),'stitched360') #EZVARS['360-batch-stitch']['outdir'] self.output_dir_entry.setText(self.parameters['360multi_output_dir']) self.parameters['360multi_crop_projections'] = True #EZVARS['360-batch-stitch']['crop'] self.crop_checkbox.setChecked(self.parameters['360multi_crop_projections']) self.parameters['360multi_bottom_axis'] = 245 #EZVARS['360-batch-stitch']['COR-in-first-set'] self.axis_bottom_entry.setText(str(self.parameters['360multi_bottom_axis'])) self.parameters['360multi_top_axis'] = 245 #EZVARS['360-batch-stitch']['COR-in-last-set'] self.axis_top_entry.setText(str(self.parameters['360multi_top_axis'])) self.parameters['360multi_axis'] = self.parameters['360multi_bottom_axis'] self.parameters['360multi_manual_axis'] = False #EZVARS['360-batch-stitch']['COR-user-defined'] self.parameters['360multi_axis_dict'] = dict.fromkeys(['z000', 'z001', 'z002', 'z003', 'z004', 'z005', 'z006', 'z007', 'z008', 'z009', 'z010', 'z011'], 0) # EZVARS['360-batch-stitch']['COR-dict'] def update_parameters(self, new_parameters): LOG.debug("Update parameters") if new_parameters['parameters_type'] != '360_multi_stitch': print("Error: Invalid parameter file type: " + str(new_parameters['parameters_type'])) return -1 # Update parameters dictionary (which is passed to auto_stitch_funcs) self.parameters = new_parameters # Update displayed parameters for GUI self.input_dir_entry.setText(self.parameters['360multi_input_dir']) self.temp_dir_entry.setText(self.parameters['360multi_temp_dir']) self.output_dir_entry.setText(self.parameters['360multi_output_dir']) self.crop_checkbox.setChecked(self.parameters['360multi_crop_projections']) self.axis_bottom_entry.setText(str(self.parameters['360multi_bottom_axis'])) self.axis_top_entry.setText(str(self.parameters['360multi_top_axis'])) self.axis_group.setChecked(bool(self.parameters['360multi_manual_axis'])) self.axis_z000_entry.setText(str(self.parameters['360multi_axis_dict']['z000'])) self.axis_z001_entry.setText(str(self.parameters['360multi_axis_dict']['z001'])) self.axis_z002_entry.setText(str(self.parameters['360multi_axis_dict']['z002'])) self.axis_z003_entry.setText(str(self.parameters['360multi_axis_dict']['z003'])) self.axis_z004_entry.setText(str(self.parameters['360multi_axis_dict']['z004'])) self.axis_z005_entry.setText(str(self.parameters['360multi_axis_dict']['z005'])) self.axis_z006_entry.setText(str(self.parameters['360multi_axis_dict']['z006'])) self.axis_z007_entry.setText(str(self.parameters['360multi_axis_dict']['z007'])) self.axis_z008_entry.setText(str(self.parameters['360multi_axis_dict']['z008'])) self.axis_z009_entry.setText(str(self.parameters['360multi_axis_dict']['z009'])) self.axis_z010_entry.setText(str(self.parameters['360multi_axis_dict']['z010'])) self.axis_z011_entry.setText(str(self.parameters['360multi_axis_dict']['z011'])) return 0 def input_button_pressed(self): LOG.debug("Input button pressed") dir_explore = QFileDialog(self) self.input_dir_entry.setText(dir_explore.getExistingDirectory()) self.set_input_entry() def set_input_entry(self): LOG.debug("Input directory: " + str(self.input_dir_entry.text())) self.parameters['360multi_input_dir'] = str(self.input_dir_entry.text()) # Set output directory to automatically follow the input directory structure self.output_dir_entry.setText(self.parameters['360multi_input_dir'] + "/hor-search") self.set_output_entry() def temp_button_pressed(self): LOG.debug("Temp button pressed") dir_explore = QFileDialog(self) self.temp_dir_entry.setText(dir_explore.getExistingDirectory()) self.set_temp_entry() def set_temp_entry(self): LOG.debug("Temp directory: " + str(self.temp_dir_entry.text())) self.parameters['360multi_temp_dir'] = str(self.temp_dir_entry.text()) def output_button_pressed(self): LOG.debug("Output button pressed") dir_explore = QFileDialog(self) self.output_dir_entry.setText(dir_explore.getExistingDirectory()) self.set_output_entry() def set_output_entry(self): LOG.debug("Output directory: " + str(self.output_dir_entry.text())) self.parameters['360multi_output_dir'] = str(self.output_dir_entry.text()) def set_crop_projections_checkbox(self): LOG.debug("Crop projections: " + str(self.crop_checkbox.isChecked())) self.parameters['360multi_crop_projections'] = bool(self.crop_checkbox.isChecked()) def set_axis_bottom(self): LOG.debug("Axis Bottom : " + str(self.axis_bottom_entry.text())) self.parameters['360multi_bottom_axis'] = int(self.axis_bottom_entry.text()) def set_axis_top(self): LOG.debug("Axis Top: " + str(self.axis_top_entry.text())) self.parameters['360multi_top_axis'] = int(self.axis_top_entry.text()) def set_axis_group(self): if self.axis_group.isChecked(): self.axis_bottom_label.setEnabled(False) self.axis_bottom_entry.setEnabled(False) self.axis_top_label.setEnabled(False) self.axis_top_entry.setEnabled(False) self.parameters['360multi_manual_axis'] = True LOG.debug("Enter axis of rotation manually: " + str(self.parameters['360multi_manual_axis'])) else: self.axis_bottom_label.setEnabled(True) self.axis_bottom_entry.setEnabled(True) self.axis_top_label.setEnabled(True) self.axis_top_entry.setEnabled(True) self.parameters['360multi_manual_axis'] = False LOG.debug("Enter axis of rotation manually: " + str(self.parameters['360multi_manual_axis'])) def set_z000(self): LOG.debug("z000 axis: " + str(self.axis_z000_entry.text())) self.parameters['360multi_axis_dict']['z000'] = int(self.axis_z000_entry.text()) def set_z001(self): LOG.debug("z001 axis: " + str(self.axis_z001_entry.text())) self.parameters['360multi_axis_dict']['z001'] = int(self.axis_z001_entry.text()) def set_z002(self): LOG.debug("z002 axis: " + str(self.axis_z002_entry.text())) self.parameters['360multi_axis_dict']['z002'] = int(self.axis_z002_entry.text()) def set_z003(self): LOG.debug("z003 axis: " + str(self.axis_z003_entry.text())) self.parameters['360multi_axis_dict']['z003'] = int(self.axis_z003_entry.text()) def set_z004(self): LOG.debug("z004 axis: " + str(self.axis_z004_entry.text())) self.parameters['360multi_axis_dict']['z004'] = int(self.axis_z004_entry.text()) def set_z005(self): LOG.debug("z005 axis: " + str(self.axis_z005_entry.text())) self.parameters['360multi_axis_dict']['z005'] = int(self.axis_z005_entry.text()) def set_z006(self): LOG.debug("z006 axis: " + str(self.axis_z006_entry.text())) self.parameters['360multi_axis_dict']['z006'] = int(self.axis_z006_entry.text()) def set_z007(self): LOG.debug("z007 axis: " + str(self.axis_z007_entry.text())) self.parameters['360multi_axis_dict']['z007'] = int(self.axis_z007_entry.text()) def set_z008(self): LOG.debug("z008 axis: " + str(self.axis_z008_entry.text())) self.parameters['360multi_axis_dict']['z008'] = int(self.axis_z008_entry.text()) def set_z009(self): LOG.debug("z009 axis: " + str(self.axis_z009_entry.text())) self.parameters['360multi_axis_dict']['z009'] = int(self.axis_z009_entry.text()) def set_z010(self): LOG.debug("z010 axis: " + str(self.axis_z010_entry.text())) self.parameters['360multi_axis_dict']['z010'] = int(self.axis_z010_entry.text()) def set_z011(self): LOG.debug("z011 axis: " + str(self.axis_z011_entry.text())) self.parameters['360multi_axis_dict']['z011'] = int(self.axis_z011_entry.text()) def stitch_button_pressed(self): LOG.debug("Stitch button pressed") self.get_fdt_names_on_stitch_pressed.emit() if os.path.exists(self.parameters['360multi_output_dir']) and \ len(os.listdir(self.parameters['360multi_output_dir'])) > 0: qm = QMessageBox() rep = qm.warning(self, '', "Output directory exists and is not empty.") return print("======= Begin 360 Multi-Stitch =======") main_360_mp_depth2(self.parameters) if os.path.isdir(self.parameters['360multi_output_dir']): params_file_path = os.path.join(self.parameters['360multi_output_dir'], '360_multi_stitch_params.yaml') params.save_parameters(self.parameters, params_file_path) print("==== Waiting for Next Task ====") def delete_button_pressed(self): LOG.debug("Delete button pressed") qm = QMessageBox() rep = qm.question(self, '', "Is it safe to delete the output directory?", qm.Yes | qm.No) if not os.path.exists(self.parameters['360multi_output_dir']): warning_message("Output directory does not exist") elif rep == qm.Yes: print("---- Deleting Data From Output Directory ----") try: rmtree(self.parameters['360multi_output_dir']) except: warning_message("Problems with deleting output directory") else: return def help_button_pressed(self): LOG.debug("Help button pressed") h = "Stitches images horizontally\n" h += "Directory structure is, f.i., Input/000, Input/001,...Input/00N\n" h += "Each 000, 001, ... 00N directory must have identical subdirectory \"Type\"\n" h += "Selected range of images from \"Type\" directory will be stitched vertically\n" h += "across all subdirectories in the Input directory" h += "to be added as options:\n" h += "(1) orthogonal reslicing, (2) interpolation, (3) horizontal stitching" QMessageBox.information(self, "Help", h) def import_parameters_button_pressed(self): LOG.debug("Import params button clicked") dir_explore = QFileDialog(self) params_file_path = dir_explore.getOpenFileName(filter="*.yaml") try: file_in = open(params_file_path[0], 'r') new_parameters = yaml.load(file_in, Loader=yaml.FullLoader) if self.update_parameters(new_parameters) == 0: print("Parameters file loaded from: " + str(params_file_path[0])) except FileNotFoundError: print("You need to select a valid input file") def save_parameters_button_pressed(self): LOG.debug("Save params button clicked") dir_explore = QFileDialog(self) params_file_path = dir_explore.getSaveFileName(filter="*.yaml") garbage, file_name = os.path.split(params_file_path[0]) file_extension = os.path.splitext(file_name) # If the user doesn't enter the .yaml extension then append it to filepath if file_extension[-1] == "": file_path = params_file_path[0] + ".yaml" else: file_path = params_file_path[0] try: file_out = open(file_path, 'w') yaml.dump(self.parameters, file_out) print("Parameters file saved at: " + str(file_path)) except FileNotFoundError: print("You need to select a directory and use a valid file name") ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414563.0 ufo-tofu-0.13.0/tofu/ez/GUI/Stitch_tools_tab/ez_360_overlap_qt.py0000664000175000017500000003250200000000000024773 0ustar00tomastomas00000000000000from PyQt5.QtWidgets import ( QGroupBox, QPushButton, QCheckBox, QLabel, QLineEdit, QGridLayout, QFileDialog, QMessageBox, ) from PyQt5.QtCore import pyqtSignal import logging from shutil import rmtree import yaml import os from tofu.ez.Helpers.find_360_overlap import find_overlap import tofu.ez.params as params import getpass #TODO Make all stitching tools compatible with the bigtiffs LOG = logging.getLogger(__name__) class Overlap360Group(QGroupBox): get_fdt_names_on_stitch_pressed = pyqtSignal() get_RR_params_on_start_pressed = pyqtSignal() def __init__(self): super().__init__() self.setTitle("360-AXIS-SEARCH") self.setToolTip("Stitches and reconstructs one slice with different axis of rotation positions for half-acqusition mode data set(s)") self.setStyleSheet('QGroupBox {color: Orange;}') self.input_dir_button = QPushButton("Select input directory") self.input_dir_button.clicked.connect(self.input_button_pressed) self.input_dir_entry = QLineEdit() self.input_dir_entry.editingFinished.connect(self.set_input_entry) self.temp_dir_button = QPushButton("Select temp directory") self.temp_dir_button.clicked.connect(self.temp_button_pressed) self.temp_dir_entry = QLineEdit() self.temp_dir_entry.editingFinished.connect(self.set_temp_entry) self.output_dir_button = QPushButton("Select output directory") self.output_dir_button.clicked.connect(self.output_button_pressed) self.output_dir_entry = QLineEdit() self.output_dir_entry.editingFinished.connect(self.set_output_entry) self.pixel_row_label = QLabel("Row to be reconstructed") self.pixel_row_label.setToolTip("TEST") self.pixel_row_entry = QLineEdit() self.pixel_row_entry.editingFinished.connect(self.set_pixel_row) self.min_label = QLabel("Lower limit of stitch/axis search range") self.min_entry = QLineEdit() self.min_entry.editingFinished.connect(self.set_lower_limit) self.max_label = QLabel("Upper limit of stitch/axis search range") self.max_entry = QLineEdit() self.max_entry.editingFinished.connect(self.set_upper_limit) self.step_label = QLabel("Value by which to increment through search range") self.step_entry = QLineEdit() self.step_entry.editingFinished.connect(self.set_increment) self.doRR = QCheckBox("Apply ring removal") #self.doRR.setEnabled(False) self.doRR.stateChanged.connect(self.set_RR_checkbox) self.help_button = QPushButton("Help") self.help_button.clicked.connect(self.help_button_pressed) self.find_overlap_button = QPushButton("Generate slices") self.find_overlap_button.clicked.connect(self.overlap_button_pressed) self.find_overlap_button.setStyleSheet("color:royalblue;font-weight:bold") self.import_parameters_button = QPushButton("Import Parameters from File") self.import_parameters_button.clicked.connect(self.import_parameters_button_pressed) self.save_parameters_button = QPushButton("Save Parameters to File") self.save_parameters_button.clicked.connect(self.save_parameters_button_pressed) self.set_layout() def set_layout(self): layout = QGridLayout() layout.addWidget(self.input_dir_button, 0, 0, 1, 2) layout.addWidget(self.input_dir_entry, 1, 0, 1, 2) layout.addWidget(self.temp_dir_button, 2, 0, 1, 2) layout.addWidget(self.temp_dir_entry, 3, 0, 1, 2) layout.addWidget(self.output_dir_button, 4, 0, 1, 2) layout.addWidget(self.output_dir_entry, 5, 0, 1, 2) layout.addWidget(self.pixel_row_label, 6, 0) layout.addWidget(self.pixel_row_entry, 6, 1) layout.addWidget(self.min_label, 7, 0) layout.addWidget(self.min_entry, 7, 1) layout.addWidget(self.max_label, 8, 0) layout.addWidget(self.max_entry, 8, 1) layout.addWidget(self.step_label, 9, 0) layout.addWidget(self.step_entry, 9, 1) layout.addWidget(self.doRR, 10, 0) layout.addWidget(self.help_button, 11, 0) layout.addWidget(self.find_overlap_button, 11, 1) layout.addWidget(self.import_parameters_button, 12, 0) layout.addWidget(self.save_parameters_button, 12, 1) self.setLayout(layout) def init_values(self): self.parameters = {'parameters_type': '360_overlap'} self.parameters['360overlap_input_dir'] = os.path.expanduser('~') #EZVARS['360-olap-search']['indir'] self.input_dir_entry.setText(self.parameters['360overlap_input_dir']) self.parameters['360overlap_temp_dir'] = os.path.join( #EZVARS['360-olap-search']['tmpdir'] os.path.expanduser('~'), "tmp-360axis-search") self.temp_dir_entry.setText(self.parameters['360overlap_temp_dir']) self.parameters['360overlap_output_dir'] = os.path.join( #EZVARS['360-olap-search']['outdir'] os.path.expanduser('~'), "ezufo-360axis-search") self.output_dir_entry.setText(self.parameters['360overlap_output_dir']) self.parameters['360overlap_row'] = 200 #EZVARS['360-olap-search']['y'] self.pixel_row_entry.setText(str(self.parameters['360overlap_row'])) self.parameters['360overlap_lower_limit'] = 100 #EZVARS['360-olap-search']['column_first'] self.min_entry.setText(str(self.parameters['360overlap_lower_limit'])) self.parameters['360overlap_upper_limit'] = 200 #EZVARS['360-olap-search']['column_last'] self.max_entry.setText(str(self.parameters['360overlap_upper_limit'])) self.parameters['360overlap_increment'] = 1 #EZVARS['360-olap-search']['column_step'] self.step_entry.setText(str(self.parameters['360overlap_increment'])) self.parameters['360overlap_doRR'] = False # replace with #EZVARS['360-olap-search']['remove-rings'] self.doRR.setChecked(bool(self.parameters['360overlap_doRR'])) def update_parameters(self, new_parameters): LOG.debug("Update parameters") if new_parameters['parameters_type'] != '360_overlap': print("Error: Invalid parameter file type: " + str(new_parameters['parameters_type'])) return -1 # Update parameters dictionary (which is passed to auto_stitch_funcs) self.parameters = new_parameters # Update displayed parameters for GUI self.input_dir_entry.setText(self.parameters['360overlap_input_dir']) self.temp_dir_entry.setText(self.parameters['360overlap_temp_dir']) self.output_dir_entry.setText(self.parameters['360overlap_output_dir']) self.pixel_row_entry.setText(str(self.parameters['360overlap_row'])) self.min_entry.setText(str(self.parameters['360overlap_lower_limit'])) self.max_entry.setText(str(self.parameters['360overlap_upper_limit'])) self.step_entry.setText(str(self.parameters['360overlap_increment'])) self.doRR.setChecked(bool(self.parameters['360overlap_doRR'])) def input_button_pressed(self): LOG.debug("Select input button pressed") dir_explore = QFileDialog(self) self.parameters['360overlap_input_dir'] = dir_explore.getExistingDirectory() self.input_dir_entry.setText(self.parameters['360overlap_input_dir']) def set_input_entry(self): LOG.debug("Input: " + str(self.input_dir_entry.text())) self.parameters['360overlap_input_dir'] = str(self.input_dir_entry.text()) def temp_button_pressed(self): LOG.debug("Select temp button pressed") dir_explore = QFileDialog(self) self.parameters['360overlap_temp_dir'] = dir_explore.getExistingDirectory() self.temp_dir_entry.setText(self.parameters['360overlap_temp_dir']) def set_temp_entry(self): LOG.debug("Temp: " + str(self.temp_dir_entry.text())) self.parameters['360overlap_temp_dir'] = str(self.temp_dir_entry.text()) def output_button_pressed(self): LOG.debug("Select output button pressed") dir_explore = QFileDialog(self) self.parameters['360overlap_output_dir'] = dir_explore.getExistingDirectory() self.output_dir_entry.setText(self.parameters['360overlap_output_dir']) def set_output_entry(self): LOG.debug("Output: " + str(self.output_dir_entry.text())) self.parameters['360overlap_output_dir'] = str(self.output_dir_entry.text()) def set_pixel_row(self): LOG.debug("Pixel row: " + str(self.pixel_row_entry.text())) self.parameters['360overlap_row'] = int(self.pixel_row_entry.text()) def set_lower_limit(self): LOG.debug("Lower limit: " + str(self.min_entry.text())) self.parameters['360overlap_lower_limit'] = int(self.min_entry.text()) def set_upper_limit(self): LOG.debug("Upper limit: " + str(self.max_entry.text())) self.parameters['360overlap_upper_limit'] = int(self.max_entry.text()) def set_increment(self): LOG.debug("Value of increment: " + str(self.step_entry.text())) self.parameters['360overlap_increment'] = int(self.step_entry.text()) def set_RR_checkbox(self): LOG.debug("Apply RR in 360-search: " + str(self.doRR.isChecked())) self.parameters['360overlap_doRR'] = bool(self.doRR.isChecked()) def overlap_button_pressed(self): LOG.debug("Find overlap button pressed") self.get_fdt_names_on_stitch_pressed.emit() self.get_RR_params_on_start_pressed.emit() if os.path.exists(self.parameters['360overlap_output_dir']) and \ len(os.listdir(self.parameters['360overlap_output_dir'])) > 0: qm = QMessageBox() rep = qm.question(self, '', "Output directory exists and not empty. Can I delete it to continue?", qm.Yes | qm.No) if rep == qm.Yes: try: rmtree(self.parameters['360overlap_output_dir']) except: QMessageBox.information(self, "Problem", "Cannot delete existing output dir") return else: return if os.path.exists(self.parameters['360overlap_temp_dir']) and \ len(os.listdir(self.parameters['360overlap_temp_dir'])) > 0: qm = QMessageBox() rep = qm.question(self, '', "Temporary dir exist and not empty. Can I delete it to continue?", qm.Yes | qm.No) if rep == qm.Yes: try: rmtree(self.parameters['360overlap_temp_dir']) except: QMessageBox.information(self, "Problem", "Cannot delete existing tmp dir") return else: return if not os.path.exists(self.parameters['360overlap_temp_dir']): os.makedirs(self.parameters['360overlap_temp_dir']) if not os.path.exists(self.parameters['360overlap_output_dir']): os.makedirs(self.parameters['360overlap_output_dir']) find_overlap(self.parameters) if os.path.exists(self.parameters['360overlap_output_dir']): params_file_path = os.path.join(self.parameters['360overlap_output_dir'], '360_overlap_params.yaml') params.save_parameters(self.parameters, params_file_path) def help_button_pressed(self): LOG.debug("Help button pressed") h = "This script takes as input a CT scan that has been collected in 'half-acquisition' mode" h += " and produces a series of reconstructed slices, each of which are generated by cropping and" h += " concatenating opposing projections together over a range of 'overlap' values (i.e. the pixel column" h += " at which the images are cropped and concatenated)." h += " The objective is to review this series of images to determine the pixel column at which the axis of rotation" h += " is located (much like the axis search function commonly used in reconstruction software)." QMessageBox.information(self, "Help", h) def import_parameters_button_pressed(self): LOG.debug("Import params button clicked") dir_explore = QFileDialog(self) params_file_path = dir_explore.getOpenFileName(filter="*.yaml") try: file_in = open(params_file_path[0], 'r') new_parameters = yaml.load(file_in, Loader=yaml.FullLoader) if self.update_parameters(new_parameters) == 0: print("Parameters file loaded from: " + str(params_file_path[0])) except FileNotFoundError: print("You need to select a valid input file") def save_parameters_button_pressed(self): LOG.debug("Save params button clicked") dir_explore = QFileDialog(self) params_file_path = dir_explore.getSaveFileName(filter="*.yaml") garbage, file_name = os.path.split(params_file_path[0]) file_extension = os.path.splitext(file_name) # If the user doesn't enter the .yaml extension then append it to filepath if file_extension[-1] == "": file_path = params_file_path[0] + ".yaml" else: file_path = params_file_path[0] try: file_out = open(file_path, 'w') yaml.dump(self.parameters, file_out) print("Parameters file saved at: " + str(file_path)) except FileNotFoundError: print("You need to select a directory and use a valid file name") ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414563.0 ufo-tofu-0.13.0/tofu/ez/GUI/Stitch_tools_tab/ezmview_qt.py0000664000175000017500000002435700000000000023734 0ustar00tomastomas00000000000000import os import logging from PyQt5.QtWidgets import ( QGroupBox, QPushButton, QLineEdit, QLabel, QCheckBox, QGridLayout, QFileDialog, QMessageBox, ) import yaml from tofu.ez.Helpers.mview_main import main_prep LOG = logging.getLogger(__name__) class EZMViewGroup(QGroupBox): def __init__(self): super().__init__() self.args = {} self.e_indir = "" self.e_nproj = 0 self.e_nflats = 0 self.e_ndarks = 0 self.e_nviews = 0 self.e_noflats2 = False #self.e_Andor = False self.setTitle("EZ-MVIEW") self.setToolTip("Splits a sequence of tif files over flats/darks/tomo directories") self.setStyleSheet("QGroupBox {color: green;}") self.input_dir_button = QPushButton() self.input_dir_button.setText("Select directory with a CT sequence") self.input_dir_button.clicked.connect(self.select_directory) self.input_dir_entry = QLineEdit() self.input_dir_entry.editingFinished.connect(self.set_directory_entry) self.num_projections_label = QLabel() self.num_projections_label.setText("Number of projections") self.num_projections_entry = QLineEdit() self.num_projections_entry.editingFinished.connect(self.set_num_projections) self.num_flats_label = QLabel() self.num_flats_label.setText("Number of flats") self.num_flats_entry = QLineEdit() self.num_flats_entry.editingFinished.connect(self.set_num_flats) self.num_darks_label = QLabel() self.num_darks_label.setText("Number of darks") self.num_darks_entry = QLineEdit() self.num_darks_entry.editingFinished.connect(self.set_num_darks) self.num_vert_steps_label = QLabel() self.num_vert_steps_label.setText("Number of CT sets in the sequence") self.num_vert_steps_entry = QLineEdit() self.num_vert_steps_entry.editingFinished.connect(self.set_num_steps) self.no_trailing_flats_darks_checkbox = QCheckBox() self.no_trailing_flats_darks_checkbox.setText("No trailing flats/darks") self.no_trailing_flats_darks_checkbox.stateChanged.connect(self.set_trailing_checkbox) self.filenames_without_padding_checkbox = QCheckBox() self.filenames_without_padding_checkbox.setText("File names without zero padding") self.filenames_without_padding_checkbox.stateChanged.connect(self.set_file_names_checkbox) self.help_button = QPushButton() self.help_button.setText("Help") self.help_button.clicked.connect(self.help_button_pressed) self.undo_button = QPushButton() self.undo_button.setText("Undo") self.undo_button.clicked.connect(self.undo_button_pressed) self.convert_button = QPushButton() self.convert_button.setText("Convert") self.convert_button.clicked.connect(self.convert_button_pressed) self.convert_button.setStyleSheet("color:royalblue;font-weight:bold") self.save_parameters_button = QPushButton("Save Parameters to File") self.save_parameters_button.clicked.connect(self.save_parameters_button_pressed) self.import_parameters_button = QPushButton("Import Parameters from File") self.import_parameters_button.clicked.connect(self.import_parameters_button_pressed) self.set_layout() def set_layout(self): layout = QGridLayout() layout.addWidget(self.input_dir_button, 0, 0, 1, 3) layout.addWidget(self.input_dir_entry, 1, 0, 1, 3) layout.addWidget(self.num_projections_label, 2, 0) layout.addWidget(self.num_projections_entry, 2, 1, 1, 2) layout.addWidget(self.num_flats_label, 3, 0) layout.addWidget(self.num_flats_entry, 3, 1, 1, 2) layout.addWidget(self.num_darks_label, 4, 0) layout.addWidget(self.num_darks_entry, 4, 1, 1, 2) layout.addWidget(self.num_vert_steps_label, 5, 0) layout.addWidget(self.num_vert_steps_entry, 5, 1, 1, 2) layout.addWidget(self.no_trailing_flats_darks_checkbox, 6, 0) layout.addWidget(self.filenames_without_padding_checkbox, 6, 1, 1, 2) layout.addWidget(self.help_button, 7, 0, 1, 1) layout.addWidget(self.undo_button, 7, 1, 1, 1) layout.addWidget(self.convert_button, 7, 2, 1, 1) layout.addWidget(self.import_parameters_button, 8, 0, 1, 2) layout.addWidget(self.save_parameters_button, 8, 2, 1, 1) self.setLayout(layout) def init_values(self): self.parameters = {'parameters_type': 'ez_mview'} self.input_dir_entry.setText(os.getcwd()) self.parameters['ezmview_input_dir'] = os.getcwd() self.num_projections_entry.setText("3000") self.parameters['ezmview_num_projections'] = 3000 self.num_flats_entry.setText("10") self.parameters['ezmview_num_flats'] = 10 self.num_darks_entry.setText("10") self.parameters['ezmview_num_darks'] = 10 self.num_vert_steps_entry.setText("1") self.parameters['ezmview_num_sets'] = 1 self.no_trailing_flats_darks_checkbox.setChecked(False) self.parameters['ezmview_flats2'] = False self.filenames_without_padding_checkbox.setChecked(False) self.parameters['ezmview_no_zero_padding'] = False def update_parameters(self, new_parameters): LOG.debug("Update parameters") if new_parameters['parameters_type'] != 'ez_mview': print("Error: Invalid parameter file type: " + str(new_parameters['parameters_type'])) return -1 # Update parameters dictionary (which is passed to auto_stitch_funcs) self.parameters = new_parameters # Update displayed parameters for GUI self.input_dir_entry.setText(str(self.parameters['ezmview_input_dir'])) self.num_projections_entry.setText(str(self.parameters['ezmview_num_projections'])) self.num_flats_entry.setText(str(self.parameters['ezmview_num_flats'])) self.num_darks_entry.setText(str(self.parameters['ezmview_num_darks'])) self.num_vert_steps_entry.setText(str(self.parameters['ezmview_num_sets'])) self.no_trailing_flats_darks_checkbox.setChecked(bool(self.parameters['ezmview_flats2'])) self.filenames_without_padding_checkbox.setChecked(bool(self.parameters['ezmview_no_zero_padding'])) def select_directory(self): LOG.debug("Select directory button pressed") dir_explore = QFileDialog(self) directory = dir_explore.getExistingDirectory() self.input_dir_entry.setText(directory) self.parameters['ezmview_input_dir'] = directory def set_directory_entry(self): LOG.debug("Directory entry: " + str(self.input_dir_entry.text())) self.parameters['ezmview_input_dir'] = str(self.input_dir_entry.text()) def set_num_projections(self): LOG.debug("Num projections: " + str(self.num_projections_entry.text())) self.parameters['ezmview_num_projections'] = int(self.num_projections_entry.text()) def set_num_flats(self): LOG.debug("Num flats: " + str(self.num_flats_entry.text())) self.parameters['ezmview_num_flats'] = int(self.num_flats_entry.text()) def set_num_darks(self): LOG.debug("Num darks: " + str(self.num_darks_entry.text())) self.parameters['ezmview_num_darks'] = int(self.num_darks_entry.text()) def set_num_steps(self): LOG.debug("Num steps: " + str(self.num_vert_steps_entry.text())) self.parameters['ezmview_num_sets'] = int(self.num_vert_steps_entry.text()) def set_trailing_checkbox(self): LOG.debug("No trailing: " + str(self.no_trailing_flats_darks_checkbox.isChecked())) self.parameters['ezmview_flats2'] = bool(self.no_trailing_flats_darks_checkbox.isChecked()) def set_file_names_checkbox(self): LOG.debug("File names without zero padding: " + str(self.filenames_without_padding_checkbox.isChecked())) self.parameters['ezmview_no_zero_padding'] = \ bool(self.filenames_without_padding_checkbox.isChecked()) def convert_button_pressed(self): LOG.debug("Convert button pressed") LOG.debug(self.parameters) main_prep(self.parameters) def undo_button_pressed(self): LOG.debug("Undo button pressed") cmd = "find {} -type f -name \"*.tif\" -exec mv -t {} {{}} +" cmd = cmd.format(str(self.parameters['ezmview_input_dir']), str(self.parameters['ezmview_input_dir'])) os.system(cmd) def help_button_pressed(self): LOG.debug("Help button pressed") h = "Distributes a sequence of CT frames in flats/darks/tomo/flats2 directories\n" h += "assuming that acqusition sequence is flats->darks->tomo->flats2\n" h += 'Use only for sequences with flat fields acquired at 0 and 180!\n' h += "Conversions happens in-place but can be undone" QMessageBox.information(self, "Help", h) def import_parameters_button_pressed(self): LOG.debug("Import params button clicked") dir_explore = QFileDialog(self) params_file_path = dir_explore.getOpenFileName(filter="*.yaml") try: file_in = open(params_file_path[0], 'r') new_parameters = yaml.load(file_in, Loader=yaml.FullLoader) if self.update_parameters(new_parameters) == 0: print("Parameters file loaded from: " + str(params_file_path[0])) except FileNotFoundError: print("You need to select a valid input file") def save_parameters_button_pressed(self): LOG.debug("Save params button clicked") dir_explore = QFileDialog(self) params_file_path = dir_explore.getSaveFileName(filter="*.yaml") garbage, file_name = os.path.split(params_file_path[0]) file_extension = os.path.splitext(file_name) # If the user doesn't enter the .yaml extension then append it to filepath if file_extension[-1] == "": file_path = params_file_path[0] + ".yaml" else: file_path = params_file_path[0] try: file_out = open(file_path, 'w') yaml.dump(self.parameters, file_out) print("Parameters file saved at: " + str(file_path)) except FileNotFoundError: print("You need to select a directory and use a valid file name") ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414563.0 ufo-tofu-0.13.0/tofu/ez/GUI/Stitch_tools_tab/ezstitch_qt.py0000664000175000017500000005424500000000000024102 0ustar00tomastomas00000000000000import os from PyQt5.QtWidgets import ( QGroupBox, QPushButton, QCheckBox, QLabel, QLineEdit, QGridLayout, QVBoxLayout, QHBoxLayout, QRadioButton, QFileDialog, QMessageBox, ) from shutil import rmtree import logging import getpass import yaml import tofu.ez.params as params from tofu.ez.Helpers.stitch_funcs import main_sti_mp, main_conc_mp, main_360_mp_depth1 from tofu.ez.GUI.message_dialog import warning_message LOG = logging.getLogger(__name__) class EZStitchGroup(QGroupBox): def __init__(self): super().__init__() self.setTitle("EZ-STITCH") self.setToolTip("Reslicing and stitching tool") self.setStyleSheet('QGroupBox {color: purple;}') self.input_dir_button = QPushButton() self.input_dir_button.setText("Select input directory") self.input_dir_button.setToolTip("Normally contains a bunch of directories at the first depth level\n" \ "each of which has a subdirectory with the same name (second depth level). \n" "Images with the same index in these second-level subdirectories will be stitched vertically.") self.input_dir_button.clicked.connect(self.input_button_pressed) self.input_dir_entry = QLineEdit() self.input_dir_entry.editingFinished.connect(self.set_input_entry) self.tmp_dir_button = QPushButton() self.tmp_dir_button.setText("Select temporary directory") self.tmp_dir_button.clicked.connect(self.temp_button_pressed) self.tmp_dir_entry = QLineEdit() self.tmp_dir_entry.editingFinished.connect(self.set_temp_entry) self.output_dir_button = QPushButton() self.output_dir_button.setText("Directory to save stitched images") self.output_dir_button.clicked.connect(self.output_button_pressed) self.output_dir_entry = QLineEdit() self.output_dir_entry.editingFinished.connect(self.set_output_entry) self.types_of_images_label = QLabel() tmpstr = "Name of subdirectories which contain the same type of images in every directory in the input" self.types_of_images_label.setToolTip(tmpstr) self.types_of_images_label.setText("Name of subdirectory with the same type of images to stitch") self.types_of_images_label.setToolTip("e.g. sli, tomo, proj-pr, etc.") self.types_of_images_entry = QLineEdit() self.types_of_images_entry.setToolTip(tmpstr) self.types_of_images_entry.editingFinished.connect(self.set_type_images) self.orthogonal_checkbox = QCheckBox() self.orthogonal_checkbox.setText("Stitch orthogonal sections") self.orthogonal_checkbox.setToolTip("Will reslice images in every subdirectory and then stitch") self.orthogonal_checkbox.stateChanged.connect(self.set_stitch_checkbox) self.start_stop_step_label = QLabel() self.start_stop_step_label.setText("Which images to be stitched: start,stop,step:") self.start_stop_step_entry = QLineEdit() self.start_stop_step_entry.editingFinished.connect(self.set_start_stop_step) self.sample_moved_down_checkbox = QCheckBox() self.sample_moved_down_checkbox.setText("Flip images upside down before stitching") self.sample_moved_down_checkbox.stateChanged.connect(self.set_sample_moved_down) self.interpolate_regions_rButton = QRadioButton() self.interpolate_regions_rButton.setText("Interpolate overlapping regions and equalize intensity") self.interpolate_regions_rButton.clicked.connect(self.set_rButton) self.num_overlaps_label = QLabel() self.num_overlaps_label.setText("Number of overlapping rows") self.num_overlaps_entry = QLineEdit() self.num_overlaps_entry.editingFinished.connect(self.set_overlap) self.clip_histogram_checkbox = QCheckBox() self.clip_histogram_checkbox.setText("Clip histogram and convert slices to 8-bit before saving") self.clip_histogram_checkbox.stateChanged.connect(self.set_histogram_checkbox) self.min_value_label = QLabel() self.min_value_label.setText("Min value in 32-bit histogram") self.min_value_entry = QLineEdit() self.min_value_entry.editingFinished.connect(self.set_min_value) self.max_value_label = QLabel() self.max_value_label.setText("Max value in 32-bit histogram") self.max_value_entry = QLineEdit() self.max_value_entry.editingFinished.connect(self.set_max_value) self.concatenate_rButton = QRadioButton() self.concatenate_rButton.setText("Concatenate only") self.concatenate_rButton.clicked.connect(self.set_rButton) self.first_row_label = QLabel() self.first_row_label.setText("First row") self.first_row_entry = QLineEdit() self.first_row_entry.editingFinished.connect(self.set_first_row) self.last_row_label = QLabel() self.last_row_label.setText("Last row") self.last_row_entry = QLineEdit() self.last_row_entry.editingFinished.connect(self.set_last_row) self.half_acquisition_rButton = QRadioButton() self.half_acquisition_rButton.setText("Horizontal stitching of half-acq. mode data") self.half_acquisition_rButton.setToolTip("Applies to tif images in all depth-one subdirectories in the Input \n" "unlike 360-MULTI-STITCH which search images at the depth two ") #self.half_acquisition_rButtonYfor a half-acqusition mode data (even number of tif files in the Input directory)") self.half_acquisition_rButton.clicked.connect(self.set_rButton) self.column_of_axis_label = QLabel() self.column_of_axis_label.setText("In which column the axis of rotation is") self.column_of_axis_entry = QLineEdit() self.column_of_axis_entry.editingFinished.connect(self.set_axis_column) self.help_button = QPushButton() self.help_button.setText("Help") self.help_button.clicked.connect(self.help_button_pressed) self.delete_button = QPushButton() self.delete_button.setText("Delete output dir") self.delete_button.clicked.connect(self.delete_button_pressed) self.stitch_button = QPushButton() self.stitch_button.setText("Stitch") self.stitch_button.clicked.connect(self.stitch_button_pressed) self.stitch_button.setStyleSheet("color:royalblue;font-weight:bold") self.import_parameters_button = QPushButton("Import Parameters from File") self.import_parameters_button.clicked.connect(self.import_parameters_button_pressed) self.save_parameters_button = QPushButton("Save Parameters to File") self.save_parameters_button.clicked.connect(self.save_parameters_button_pressed) self.set_layout() def set_layout(self): layout = QGridLayout() vbox1 = QVBoxLayout() vbox1.addWidget(self.input_dir_button) vbox1.addWidget(self.input_dir_entry) vbox1.addWidget(self.tmp_dir_button) vbox1.addWidget(self.tmp_dir_entry) vbox1.addWidget(self.output_dir_button) vbox1.addWidget(self.output_dir_entry) layout.addItem(vbox1, 0, 0) grid = QGridLayout() grid.addWidget(self.types_of_images_label, 0, 0) grid.addWidget(self.types_of_images_entry, 0, 1) grid.addWidget(self.orthogonal_checkbox, 1, 0, 1, 2) grid.addWidget(self.start_stop_step_label, 2, 0) grid.addWidget(self.start_stop_step_entry, 2, 1) grid.addWidget(self.sample_moved_down_checkbox, 3, 0) grid.addWidget(self.interpolate_regions_rButton, 4, 0, 1, 2) grid.addWidget(self.num_overlaps_label, 5, 0) grid.addWidget(self.num_overlaps_entry, 5, 1) grid.addWidget(self.clip_histogram_checkbox, 6, 0) grid.addWidget(self.min_value_label, 7, 0) grid.addWidget(self.min_value_entry, 7, 1) grid.addWidget(self.max_value_label, 8, 0) grid.addWidget(self.max_value_entry, 8, 1) layout.addItem(grid, 1, 0) grid2 = QGridLayout() grid2.addWidget(self.concatenate_rButton, 0, 0, 1, 2) grid2.addWidget(self.first_row_label, 1, 0) grid2.addWidget(self.first_row_entry, 1, 1) grid2.addWidget(self.last_row_label, 1, 2) grid2.addWidget(self.last_row_entry, 1, 3) layout.addItem(grid2, 2, 0) grid3 = QGridLayout() grid3.addWidget(self.half_acquisition_rButton, 0, 0, 1, 2) grid3.addWidget(self.column_of_axis_label, 1, 0) grid3.addWidget(self.column_of_axis_entry, 1, 1) layout.addItem(grid3, 3, 0) grid4 = QGridLayout() grid4.addWidget(self.help_button, 0, 0) grid4.addWidget(self.delete_button, 0, 1) grid4.addWidget(self.stitch_button, 0, 2) grid4.addWidget(self.import_parameters_button, 1, 0, 1, 2) grid4.addWidget(self.save_parameters_button, 1, 2) layout.addItem(grid4, 4, 0) self.setLayout(layout) def init_values(self): self.parameters = {'parameters_type': 'ez_stitch'} self.parameters['ezstitch_input_dir'] = os.path.expanduser('~') self.input_dir_entry.setText(self.parameters['ezstitch_input_dir']) self.parameters['ezstitch_temp_dir'] = os.path.join( os.path.expanduser('~'), "tmp-ezstitch") self.tmp_dir_entry.setText(self.parameters['ezstitch_temp_dir']) self.parameters['ezstitch_output_dir'] = os.path.join( os.path.expanduser('~'), "ezufo-stitched-images") self.output_dir_entry.setText(self.parameters['ezstitch_output_dir']) self.parameters['ezstitch_type_image'] = "sli" self.types_of_images_entry.setText(self.parameters['ezstitch_type_image']) self.parameters['ezstitch_stitch_orthogonal'] = True self.orthogonal_checkbox.setChecked(self.parameters['ezstitch_stitch_orthogonal']) self.parameters['ezstitch_start_stop_step'] = "200,2000,200" self.start_stop_step_entry.setText(self.parameters['ezstitch_start_stop_step']) self.parameters['ezstitch_sample_moved_down'] = False self.sample_moved_down_checkbox.setChecked(self.parameters['ezstitch_sample_moved_down']) self.parameters['ezstitch_stitch_type'] = 0 self.interpolate_regions_rButton.setChecked(True) self.concatenate_rButton.setChecked(False) self.half_acquisition_rButton.setChecked(False) self.parameters['ezstitch_num_overlap_rows'] = 60 self.num_overlaps_entry.setText(str(self.parameters['ezstitch_num_overlap_rows'])) self.parameters['ezstitch_clip_histo'] = False self.clip_histogram_checkbox.setChecked(self.parameters['ezstitch_clip_histo']) self.parameters['ezstitch_histo_min'] = -0.0003 self.min_value_entry.setText(str(self.parameters['ezstitch_histo_min'])) self.parameters['ezstitch_histo_max'] = 0.0002 self.max_value_entry.setText(str(self.parameters['ezstitch_histo_max'])) self.parameters['ezstitch_first_row'] = 40 self.first_row_entry.setText(str(self.parameters['ezstitch_first_row'])) self.parameters['ezstitch_last_row'] = 440 self.last_row_entry.setText(str(self.parameters['ezstitch_last_row'])) self.parameters['ezstitch_axis_of_rotation'] = 245 self.column_of_axis_entry.setText(str(self.parameters['ezstitch_axis_of_rotation'])) def update_parameters(self, new_parameters): LOG.debug("Update parameters") if new_parameters['parameters_type'] != 'ez_stitch': print("Error: Invalid parameter file type: " + str(new_parameters['parameters_type'])) return -1 # Update parameters dictionary (which is passed to auto_stitch_funcs) self.parameters = new_parameters # Update displayed parameters for GUI self.input_dir_entry.setText(self.parameters['ezstitch_input_dir']) self.tmp_dir_entry.setText(self.parameters['ezstitch_temp_dir']) self.output_dir_entry.setText(self.parameters['ezstitch_output_dir']) self.types_of_images_entry.setText(self.parameters['ezstitch_type_image']) self.orthogonal_checkbox.setChecked(self.parameters['ezstitch_stitch_orthogonal']) self.start_stop_step_entry.setText(self.parameters['ezstitch_start_stop_step']) self.sample_moved_down_checkbox.setChecked(self.parameters['ezstitch_sample_moved_down']) if self.parameters['ezstitch_stitch_type'] == 0: self.interpolate_regions_rButton.setChecked(True) elif self.parameters['ezstitch_stitch_type'] == 1: self.concatenate_rButton.setChecked(True) elif self.parameters['ezstitch_stitch_type'] == 2: self.half_acquisition_rButton.setChecked(True) self.num_overlaps_entry.setText(str(self.parameters['ezstitch_num_overlap_rows'])) self.clip_histogram_checkbox.setChecked(self.parameters['ezstitch_clip_histo']) self.min_value_entry.setText(str(self.parameters['ezstitch_histo_min'])) self.max_value_entry.setText(str(self.parameters['ezstitch_histo_max'])) self.first_row_entry.setText(str(self.parameters['ezstitch_first_row'])) self.last_row_entry.setText(str(self.parameters['ezstitch_last_row'])) self.column_of_axis_entry.setText(str(self.parameters['ezstitch_axis_of_rotation'])) def set_rButton(self): if self.interpolate_regions_rButton.isChecked(): LOG.debug("Interpolate regions") self.parameters['ezstitch_stitch_type'] = 0 elif self.concatenate_rButton.isChecked(): LOG.debug("Concatenate only") self.parameters['ezstitch_stitch_type'] = 1 elif self.half_acquisition_rButton.isChecked(): LOG.debug("Half-acquisition mode") self.parameters['ezstitch_stitch_type'] = 2 def input_button_pressed(self): LOG.debug("Input button pressed") dir_explore = QFileDialog(self) self.parameters['ezstitch_input_dir'] = dir_explore.getExistingDirectory() self.input_dir_entry.setText(self.parameters['ezstitch_input_dir']) def set_input_entry(self): LOG.debug("Input: " + str(self.input_dir_entry.text())) self.parameters['ezstitch_input_dir'] = str(self.input_dir_entry.text()) def temp_button_pressed(self): LOG.debug("Temp button pressed") dir_explore = QFileDialog(self) self.parameters['ezstitch_temp_dir'] = dir_explore.getExistingDirectory() self.tmp_dir_entry.setText(self.parameters['ezstitch_temp_dir']) def set_temp_entry(self): LOG.debug("Temp: " + str(self.tmp_dir_entry.text())) self.parameters['ezstitch_temp_dir'] = str(self.tmp_dir_entry.text()) def output_button_pressed(self): LOG.debug("Output button pressed") dir_explore = QFileDialog(self) self.parameters['ezstitch_output_dir'] = dir_explore.getExistingDirectory() self.output_dir_entry.setText(self.parameters['ezstitch_output_dir']) def set_output_entry(self): LOG.debug("Output: " + str(self.output_dir_entry.text())) self.parameters['ezstitch_output_dir'] = str(self.output_dir_entry.text()) def set_type_images(self): LOG.debug("Type of images: " + str(self.types_of_images_entry.text())) self.parameters['ezstitch_type_image'] = str(self.types_of_images_entry.text()) def set_stitch_checkbox(self): LOG.debug("Stitch orthogonal: " + str(self.orthogonal_checkbox.isChecked())) self.parameters['ezstitch_stitch_orthogonal'] = bool(self.orthogonal_checkbox.isChecked()) def set_start_stop_step(self): LOG.debug("Images to be stitched: " + str(self.start_stop_step_entry.text())) self.parameters['ezstitch_start_stop_step'] = str(self.start_stop_step_entry.text()) def set_sample_moved_down(self): LOG.debug("Sample moved down: " + str(self.sample_moved_down_checkbox.isChecked())) self.parameters['ezstitch_sample_moved_down'] = bool(self.sample_moved_down_checkbox.isChecked()) def set_overlap(self): LOG.debug("Num overlapping rows: " + str(self.num_overlaps_entry.text())) self.parameters['ezstitch_num_overlap_rows'] = int(self.num_overlaps_entry.text()) def set_histogram_checkbox(self): LOG.debug("Clip histogram: " + str(self.clip_histogram_checkbox.isChecked())) self.parameters['ezstitch_clip_histo'] = bool(self.clip_histogram_checkbox.isChecked()) def set_min_value(self): LOG.debug("Min value: " + str(self.min_value_entry.text())) self.parameters['ezstitch_histo_min'] = float(self.min_value_entry.text()) def set_max_value(self): LOG.debug("Max value: " + str(self.max_value_entry.text())) self.parameters['ezstitch_histo_max'] = float(self.max_value_entry.text()) def set_first_row(self): LOG.debug("First row: " + str(self.first_row_entry.text())) self.parameters['ezstitch_first_row'] = int(self.first_row_entry.text()) def set_last_row(self): LOG.debug("Last row: " + str(self.last_row_entry.text())) self.parameters['ezstitch_last_row'] = int(self.last_row_entry.text()) def set_axis_column(self): LOG.debug("Column of axis: " + str(self.column_of_axis_entry.text())) self.parameters['ezstitch_axis_of_rotation'] = int(self.column_of_axis_entry.text()) def help_button_pressed(self): LOG.debug("Help button pressed") h = "Stitches images vertically\n" h += "Directory structure is, f.i., Input/000, Input/001,...Input/00N\n" h += "Each 000, 001, ... 00N directory must have identical subdirectory \"Type\"\n" h += "Selected range of images from \"Type\" directory will be stitched vertically\n" h += "across all subdirectories in the Input directory" h += "to be added as options:\n" h += "(1) orthogonal reslicing, (2) interpolation, (3) horizontal stitching" QMessageBox.information(self, "Help", h) def delete_button_pressed(self): LOG.debug("Delete button pressed") # if os.path.exists(self.parameters['ezstitch_output_dir']): # os.system('rm -r {}'.format(self.parameters['ezstitch_output_dir'])) # print(" - Directory with reconstructed data was removed") if os.path.exists(self.parameters['ezstitch_output_dir']): qm = QMessageBox() rep = qm.question(self, '', f"{self.parameters['ezstitch_output_dir']} \n" "will be removed. Continue?", qm.Yes | qm.No) if rep == qm.Yes: try: rmtree(self.parameters['ezstitch_output_dir']) except: warning_message('Error while deleting directory') return else: return def stitch_button_pressed(self): LOG.debug("Stitch button pressed") if os.path.exists(self.parameters['ezstitch_temp_dir']) and \ len(os.listdir(self.parameters['ezstitch_temp_dir'])) > 0: qm = QMessageBox() rep = qm.question(self, '', "Temporary dir is not empty. Is it safe to delete it?", qm.Yes | qm.No) if rep == qm.Yes: try: rmtree(self.parameters['ezstitch_temp_dir']) except: warning_message('Error while deleting directory') return else: return if os.path.exists(self.parameters['ezstitch_output_dir']) and \ len(os.listdir(self.parameters['ezstitch_output_dir'])) > 0: qm = QMessageBox() rep = qm.question(self, '', "Output dir is not empty. Is it safe to delete it?", qm.Yes | qm.No) if rep == qm.Yes: try: rmtree(self.parameters['ezstitch_output_dir']) except: warning_message('Error while deleting directory') return else: return print("======= Begin Stitching =======") # Interpolate overlapping regions and equalize intensity if self.parameters['ezstitch_stitch_type'] == 0: main_sti_mp(self.parameters) # Concatenate only elif self.parameters['ezstitch_stitch_type'] == 1: main_conc_mp(self.parameters) # Half acquisition mode elif self.parameters['ezstitch_stitch_type'] == 2: main_360_mp_depth1(self.parameters['ezstitch_input_dir'], self.parameters['ezstitch_output_dir'], self.parameters['ezstitch_axis_of_rotation'], 0) if os.path.isdir(self.parameters['ezstitch_output_dir']): params_file_path = os.path.join(self.parameters['ezstitch_output_dir'], 'ezmview_params.yaml') params.save_parameters(self.parameters, params_file_path) print("==== Waiting for Next Task ====") def import_parameters_button_pressed(self): LOG.debug("Import params button clicked") dir_explore = QFileDialog(self) params_file_path = dir_explore.getOpenFileName(filter="*.yaml") try: file_in = open(params_file_path[0], 'r') new_parameters = yaml.load(file_in, Loader=yaml.FullLoader) if self.update_parameters(new_parameters) == 0: print("Parameters file loaded from: " + str(params_file_path[0])) except FileNotFoundError: print("You need to select a valid input file") def save_parameters_button_pressed(self): LOG.debug("Save params button clicked") dir_explore = QFileDialog(self) params_file_path = dir_explore.getSaveFileName(filter="*.yaml") garbage, file_name = os.path.split(params_file_path[0]) file_extension = os.path.splitext(file_name) # If the user doesn't enter the .yaml extension then append it to filepath if file_extension[-1] == "": file_path = params_file_path[0] + ".yaml" else: file_path = params_file_path[0] try: file_out = open(file_path, 'w') yaml.dump(self.parameters, file_out) print("Parameters file saved at: " + str(file_path)) except FileNotFoundError: print("You need to select a directory and use a valid file name") ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1683790582.0 ufo-tofu-0.13.0/tofu/ez/GUI/__init__.py0000664000175000017500000000000100000000000017751 0ustar00tomastomas00000000000000 ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414600.0 ufo-tofu-0.13.0/tofu/ez/GUI/ezufo_launcher.py0000664000175000017500000002661200000000000021243 0ustar00tomastomas00000000000000import logging import os import sys from PyQt5 import QtWidgets as qtw from tofu.ez.GUI.Main.centre_of_rotation import CentreOfRotationGroup from tofu.ez.GUI.Main.filters import FiltersGroup from tofu.ez.GUI.Advanced.ffc import FFCGroup from tofu.ez.GUI.Main.phase_retrieval import PhaseRetrievalGroup from tofu.ez.GUI.Main.region_and_histogram import ROIandHistGroup from tofu.ez.GUI.Main.config import ConfigGroup from tofu.ez.main import clean_tmp_dirs from tofu.ez.GUI.image_viewer import ImageViewerGroup from tofu.ez.params import EZVARS from tofu.config import SECTIONS from tofu.ez.util import load_values_from_ezdefault from tofu.ez.GUI.Advanced.advanced import AdvancedGroup from tofu.ez.GUI.Advanced.optimization import OptimizationGroup from tofu.ez.GUI.Advanced.nlmdn import NLMDNGroup from tofu.ez.GUI.Stitch_tools_tab.ez_360_multi_stitch_qt import MultiStitch360Group from tofu.ez.GUI.Stitch_tools_tab.ezstitch_qt import EZStitchGroup from tofu.ez.GUI.Stitch_tools_tab.ezmview_qt import EZMViewGroup from tofu.ez.GUI.Stitch_tools_tab.ez_360_overlap_qt import Overlap360Group from tofu.ez.GUI.login_dialog import Login from tofu.ez.GUI.Main.batch_process import BatchProcessGroup from tofu.ez.GUI.Stitch_tools_tab.auto_horizontal_stitch_gui import AutoHorizontalStitchGUI LOG = logging.getLogger(__name__) class GUI(qtw.QWidget): """ Creates main GUI """ def __init__(self, *args, **kwargs): super(GUI, self).__init__(*args, **kwargs) self.setWindowTitle("EZ-UFO") self.setStyleSheet("font: 10pt; font-family: Arial") # initialize dictionary entries load_values_from_ezdefault(EZVARS) load_values_from_ezdefault(SECTIONS) # Call login dialog # self.login_parameters = {} # QTimer.singleShot(0, self.login) # Initialize tab screen self.tabs = qtw.QTabWidget() self.tab1 = qtw.QWidget() self.tab2 = qtw.QWidget() self.tab3 = qtw.QWidget() self.tab4 = qtw.QWidget() self.tab5 = qtw.QWidget() self.tab6 = qtw.QWidget() # Create and setup classes for each section of GUI # Main Tab self.config_group = ConfigGroup() self.config_group.load_values() self.centre_of_rotation_group = CentreOfRotationGroup() self.centre_of_rotation_group.load_values() self.filters_group = FiltersGroup() self.filters_group.load_values() self.ffc_group = FFCGroup() self.ffc_group.load_values() self.phase_retrieval_group = PhaseRetrievalGroup() self.phase_retrieval_group.load_values() self.binning_group = ROIandHistGroup() self.binning_group.load_values() # Image Viewer self.image_group = ImageViewerGroup() # Advanced Tab self.advanced_group = AdvancedGroup() self.advanced_group.load_values() self.optimization_group = OptimizationGroup() self.optimization_group.load_values() self.nlmdn_group = NLMDNGroup() self.nlmdn_group.load_values() # Stitch_tools_tab Tab # ----((P)Completed up to here) ----# self.multi_stitch_group = MultiStitch360Group() self.multi_stitch_group.init_values() self.ezmview_group = EZMViewGroup() self.ezmview_group.init_values() self.ezstitch_group = EZStitchGroup() self.ezstitch_group.init_values() self.overlap_group = Overlap360Group() self.overlap_group.init_values() ####################################################### self.set_layout() self.resize(0, 0) # window to minimum size # When new settings are imported signal is sent and this catches it to update params for each GUI object self.config_group.signal_update_vals_from_params.connect(self.update_values) # When RECO is done send signal from config self.config_group.signal_reco_done.connect(self.switch_to_image_tab) # To pass directory names from config tab to stitch tab when button pressed self.multi_stitch_group.get_fdt_names_on_stitch_pressed.connect(self.config_group.set_fdt_names) self.overlap_group.get_fdt_names_on_stitch_pressed.connect(self.config_group.set_fdt_names) # To pass RR params from filters section to 360-search tab when button pressed self.overlap_group.get_RR_params_on_start_pressed.connect( self.filters_group.set_ufoRR_params_for_360_axis_search) finish = qtw.QAction("Quit", self) finish.triggered.connect(self.closeEvent) self.show() def set_layout(self): """ Set the layout of groups/tabs for the overall application layout """ layout = qtw.QVBoxLayout(self) main_layout = qtw.QGridLayout() main_layout.addWidget(self.centre_of_rotation_group, 0, 0) main_layout.addWidget(self.filters_group, 0, 1) main_layout.addWidget(self.phase_retrieval_group, 1, 0) main_layout.addWidget(self.binning_group, 1, 1) main_layout.addWidget(self.config_group, 2, 0, 2, 0) image_layout = qtw.QGridLayout() image_layout.addWidget(self.image_group, 0, 0) advanced_layout = qtw.QGridLayout() advanced_layout.addWidget(self.ffc_group, 0, 0) advanced_layout.addWidget(self.advanced_group, 1, 0) advanced_layout.addWidget(self.optimization_group, 1, 1) advanced_layout.addWidget(self.nlmdn_group, 2, 0) helpers_layout = qtw.QGridLayout() helpers_layout.addWidget(self.ezmview_group, 0, 0) helpers_layout.addWidget(self.overlap_group, 0, 1) helpers_layout.addWidget(self.multi_stitch_group, 1, 0) helpers_layout.addWidget(self.ezstitch_group, 1, 1) # stitching2_layout = qtw.QGridLayout() # stitching2_layout.addWidget(self.auto_horizontal_stitch, 0, 0) # # batch_tools_layout = qtw.QGridLayout() # batch_tools_layout.addWidget(self.batch_process_group, 0, 0) # Add tabs self.tabs.addTab(self.tab1, "Main") self.tabs.addTab(self.tab2, "Advanced") self.tabs.addTab(self.tab3, "Stitching tools 1") self.tabs.addTab(self.tab4, "Image Viewer") # self.tabs.addTab(self.tab5, "Stitching Tools 2") # self.tabs.addTab(self.tab6, "Batch Tools") # Create main tab self.tab1.layout = main_layout self.tab1.setLayout(self.tab1.layout) # Create image tab self.tab4.layout = image_layout self.tab4.setLayout(self.tab4.layout) # Create advanced tab self.tab2.layout = advanced_layout self.tab2.setLayout(self.tab2.layout) # Create helpers tab self.tab3.layout = helpers_layout self.tab3.setLayout(self.tab3.layout) # # Create stitching2 tab # self.tab5.layout = stitching2_layout # self.tab5.setLayout(self.tab5.layout) # # # Create batch tools tab # self.tab6.layout = batch_tools_layout # self.tab6.setLayout(self.tab6.layout) # Add tabs to widget layout.addWidget(self.tabs) self.setLayout(layout) def update_values(self): """ Updates displayed values when loaded in from external .yaml file of parameters """ LOG.debug("Update Values from dictionary entries") self.centre_of_rotation_group.load_values() self.filters_group.load_values() self.ffc_group.load_values() self.phase_retrieval_group.load_values() self.binning_group.load_values() self.config_group.load_values() self.nlmdn_group.load_values() self.advanced_group.load_values() self.optimization_group.load_values() def switch_to_image_tab(self): """ Function is called after reconstruction when checkbox "Load images and open viewer after reconstruction" is enabled Automatically loads images from the output reconstruction directory for viewing """ if EZVARS['inout']['open-viewer']['value'] is True: LOG.debug("Switch to Image Tab") self.tabs.setCurrentWidget(self.tab2) if os.path.isdir(str(EZVARS['inout']['output-dir']['value'] + '/sli')): files = os.listdir(str(EZVARS['inout']['output-dir']['value'] + '/sli')) #Start thread here to load images ##CHECK IF ONLY SINGLE IMAGE THEN USE OPEN IMAGE -- OTHERWISE OPEN STACK if len(files) == 1: print("Only one file in {}: Opening single image {}". format(EZVARS['inout']['output-dir']['value'] + '/sli', files[0])) filePath = str(EZVARS['inout']['output-dir']['value'] + '/sli/' + str(files[0])) self.image_group.open_image_from_filepath(filePath) else: print("Multiple files in {}: Opening stack of images". format(str(EZVARS['inout']['output-dir']['value'] + '/sli'))) self.image_group.open_stack_from_path( str(EZVARS['inout']['output-dir']['value'] + '/sli')) else: print("No output directory found") def closeEvent(self, event): """ Creates verification message box Cleans up temporary directories when user quits application """ logging.debug("QUIT") reply = qtw.QMessageBox.question(self, 'Quit', 'Are you sure you want to quit?', qtw.QMessageBox.Yes | qtw.QMessageBox.No, qtw.QMessageBox.No) if reply == qtw.QMessageBox.Yes: # remove all directories with projections clean_tmp_dirs(EZVARS['inout']['tmp-dir']['value'], self.config_group.get_fdt_names()) # remove axis-search dir too tmp = os.path.join(EZVARS['inout']['tmp-dir']['value'], 'axis-search') event.accept() else: event.ignore() def login(self): login_dialog = Login(self.login_parameters) if login_dialog.exec_() != qtw.QDialog.Accepted: self.exit() else: #self.file_writer_group.root_dir_entry.setText(self.login_parameters['expdir']) self.config_group.input_dir_entry.setText(self.login_parameters['expdir'] + "/raw") self.config_group.set_input_dir() self.config_group.output_dir_entry.setText(self.login_parameters['expdir'] + "/rec") self.config_group.set_output_dir() ''' td = date.today() tdstr = "{}.{}.{}".format(td.year, td.month, td.day) logfname = os.path.join(self.login_parameters['expdir'], 'exp-log-' + tdstr + '.log') if self.login_parameters.has_key('project'): logfname = os.path.join(self.login_parameters['expdir'], '{}-log-{}-{}.log'. format(self.login_parameters['project'], self.login_parameters['bl'], tdstr)) try: open(logfname, 'a').close() except: warning_message('Cannot create log file in the selected directory. \n' 'Check permissions and restart.') self.exit() ''' def exit(self): self.close() def main_qt(args=None): app = qtw.QApplication(sys.argv) window = GUI() sys.exit(app.exec_()) if __name__ == "__main__": main_qt() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1683790733.0 ufo-tofu-0.13.0/tofu/ez/GUI/image_viewer.py0000664000175000017500000003575200000000000020702 0ustar00tomastomas00000000000000import os import logging import pyqtgraph as pg import numpy as np import tifffile from PyQt5.QtWidgets import ( QPushButton, QGroupBox, QLabel, QDoubleSpinBox, QRadioButton, QScrollBar, QVBoxLayout, QGridLayout, QFileDialog, QMessageBox, ) from PyQt5.QtCore import Qt import tofu.ez.image_read_write as image_read_write #TODO Integrate axis search tab ob tofu gui into this interface LOG = logging.getLogger(__name__) class ImageViewerGroup(QGroupBox): def __init__(self): super().__init__() #TODO: initialize on every opening with explicit data type #mmatching the data format being opened. #must check that there is enough RAM before loading!! self.tiff_arr = np.empty([0, 0, 0]) # float32 self.img_arr = np.empty([0, 0]) self.bit_depth = 32 self.open_file_button = QPushButton("Open Image File") self.open_file_button.clicked.connect(self.open_image_from_file) self.open_file_button.setStyleSheet("background-color: lightgrey; font: 11pt") self.open_stack_button = QPushButton("Open Image Stack") self.open_stack_button.clicked.connect(self.open_stack_from_directory) self.open_stack_button.setStyleSheet("background-color: lightgrey; font: 11pt") self.save_file_button = QPushButton("Save Image File") self.save_file_button.clicked.connect(self.save_image_to_file) self.save_file_button.setStyleSheet("background-color: lightgrey; font: 11pt") self.save_stack_button = QPushButton("Save Image Stack") self.save_stack_button.clicked.connect(self.save_stack_to_directory) self.save_stack_button.setStyleSheet("background-color: lightgrey; font: 11pt") self.open_big_tiff_button = QPushButton("Open BigTiff") self.open_big_tiff_button.clicked.connect(self.open_big_tiff) self.open_big_tiff_button.setStyleSheet("background-color: lightgrey; font: 11pt") self.save_big_tiff_button = QPushButton("Save BigTiff") self.save_big_tiff_button.clicked.connect(self.save_stack_to_big_tiff) self.save_big_tiff_button.setStyleSheet("background-color: lightgrey; font: 11pt") self.save_8bit_rButton = QRadioButton() self.save_8bit_rButton.setText("Save as 8-bit") self.save_8bit_rButton.clicked.connect(self.set_8bit) self.save_8bit_rButton.setChecked(False) self.save_16bit_rButton = QRadioButton() self.save_16bit_rButton.setText("Save as 16-bit") self.save_16bit_rButton.clicked.connect(self.set_16bit) self.save_16bit_rButton.setChecked(False) self.save_32bit_rButton = QRadioButton() self.save_32bit_rButton.setText("Save as 32-bit") self.save_32bit_rButton.clicked.connect(self.set_32bit) self.save_32bit_rButton.setChecked(True) self.hist_min_label = QLabel("Histogram Min:") self.hist_min_label.setAlignment(Qt.AlignRight | Qt.AlignVCenter) self.hist_max_label = QLabel("Histogram Max:") self.hist_max_label.setAlignment(Qt.AlignRight | Qt.AlignVCenter) self.hist_min_input = QDoubleSpinBox() self.hist_min_input.setDecimals(12) self.hist_min_input.setRange(-10, 10) self.hist_min_input.valueChanged.connect(self.min_spin_changed) self.hist_max_input = QDoubleSpinBox() self.hist_max_input.setDecimals(12) self.hist_max_input.setRange(-10, 10) self.hist_max_input.valueChanged.connect(self.max_spin_changed) self.apply_histogram_button = QPushButton("Apply Histogram to Image Stack") self.apply_histogram_button.clicked.connect(self.apply_histogram_button_clicked) self.image_window = pg.ImageView() self.image_window.ui.histogram.gradient.hide() self.histo = self.image_window.getHistogramWidget() self.scroller = QScrollBar(Qt.Horizontal) self.scroller.orientation() self.scroller.setEnabled(False) self.scroller.valueChanged.connect(self.scroll_changed) self.set_layout() def set_layout(self): vbox = QVBoxLayout() vbox.addWidget(self.save_8bit_rButton) vbox.addWidget(self.save_16bit_rButton) vbox.addWidget(self.save_32bit_rButton) gridbox = QGridLayout() gridbox.addWidget(self.hist_max_label, 0, 0) gridbox.addWidget(self.hist_max_input, 0, 1) gridbox.addWidget(self.hist_min_label, 1, 0) gridbox.addWidget(self.hist_min_input, 1, 1) layout = QGridLayout() layout.addWidget(self.open_file_button, 0, 0) layout.addWidget(self.save_file_button, 1, 0) layout.addWidget(self.open_stack_button, 0, 1) layout.addWidget(self.save_stack_button, 1, 1) layout.addWidget(self.open_big_tiff_button, 0, 2) layout.addWidget(self.save_big_tiff_button, 1, 2) layout.addItem(vbox, 0, 3, 2, 1) layout.addItem(gridbox, 0, 4, 2, 1) layout.addWidget(self.apply_histogram_button, 0, 5) layout.addWidget(self.image_window, 2, 0, 1, 6) layout.addWidget(self.scroller, 4, 0, 1, 5) self.setLayout(layout) self.resize(640, 480) self.show() def scroll_changed(self): """ Updated the currently displayed image based on position of scroll bar :return: None """ self.image_window.setImage(self.tiff_arr[self.scroller.value()].T) def open_image_from_file(self): """ Opens and displays a single image (.tif) specified by the user in the file dialog :return: None """ LOG.debug("Open image button pressed") options = QFileDialog.Options() filePath, _ = QFileDialog.getOpenFileName( self, "Open .tif Image File", "", "Tiff Files (*.tif *.tiff)", options=options ) if filePath: LOG.debug("Import image path: " + filePath) self.img_arr = image_read_write.read_image(filePath) self.image_window.setImage(self.img_arr.T) self.scroller.setEnabled(False) def open_image_from_filepath(self, filePath): """ Opens and displays a single image (.tif) contained in a directory - (used when one slice is reconstructed) :param filePath: Full path and filename :return: None """ LOG.debug("Open image from filepath: " + str(filePath)) if filePath: LOG.debug("Import image path: " + filePath) self.img_arr = image_read_write.read_image(filePath) self.image_window.setImage(self.img_arr.T) self.scroller.setEnabled(False) def save_image_to_file(self): """ Saves the currently displayed image to a file (.tif) specified by the user in the file dialog :return: None """ LOG.debug("Save image to file") options = QFileDialog.Options() filepath, _ = QFileDialog.getSaveFileName( self, "QFileDialog.getSaveFileName()", "", "Tiff Files (*.tif *.tiff)", options=options ) if filepath: LOG.debug(filepath) bit_depth_string = self.check_bit_depth(self.bit_depth) img = self.image_window.imageItem.qimage # https://www.programmersought.com/article/73475006380/ size = img.size() s = img.bits().asstring( size.width() * size.height() * img.depth() // 8 ) # format 0xffRRGGBB arr = np.fromstring(s, dtype=np.uint8).reshape( (size.height(), size.width(), img.depth() // 8) ) image_read_write.write_image( arr.T[0].T, os.path.dirname(filepath), os.path.basename(filepath), bit_depth_string ) def open_stack_from_directory(self): """ Opens all images (.tif) in a directory and displays them. Allows for scrolling through images with slider :return: None """ LOG.debug("Open image stack button pressed") dir_explore = QFileDialog() directory = dir_explore.getExistingDirectory() if directory: try: tiff_list = (".tif", ".tiff") msg = QMessageBox() msg.setIcon(QMessageBox.Information) msg.setWindowTitle("Loading Images...") msg.setText("Loading Images from Directory") msg.show() self.tiff_arr = image_read_write.read_all_images(directory, tiff_list) self.scroller.setRange(0, self.tiff_arr.shape[0] - 1) self.scroller.setEnabled(True) self.image_window.setImage(self.tiff_arr[0].T) msg.close() mid_index = self.tiff_arr.shape[0] // 2 self.scroller.setValue(mid_index) except image_read_write.InvalidDataSetError: print("Invalid Data Set") def open_stack_from_path(self, dir_path: str): """ Read images (.tif) from directory path into RAM as 3D numpy array :param dir_path: Path to directory containing multiple .tiff image files """ LOG.debug("Open stack from path") try: tiff_list = (".tif", ".tiff") msg = QMessageBox() msg.setIcon(QMessageBox.Information) msg.setWindowTitle("Loading Images...") msg.setText("Loading Images from Directory") msg.show() self.tiff_arr = image_read_write.read_all_images(dir_path, tiff_list) self.scroller.setRange(0, self.tiff_arr.shape[0] - 1) self.scroller.setEnabled(True) self.image_window.setImage(self.tiff_arr[0].T) msg.close() mid_index = self.tiff_arr.shape[0] // 2 self.scroller.setValue(mid_index) except image_read_write.InvalidDataSetError: print("Invalid Data Set") def save_stack_to_directory(self): """ Saves images stored in numpy array to individual files (.tif) in directory specified by user dialog Saves these images as BigTiff if checkbox is set to True """ LOG.debug("Save stack to directory button pressed") LOG.debug("Saving with bitdepth: " + str(self.bit_depth)) dir_explore = QFileDialog() directory = dir_explore.getExistingDirectory() LOG.debug("Writing to directory: " + directory) if directory: bit_depth_string = self.check_bit_depth(self.bit_depth) msg = QMessageBox() msg.setIcon(QMessageBox.Information) msg.setWindowTitle("Saving Images...") msg.setText("Saving Images to Directory") msg.show() self.apply_histogram_to_images() image_read_write.write_all_images(self.tiff_arr, directory, bit_depth_string) msg.close() def open_big_tiff(self): """ Opens images stored in a big tiff file (.tif) and displays them. Allows user to view them using scrollbar. :return: None """ LOG.debug("Open big tiff button pressed") options = QFileDialog.Options() filePath, _ = QFileDialog.getOpenFileName( self, "QFileDialog.getOpenFileName()", "", "All Files (*)", options=options ) if filePath: LOG.debug("Import image path: " + filePath) msg = QMessageBox() msg.setIcon(QMessageBox.Information) msg.setWindowTitle("Loading Images...") msg.setText("Loading Images from BigTiff") msg.show() self.tiff_arr = tifffile.imread(filePath).astype(dtype=np.float32) self.scroller.setRange(0, self.tiff_arr.shape[0] - 1) self.scroller.setEnabled(True) self.image_window.setImage(self.tiff_arr[0].T) msg.close() mid_index = self.tiff_arr.shape[0] // 2 self.scroller.setValue(mid_index) def save_stack_to_big_tiff(self): """ Saves the stack of images currently loaded into RAM to a single bigtif file :return: None """ LOG.debug("Save stack to bigtiff button pressed") LOG.debug("Saving with bitdepth: " + str(self.bit_depth)) dir_explore = QFileDialog() options = QFileDialog.Options() filepath, _ = QFileDialog.getSaveFileName( self, "QFileDialog.getSaveFileName()", "", "Tiff Files (*.tif *.tiff)", options=options ) if filepath: msg = QMessageBox() msg.setIcon(QMessageBox.Information) msg.setWindowTitle("Saving Images...") msg.setText("Saving Images to BigTiff") msg.show() # self.apply_histogram_to_images() bit_depth_string = self.check_bit_depth(self.bit_depth) tifffile.imwrite(filepath, self.tiff_arr, bigtiff=True, dtype=bit_depth_string) msg.close() def min_spin_changed(self): """ Changes the levels of the histogram widget if the min spinbox has been changed :return: None """ histo = self.image_window.getHistogramWidget() levels = histo.getLevels() min_level = self.hist_min_input.value() self.image_window.setLevels(min_level, levels[1]) def max_spin_changed(self): """ Changes the levels of the histogram widget if the max spinbox has been changed :return: None """ histo = self.image_window.getHistogramWidget() levels = histo.getLevels() max_level = self.hist_max_input.value() self.image_window.setLevels(levels[0], max_level) def apply_histogram_button_clicked(self): LOG.debug("Apply Histogram Button Clicked") print("Applying histogram to images. This may take a moment.") self.apply_histogram_to_images() def apply_histogram_to_images(self): """ Gets the histogram levels of the currently displayed image and applies them to all images in RAM :return: None """ levels = self.histo.getLevels() self.tiff_arr = np.clip(self.tiff_arr, levels[0], levels[1]) def check_bit_depth(self, bit_depth: int) -> str: """ Returns a string indicating the bitdepth to store the images based on value of bit-depth radio buttons :param bit_depth: :return: String specifying datatype for numpy array """ if bit_depth == 8: return "uint8" elif bit_depth == 16: return "uint16" elif bit_depth == 32: return "uint32" def set_8bit(self): """ Sets value of bit_depth variable based on radio button selection :return: None """ LOG.debug("Set 8-bit") self.bit_depth = 8 def set_16bit(self): """ Sets value of bit_depth variable based on radio button selection :return: None """ LOG.debug("Set 16-bit") self.bit_depth = 16 def set_32bit(self): """ Sets value of bit_depth variable based on radio button selection :return: None """ LOG.debug("Set 32-bit") self.bit_depth = 32 ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1683790582.0 ufo-tofu-0.13.0/tofu/ez/GUI/login_dialog.py0000664000175000017500000001245200000000000020656 0ustar00tomastomas00000000000000import re from PyQt5.QtCore import Qt from PyQt5.QtWidgets import ( QDialog, QLineEdit, QPushButton, QLabel, QGridLayout, QFileDialog, QComboBox, ) from tofu.ez.GUI.message_dialog import error_message import os class Login(QDialog): def __init__(self, login_parameters_dict, **kwargs): super(Login, self).__init__(**kwargs) # Pass a method from main GUI self.login_parameters_dict = login_parameters_dict self.setWindowTitle("USER LOGIN") self.setWindowModality(Qt.ApplicationModal) self.setAttribute(Qt.WA_DeleteOnClose) self.welcome_label = QLabel() self.welcome_label.setText("Welcome to BMIT!") self.prompt_label_bl = QLabel() self.prompt_label_bl.setText("Please select the beamline and project:") self.bl_label = QLabel() self.bl_label.setText("Beamline:") self.bl_entry = QComboBox() self.bl_entry.addItems(["BM", "ID"]) self.proj_label = QLabel() self.proj_label.setText("Project:") self.proj_entry = QLineEdit() self.prompt_label_expdir = QLabel() self.prompt_label_expdir.setText("OR select the path to the working directory") self.expdir_entry = QLineEdit() # self.expdir_entry.setText("/data/gui-test") self.expdir_entry.setReadOnly(True) self.expdir_select_button = QPushButton("...") self.expdir_select_button.clicked.connect(self.select_expdir_func) self.login_button = QPushButton("LOGIN") self.login_button.clicked.connect(self.on_login_button_clicked) self.set_layout() def set_layout(self): layout = QGridLayout() self.welcome_label.setAlignment(Qt.AlignCenter) self.prompt_label_bl.setAlignment(Qt.AlignCenter) self.prompt_label_expdir.setAlignment(Qt.AlignCenter) layout.addWidget(self.welcome_label, 0, 0, 1, 2) layout.addWidget(self.prompt_label_bl, 1, 0, 1, 2) layout.addWidget(self.bl_label, 2, 0, 1, 1) layout.addWidget(self.bl_entry, 2, 1, 1, 1) layout.addWidget(self.proj_label, 3, 0, 1, 1) layout.addWidget(self.proj_entry, 3, 1, 1, 1) layout.addWidget(self.prompt_label_expdir, 4, 0, 1, 2) layout.addWidget(self.expdir_entry, 5, 0, 1, 1) layout.addWidget(self.expdir_select_button, 5, 1, 1, 1) layout.addWidget(self.login_button, 6, 0, 1, 2) layout.setSpacing(15) layout.setContentsMargins(25, 25, 25, 25) self.setLayout(layout) def select_expdir_func(self): options = QFileDialog.Options() options |= QFileDialog.DontUseNativeDialog root_dir = QFileDialog.getExistingDirectory( self, "Select working directory", "/data/gui-test", options=options ) if root_dir: self.expdir_entry.setText(root_dir) def uppercase_project_entry(self): self.proj_entry.setText(self.proj_entry.text().upper()) def strip_spaces_from_user_entry(self): self.user_entry.setText(self.user_entry.text().replace(" ", "")) @property def project_name(self): return self.proj_entry.text() @property def user_name(self): return self.user_entry.text() @property def expdir_name(self): return self.expdir_entry.text() @property def bl_name(self): return self.bl_entry.currentText() def validate_entries(self): self.uppercase_project_entry() # self.strip_spaces_from_user_entry() project_valid = bool(re.match(r"^[0-9]{2}[A-Z][0-9]{5}$", self.project_name)) # username_valid = bool(re.match(r"^[a-zA-Z0-9]*$", self.user_name)) # return project_valid, username_valid return project_valid def validate_dir(self, pdr): return os.access(pdr, os.W_OK) def on_login_button_clicked(self): # project_valid, username_valid = self.validate_entries() if self.project_name != "": prj_dir_name = os.path.join( "/beamlinedata/BMIT/projects/prj" + self.project_name, "raw" ) project_valid = self.validate_entries() can_write = self.validate_dir(prj_dir_name) if project_valid and can_write: self.login_parameters_dict.update({"bl": self.bl_name}) self.login_parameters_dict.update({"project": self.project_name}) # add fileExistsError exception later in Py3 self.login_parameters_dict.update({"expdir": prj_dir_name}) self.accept() # elif not username_valid: # error_message("Username should be alpha-numeric ") elif not project_valid: error_message( "The project should be in format: CCTNNNNN, \n" "where CC is cycle number, " "T is one-letter type, " "and NNNNN is project number" ) elif not can_write: error_message("Cannot write in the project directory") elif self.expdir_name != "": if self.validate_dir(self.expdir_entry.text()): self.login_parameters_dict.update({"expdir": self.expdir_name}) self.accept() else: error_message("Cannot write in the selected directory") ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1683790582.0 ufo-tofu-0.13.0/tofu/ez/GUI/message_dialog.py0000664000175000017500000000066600000000000021176 0ustar00tomastomas00000000000000from PyQt5.QtWidgets import QMessageBox def message_dialog(window_title, message_text): alert = QMessageBox() alert.setWindowTitle(window_title) alert.setText(message_text) alert.exec_() def error_message(message_text): message_dialog("Error", message_text) def warning_message(message_text): message_dialog("Warning", message_text) def info_message(message_text): message_dialog("Info", message_text) ././@PaxHeader0000000000000000000000000000003400000000000011452 xustar000000000000000028 mtime=1698416097.7697759 ufo-tofu-0.13.0/tofu/ez/Helpers/0000775000175000017500000000000000000000000016627 5ustar00tomastomas00000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1683790582.0 ufo-tofu-0.13.0/tofu/ez/Helpers/__init__.py0000664000175000017500000000000000000000000020726 0ustar00tomastomas00000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414563.0 ufo-tofu-0.13.0/tofu/ez/Helpers/find_360_overlap.py0000664000175000017500000001522100000000000022242 0ustar00tomastomas00000000000000""" This script takes as input a CT scan that has been collected in "half-acquisition" mode and produces a series of reconstructed slices, each of which are generated by cropping and concatenating opposing projections together over a range of "overlap" values (i.e. the pixel column at which the images are cropped and concatenated). The objective is to review this series of images to determine the pixel column at which the axis of rotation is located (much like the axis search function commonly used in reconstruction software). """ import os import numpy as np import tifffile from tofu.ez.image_read_write import TiffSequenceReader from tofu.ez.params import EZVARS from tofu.ez.Helpers.stitch_funcs import findCTdirs, stitch_float32_output from tofu.util import get_filenames, get_image_shape from tofu.ez.ufo_cmd_gen import get_filter2d_sinos_cmd from tofu.ez.find_axis_cmd_gen import evaluate_images_simp def extract_row(dir_name, row): tsr = TiffSequenceReader(dir_name) tmp = tsr.read(0) (N, M) = tmp.shape if (row < 0) or (row > N): row = N//2 num_images = tsr.num_images if num_images % 2 == 1: print(f"odd number of images ({num_images}) in {dir_name}, " f"discarding the last one before stitching pairs") num_images-=1 A = np.empty((num_images, M), dtype=np.uint16) for i in range(num_images): A[i, :] = tsr.read(i)[row, :] tsr.close() return A def find_overlap(parameters): print("Finding CTDirs...") ctdirs, lvl0 = findCTdirs(parameters['360overlap_input_dir'], EZVARS['inout']['tomo-dir']['value']) print(ctdirs) dirdark = EZVARS['inout']['darks-dir']['value'] dirflats = EZVARS['inout']['flats-dir']['value'] dirflats2 = EZVARS['inout']['flats2-dir']['value'] if EZVARS['inout']['shared-flatsdarks']['value']: dirdark = EZVARS['inout']['path2-shared-darks']['value'] dirflats = EZVARS['inout']['path2-shared-flats']['value'] dirflats2 = EZVARS['inout']['path2-shared-flats2']['value'] # concatenate images with various overlap and generate sinograms for ctset in ctdirs: print("Working on ctset:" + str(ctset)) index_dir = os.path.basename(os.path.normpath(ctset)) # loading: try: row_flat = np.mean(extract_row( os.path.join(ctset, dirflats), parameters['360overlap_row'])) except: print(f"Problem loading flats in {ctset}") continue try: row_dark = np.mean(extract_row( os.path.join(ctset, dirdark), parameters['360overlap_row'])) except: print(f"Problem loading darks in {ctset}") continue try: row_tomo = extract_row( os.path.join(ctset, EZVARS['inout']['tomo-dir']['value']), parameters['360overlap_row']) except: print(f"Problem loading projections from " f"{os.path.join(ctset, EZVARS['inout']['tomo-dir']['value'])}") continue row_flat2 = None tmpstr = os.path.join(ctset, dirflats2) if os.path.exists(tmpstr): try: row_flat2 = np.mean(extract_row(tmpstr, parameters['360overlap_row'])) except: print(f"Problem loading flats2 in {ctset}") (num_proj, M) = row_tomo.shape print('Flat-field correction...') # Flat-correction tmp_flat = np.tile(row_flat, (num_proj, 1)) if row_flat2 is not None: tmp_flat2 = np.tile(row_flat2, (num_proj, 1)) ramp = np.linspace(0, 1, num_proj) ramp = np.transpose(np.tile(ramp, (M, 1))) tmp_flat = tmp_flat * (1-ramp) + tmp_flat2 * ramp del ramp, tmp_flat2 tmp_dark = np.tile(row_dark, (num_proj, 1)) tomo_ffc = -np.log((row_tomo - tmp_dark)/np.float32(tmp_flat - tmp_dark)) del row_tomo, row_dark, row_flat, tmp_flat, tmp_dark np.nan_to_num(tomo_ffc, copy=False, nan=0.0, posinf=0.0, neginf=0.0) # create interpolated sinogram of flats on the # same row as we use for the projections, then flat/dark correction print('Creating stitched sinograms...') sin_tmp_dir = os.path.join(parameters['360overlap_temp_dir'], index_dir, 'sinos') print(sin_tmp_dir) os.makedirs(sin_tmp_dir) for axis in range(parameters['360overlap_lower_limit'], parameters['360overlap_upper_limit']+parameters['360overlap_increment'], parameters['360overlap_increment']): cro = parameters['360overlap_upper_limit'] - axis if axis > M // 2: cro = axis - parameters['360overlap_lower_limit'] A = stitch_float32_output( tomo_ffc[: num_proj//2, :], tomo_ffc[num_proj//2:, ::-1], axis, cro) print(A.shape[1]) tifffile.imwrite(os.path.join( sin_tmp_dir, 'sin-axis-' + str(axis).zfill(4) + '.tif'), A.astype(np.float32)) # perform reconstructions for each sinogram and save to output folder print('Reconstructing slices...') #reco_axis = M-parameters['360overlap_upper_limit'] # equivalently half-width sin_width = get_image_shape(get_filenames(sin_tmp_dir)[0])[-1] sin_height = get_image_shape(get_filenames(sin_tmp_dir)[0])[-2] if parameters['360overlap_doRR']: print("Applying ring removal filter") tmpdir = os.path.join(parameters['360overlap_temp_dir'], index_dir) rrcmd = get_filter2d_sinos_cmd(tmpdir, EZVARS['RR']['sx']['value'], EZVARS['RR']['sy']['value'], sin_height, sin_width) print(rrcmd) os.system(rrcmd) sin_tmp_dir = os.path.join(parameters['360overlap_temp_dir'], index_dir, 'sinos-filt') outname = os.path.join(os.path.join( parameters['360overlap_output_dir'], f"{index_dir}-sli.tif")) cmd = f'tofu tomo --axis {sin_width//2} --sinograms {sin_tmp_dir}' cmd +=' --output '+os.path.join(outname) os.system(cmd) points, maximum = evaluate_images_simp(outname, "msag") print(f"Estimated overlap:" f"{parameters['360overlap_lower_limit'] + parameters['360overlap_increment'] * maximum}") print("Finished processing: " + str(index_dir)) print("****************************************") #shutil.rmtree(parameters['360overlap_temp_dir']) print("Finished processing of all subdirectories in " + str(parameters['360overlap_input_dir'])) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414563.0 ufo-tofu-0.13.0/tofu/ez/Helpers/halfacqmode-mpi-stitch.py0000664000175000017500000000253300000000000023527 0ustar00tomastomas00000000000000#!/usr/bin/env python3 import sys import time import tifffile from mpi4py import MPI from tofu.ez.Helpers.stitch_funcs import stitch from tofu.ez.image_read_write import TiffSequenceReader path_to_script, ax, crop, bigtif_name, out_fmt = sys.argv comm = MPI.COMM_WORLD rank = comm.Get_rank() size = comm.Get_size() # t0 = time.time() #print(f"{t0:.2f}: Private {rank} of {size} is at your service") tfs = TiffSequenceReader(bigtif_name) npairs = tfs.num_images//2 n_my_pairs = int(npairs/size) + (1 if npairs%size > rank else 0) #print(f'Private {rank} got {n_my_pairs} pairs to process out of total {npairs}') for pair_number in range(n_my_pairs): idx = rank + pair_number * size # print(f'Private {rank} processing pair {idx} - {idx+npairs}') first = tfs.read(idx) second = tfs.read(idx+npairs)[:, ::-1] stitched = stitch(first, second, int(ax), int(crop)) tifffile.imwrite(out_fmt.format(idx), stitched) tfs.close() #print(f"Private {rank} stitched {n_my_pairs} pairs in {time.time()-t0:.2f} s! Am I first?") # Important - release communicator! try: parent_comm = comm.Get_parent() parent_comm.Disconnect() except MPI.Exception: pass # # def main(): # comm = MPI.COMM_WORLD # size = comm.Get_size() # rank = comm.Get_rank() # print(f"I am {rank} of {size}") # # if __name__ == "__main__": # main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414563.0 ufo-tofu-0.13.0/tofu/ez/Helpers/mview_main.py0000664000175000017500000000743200000000000021342 0ustar00tomastomas00000000000000#!/bin/python import os import numpy from tofu.util import get_filenames import re def check_folders(p, noflats2): if not os.path.exists(p): os.makedirs(p) tmp = p + "/darks" if not os.path.exists(tmp): os.makedirs(tmp) tmp = p + "/flats" if not os.path.exists(tmp): os.makedirs(tmp) if noflats2 == False: tmp = p + "/flats2" if not os.path.exists(tmp): os.makedirs(tmp) tmp = p + "/tomo" if not os.path.exists(tmp): os.makedirs(tmp) def rename_Andor(indir): names = get_filenames(os.path.join(indir, "*.tif")) maxnum = re.match(".*?([0-9]+)$", names[0][:-4]).group(1) n_dgts = len(maxnum) trnc_len = n_dgts + 4 prefix = names[0][:-trnc_len] maxnum = int(maxnum) for name in names: num = int(re.match(".*?([0-9]+)$", name[:-4]).group(1)) maxnum = num if (num > maxnum) else maxnum n_dgts = len(str(maxnum)) lin_fmt = prefix + "{:0" + str(n_dgts) + "}.tif" for name in names: num = re.match(".*?([0-9]+)$", name[:-4]).group(1) if name == lin_fmt.format(int(num)): continue else: cmd = "mv {} {}".format(name, lin_fmt.format(int(num))) os.system(cmd) def main_prep(params): if params['ezmview_no_zero_padding']: rename_Andor(params['ezmview_input_dir']) frames = get_filenames(os.path.join(params['ezmview_input_dir'], "*.tif")) nframes = len(frames) if nframes == 0: tmp = "Check INPUT directory: there are no tif files there" raise ValueError(tmp) # replace first frame with the second to get rid of # corrupted first file in the PCO Edge sequencies # Happened long ago in CamWare ... cmd = "rm {}; cp {} {}".format(frames[0], frames[1], frames[0]) os.system(cmd) FFinterval = params["ezmview_num_projections"] int_tot = params['ezmview_num_sets'] # (args.nproj/FFinterval)*args.nviews #int_1view = 1.0 # args.nproj/FFinterval #remainder of a more general FF correction files_in_int = params['ezmview_num_flats'] + params['ezmview_num_darks'] + FFinterval files_input = files_in_int * int_tot if params['ezmview_flats2'] == False: files_input += params['ezmview_num_flats'] #+ params['ezmview_num_darks'] if files_input != nframes: tmp = ( "Sequence length (found {} files) does not match ".format(nframes) + "one calculated from input parameters " + "(expected {} files)".format(files_input) ) raise ValueError(tmp) for i in range(params['ezmview_num_sets']): if params['ezmview_num_sets'] > 1: pout = os.path.join(params['ezmview_input_dir'], "z{:02d}".format(i)) else: pout = params['ezmview_input_dir'] check_folders(pout, params['ezmview_flats2']) # offset to heading flats and darks o = i * files_in_int for i in range(params['ezmview_num_flats']): cmd = "mv {} {}/flats/".format(frames[o + i], pout) os.system(cmd) # print(cmd) o += params['ezmview_num_flats'] for i in range(params['ezmview_num_darks']): cmd = "mv {} {}/darks/".format(frames[o + i], pout) os.system(cmd) # print(cmd) o += params['ezmview_num_darks'] for i in range(params["ezmview_num_projections"]): cmd = "mv {} {}/tomo/".format(frames[o + i], pout) os.system(cmd) # print(cmd) o += params["ezmview_num_projections"] if params['ezmview_flats2']: continue for i in range(params['ezmview_num_flats']): cmd = "cp {} {}/flats2/".format(frames[o + i], pout) os.system(cmd) # print(cmd) print("========== Done ==========") ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414563.0 ufo-tofu-0.13.0/tofu/ez/Helpers/stitch_funcs.py0000664000175000017500000005006400000000000021702 0ustar00tomastomas00000000000000""" Last modified on Apr 1, 2022 @author: sergei gasilov """ import glob import os import shutil import numpy as np import tifffile from tofu.util import read_image, get_image_shape, get_filenames from tofu.ez.image_read_write import TiffSequenceReader, get_image_dtype import multiprocessing as mp from functools import partial import time import sys try: from mpi4py import MPI except ImportError: print("You must install openmpi/mpi4py in order to stitch half acq. mode data", file=sys.stderr) from tofu.ez.params import EZVARS def findCTdirs(root: str, tomo_name: str): """ Walks directories rooted at "Input ctset" location Appends their absolute path to ctdir if they contain a ctset with same name as "tomo" entry in GUI """ lvl0 = os.path.abspath(root) ctdirs = [] for root, dirs, files in os.walk(lvl0): for name in dirs: if name == tomo_name: ctdirs.append(root) ctdirs.sort() return ctdirs, lvl0 def prepare(parameters, dir_type: int, ctdir: str): """ :param parameters: GUI params :param dir_type 1 if CTDir containing Z00-Z0N slices - 2 if parent directory containing CTdirs each containing Z slices: :param ctdir Name of the ctdir - blank string if not using multiple ctdirs: :return: """ hmin, hmax = 0.0, 0.0 if parameters['ezstitch_clip_histo']: if parameters['ezstitch_histo_min'] == parameters['ezstitch_histo_max']: raise ValueError(' - Define hmin and hmax correctly in order to convert to 8bit') else: hmin, hmax = parameters['ezstitch_histo_min'], parameters['ezstitch_histo_max'] start, stop, step = [int(value) for value in parameters['ezstitch_start_stop_step'].split(',')] if not os.path.exists(parameters['ezstitch_output_dir']): os.makedirs(parameters['ezstitch_output_dir']) Vsteps = sorted(os.listdir(os.path.join(parameters['ezstitch_input_dir'], ctdir))) #determine input data type if dir_type == 1: tmp = os.path.join(parameters['ezstitch_input_dir'], Vsteps[0], parameters['ezstitch_type_image'], '*.tif') tmp = sorted(glob.glob(tmp))[0] elif dir_type == 2: tmp = os.path.join(parameters['ezstitch_input_dir'], ctdir, Vsteps[0], parameters['ezstitch_type_image'], '*.tif') tmp = sorted(glob.glob(tmp))[0] indtype_digit, indtype = get_image_dtype(tmp) if parameters['ezstitch_stitch_orthogonal']: for vstep in Vsteps: if dir_type == 1: in_name = os.path.join(parameters['ezstitch_input_dir'], vstep, parameters['ezstitch_type_image']) out_name = os.path.join(parameters['ezstitch_temp_dir'], vstep, parameters['ezstitch_type_image'], 'sli-%04i.tif') elif dir_type == 2: in_name = os.path.join(parameters['ezstitch_input_dir'], ctdir, vstep, parameters['ezstitch_type_image']) out_name = os.path.join(parameters['ezstitch_temp_dir'], ctdir, vstep, parameters['ezstitch_type_image'], 'sli-%04i.tif') cmd = 'tofu sinos --projections {} --output {}'.format(in_name, out_name) cmd += " --y {} --height {} --y-step {}".format(start, stop-start, step) cmd += " --output-bytes-per-file 0" if indtype_digit == '8' or indtype_digit == '16': cmd += f" --output-bitdepth {indtype_digit}" print(cmd) os.system(cmd) time.sleep(10) indir = parameters['ezstitch_temp_dir'] else: indir = parameters['ezstitch_input_dir'] return indir, hmin, hmax, start, stop, step, indtype def exec_sti_mp(start, step, N, Nnew, Vsteps, indir, dx, M, parameters, ramp, hmin, hmax, indtype, ctdir, dir_type, j): index = start+j*step Large = np.empty((Nnew*len(Vsteps)+dx, M), dtype=np.float32) for i, vstep in enumerate(Vsteps[:-1]): if dir_type == 1: tmp = os.path.join(indir, Vsteps[i], parameters['ezstitch_type_image'], '*.tif') tmp1 = os.path.join(indir, Vsteps[i+1], parameters['ezstitch_type_image'], '*.tif') elif dir_type == 2: tmp = os.path.join(indir, ctdir, Vsteps[i], parameters['ezstitch_type_image'], '*.tif') tmp1 = os.path.join(indir, ctdir, Vsteps[i + 1], parameters['ezstitch_type_image'], '*.tif') if parameters['ezstitch_stitch_orthogonal']: tmp = sorted(glob.glob(tmp))[j] tmp1 = sorted(glob.glob(tmp1))[j] else: tmp = sorted(glob.glob(tmp))[index] tmp1 = sorted(glob.glob(tmp1))[index] first = read_image(tmp) second = read_image(tmp1) # sample moved downwards if parameters['ezstitch_sample_moved_down']: first, second = np.flipud(first), np.flipud(second) k = np.mean(first[N - dx:, :]) / np.mean(second[:dx, :]) second = second * k a, b, c = i*Nnew, (i+1)*Nnew, (i+2)*Nnew Large[a:b, :] = first[:N-dx, :] Large[b:b+dx, :] = np.transpose(np.transpose(first[N-dx:, :])*(1 - ramp) + np.transpose(second[:dx, :]) * ramp) Large[b+dx:c+dx, :] = second[dx:, :] pout = os.path.join(parameters['ezstitch_output_dir'], ctdir, parameters['ezstitch_type_image']+'-sti-{:>04}.tif'.format(index)) if not parameters['ezstitch_clip_histo']: tifffile.imwrite(pout, Large.astype(indtype)) else: Large = 255.0/(hmax-hmin) * (np.clip(Large, hmin, hmax) - hmin) tifffile.imwrite(pout, Large.astype(np.uint8)) def main_sti_mp(parameters): #Check whether indir is CTdir or parent containing CTdirs #if indir + some z00 subdir + sli + *.tif does not exist then use original subdirs = sorted(os.listdir(parameters['ezstitch_input_dir'])) if os.path.exists(os.path.join(parameters['ezstitch_input_dir'], subdirs[0], parameters['ezstitch_type_image'])): dir_type = 1 ctdir = "" print(" - Using CT directory containing slices") if parameters['ezstitch_stitch_orthogonal']: print(" - Creating orthogonal sections") indir, hmin, hmax, start, stop, step, indtype = prepare(parameters, dir_type, "") dx = int(parameters['ezstitch_num_overlap_rows']) # second: stitch them Vsteps = sorted(os.listdir(indir)) tmp = glob.glob(os.path.join(indir, Vsteps[0], parameters['ezstitch_type_image'], '*.tif'))[0] first = read_image(tmp) N, M = first.shape Nnew = N - dx ramp = np.linspace(0, 1, dx) J = range((stop - start) // step) pool = mp.Pool(processes=mp.cpu_count()) # ??? IT was OK back in 2.7 but now can crash # if pool size is larger than array being multiprocessed? exec_func = partial(exec_sti_mp, start, step, N, Nnew, \ Vsteps, indir, dx, M, parameters, ramp, hmin, hmax, indtype, ctdir, dir_type) print(" - Adjusting and stitching") # start = time.time() pool.map(exec_func, J) print("========== Done ==========") else: second_subdirs = sorted(os.listdir(os.path.join(parameters['ezstitch_input_dir'], subdirs[0]))) if os.path.exists(os.path.join(parameters['ezstitch_input_dir'], subdirs[0], second_subdirs[0], parameters['ezstitch_type_image'])): print(" - Using parent directory containing CT directories, each of which contains slices") dir_type = 2 #For each subdirectory do the same thing for ctdir in subdirs: print("-> Working on " + str(ctdir)) if not os.path.exists(os.path.join(parameters['ezstitch_output_dir'], ctdir)): os.makedirs(os.path.join(parameters['ezstitch_output_dir'], ctdir)) if parameters['ezstitch_stitch_orthogonal']: print(" - Creating orthogonal sections") indir, hmin, hmax, start, stop, step, indtype = prepare(parameters, dir_type, ctdir) dx = int(parameters['ezstitch_num_overlap_rows']) # second: stitch them Vsteps = sorted(os.listdir(os.path.join(indir, ctdir))) tmp = glob.glob(os.path.join(indir, ctdir, Vsteps[0], parameters['ezstitch_type_image'], '*.tif'))[0] first = read_image(tmp) N, M = first.shape Nnew = N - dx ramp = np.linspace(0, 1, dx) J = range(int((stop - start) / step)) pool = mp.Pool(processes=mp.cpu_count()) exec_func = partial(exec_sti_mp, start, step, N, Nnew, \ Vsteps, indir, dx, M, parameters, ramp, hmin, hmax, indtype, ctdir, dir_type) print(" - Adjusting and stitching") # start = time.time() pool.map(exec_func, J) print("========== Done ==========") # Clear temp directory clear_tmp(parameters) else: print("Invalid input directory") complete_message() def make_buf(tmp, l, a, b): first = read_image(tmp) N, M = first[a:b, :].shape return np.empty((N*l, M), dtype=first.dtype), N, first.dtype def exec_conc_mp(start, step, example_im, l, parameters, zfold, indir, ctdir, j): index = start+j*step Large, N, dtype = make_buf(example_im, l, parameters['ezstitch_first_row'], parameters['ezstitch_last_row']) for i, vert in enumerate(zfold): tmp = os.path.join(indir, ctdir, vert, parameters['ezstitch_type_image'], '*.tif') if parameters['ezstitch_stitch_orthogonal']: fname=sorted(glob.glob(tmp))[j] else: fname=sorted(glob.glob(tmp))[index] frame = read_image(fname)[parameters['ezstitch_first_row']:parameters['ezstitch_last_row'], :] if parameters['ezstitch_sample_moved_down']: Large[i*N:N*(i+1), :] = np.flipud(frame) else: Large[i*N:N*(i+1), :] = frame pout = os.path.join(parameters['ezstitch_output_dir'], ctdir, parameters['ezstitch_type_image']+'-sti-{:>04}.tif'.format(index)) #print "input data type {:}".format(dtype) tifffile.imwrite(pout, Large) def main_conc_mp(parameters): # Check whether indir is CTdir or parent containing CTdirs # if indir + some z00 subdir + sli + *.tif does not exist then use original subdirs = sorted(os.listdir(parameters['ezstitch_input_dir'])) if os.path.exists(os.path.join(parameters['ezstitch_input_dir'], subdirs[0], parameters['ezstitch_type_image'])): dir_type = 1 ctdir = "" print(" - Using CT directory containing slices") if parameters['ezstitch_stitch_orthogonal']: print(" - Creating orthogonal sections") #start = time.time() indir, hmin, hmax, start, stop, step, indtype = prepare(parameters, dir_type, ctdir) subdirs = [dI for dI in os.listdir(parameters['ezstitch_input_dir']) if os.path.isdir(os.path.join(parameters['ezstitch_input_dir'], dI))] zfold = sorted(subdirs) l = len(zfold) tmp = glob.glob(os.path.join(indir, zfold[0], parameters['ezstitch_type_image'], '*.tif')) J = range((stop-start)//step) pool = mp.Pool(processes=mp.cpu_count()) exec_func = partial(exec_conc_mp, start, step, tmp[0], l, parameters, zfold, indir, ctdir) print(" - Concatenating") #start = time.time() pool.map(exec_func, J) #print "Images stitched in {:.01f} sec".format(time.time()-start) print("============ Done ============") else: second_subdirs = sorted(os.listdir(os.path.join(parameters['ezstitch_input_dir'], subdirs[0]))) if os.path.exists(os.path.join(parameters['ezstitch_input_dir'], subdirs[0], second_subdirs[0], parameters['ezstitch_type_image'])): print(" - Using parent directory containing CT directories, each of which contains slices") dir_type = 2 for ctdir in subdirs: print(" == Working on " + str(ctdir) + " ==") if not os.path.exists(os.path.join(parameters['ezstitch_output_dir'], ctdir)): os.makedirs(os.path.join(parameters['ezstitch_output_dir'], ctdir)) if parameters['ezstitch_stitch_orthogonal']: print(" - Creating orthogonal sections") # start = time.time() indir, hmin, hmax, start, stop, step, indtype = prepare(parameters, dir_type, ctdir) zfold = sorted(os.listdir(os.path.join(indir, ctdir))) l = len(zfold) tmp = glob.glob(os.path.join(indir, ctdir, zfold[0], parameters['ezstitch_type_image'], '*.tif')) J = range((stop - start) // step) pool = mp.Pool(processes=mp.cpu_count()) exec_func = partial(exec_conc_mp, start, step, tmp[0], l, parameters, zfold, indir, ctdir) print(" - Concatenating") # start = time.time() pool.map(exec_func, J) # print "Images stitched in {:.01f} sec".format(time.time()-start) print("============ Done ============") #Clear temp directory clear_tmp(parameters) complete_message() ############################## HALF ACQ ############################## def stitch(first, second, axis, crop): h, w = first.shape if axis > w // 2: axis = w - axis first = np.fliplr(first) second = np.fliplr(second) dx = int(2 * axis + 0.5) tmp = np.copy(first) first = second second = tmp result = np.empty((h, 2 * w - dx), dtype=first.dtype) ramp = np.linspace(0, 1, dx) # Mean values of the overlapping regions must match, which corrects flat-field inconsistency # between the two projections # We clip the values in second so that there are no saturated pixel overflow problems k = np.mean(first[:, w - dx:]) / np.mean(second[:, :dx]) second = np.clip(second * k, np.iinfo(np.uint16).min, np.iinfo(np.uint16).max).astype(np.uint16) result[:, :w - dx] = first[:, :w - dx] result[:, w - dx:w] = first[:, w - dx:] * (1 - ramp) + second[:, :dx] * ramp result[:, w:] = second[:, dx:] return result[:, slice(int(crop), int(2*(w - axis) - crop), 1)] ############################## HALF ACQ ############################## def stitch_float32_output(first, second, axis, crop): print(f"Stitching two halves with axis {axis}, cropping by {crop}") h, w = first.shape if axis > w // 2: axis = w - axis first = np.fliplr(first) second = np.fliplr(second) dx = int(2 * axis + 0.5) tmp = np.copy(first) first = second second = tmp result = np.empty((h, 2 * w - dx), dtype=first.dtype) ramp = np.linspace(0, 1, dx) # Mean values of the overlapping regions must match, which corrects flat-field inconsistency # between the two projections # We clip the values in second so that there are no saturated pixel overflow problems k = np.mean(first[:, w - dx:]) / np.mean(second[:, :dx]) result[:, :w - dx] = first[:, :w - dx] result[:, w - dx:w] = first[:, w - dx:] * (1 - ramp) + second[:, :dx] * ramp result[:, w:] = second[:, dx:] * k return result[:, slice(int(crop), int(2*(w - axis) - crop), 1)] def main_360_mp_depth1(indir, outdir, ax, cro): if not os.path.exists(outdir): os.makedirs(outdir) subdirs = [dI for dI in os.listdir(indir) \ if os.path.isdir(os.path.join(indir, dI))] for i, sdir in enumerate(subdirs): print(f"Stitching images in {sdir}") tfs = TiffSequenceReader(os.path.join(indir, sdir)) if tfs.num_images < 2: print("Warning: less than 2 files, skipping this dir") continue else: print(f"{tfs.num_images//2} pairs will be stitched in {sdir}") tfs.close() os.makedirs(os.path.join(outdir, sdir)) out_fmt = os.path.join(outdir, sdir, 'sti-{:>04}.tif') tmp = os.path.dirname(os.path.abspath(__file__)) path_to_script = os.path.join(tmp, "halfacqmode-mpi-stitch.py") if os.path.isfile(path_to_script): tstart = time.time() child_comm = MPI.COMM_WORLD.Spawn( sys.executable, [path_to_script, f"{ax}", f"{cro}", os.path.join(indir, sdir), out_fmt], maxprocs=12) child_comm.Disconnect() print(f"Child finished in {time.time() - tstart} yay!") else: print('Cannot see the script for parallel stitching of bigtiff files') break print("========== Done ==========") def main_360_mp_depth2(parameters): ctdirs, lvl0 = findCTdirs(parameters['360multi_input_dir'], EZVARS['inout']['tomo-dir']['value']) num_sets = len(ctdirs) if num_sets < 1: print(f"Didn't find any CT dirs in the input. Check directory structure and permissions. \n" f"Program expects to see a number of subdirectories in the input each of with \n" f"contains at least one directory with CT projections (currently name set to " f"{EZVARS['inout']['tomo-dir']['value']}. \n"+ f"The tif files in all " \ f" {EZVARS['inout']['tomo-dir']['value']}, " f" {EZVARS['inout']['flats-dir']['value']}, " f" {EZVARS['inout']['darks-dir']['value']} \n" f"subdirectories will be stitched to convert half-acquisition mode scans to ordinary \n" f"180-deg parallel-beam scans") return tmp = len(parameters['360multi_input_dir']) ctdirs_rel_paths = [] for i in range(num_sets): ctdirs_rel_paths.append(ctdirs[i][tmp+1:len(ctdirs[i])]) print(f"Found the {num_sets} directories in the input with relative paths: {ctdirs_rel_paths}") # prepare axis and crop arrays dax = np.round(np.linspace(parameters['360multi_bottom_axis'], parameters['360multi_top_axis'], num_sets)) if parameters['360multi_manual_axis']: #print(parameters['360multi_axis_dict']) dax = np.array(list(parameters['360multi_axis_dict'].values()))[:num_sets] print(f'Overlaps: {dax}') # compute crop: cra = np.max(dax)-dax # Axis on the right ? Must open one file to find out >< tmpname = os.path.join(parameters['360multi_input_dir'], ctdirs_rel_paths[0]) subdirs = [dI for dI in os.listdir(tmpname) if os.path.isdir(os.path.join(tmpname, dI))] M = get_image_shape(get_filenames(os.path.join(tmpname, subdirs[0]))[0])[-1] if np.min(dax) > M//2: cra = dax - np.min(dax) print(f'Crop by: {cra}') for i, ctdir in enumerate(ctdirs): print("================================================================") print(" -> Working On: " + str(ctdir)) print(f" axis position {dax[i]}, margin to crop {cra[i]} pixels") main_360_mp_depth1(ctdir, os.path.join(parameters['360multi_output_dir'], ctdirs_rel_paths[i]), dax[i], cra[i]) # print(ctdir, os.path.join(parameters['360multi_output_dir'], ctdirs_rel_paths[i]), dax[i], cra[i]) def clear_tmp(parameters): tmp_dirs = os.listdir(parameters['ezstitch_temp_dir']) for tmp_dir in tmp_dirs: shutil.rmtree(os.path.join(parameters['ezstitch_temp_dir'], tmp_dir)) def check_last_index(axis_list): """ Return the index of item in list immediately before first 'None' type :param axis_list: :return: the index of last non-None value """ last_index = 0 for index, item in enumerate(axis_list): if item == 'None': last_index = index - 1 return last_index last_index = index return last_index def complete_message(): print(" __.-/|") print(" \\`o_O'") print(" =( )= +-----+") print(" U| | FIN |") print(" /\\ /\\ / | +-----+") print(" ) /^\\) ^\\/ _)\\ |") print(" ) /^\\/ _) \\ |") print(" ) _ / / _) \\___|_") print(" /\\ )/\\/ || | )_)\\___,|))") print("< > |(,,) )__) |") print(" || / \\)___)\\") print(" | \\____( )___) )____") print(" \\______(_______;;;)__;;;)") ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414563.0 ufo-tofu-0.13.0/tofu/ez/RR_external.py0000664000175000017500000001452600000000000020034 0ustar00tomastomas00000000000000#!/usr/bin/env python3 """ Created on Aug 3, 2018 @author: SGasilov Initially it has been simplest median sorting Replaced by non-FFT based methods proposed by Nghia T. Vo and published in https://doi.org/10.1364/OE.26.028396 """ import os import argparse from tofu.util import read_image import numpy as np from tofu.util import get_filenames import multiprocessing as mp from functools import partial from scipy.ndimage import median_filter from scipy.ndimage import binary_dilation from tifffile import imwrite def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--sinos", type=str, help="Input directory") parser.add_argument("--mws", type=int, help="Window size for small rings (sorting algorithm)") parser.add_argument("--mws2", type=int, help="Window size for large rings") parser.add_argument("--snr", type=int, help="Median window size along columns") parser.add_argument("--sort_only", type=int, help="Only sorting or both") return parser.parse_args() def RR_wide_sort(mws, mws2, snr, odir, fname): filt_sin_name = os.path.join(odir, os.path.split(fname)[1]) im = read_image(fname).astype(np.float32) im = remove_large_stripe(im, snr, mws2) im = remove_stripe_based_sorting(im, mws) imwrite(filt_sin_name, im.astype(np.float32)) def RR_sort(mws, odir, fname): filt_sin_name = os.path.join(odir, os.path.split(fname)[1]) imwrite( filt_sin_name, remove_stripe_based_sorting(read_image(fname).astype(np.float32), mws).astype(np.float32), ) def remove_stripe_based_sorting(sinogram, size, dim=1): # taken from sarepy, Author: Nghia T. Vo https://doi.org/10.1364/OE.26.028396 """ Remove stripe artifacts in a sinogram using the sorting technique, algorithm 3 in Ref. [1]. Angular direction is along the axis 0. Parameters ---------- sinogram : array_like 2D array. Sinogram image. size : int Window size of the median filter. dim : {1, 2}, optional Dimension of the window. """ sinogram = np.transpose(sinogram) (nrow, ncol) = sinogram.shape list_index = np.arange(0.0, ncol, 1.0) mat_index = np.tile(list_index, (nrow, 1)) mat_comb = np.asarray(np.dstack((mat_index, sinogram))) mat_sort = np.asarray([row[row[:, 1].argsort()] for row in mat_comb]) if dim == 2: mat_sort[:, :, 1] = median_filter(mat_sort[:, :, 1], (size, size)) else: mat_sort[:, :, 1] = median_filter(mat_sort[:, :, 1], (size, 1)) mat_sort_back = np.asarray([row[row[:, 0].argsort()] for row in mat_sort]) return np.transpose(mat_sort_back[:, :, 1]) def detect_stripe(list_data, snr): # taken from sarepy, Author: Nghia T. Vo https://doi.org/10.1364/OE.26.028396 """ Locate stripe positions using Algorithm 4 in Ref. [1]. Parameters ---------- list_data : array_like 1D array. Normalized data. snr : float Ratio used to segment stripes from background noise. """ npoint = len(list_data) list_sort = np.sort(list_data) listx = np.arange(0, npoint, 1.0) ndrop = np.int16(0.25 * npoint) (slope, intercept) = np.polyfit(listx[ndrop : -ndrop - 1], list_sort[ndrop : -ndrop - 1], 1) y_end = intercept + slope * listx[-1] noise_level = np.abs(y_end - intercept) noise_level = np.clip(noise_level, 1e-6, None) val1 = np.abs(list_sort[-1] - y_end) / noise_level val2 = np.abs(intercept - list_sort[0]) / noise_level list_mask = np.zeros(npoint, dtype=np.float32) if val1 >= snr: upper_thresh = y_end + noise_level * snr * 0.5 list_mask[list_data > upper_thresh] = 1.0 if val2 >= snr: lower_thresh = intercept - noise_level * snr * 0.5 list_mask[list_data <= lower_thresh] = 1.0 return list_mask def remove_large_stripe(sinogram, size, snr=3, drop_ratio=0.1, norm=True): # taken from sarepy, Author: Nghia T. Vo https://doi.org/10.1364/OE.26.028396 """ Remove large stripes, algorithm 5 in Ref. [1], by: locating stripes, normalizing to remove full stripes, and using the sorting technique (Ref. [1]) to remove partial stripes. Angular direction is along the axis 0. Parameters ---------- sinogram : array_like 2D array. Sinogram image snr : float Ratio used to segment stripes from background noise. size : int Window size of the median filter. drop_ratio : float, optional Ratio of pixels to be dropped, which is used to reduce the false detection of stripes. norm : bool, optional Apply normalization if True. """ sinogram = np.copy(sinogram) # Make it mutable drop_ratio = np.clip(drop_ratio, 0.0, 0.8) (nrow, ncol) = sinogram.shape ndrop = int(0.5 * drop_ratio * nrow) sino_sort = np.sort(sinogram, axis=0) sino_smooth = median_filter(sino_sort, (1, size)) list1 = np.mean(sino_sort[ndrop : nrow - ndrop], axis=0) list2 = np.mean(sino_smooth[ndrop : nrow - ndrop], axis=0) list_fact = np.divide(list1, list2, out=np.ones_like(list1), where=list2 != 0) list_mask = detect_stripe(list_fact, snr) list_mask = np.float32(binary_dilation(list_mask, iterations=1)) mat_fact = np.tile(list_fact, (nrow, 1)) if norm is True: sinogram = sinogram / mat_fact # Normalization sino_tran = np.transpose(sinogram) list_index = np.arange(0.0, nrow, 1.0) mat_index = np.tile(list_index, (ncol, 1)) mat_comb = np.asarray(np.dstack((mat_index, sino_tran))) mat_sort = np.asarray([row[row[:, 1].argsort()] for row in mat_comb]) mat_sort[:, :, 1] = np.transpose(sino_smooth) mat_sort_back = np.asarray([row[row[:, 0].argsort()] for row in mat_sort]) sino_cor = np.transpose(mat_sort_back[:, :, 1]) listx_miss = np.where(list_mask > 0.0)[0] sinogram[:, listx_miss] = sino_cor[:, listx_miss] return sinogram def main(): args = parse_args() sinos = get_filenames(os.path.join(args.sinos, "*.tif")) # create output directory wdir = os.path.split(args.sinos)[0] odir = os.path.join(wdir, "sinos-filt") if not os.path.exists(odir): os.makedirs(odir) pool = mp.Pool(processes=mp.cpu_count()) if args.sort_only: exec_func = partial(RR_sort, args.mws, odir) else: exec_func = partial(RR_wide_sort, args.mws, args.mws2, args.snr, odir) pool.map(exec_func, sinos) if __name__ == "__main__": main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1683790582.0 ufo-tofu-0.13.0/tofu/ez/__init__.py0000664000175000017500000000000100000000000017325 0ustar00tomastomas00000000000000 ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414563.0 ufo-tofu-0.13.0/tofu/ez/ctdir_walker.py0000664000175000017500000001644600000000000020264 0ustar00tomastomas00000000000000""" Created on Apr 5, 2018 @author: gasilos """ import os from tofu.ez.params import EZVARS class WalkCTdirs: """ Walks in the directory structure and creates list of paths to CT folders Determines flats before/after and checks that folders contain only tiff files fdt_names = flats/darks/tomo directory names """ def __init__(self, inpath, fdt_names, verb=True): self.lvl0 = os.path.abspath(inpath) self.ctdirs = [] self.types = [] self.ctsets = [] self.typ = [] self.total = 0 self.good = 0 self.verb = verb self._fdt_names = fdt_names self.common_flats = EZVARS['inout']['path2-shared-flats']['value'] self.common_darks = EZVARS['inout']['path2-shared-darks']['value'] self.common_flats2 = EZVARS['inout']['path2-shared-flats2']['value'] self.use_common_flats2 = EZVARS['inout']['shared-flats-after']['value'] def print_tree(self): print("We start in {}".format(self.lvl0)) def findCTdirs(self): """ Walks directories rooted at "Input Directory" location Appends their absolute path to ctdir if they contain a directory with same name as "tomo" entry in GUI """ for root, dirs, files in os.walk(self.lvl0): for name in dirs: if name == self._fdt_names[2]: self.ctdirs.append(root) self.ctdirs = list(set(self.ctdirs)) self.ctdirs.sort() def checkCTdirs(self): """ Determine whether directory is of type 3 or type 4 and store in self.typ with index corresponding to ctdir Type3: Has flats, darks and not flats2 -- or flats==flats2 Type4: Has flats, darks and flats2 """ for ctdir in self.ctdirs: # flats/darks and no flats2 or flats2==flats -> type 3 if ( os.path.exists(os.path.join(ctdir, self._fdt_names[1])) and os.path.exists(os.path.join(ctdir, self._fdt_names[0])) and ( not os.path.exists(os.path.join(ctdir, self._fdt_names[3])) or self._fdt_names[1] == self._fdt_names[3] ) ): self.typ.append(3) # flats/darks/flats2 -> type4 elif ( os.path.exists(os.path.join(ctdir, self._fdt_names[1])) and os.path.exists(os.path.join(ctdir, self._fdt_names[0])) and os.path.exists(os.path.join(ctdir, self._fdt_names[3])) ): self.typ.append(4) else: print(os.path.basename(ctdir)) self.typ.append(0) def checkcommonfdt(self): """ Verifies that paths to directories specified by common_flats, common_darks, and common_flats2 exist :return: True if directories exist, False if they do not exist """ for ctdir in self.ctdirs: if self.use_common_flats2 is True: self.typ.append(4) elif self.use_common_flats2 is False: self.typ.append(3) if self.use_common_flats2 is True: if ( os.path.exists(self.common_flats) and os.path.exists(self.common_darks) and os.path.exists(self.common_flats2) ): return True elif self.use_common_flats2 is False: if (os.path.exists(self.common_flats) and os.path.exists(self.common_darks)): return True return False def checkcommonfdtFiles(self): """ Verifies that directories of tomo and common flats/darks/flats contain only .tif files :return: True if directories exist, False if they do not exist """ for i, ctdir in enumerate(self.ctdirs): ctdir_tomo_path = os.path.join(ctdir, self._fdt_names[2]) if not self._checktifs(ctdir_tomo_path): print("Invalid files found in " + str(ctdir_tomo_path)) self.typ[i] = 0 return False if not self._checktifs(self.common_flats): print("Invalid files found in " + str(self.common_flats)) return False if not self._checktifs(self.common_darks): print("Invalid files found in " + str(self.common_darks)) return False if self.use_common_flats2 and not self._checktifs(self.common_flats2): print("Invalid files found in " + str(self.common_flats2)) return False return True def checkCTfiles(self): """ Checks whether each ctdir is of type 3 or 4 by comparing index of self.typ[] to corresponding index of ctdir[] Then for each directory of type 3 or 4 it checks sub-directories contain only .tif files If it contains invalid data then typ[] is set to 0 for corresponding index location """ for i, ctdir in enumerate(self.ctdirs): if ( self.typ[i] == 3 and self._checktifs(os.path.join(ctdir, self._fdt_names[1])) and self._checktifs(os.path.join(ctdir, self._fdt_names[0])) and self._checktifs(os.path.join(ctdir, self._fdt_names[2])) ): continue elif ( self.typ[i] == 4 and self._checktifs(os.path.join(ctdir, self._fdt_names[1])) and self._checktifs(os.path.join(ctdir, self._fdt_names[0])) and self._checktifs(os.path.join(ctdir, self._fdt_names[2])) and self._checktifs(os.path.join(ctdir, self._fdt_names[3])) ): continue else: self.typ[i] = 0 def _checktifs(self, tmpath): """ Checks each whether item in directory tmppath is a .tif file :param tmpath: Path to directory :return: 0 if invalid item found in directory - 1 if no invalid items found in directory """ for i in os.listdir(tmpath): if os.path.isdir(i): print(f"Directory {tmpath} contains a subdirectory") return 0 if i.split(".")[-1] != "tif": print(f"Directory {tmpath} has files which are not tif images or containers") return 0 return 1 def sortbadgoodsets(self): """ Reduces type of all directories to either Good with flats 2 (1) or good without flats2 (0) or bad (<0) """ self.total = len(self.ctdirs) self.ctsets = sorted(zip(self.ctdirs, self.typ), key=lambda s: s[0]) self.total = len(self.ctsets) self.good = [int(y) > 2 for x, y in self.ctsets].count(True) tmp = len(self.lvl0) if self.verb: print("Total folders {}, good folders {}".format(self.total, self.good)) print("{:>20}\t{}".format("Path to CT set", "Typ: 0 bad, 3 no flats2, 4 with flats2")) for ctdir in self.ctsets: msg1 = ctdir[0][tmp:] if msg1 == "": msg1 = "." print("{:>20}\t{}".format(msg1, ctdir[1])) # keep paths to directories with good ct data only: self.ctsets = [q for q in self.ctsets if int(q[1] > 0)] def Getlvl0(self): return self.lvl0 ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414563.0 ufo-tofu-0.13.0/tofu/ez/evaluate_sharpness.py0000664000175000017500000003172000000000000021476 0ustar00tomastomas00000000000000import argparse import glob import multiprocessing import os import time import numpy as np from functools import partial from tofu.util import read_image from scipy.stats import skew, kurtosis def sum_abs_gradient(data): """Sum of absolute gradients.""" return np.sum(np.abs(np.gradient(data))) def mad(data): """Median absolute deviation.""" return np.median(np.abs(data - np.median(data))) def abs_sum(data): """Sum of the absolute values.""" return np.sum(np.abs(data)) def entropy(data, bins=256): """Image entropy.""" hist, bins = np.histogram(data, bins=bins) hist = hist.astype(float) hist /= hist.sum() valid = np.where(hist > 0) return -np.sum(np.dot(hist[valid], np.log2(hist[valid]))) def inverted(func, *args, **kwargs): """Return -func(*args, **kwargs).""" return -func(*args, **kwargs) def filter_data(data, fwhm=32.0): """Filter low frequencies in 1D *data* (needed when the axis is far away by axis evaluation). *fwhm* is the FwhM of the gaussian used to filter out low frequencies in real space. The window is then computed as fft(1 - gauss). """ mean = np.mean(data) sigma = fwhm / (2 * np.sqrt(2 * np.log(2))) # We compute the gaussian in Fourier space, so convert sigma first f_sigma = 1.0 / (2 * np.pi * sigma) x = np.fft.fftfreq(len(data)) fltr = 1 - np.exp(-(x ** 2) / (2 * f_sigma ** 2)) return np.fft.ifft(np.fft.fft(data) * fltr).real + mean METRICS_1D = { "mean": np.mean, "std": np.std, "skew": skew, "kurtosis": kurtosis, "mad": mad, "asum": abs_sum, "min": np.min, "max": np.max, "entropy": entropy, } METRICS_2D = {"sag": sum_abs_gradient} for key in list(METRICS_1D): METRICS_1D["m" + key] = partial(inverted, METRICS_1D[key]) for key in list(METRICS_2D): METRICS_2D["m" + key] = partial(inverted, METRICS_2D[key]) # for key in METRICS_1D.keys(): # METRICS_1D['m' + key] = partial(inverted, METRICS_1D[key]) # for key in METRICS_2D.keys(): # METRICS_2D['m' + key] = partial(inverted, METRICS_2D[key]) def evaluate( image, metrics_1d=None, metrics_2d=None, global_min=None, global_max=None, metrics_1d_kwargs=None, blur_fwhm=None, ): """Evaluate *metrics_1d* which work on a flattened image and *metrics_2d* in an *image* which can either be a file path or an imageIf the metrics are None all the default ones are used. *global_min* and *global_max* are the mean extrema of the whole sequence used to cut off outlier values. Extrema are used only by 1d metrics. *metrics_1d_kwargs* are additional keyword arguments passed to the functions, they are specified in dictioinary {func_name: kwargs}. """ if metrics_1d is None: metrics_1d = METRICS_1D if metrics_2d is None: metrics_2d = METRICS_2D results = {} if type(image) == str: image = read_image(image) if blur_fwhm: from scipy.ndimage import gaussian_filter image = gaussian_filter(image, blur_fwhm / (2 * np.sqrt(2 * np.log(2)))) if global_min is None or global_max is None: flattened = image.flatten() else: # Use global cutoff flattened = image[np.where((image >= global_min) & (image <= global_max))] if metrics_1d is not None: for metric in metrics_1d: kwargs = {} if metrics_1d_kwargs and metric in metrics_1d_kwargs: kwargs = metrics_1d_kwargs[metric] results[metric] = metrics_1d[metric](flattened, **kwargs) if metrics_2d is not None: for metric in metrics_2d: results[metric] = metrics_2d[metric](image) return results def evaluate_metrics(images, out_prefix, *args, **kwargs): """Evaluate many *images* which are either file paths or images. *out_prefix* is the metric results file prefix. Metric names and file extension are appended to it. *args* and *kwargs* are passed to :func:`evaluate`. Except for *fwhm* in *kwargs* which is used to filter low frequencies from the results. """ fwhm = kwargs.pop("fwhm") if "fwhm" in kwargs else None pool = multiprocessing.Pool(processes=multiprocessing.cpu_count()) exec_func = partial(evaluate, *args, **kwargs) results = pool.map(exec_func, images) merged = {} for metric in results[0].keys(): merged[metric] = np.array([result[metric] for result in results]) if fwhm: # Filter out low frequencies merged[metric] = filter_data(merged[metric], fwhm=fwhm) if out_prefix is not None: path = out_prefix + "_" + metric + ".txt" np.savetxt(path, merged[metric], fmt="%g") return merged def process( names, num_images_for_stats=0, metric_names=None, out_prefix=None, fwhm=None, metrcs_1d_kwargs=None, blur_fwhm=None, ): """Process many files given by *names*. *out_prefix* is the output file prefix where the metric results will be written to. *fwhm* is used to filter our low frequencies from the results. *metrics_1d_kwargs* are additional keyword arguments passed to the functions, they are specified in dictioinary {func_name: kwargs}. """ if num_images_for_stats: if num_images_for_stats == -1: num_images_for_stats = len(names) extrema_metrics = {"min": np.min, "max": np.max} extrema = evaluate_metrics( names[:num_images_for_stats], None, metrics_1d=extrema_metrics, fwhm=fwhm, blur_fwhm=blur_fwhm, ) global_min = np.mean(extrema["min"]) global_max = np.mean(extrema["max"]) else: global_min = global_max = None metrics_1d, metrics_2d = make_metrics(metric_names) return evaluate_metrics( names, out_prefix, metrics_1d=metrics_1d, metrics_2d=metrics_2d, global_min=global_min, global_max=global_max, fwhm=fwhm, metrics_1d_kwargs=metrcs_1d_kwargs, blur_fwhm=blur_fwhm, ) def main(): args = parse_args() names = sorted(glob.glob(args.input)) if args.dims == 2: axis_length = int(np.sqrt(len(names))) size_str = "{} x {}".format(axis_length, axis_length) else: axis_length = len(names) size_str = str(axis_length) print("Data size: {}".format(size_str)) kwargs = {"entropy": {"bins": args.entropy_num_bins}} for key in kwargs.keys(): kwargs["m" + key] = kwargs[key] st = time.time() results = process( names, num_images_for_stats=args.num_images_for_stats, metric_names=args.metrics, fwhm=args.fwhm, metrcs_1d_kwargs=kwargs, blur_fwhm=args.blur_fwhm, ) if args.verbose: print("Duration: {} s".format(time.time() - st)) x_data = y_data = None for metric, data in results.iteritems(): if x_data is None: x_data = construct_range(args.x_from, args.x_to, len(data), unit=args.x_unit) y_data = construct_range(args.y_from, args.y_to, len(data), unit=args.y_unit) write( args.output, metric, data, axis_length, x_data=x_data, y_data=y_data, save_raw=args.save_raw, save_txt=args.save_txt, save_plot=args.save_plot, ) argmax = np.argmax(data) if args.dims == 2: argmax = np.unravel_index(argmax, (axis_length, axis_length)) y_argmax = y_data[argmax[0]].magnitude x_argmax = x_data[argmax[1]].magnitude retval = (x_argmax, y_argmax) else: x_argmax = x_data[argmax].magnitude retval = x_argmax print(retval) def write( out_dir, metric, data, axis_length, x_data=None, y_data=None, save_raw=False, save_txt=False, save_plot=False, ): out_path = os.path.join(out_dir, metric) if not os.path.exists(out_dir): os.makedirs(out_dir, mode=0o755) if axis_length == len(data): # 1D if save_raw: np.save(out_path + ".npy", data) if save_plot: write_1d_plot(out_path, data, metric, x_data=x_data) else: reshaped = data.reshape(axis_length, axis_length) if save_raw: write_libtiff(out_path + "_raw" + ".tif", reshaped.astype(np.float32)) if save_plot: write_2d_plot(out_path, reshaped, metric, x_data=x_data, y_data=y_data) if save_txt: data = np.array((x_data.magnitude, data)) # Convenient to be read by pgfplots np.savetxt(out_path + ".txt", data.T, fmt="%g", delimiter="\t", comments="", header="x\ty") def write_1d_plot(out_path, data, metric, x_data=None): from matplotlib import pyplot as plt plt.figure() if x_data is not None: plt.plot(x_data.magnitude, data) plt.xlabel(x_data.units) else: plt.plot(data) plt.title(metric) plt.grid() plt.savefig(out_path + ".tif") plt.close() def write_2d_plot(out_path, data, metric, x_data=None, y_data=None): from matplotlib import pyplot as plt, cm plt.figure() plt.imshow(data, cmap=cm.gray) if x_data is not None: x_from = x_data[0].magnitude x_to = x_data[-1].magnitude num_x_ticks = min(data.shape[1], 9) x_locs = np.linspace(-0.5, data.shape[1] - 0.5, num_x_ticks) x_labels = np.linspace(x_from, x_to, num_x_ticks) plt.xticks(x_locs, x_labels) plt.xlabel(x_data.units) if y_data is not None: y_from = y_data[0].magnitude y_to = y_data[-1].magnitude num_y_ticks = min(data.shape[0], 9) y_locs = np.linspace(-0.5, data.shape[0] - 0.5, num_y_ticks) y_labels = np.linspace(y_from, y_to, num_y_ticks) plt.yticks(y_locs, y_labels) plt.ylabel(y_data.units) plt.title(metric) plt.savefig(out_path + ".tif") plt.close() def construct_range(start, stop, num, unit=""): start = 0 if start is None else start stop = num if stop is None else stop region = np.linspace(start, stop, num=num, endpoint=False) return q.Quantity(region, unit) def make_metrics(keys): """Buld 1d and 2d metrics dictionaries from *keys*.""" if keys is None: metrics_1d = METRICS_1D metrics_2d = METRICS_2D else: metrics_1d = {key: METRICS_1D[key] for key in keys if key in METRICS_1D} metrics_2d = {key: METRICS_2D[key] for key in keys if key in METRICS_2D} return metrics_1d, metrics_2d def parse_args(): parser = argparse.ArgumentParser( description="Evaluate sharpness metrics based on parameter changes in 3D reconstruction" ) parser.add_argument("input", type=str, help="Input path pattern") parser.add_argument( "dims", type=int, choices=(1, 2), help="Number of scanned parameters in the data set" ) parser.add_argument("--output", type=str, default=".", help="Output directory") parser.add_argument( "--metrics", type=str, nargs="*", choices=METRICS_1D.keys() + METRICS_2D.keys(), help="Metrics to determine (m prefix means -metric)", ) parser.add_argument("--x-from", type=float, help="X data from") parser.add_argument("--x-to", type=float, help="X data to") parser.add_argument("--x-unit", type=str, default="", help="X axis units") parser.add_argument("--y-from", type=float, help="Y data from") parser.add_argument("--y-to", type=float, help="Y data to") parser.add_argument("--y-unit", type=str, default="", help="Y axis units") parser.add_argument( "--num-images-for-stats", type=int, default=0, help=( "If not zero, an " "image sequence is first read and the mean min and max intensities are " "used as a global range of values to work on (-1 means read all images)" ), ) parser.add_argument( "--fwhm", type=float, help="FwhM of 1 - Gauss in real space used to filter out low frequencies.", ) parser.add_argument( "--entropy-num-bins", type=int, default=256, help="Number of bins to use for histogram calculation by entropy", ) parser.add_argument( "--blur-fwhm", type=float, help="FwhM of the Gaussian blur applied to images" ) parser.add_argument("--save-raw", action="store_true", help="Store raw data (1D npy, 2D tiff)?") parser.add_argument("--save-txt", action="store_true", help="Store raw data as text files") parser.add_argument("--save-plot", action="store_true", help="Store plot data") parser.add_argument("--verbose", action="store_true", help="Verbose output") args = parser.parse_args() if (args.x_from is None) ^ (args.x_to is None): raise ValueError("Either both x-from and x-to are set or both are not") if (args.y_from is None) ^ (args.y_to is None): raise ValueError("Either both y-from and y-to are set or both are not") return args if __name__ == "__main__": main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414563.0 ufo-tofu-0.13.0/tofu/ez/find_axis_cmd_gen.py0000664000175000017500000001217700000000000021227 0ustar00tomastomas00000000000000#!/bin/python """ Created on Apr 6, 2018 @author: gasilos """ import glob, os, tifffile import numpy as np from tofu.ez.evaluate_sharpness import process as process_metrics from tofu.ez.util import enquote, make_inpaths from tofu.util import get_filenames, read_image, determine_shape from tofu.ez.params import EZVARS from tofu.config import SECTIONS from tofu.ez.tofu_cmd_gen import check_lamino, gpu_optim def find_axis_std(ctset, tmpdir, ax_range, p_width, nviews, wh): indir = make_inpaths(ctset[0], ctset[1]) cmd = 'tofu reco' if EZVARS['advanced']['more-reco-params']['value'] is True: cmd += check_lamino() elif EZVARS['advanced']['more-reco-params']['value'] is False: cmd += " --overall-angle 180" cmd += " --darks {} --flats {} --projections {}".format( indir[0], indir[1], enquote(indir[2]) ) cmd += " --number {}".format(nviews) if EZVARS['COR']['min-std-apply-pr']['value']: cmd += f" --disable-projection-crop --delta 1e-6" \ f" --energy {SECTIONS['retrieve-phase']['energy']['value']} " \ f" --propagation-distance {SECTIONS['retrieve-phase']['propagation-distance']['value'][0]}" \ f" --pixel-size {SECTIONS['retrieve-phase']['pixel-size']['value']} " \ f" --regularization-rate {SECTIONS['retrieve-phase']['regularization-rate']['value']:0.2f}" else: cmd += " --absorptivity --fix-nan-and-inf" if ctset[1] == 4: cmd += " --flats2 {}".format(indir[3]) out_pattern = os.path.join(tmpdir, "axis-search/sli") cmd += " --output {}".format(enquote(out_pattern)) cmd += " --x-region={},{},{}".format(int(-p_width / 2), int(p_width / 2), 1) cmd += " --y-region={},{},{}".format(int(-p_width / 2), int(p_width / 2), 1) image_height = wh[0] ax_range_list = ax_range.split(",") range_min = ax_range_list[0] range_max = ax_range_list[1] step = ax_range_list[2] range_string = str(range_min) + "," + str(range_max) + "," + str(step) cmd += " --region={}".format(range_string) res = [float(num) for num in ax_range.split(",")] cmd += " --output-bytes-per-file 0" cmd += ' --z-parameter center-position-x' cmd += ' --z {}'.format(EZVARS['COR']['search-row']['value'] - int(image_height/2)) cmd += gpu_optim() print(cmd) os.system(cmd) points, maximum = evaluate_images_simp(out_pattern + "*.tif", "msag") return res[0] + res[2] * maximum def find_axis_corr(ctset, vcrop, y, height, multipage): indir = make_inpaths(ctset[0], ctset[1]) """Use correlation to estimate center of rotation for tomography.""" from scipy.signal import fftconvolve def flat_correct(flat, radio): nonzero = np.where(radio != 0) result = np.zeros_like(radio) result[nonzero] = flat[nonzero] / radio[nonzero] # log(1) = 0 result[result <= 0] = 1 return np.log(result) if multipage: with tifffile.TiffFile(get_filenames(indir[2])[0]) as tif: first = tif.pages[0].asarray().astype(float) with tifffile.TiffFile(get_filenames(indir[2])[-1]) as tif: last = tif.pages[-1].asarray().astype(float) with tifffile.TiffFile(get_filenames(indir[0])[-1]) as tif: dark = tif.pages[-1].asarray().astype(float) with tifffile.TiffFile(get_filenames(indir[1])[0]) as tif: flat1 = tif.pages[-1].asarray().astype(float) - dark else: first = read_image(get_filenames(indir[2])[0]).astype(float) last = read_image(get_filenames(indir[2])[-1]).astype(float) dark = read_image(get_filenames(indir[0])[-1]).astype(float) flat1 = read_image(get_filenames(indir[1])[-1]) - dark first = flat_correct(flat1, first - dark) if ctset[1] == 4: if multipage: with tifffile.TiffFile(get_filenames(indir[3])[0]) as tif: flat2 = tif.pages[-1].asarray().astype(float) - dark else: flat2 = read_image(get_filenames(indir[3])[-1]) - dark last = flat_correct(flat2, last - dark) else: last = flat_correct(flat1, last - dark) if vcrop: y_region = slice(y, min(y + height, first.shape[0]), 1) first = first[y_region, :] last = last[y_region, :] width = first.shape[1] first = first - first.mean() last = last - last.mean() conv = fftconvolve(first, last[::-1, :], mode="same") center = np.unravel_index(conv.argmax(), conv.shape)[1] return (width / 2.0 + center) / 2.0 # Find midpoint width of image and return its value def find_axis_image_midpoint(height_width): return height_width[1] // 2 def evaluate_images_simp( input_pattern, metric, num_images_for_stats=0, out_prefix=None, fwhm=None, blur_fwhm=None, verbose=False, ): # simplified version of original evaluate_images function # from Tomas's optimize_parameters script names = sorted(glob.glob(input_pattern)) res = process_metrics( names, num_images_for_stats=num_images_for_stats, metric_names=(metric,), out_prefix=out_prefix, fwhm=fwhm, blur_fwhm=blur_fwhm, )[metric] return res, np.argmax(res) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414563.0 ufo-tofu-0.13.0/tofu/ez/image_read_write.py0000664000175000017500000001703200000000000021071 0ustar00tomastomas00000000000000import os, glob import numpy as np import tifffile from tifffile import imread, imwrite class InvalidDataSetError(Exception): """ Error to be raised on attempt to read data from empty or non-existing data set """ def validate_files_path(files_path: str, supported_file_types: list) -> bool: """ Validates specified path :param supported_file_types: List of supported extensions :param files_path: Path to validate :return: True if path exists and contains at least one file of supported type, else False """ try: valid_files_list = get_valid_files_list( files_path=files_path, supported_file_types=supported_file_types ) except InvalidDataSetError: return False return len(valid_files_list) > 0 def get_valid_files_list(files_path: str, supported_file_types: list) -> list: """ Get the list of files of supported type in directory :param supported_file_types: List of supported extensions :param files_path: Path to directory with files :return: List of full paths to files """ # Check if directory exists if not os.path.exists(files_path): raise InvalidDataSetError(f"No such directory: {files_path}") files_list = os.listdir(files_path) valid_files_list = [ os.path.join(files_path, file_name) for file_name in files_list if os.path.splitext(file_name)[1] in supported_file_types ] return sorted(valid_files_list) def read_image(image_file_path: str, data_type=np.float32) -> np.ndarray: """ Reads image file to numpy.ndarray of specified type :param data_type: Data type to store the image :param image_file_path: Full path to image to read :return: """ return imread(image_file_path).astype(dtype=data_type) def write_image(image: np.ndarray, target_directory: str, target_name: str, data_type=np.float32): """ Writes image data to file :param image: Image data :param target_directory: Path to directory to write image :param target_name: Target image file name :param data_type: Data type to write the image :return: """ os.makedirs(target_directory, exist_ok=True) data_file_path = os.path.join(target_directory, target_name) imwrite(data_file_path, data=image.astype(dtype=data_type)) def write_all_images(tiff_arr: np.ndarray, target_directory: str, data_type=np.float32): """ Writes all images in numpy array as individual files in a directory :param tiff_arr: Array containing images :param target_directory: Path to directory to write images :param data_type: Data type to write the images :return: """ print("Writing Images to Directory") # We determine the number of leading zeros to append. # Find number of digits from number of files to write, then add +1 number of leading zeros index = 1 length_str = str(tiff_arr.shape[0]) num_digits = len(length_str) for image in tiff_arr: write_image( image, target_directory, "Image_" + str(index).zfill(num_digits + 1) + ".tif", data_type ) index += 1 print("Finished Writing Images to Directory") def read_all_images( image_files_path: str, supported_image_types: list, data_type=np.float32 ) -> np.ndarray: """ Reads all images of the supported type from specified directory :param supported_image_types: List of supported extensions :param image_files_path: Path to directory with images :param data_type: Data type to store the images :return: 3-dimensional numpy.ndarray of specified type, first index being image index """ valid_files_list = get_valid_files_list( files_path=image_files_path, supported_file_types=supported_image_types ) if len(valid_files_list) == 0: raise InvalidDataSetError( f"Directory {image_files_path} " f"does not contain files of supported types {supported_image_types}" ) data_array = imread(valid_files_list).astype(dtype=data_type) return np.array(data_array) """Image readers for convenient work with multi-page image sequences.""" """ TAKEN STRAIT FROM ufo-kit/Concert (python 2 version) with permission of Tomas""" class FileSequenceReader(object): """Image sequence reader optimized for reading consecutive images. One multi-page image file is not closed after an image is read so that it does not have to be re-opened for reading the next image. The :func:`.close` function must be called explicitly in order to close the last opened image. """ def __init__(self, file_prefix, ext=''): if os.path.isdir(file_prefix): file_prefix = os.path.join(file_prefix, '*' + ext) self._filenames = sorted(glob.glob(file_prefix)) if not self._filenames: raise SequenceReaderError("No files matching `{}' found".format(file_prefix)) self._lengths = {} self._file = None self._filename = None def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): self.close() @property def num_images(self): num = 0 for filename in self._filenames: num += self._get_num_images_in_file(filename) return num def read(self, index): if index < 0: # Enables negative indexing index += self.num_images file_index = 0 while index >= 0: if file_index >= len(self._filenames): raise SequenceReaderError('image index greater than sequence length') index -= self._get_num_images_in_file(self._filenames[file_index]) file_index += 1 file_index -= 1 index += self._lengths[self._filenames[file_index]] self._open(self._filenames[file_index]) return self._read_real(index) def _open(self, filename): if self._filename != filename: if self._filename: self.close() self._file = self._open_real(filename) self._filename = filename def close(self): if self._filename: self._close_real() self._file = None self._filename = None def _get_num_images_in_file(self, filename): if filename not in self._lengths: self._open(filename) self._lengths[filename] = self._get_num_images_in_file_real() return self._lengths[filename] def _open_real(self, filename): """Returns an open file.""" raise NotImplementedError def _close_real(self, filename): raise NotImplementedError def _get_num_images_in_file_real(self): raise NotImplementedError def _read_real(self, index): raise NotImplementedError class TiffSequenceReader(FileSequenceReader): def __init__(self, file_prefix, ext='.tif'): super(TiffSequenceReader, self).__init__(file_prefix, ext=ext) def _open_real(self, filename): import tifffile return tifffile.TiffFile(filename) def _close_real(self): self._file.close() def _get_num_images_in_file_real(self): return len(self._file.pages) def _read_real(self, index): return self._file.pages[index].asarray() def get_image_dtype(file_prefix): tsr = TiffSequenceReader(file_prefix) tmp = tsr.read(0).dtype tsr.close() if tmp == 'uint8': return '8', 'uint8' elif tmp == 'uint16': return '16', 'uint16' elif tmp == 'float32': return '32', 'float32' else: return tmp class SequenceReaderError(Exception): pass././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414563.0 ufo-tofu-0.13.0/tofu/ez/main.py0000664000175000017500000003664100000000000016535 0ustar00tomastomas00000000000000""" Created on Apr 5, 2018 @author: sergei gasilov """ import logging import os import warnings warnings.filterwarnings("ignore") import time from tofu.ez.ctdir_walker import WalkCTdirs from tofu.ez.tofu_cmd_gen import * from tofu.ez.ufo_cmd_gen import * from tofu.ez.find_axis_cmd_gen import * from tofu.ez.util import * from tofu.ez.image_read_write import TiffSequenceReader from tifffile import imwrite from tofu.ez.params import EZVARS from tofu.config import SECTIONS LOG = logging.getLogger(__name__) def get_CTdirs_list(inpath, fdt_names): """ Determines whether directories containing CT data are valid. Returns list of subdirectories with valid CT data :param inpath: Path to the CT directory containing subdirectories with flats/darks/tomo (and flats2 if used) :param fdt_names: Names of the directories which store flats/darks/tomo (and flats2 if used) :return: W.ctsets: List of "good" CTSets and W.lvl0: Path to root of CT sets """ # Constructor call to create WalkCTDirs object W = WalkCTdirs(inpath, fdt_names) # Find any directories containing "tomo" directory W.findCTdirs() # If "Use common flats/darks across multiple experiments" is enabled if EZVARS['inout']['shared-flatsdarks']['value']: logging.debug("Use common darks/flats") logging.debug("Path to darks: " + str(EZVARS['inout']['path2-shared-darks']['value'])) logging.debug("Path to flats: " + str(EZVARS['inout']['path2-shared-flats']['value'])) logging.debug("Path to flats2: " + str(EZVARS['inout']['path2-shared-flats2']['value'])) logging.debug("Use flats2: " + str(EZVARS['inout']['shared-flats-after']['value'])) # Determine whether paths to common flats/darks/flats2 exist if not W.checkcommonfdt(): print("Invalid path to common flats/darks") return W.ctsets, W.lvl0 else: LOG.debug("Paths to common flats/darks exist") # Check whether directories contain only .tif files if not W.checkcommonfdtFiles(): return W.ctsets, W.lvl0 else: # Sort good bad sets W.sortbadgoodsets() return W.ctsets, W.lvl0 # If "Use common flats/darks across multiple experiments" is not enabled else: LOG.debug("Use flats/darks in same directory as tomo") # Check if common flats/darks/flats2 are type 3 or 4 W.checkCTdirs() # Need to check if common flats/darks contain only .tif files W.checkCTfiles() W.sortbadgoodsets() return W.ctsets, W.lvl0 def frmt_ufo_cmds(cmds, ctset, out_pattern, ax, nviews, wh): """formats list of processing commands for a CT set""" # two helper variables to note that PR/FFC has been done at some step swiFFC = True # FFC is always required swiPR = EZVARS['retrieve-phase']['apply-pr']['value'] # PR is an optional operation ####### PREPROCESSING ######### #if we need to use shared flat/darks we have to do it only once so we need to keep track of that #will be set to False in util/make_inpaths as soon as it was used add_value_to_dict_entry(EZVARS['inout']['shared-df-used'], False) if EZVARS['filters']['rm_spots']['value']: # copy one flat to tmpdir now as path might change if preprocess is enabled if not EZVARS['inout']['shared-flatsdarks']['value']: tsr = TiffSequenceReader(os.path.join(ctset[0], EZVARS['inout']['flats-dir']['value'])) else: tsr = TiffSequenceReader(os.path.join(ctset[0], EZVARS['inout']['path2-shared-flats']['value'])) flat1 = tsr.read(tsr.num_images - 1) # taking the last flat tsr.close() flat1_file = os.path.join(EZVARS['inout']['tmp-dir']['value'], "flat1.tif") imwrite(flat1_file, flat1) if EZVARS['inout']['preprocess']['value']: cmds.append('echo " - Applying filter(s) to images "') cmds_prepro = get_pre_cmd(ctset, EZVARS['inout']['preprocess-command']['value'], EZVARS['inout']['tmp-dir']['value']) cmds.extend(cmds_prepro) # reset location of input data ctset = (EZVARS['inout']['tmp-dir']['value'], ctset[1]) ################################################### if EZVARS['filters']['rm_spots']['value']: # generate commands to remove sci. spots from projections cmds.append('echo " - Flat-correcting and removing large spots"') cmds_inpaint = get_inp_cmd(ctset, EZVARS['inout']['tmp-dir']['value'], wh[0], nviews) # reset location of input data ctset = (EZVARS['inout']['tmp-dir']['value'], ctset[1]) cmds.extend(cmds_inpaint) swiFFC = False # no need to do FFC anymore ######## PHASE-RETRIEVAL ####### # Do PR separately if sinograms must be generate # todo? also if vertical ROI is defined to speed up the phase retrieval if EZVARS['retrieve-phase']['apply-pr']['value'] and EZVARS['RR']['enable-RR']['value']: # or (SECTIONS['retrieve-phase']['enable-phase']['value'] and EZVARS['inout']['input_ROI']['value']): if swiFFC: # we still need need flat correction #Inpaint No cmds.append('echo " - Phase retrieval with flat-correction"') if EZVARS['flat-correction']['smart-ffc']['value']: cmds.append(get_pr_sinFFC_cmd(ctset)) cmds.append(get_pr_tofu_cmd_sinFFC(ctset)) elif not EZVARS['flat-correction']['smart-ffc']['value']: cmds.append(get_pr_tofu_cmd(ctset)) else: # Inpaint Yes cmds.append('echo " - Phase retrieval from flat-corrected projections"') cmds.extend(get_pr_ufo_cmd(nviews, wh)) swiPR = False # no need to do PR anymore swiFFC = False # no need to do FFC anymore ################# RING REMOVAL ####################### if EZVARS['RR']['enable-RR']['value']: # Generate sinograms first if swiFFC: # we still need to do flat-field correction if EZVARS['flat-correction']['smart-ffc']['value']: # Create flat corrected images using sinFFC cmds.append(get_sinFFC_cmd(ctset)) # Feed the flat corrected images to sino gram generation cmds.append(get_sinos_noffc_cmd(ctset[0], EZVARS['inout']['tmp-dir']['value'], nviews, wh)) elif not EZVARS['flat-correction']['smart-ffc']['value']: cmds.append('echo " - Make sinograms with flat-correction"') cmds.append(get_sinos_ffc_cmd(ctset, EZVARS['inout']['tmp-dir']['value'], nviews, wh)) else: # we do not need flat-field correction cmds.append('echo " - Make sinograms without flat-correction"') cmds.append(get_sinos_noffc_cmd(ctset[0], EZVARS['inout']['tmp-dir']['value'], nviews, wh)) swiFFC = False # Filter sinograms if EZVARS['RR']['use-ufo']['value']: if EZVARS['RR']['ufo-2d']['value']: cmds.append('echo " - Ring removal - ufo 1d stripes filter"') cmds.append(get_filter1d_sinos_cmd(EZVARS['inout']['tmp-dir']['value'], EZVARS['RR']['sx']['value'], nviews)) else: cmds.append('echo " - Ring removal - ufo 2d stripes filter"') cmds.append(get_filter2d_sinos_cmd(EZVARS['inout']['tmp-dir']['value'], \ EZVARS['RR']['sx']['value'], EZVARS['RR']['sy']['value'], nviews, wh[1])) else: cmds.append('echo " - Ring removal - sarepy filter(s)"') # note - calling an external program, not an ufo-kit script tmp = os.path.dirname(os.path.abspath(__file__)) path_to_filt = os.path.join(tmp, "RR_external.py") if os.path.isfile(path_to_filt): tmp = os.path.join(EZVARS['inout']['tmp-dir']['value'], "sinos") cmdtmp = 'python {} --sinos {} --mws {} --mws2 {} --snr {} --sort_only {}' \ .format(path_to_filt, tmp, EZVARS['RR']['spy-narrow-window']['value'], EZVARS['RR']['spy-wide-window']['value'], EZVARS['RR']['spy-wide-SNR']['value'], int(not EZVARS['RR']['spy-rm-wide']['value'])) cmds.append(cmdtmp) else: cmds.append('echo "Omitting RR because file with filter does not exist"') if not EZVARS['inout']['keep-tmp']['value']: cmds.append("rm -rf {}".format(os.path.join(EZVARS['inout']['tmp-dir']['value'], "sinos"))) # Convert filtered sinograms back to projections cmds.append('echo " - Generating proj from filtered sinograms"') cmds.append(get_sinos2proj_cmd(wh[0])) # reset location of input data ctset = (EZVARS['inout']['tmp-dir']['value'], ctset[1]) # Finally - call to tofu reco cmds.append('echo " - CT with axis {}; ffc:{}, PR:{}"'.format(ax, swiFFC, swiPR)) if EZVARS['flat-correction']['smart-ffc']['value'] and swiFFC: cmds.append(get_sinFFC_cmd(ctset)) cmds.append(get_reco_cmd(ctset, out_pattern, ax, nviews, wh, False, swiPR)) else: # If not using sinFFC cmds.append(get_reco_cmd(ctset, out_pattern, ax, nviews, wh, swiFFC, swiPR)) return nviews, wh #TODO: get rid of fdt_names everywhere - work directly with EZVARS instead def execute_reconstruction(fdt_names): # array with the list of commands cmds = [] # create temporary directory if not os.path.exists(EZVARS['inout']['tmp-dir']['value']): os.makedirs(EZVARS['inout']['tmp-dir']['value']) if EZVARS['inout']['clip_hist']['value']: if SECTIONS['general']['output-minimum']['value'] > SECTIONS['general']['output-maximum']['value']: raise ValueError('hmin must be smaller than hmax to convert to 8bit without contrast inversion') # get list of all good CT directories to be reconstructed print('*********** Analyzing input directory ************') W, lvl0 = get_CTdirs_list(EZVARS['inout']['input-dir']['value'], fdt_names) # W is an array of tuples (path, type) # get list of already reconstructed sets recd_sets = findSlicesDirs(EZVARS['inout']['output-dir']['value']) # initialize command generators # populate list of reconstruction commands print("*********** AXIS INFO ************") for i, ctset in enumerate(W): # ctset is a tuple containing a path and a type (3 or 4) if not already_recd(ctset[0], lvl0, recd_sets): # determine initial number of projections and their shape path2proj = os.path.join(ctset[0], fdt_names[2]) nviews, wh, multipage = get_dims(path2proj) # If EZVARS['COR']['search-method']['value'] == 4 then bypass axis search and use image midpoint if EZVARS['COR']['search-method']['value'] != 4: if (EZVARS['inout']['input_ROI']['value'] and bad_vert_ROI(multipage, path2proj, SECTIONS['reading']['y']['value'], SECTIONS['reading']['height']['value'])): print('{}\t{}'.format('CTset:', ctset[0])) print('{:>30}\t{}'.format('Axis:', 'na')) print('Vertical ROI does not contain any rows.') print("{:>30}\t{}, dimensions: {}".format("Number of projections:", nviews, wh)) continue # Find axis of rotation using auto: correlate first/last projections if EZVARS['COR']['search-method']['value'] == 1: ax = find_axis_corr(ctset, EZVARS['inout']['input_ROI']['value'], SECTIONS['reading']['y']['value'], SECTIONS['reading']['height']['value'], multipage) # Find axis of rotation using auto: minimize STD of a slice elif EZVARS['COR']['search-method']['value'] == 2: cmds.append("echo \"Cleaning axis-search in tmp directory\"") os.system('rm -rf {}'.format(os.path.join(EZVARS['inout']['tmp-dir']['value'], 'axis-search'))) ax = find_axis_std(ctset, EZVARS['inout']['tmp-dir']['value'], EZVARS['COR']['search-interval']['value'], EZVARS['COR']['patch-size']['value'], nviews, wh) else: ax = EZVARS['COR']['user-defined-ax']['value'] + i * EZVARS['COR']['user-defined-dax']['value'] # If EZVARS['COR']['search-method']['value'] == 4 then bypass axis search and use image midpoint elif EZVARS['COR']['search-method']['value'] == 4: ax = find_axis_image_midpoint(wh) print("Bypassing axis search and using image midpoint: {}".format(ax)) setid = ctset[0][len(lvl0) + 1:] out_pattern = os.path.join(EZVARS['inout']['output-dir']['value'], setid, 'sli/sli') cmds.append('echo ">>>>> PROCESSING {}"'.format(setid)) # rm files in temporary directory first of all to # format paths correctly and to avoid problems # when reconstructing ct sets with variable number of rows or projections cmds.append('echo "Cleaning temporary directory"'.format(setid)) clean_tmp_dirs(EZVARS['inout']['tmp-dir']['value'], fdt_names) # call function which formats commands for this data set nviews, wh = frmt_ufo_cmds(cmds, ctset, out_pattern, ax, nviews, wh) save_params(setid, ax, nviews, wh) print('{}\t{}'.format('CTset:', ctset[0])) print('{:>30}\t{}'.format('Axis:', ax)) print("{:>30}\t{}, dimensions: {}".format("Number of projections:", nviews, wh)) # tmp = "Number of projections: {}, dimensions: {}".format(nviews, wh) # cmds.append("echo \"{}\"".format(tmp)) if EZVARS['nlmdn']['do-after-reco']['value']: logging.debug("Using Non-Local Means Denoising") head, tail = os.path.split(out_pattern) slidir = os.path.dirname(os.path.join(head, 'sli')) nlmdn_output = os.path.join(slidir+"-nlmdn", "sli-nlmdn-%04i.tif") cmds.append(fmt_nlmdn_ufo_cmd(slidir, nlmdn_output)) else: print("{} has been already reconstructed".format(ctset[0])) # execute commands = start reconstruction start = time.time() print("*********** PROCESSING ************") for cmd in cmds: if not EZVARS['inout']['dryrun']['value']: os.system(cmd) else: print(cmd) if not EZVARS['inout']['keep-tmp']['value']: clean_tmp_dirs(EZVARS['inout']['tmp-dir']['value'], fdt_names) print("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx") print("*** Done. Total processing time {} sec.".format(int(time.time() - start))) print("*** Waiting for the next job...........") def already_recd(ctset, indir, recd_sets): x = False if ctset[len(indir) + 1 :] in recd_sets: x = True return x def findSlicesDirs(lvl0): recd_sets = [] for root, dirs, files in os.walk(lvl0): for name in dirs: if name == "sli": recd_sets.append(root[len(lvl0) + 1 :]) return recd_sets ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414563.0 ufo-tofu-0.13.0/tofu/ez/params.py0000664000175000017500000002065700000000000017074 0ustar00tomastomas00000000000000# This file is used to share params as a global variable import yaml import os from collections import OrderedDict from tofu.util import restrict_value params = {} def save_parameters(params, file_path): file_out = open(file_path, 'w') yaml.dump(params, file_out) print("Parameters file saved at: " + str(file_path)) EZVARS = OrderedDict() EZVARS['inout'] = { 'input-dir': { 'ezdefault': os.path.join(os.path.expanduser('~'),""), 'type': str, 'help': "TODO"}, 'output-dir': { 'ezdefault': os.path.join(os.path.expanduser('~'),"rec"), 'type': str, 'help': "TODO"}, 'tmp-dir' : { 'ezdefault': os.path.join(os.path.expanduser('~'),"tmp-ezufo"), 'type': str, 'help': "TODO"}, 'darks-dir': { 'ezdefault': "darks", 'type': str, 'help': "TODO"}, 'flats-dir': { 'ezdefault': "flats", 'type': str, 'help': "TODO"}, 'tomo-dir': { 'ezdefault': "tomo", 'type': str, 'help': "TODO"}, 'flats2-dir': { 'ezdefault': "flats2", 'type': str, 'help': "TODO"}, 'bigtiff-output': { 'ezdefault': False, 'type': bool, 'help': "TODO"}, 'input_ROI': { 'ezdefault': False, 'type': bool, 'help': "TODO"}, 'clip_hist': { 'ezdefault': False, 'type': bool, 'help': "TODO"}, 'preprocess': { 'ezdefault': False, 'type': bool, 'help': "TODO"}, 'preprocess-command': { 'ezdefault': "remove-outliers size=3 threshold=500 sign=1", 'type': str, 'help': "TODO"}, 'output-ROI': { 'ezdefault': False, 'type': bool, 'help': "TODO"}, 'output-x': { 'ezdefault': 0, 'type': restrict_value((0,None),dtype=int), 'help': "Crop slices: x"}, 'output-width': { 'ezdefault': 0, 'type': restrict_value((0,None),dtype=int), 'help': "Crop slices: width"}, 'output-y': { 'ezdefault': 0, 'type': restrict_value((0,None),dtype=int), 'help': "Crop slices: y"}, 'output-height': { 'ezdefault': 0, 'type': restrict_value((0,None),dtype=int), 'help': "Crop slices: height"}, 'dryrun': { 'ezdefault': False, 'type': bool, 'help': "TODO"}, 'save-params': { 'ezdefault': True, 'type': bool, 'help': "TODO"}, 'keep-tmp': { 'ezdefault': False, 'type': bool, 'help': "TODO"}, 'open-viewer': { 'ezdefault': False, 'type': bool, 'help': "TODO"}, 'shared-flatsdarks': { 'ezdefault': False, 'type': bool, 'help': "TODO"}, 'path2-shared-darks': { 'ezdefault': "Absolute path to darks", 'type': str, 'help': "TODO"}, 'path2-shared-flats': { 'ezdefault': "Absolute path to flats", 'type': str, 'help': "TODO"}, 'shared-flats-after': { 'ezdefault': False, 'type': bool, 'help': "TODO"}, 'path2-shared-flats2': { 'ezdefault': "Absolute path to flats2", 'type': str, 'help': "TODO"}, 'shared-df-used': { 'ezdefault': False, 'type': bool, 'help': "Internal variable; must be set to True once " "shared flats/darks were used in the recontruction pipeline"}, } EZVARS['COR'] = { 'search-method': { 'ezdefault': 1, 'type': int, 'help': "TODO"}, 'search-interval': { 'ezdefault': "1010,1030,0.5", 'type': str, 'help': "TODO"}, 'patch-size': { 'ezdefault': 256, 'type': restrict_value((0,None),dtype=int), 'help': "Size of reconstructed patch [pixel]"}, 'search-row': { 'ezdefault': 100, 'type': restrict_value((0,None), dtype=int), 'help': "Search in slice from row number"}, 'min-std-apply-pr': { 'ezdefault': False, 'type': bool, 'help': "Will apply phase retreival but only while estimating the axis"}, 'user-defined-ax': { 'ezdefault': 0.0, 'type': restrict_value((0,None),dtype=float), 'help': "Axis is in column No [pixel]"}, 'user-defined-dax': { 'ezdefault': 0.0, 'type': float, 'help': "TODO"}, } EZVARS['retrieve-phase']= { 'apply-pr': { 'default': False, 'ezdefault': False, 'type': bool, 'help': "Applies phase retrieval if checked"} } EZVARS['filters'] = { 'rm_spots': { 'ezdefault': False, 'type': bool, 'help': "TODO-G"}, 'spot-threshold': { 'ezdefault': 1000, 'type': restrict_value((0,None), dtype=float), 'help': "TODO-G"} } EZVARS['RR'] = { 'enable-RR': { 'ezdefault': False, 'type': bool, 'help': "TODO-G"}, 'use-ufo': { 'ezdefault': True, 'type': bool, 'help': "TODO-G"}, 'ufo-2d': { 'ezdefault': False, 'type': bool, 'help': "TODO"}, 'sx': { 'ezdefault': 3, 'type': restrict_value((0,None),dtype=int), 'help': "ufo ring-removal sigma horizontal (try 3..31)"}, 'sy': { 'ezdefault': 1, 'type': restrict_value((0,None),dtype=int), 'help': "ufo ring-removal sigma vertical (try 1..5)"}, 'spy-narrow-window': { 'ezdefault': 21, 'type': restrict_value((0,None),dtype=int), 'help': "window size"}, 'spy-rm-wide': { 'ezdefault': False, 'type': bool, 'help': "TODO"}, 'spy-wide-window': { 'ezdefault': 91, 'type': restrict_value((0,None),dtype=int), 'help': "wind"}, 'spy-wide-SNR': { 'ezdefault': 3, 'type': restrict_value((0,None),dtype=int), 'help': "SNR"}, } EZVARS['flat-correction'] = { 'smart-ffc': { 'ezdefault': False, 'type': bool, 'help': "TODO"}, 'smart-ffc-method': { 'ezdefault': "eigen", 'type': str, 'help': "TODO"}, 'eigen-pco-reps': { 'ezdefault': 4, 'type': restrict_value((0,None),dtype=int), 'help': "Flat Field Correction: Eigen PCO Repetitions"}, 'eigen-pco-downsample': { 'ezdefault': 2, 'type': restrict_value((0,None),dtype=int), 'help': "Flat Field Correction: Eigen PCO Downsample"}, 'downsample': { 'ezdefault': 4, 'type': restrict_value((0,None),dtype=int), 'help': "Flat Field Correction: Downsample"}, 'dark-scale': { 'ezdefault': 1.0, 'type': float, 'help': "Scaling dark"}, #(?) has the same name in SECTION 'flat-scale': { 'ezdefault': 1.0, 'type': float, 'help': "Scaling falt"}, #(?) has the same name in SECTION } #TODO ADD CHECKING NLMDN SETTINGS EZVARS['nlmdn'] = { 'do-after-reco': { 'ezdefault': False, 'type': bool, 'help': "TODO"}, 'input-dir': { 'ezdefault': os.getcwd(), 'type': str, 'help': "TODO"}, 'input-is-1file': { 'ezdefault': False, 'type': bool, 'help': "TODO"}, 'output_pattern': { 'ezdefault': os.getcwd() + '-nlmfilt', 'type': str, 'help': "TODO"}, 'bigtiff_output': { 'ezdefault': False, 'type': bool, 'help': "TODO"}, 'search-radius': { 'ezdefault': 10, 'type': int, 'help': "TODO"}, 'patch-radius': { 'ezdefault': 3, 'type': int, 'help': "TODO"}, 'h': { 'ezdefault': 0.0, 'type': float, 'help': "TODO"}, 'sigma': { 'ezdefault': 0.0, 'type': float, 'help': "TODO"}, 'window': { 'ezdefault': 0.0, 'type': float, 'help': "TODO"}, 'fast': { 'ezdefault': True, 'type': bool, 'help': "TODO"}, 'estimate-sigma': { 'ezdefault': False, 'type': bool, 'help': "TODO"}, 'dryrun': { 'ezdefault': False, 'type': bool, 'help': "TODO"}, } EZVARS['advanced'] = { 'more-reco-params': { 'ezdefault': False, 'type': bool, 'help': "TODO"}, 'parameter-type': { 'ezdefault': "", 'type': str, 'help': "TODO"}, 'enable-optimization': { 'ezdefault': False, 'type': bool, 'help': "TODO" } }././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414563.0 ufo-tofu-0.13.0/tofu/ez/tofu_cmd_gen.py0000664000175000017500000004377300000000000020246 0ustar00tomastomas00000000000000#!/bin/python """ Created on Apr 6, 2018 @author: gasilos """ import os import numpy as np from tofu.ez.ufo_cmd_gen import fmt_in_out_path from tofu.ez.params import EZVARS from tofu.config import SECTIONS from tofu.ez.util import make_inpaths, fmt_in_out_path def check_lamino(): cmd = '' if not SECTIONS['cone-beam-weight']['axis-angle-x']['value'][0] == '': cmd += ' --axis-angle-x {}'.format(SECTIONS['cone-beam-weight']['axis-angle-x']['value'][0]) if not SECTIONS['general-reconstruction']['overall-angle']['value'] == '': cmd += ' --overall-angle {}'.format(SECTIONS['general-reconstruction']['overall-angle']['value']) if not SECTIONS['cone-beam-weight']['center-position-z']['value'][0] == '': cmd += ' --center-position-z {}'.format(SECTIONS['cone-beam-weight']['center-position-z']['value'][0]) if not SECTIONS['general-reconstruction']['axis-angle-y']['value'][0] == '': cmd += ' --axis-angle-y {}'.format(SECTIONS['general-reconstruction']['axis-angle-y']['value'][0]) return cmd def gpu_optim(): cmd = '' if SECTIONS['general']['verbose']['value']: cmd += ' --verbose' if EZVARS['advanced']['enable-optimization']['value']: cmd += ' --slice-memory-coeff={}'.format(SECTIONS['general-reconstruction']['slice-memory-coeff']['value']) if not SECTIONS['general-reconstruction']['slices-per-device']['value'] is None: cmd += ' --slices-per-device {}'.format(SECTIONS['general-reconstruction']['slices-per-device']['value']) if not SECTIONS['general-reconstruction']['data-splitting-policy']['value'] is None: cmd += ' --data-splitting-policy {}'.format( SECTIONS['general-reconstruction']['data-splitting-policy']['value']) return cmd def check_8bit(cmd, gray256, bit, hmin, hmax): if gray256: cmd += " --output-bitdepth {}".format(bit) # cmd += " --output-minimum \" {}\" --output-maximum \" {}\""\ # .format(hmin, hmax) cmd += ' --output-minimum " {}" --output-maximum " {}"'.format(hmin, hmax) return cmd def check_vcrop(cmd, vcrop, y, yheight, ystep, ori_height): if vcrop: cmd += " --y {} --height {} --y-step {}".format(y, yheight, ystep) else: cmd += " --height {}".format(ori_height) return cmd def check_bigtif(cmd, swi): if not swi: cmd += " --output-bytes-per-file 0" return cmd def get_1step_ct_cmd(ctset, out_pattern, ax, nviews, wh): # direct CT reconstruction from input dir to output dir; # obsolete, replaced by tofu reco, see get_reco_cmd() lower indir = make_inpaths(ctset[0], ctset[1]) # correct location of proj folder in case if prepro was done in_proj_dir, quatsch = fmt_in_out_path(EZVARS['inout']['tmp-dir']['value'], ctset[0], EZVARS['inout']['tomo-dir']['value'], False) indir[2] = os.path.join(os.path.split(indir[2])[0], os.path.split(in_proj_dir)[1]) # format command cmd = "tofu tomo --absorptivity --fix-nan-and-inf" cmd += " --darks {} --flats {} --projections {}".format(indir[0], indir[1], indir[2]) if ctset[1] == 4: # must be equivalent to len(indir)>3 cmd += " --flats2 {}".format(indir[3]) cmd += " --output {}".format(out_pattern) cmd += " --axis {}".format(ax) cmd += " --offset {}".format(SECTIONS['general-reconstruction']['volume-angle-z']['value'][0]) cmd += " --number {}".format(nviews) if SECTIONS['reading']['step']['value'] > 0.0: cmd += ' --angle {}'.format(SECTIONS['reading']['step']['value']) cmd = check_vcrop(cmd, EZVARS['inout']['input_ROI']['value'], SECTIONS['reading']['y']['value'], SECTIONS['reading']['height']['value'], SECTIONS['reading']['y-step']['value'], wh[0]) cmd = check_8bit(cmd, EZVARS['inout']['clip_hist']['value'], SECTIONS['general']['output-bitdepth']['value'], SECTIONS['general']['output-minimum']['value'], SECTIONS['general']['output-maximum']['value']) cmd = check_bigtif(cmd, EZVARS['inout']['bigtiff-output']['value']) return cmd def get_ct_proj_cmd( out_pattern, ax, nviews, wh): # CT reconstruction from pre-processed and flat-corrected projections in_proj_dir, quatsch = fmt_in_out_path( EZVARS['inout']['tmp-dir']['value'], "obsolete;if-you-need-fix-it", EZVARS['inout']['tomo-dir']['value'], False ) cmd = "tofu tomo --projections {}".format(in_proj_dir) cmd += " --output {}".format(out_pattern) cmd += " --axis {}".format(ax) cmd += " --offset {}".format(SECTIONS['general-reconstruction']['volume-angle-z']['value'][0]) cmd += " --number {}".format(nviews) if SECTIONS['reading']['step']['value'] > 0.0: cmd += ' --angle {}'.format(SECTIONS['reading']['step']['value']) cmd = check_vcrop(cmd, EZVARS['inout']['input_ROI']['value'], SECTIONS['reading']['y']['value'], SECTIONS['reading']['height']['value'], SECTIONS['reading']['y-step']['value'], wh[0]) cmd = check_8bit(cmd, EZVARS['inout']['clip_hist']['value'], SECTIONS['general']['output-bitdepth']['value'], SECTIONS['general']['output-minimum']['value'], SECTIONS['general']['output-maximum']['value']) cmd = check_bigtif(cmd, EZVARS['inout']['bigtiff-output']['value']) return cmd def get_ct_sin_cmd(out_pattern, ax, nviews, wh): sinos_dir = os.path.join(EZVARS['inout']['tmp-dir']['value'], 'sinos-filt') cmd = 'tofu tomo --sinograms {}'.format(sinos_dir) cmd += ' --output {}'.format(out_pattern) cmd += ' --axis {}'.format(ax) cmd += ' --offset {}'.format(SECTIONS['general-reconstruction']['volume-angle-z']['value'][0]) if EZVARS['inout']['input_ROI']['value']: cmd += ' --number {}'.format(int(SECTIONS['reading']['height']['value'] / SECTIONS['reading']['y-step']['value'])) else: cmd += " --number {}".format(wh[0]) cmd += " --height {}".format(nviews) if SECTIONS['reading']['step']['value'] > 0.0: cmd += ' --angle {}'.format(SECTIONS['reading']['step']['value']) cmd = check_8bit(cmd, EZVARS['inout']['clip_hist']['value'], SECTIONS['general']['output-bitdepth']['value'], SECTIONS['general']['output-minimum']['value'], SECTIONS['general']['output-maximum']['value']) cmd = check_bigtif(cmd, EZVARS['inout']['bigtiff-output']['value']) return cmd def get_sinos_ffc_cmd(ctset, tmpdir, nviews, wh): indir = make_inpaths(ctset[0], ctset[1]) in_proj_dir, out_pattern = fmt_in_out_path(EZVARS['inout']['tmp-dir']['value'], ctset[0], EZVARS['inout']['tomo-dir']['value'], False) cmd = 'tofu sinos --absorptivity --fix-nan-and-inf' cmd += ' --darks {} --flats {} '.format(indir[0], indir[1]) if ctset[1] == 4: cmd += " --flats2 {}".format(indir[3]) cmd += " --projections {}".format(in_proj_dir) cmd += " --output {}".format(os.path.join(tmpdir, "sinos/sin-%04i.tif")) cmd += " --number {}".format(nviews) cmd = check_vcrop(cmd, EZVARS['inout']['input_ROI']['value'], SECTIONS['reading']['y']['value'], SECTIONS['reading']['height']['value'], SECTIONS['reading']['y-step']['value'], wh[0]) if not EZVARS['RR']['use-ufo']['value']: # because second RR algorithm does not know how to work with multipage tiffs cmd += " --output-bytes-per-file 0" if not EZVARS['flat-correction']['dark-scale']['value'] == "": cmd += ' --dark-scale {}'.format(EZVARS['flat-correction']['dark-scale']['value']) if not EZVARS['flat-correction']['flat-scale']['value'] == "": cmd += ' --flat-scale {}'.format(EZVARS['flat-correction']['flat-scale']['value']) return cmd def get_sinos_noffc_cmd(ctsetpath, tmpdir, nviews, wh): in_proj_dir, out_pattern = fmt_in_out_path( EZVARS['inout']['tmp-dir']['value'], ctsetpath, EZVARS['inout']['tomo-dir']['value'], False ) cmd = "tofu sinos" cmd += " --projections {}".format(in_proj_dir) cmd += " --output {}".format(os.path.join(tmpdir, "sinos/sin-%04i.tif")) cmd += " --number {}".format(nviews) cmd = check_vcrop(cmd, EZVARS['inout']['input_ROI']['value'], SECTIONS['reading']['y']['value'], SECTIONS['reading']['height']['value'], SECTIONS['reading']['y-step']['value'], wh[0]) if not EZVARS['RR']['use-ufo']['value']: # because second RR algorithm does not know how to work with multipage tiffs cmd += " --output-bytes-per-file 0" return cmd def get_sinos2proj_cmd(proj_height): quatsch, out_pattern = fmt_in_out_path(EZVARS['inout']['tmp-dir']['value'], 'quatsch', EZVARS['inout']['tomo-dir']['value'], True) cmd = 'tofu sinos' cmd += ' --projections {}'.format(os.path.join(EZVARS['inout']['tmp-dir']['value'], 'sinos-filt')) cmd += ' --output {}'.format(out_pattern) if not EZVARS['inout']['input_ROI']['value']: cmd += ' --number {}'.format(proj_height) else: cmd += ' --number {}'.format(int(SECTIONS['reading']['height']['value'] / SECTIONS['reading']['y-step']['value'])) return cmd def get_sinFFC_cmd(ctset): indir = make_inpaths(ctset[0], ctset[1]) in_proj_dir, out_pattern = fmt_in_out_path(EZVARS['inout']['tmp-dir']['value'], ctset[0], EZVARS['inout']['tomo-dir']['value']) cmd = 'bmit_sin --fix-nan' cmd += ' --darks {} --flats {} --projections {}'.format(indir[0], indir[1], in_proj_dir) if ctset[1] == 4: cmd += ' --flats2 {}'.format(indir[3]) cmd += ' --output {}'.format(os.path.dirname(out_pattern)) cmd += ' --method {}'.format(EZVARS['flat-correction']['smart-ffc-method']['value']) cmd += ' --multiprocessing' cmd += ' --eigen-pco-repetitions {}'.format(EZVARS['flat-correction']['eigen-pco-reps']['value']) cmd += ' --eigen-pco-downsample {}'.format(EZVARS['flat-correction']['eigen-pco-downsample']['value']) cmd += ' --downsample {}'.format(EZVARS['flat-correction']['downsample']['value']) return cmd def get_pr_sinFFC_cmd(ctset): indir = make_inpaths(ctset[0], ctset[1]) in_proj_dir, out_pattern = fmt_in_out_path( EZVARS['inout']['tmp-dir']['value'], ctset[0], EZVARS['inout']['tomo-dir']['value']) cmd = 'bmit_sin --fix-nan' cmd += ' --darks {} --flats {} --projections {}'.format(indir[0], indir[1], in_proj_dir) if ctset[1] == 4: cmd += ' --flats2 {}'.format(indir[3]) cmd += ' --output {}'.format(os.path.dirname(out_pattern)) cmd += ' --method {}'.format(EZVARS['flat-correction']['smart-ffc-method']['value']) cmd += ' --multiprocessing' cmd += ' --eigen-pco-repetitions {}'.format(EZVARS['flat-correction']['eigen-pco-reps']['value']) cmd += ' --eigen-pco-downsample {}'.format(EZVARS['flat-correction']['eigen-pco-downsample']['value']) cmd += ' --downsample {}'.format(EZVARS['flat-correction']['downsample']['value']) return cmd def get_pr_tofu_cmd_sinFFC(ctset): # indir will format paths to flats darks and tomo2 correctly even if they were # pre-processed, however path to the input directory with projections # cannot be formatted with that command correctly # indir = make_inpaths(ctset[0], ctset[1]) # so we need a separate "universal" command which considers all previous steps in_proj_dir, out_pattern = fmt_in_out_path(EZVARS['inout']['tmp-dir']['value'], ctset[0], EZVARS['inout']['tomo-dir']['value']) # Phase retrieval cmd = 'tofu preprocess --delta 1e-6' cmd += ' --energy {} --propagation-distance {}' \ ' --pixel-size {} --regularization-rate {:0.2f}' \ .format(SECTIONS['retrieve-phase']['energy']['value'], SECTIONS['retrieve-phase']['propagation-distance']['value'][0], SECTIONS['retrieve-phase']['pixel-size']['value'], SECTIONS['retrieve-phase']['regularization-rate']['value']) cmd += ' --projections {}'.format(in_proj_dir) cmd += ' --output {}'.format(out_pattern) cmd += ' --projection-crop-after filter' return cmd def get_pr_tofu_cmd(ctset): # indir will format paths to flats darks and tomo2 correctly even if they were # pre-processed, however path to the input directory with projections # cannot be formatted with that command correctly indir = make_inpaths(ctset[0], ctset[1]) # so we need a separate "universal" command which considers all previous steps in_proj_dir, out_pattern = fmt_in_out_path(EZVARS['inout']['tmp-dir']['value'], ctset[0], EZVARS['inout']['tomo-dir']['value']) cmd = 'tofu preprocess --fix-nan-and-inf --projection-filter none --delta 1e-6' cmd += ' --darks {} --flats {} --projections {}'.format(indir[0], indir[1], in_proj_dir) if ctset[1] == 4: cmd += ' --flats2 {}'.format(indir[3]) cmd += ' --output {}'.format(out_pattern) cmd += ' --energy {} --propagation-distance {}' \ ' --pixel-size {} --regularization-rate {:0.2f}' \ .format(SECTIONS['retrieve-phase']['energy']['value'], SECTIONS['retrieve-phase']['propagation-distance']['value'][0], SECTIONS['retrieve-phase']['pixel-size']['value'], SECTIONS['retrieve-phase']['regularization-rate']['value']) if not EZVARS['flat-correction']['dark-scale']['value'] is None: cmd += ' --dark-scale {}'.format(EZVARS['flat-correction']['dark-scale']['value']) if not EZVARS['flat-correction']['flat-scale']['value'] is None: cmd += ' --flat-scale {}'.format(EZVARS['flat-correction']['flat-scale']['value']) return cmd def get_reco_cmd(ctset, out_pattern, ax, nviews, wh, ffc, pr): # direct CT reconstruction from input dir to output dir; # or CT reconstruction after preprocessing only indir = make_inpaths(ctset[0], ctset[1]) # correct location of proj folder in case if prepro was done in_proj_dir, quatsch = fmt_in_out_path(EZVARS['inout']['tmp-dir']['value'], ctset[0], EZVARS['inout']['tomo-dir']['value'], False) cmd = 'tofu reco' # Laminography ? if EZVARS['advanced']['more-reco-params']['value'] is True: cmd += check_lamino() elif EZVARS['advanced']['more-reco-params']['value'] is False: cmd += ' --overall-angle 180' ############## cmd += ' --projections {}'.format(in_proj_dir) cmd += ' --output {}'.format(out_pattern) if ffc: cmd += ' --fix-nan-and-inf' cmd += ' --darks {} --flats {}'.format(indir[0], indir[1]) if ctset[1] == 4: # must be equivalent to len(indir)>3 cmd += ' --flats2 {}'.format(indir[3]) if not pr: cmd += ' --absorptivity' if not EZVARS['flat-correction']['dark-scale']['value'] is None: cmd += ' --dark-scale {}'.format(EZVARS['flat-correction']['dark-scale']['value']) if not EZVARS['flat-correction']['flat-scale']['value'] is None: cmd += ' --flat-scale {}'.format(EZVARS['flat-correction']['flat-scale']['value']) if pr: cmd += ( " --disable-projection-crop" " --delta 1e-6" " --energy {} --propagation-distance {}" " --pixel-size {} --regularization-rate {:0.2f}" \ .format(SECTIONS['retrieve-phase']['energy']['value'], SECTIONS['retrieve-phase']['propagation-distance']['value'][0], SECTIONS['retrieve-phase']['pixel-size']['value'], SECTIONS['retrieve-phase']['regularization-rate']['value']) ) cmd += " --center-position-x {}".format(ax) # if args.nviews==0: cmd += " --number {}".format(nviews) # elif args.nviews>0: # cmd += ' --number {}'.format(args.nviews) cmd += ' --volume-angle-z {:0.5f}'.format(SECTIONS['general-reconstruction']['volume-angle-z']['value'][0]) # rows-slices to be reconstructed # full ROI b = int(np.ceil(wh[0] / 2.0)) a = -int(wh[0] / 2.0) c = 1 if EZVARS['inout']['input_ROI']['value']: if EZVARS['RR']['enable-RR']['value']: h2 = SECTIONS['reading']['height']['value'] / SECTIONS['reading']['y-step']['value'] / 2.0 b = np.ceil(h2) a = -int(h2) else: h2 = int(wh[0] / 2.0) a = SECTIONS['reading']['y']['value'] - h2 b = SECTIONS['reading']['y']['value'] + SECTIONS['reading']['height']['value'] - h2 c = SECTIONS['reading']['y-step']['value'] cmd += ' --region={},{},{}'.format(a, b, c) # crop of reconstructed slice in the axial plane b = wh[1] / 2 if EZVARS['inout']['output-ROI']['value']: if EZVARS['inout']['output-x']['value'] != 0 or EZVARS['inout']['output-width']['value'] != 0: cmd += ' --x-region={},{},{}'.format(EZVARS['inout']['output-x']['value'] - b, EZVARS['inout']['output-x']['value'] + EZVARS['inout']['output-width']['value'] - b, 1) if EZVARS['inout']['output-y']['value'] != 0 or EZVARS['inout']['output-height']['value'] != 0: cmd += ' --y-region={},{},{}'.format(EZVARS['inout']['output-y']['value'] - b, EZVARS['inout']['output-y']['value'] + EZVARS['inout']['output-height']['value'] - b, 1) # cmd = check_vcrop(cmd, EZVARS['inout']['input_ROI']['value'], SECTIONS['reading']['y']['value'], SECTIONS['reading']['height']['value'], SECTIONS['reading']['y-step']['value'], wh[0]) cmd = check_8bit(cmd, EZVARS['inout']['clip_hist']['value'], SECTIONS['general']['output-bitdepth']['value'], SECTIONS['general']['output-minimum']['value'], SECTIONS['general']['output-maximum']['value']) cmd = check_bigtif(cmd, EZVARS['inout']['bigtiff-output']['value']) # Optimization cmd += gpu_optim() return cmd ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414563.0 ufo-tofu-0.13.0/tofu/ez/ufo_cmd_gen.py0000664000175000017500000002430300000000000020046 0ustar00tomastomas00000000000000#!/bin/python """ Created on Apr 6, 2018 @author: gasilos """ import os from tofu.util import next_power_of_two from tofu.ez.params import EZVARS from tofu.config import SECTIONS from tofu.ez.util import enquote, make_inpaths, fmt_in_out_path def make_outpaths(lvl0, flats2): """ Creates a list of paths to flats/darks/tomo directories in tmp data only used in one place to format paths in the temporary directory :param lvl0: Root of directory containing flats/darks/tomo :param flats2: The type of directory: 3 contains flats/darks/tomo 4 contains flats/darks/tomo/flats2 :return: List of paths to the filtered darks/flats/tomo and flats2 (if used) """ indir = [] for i in [EZVARS['inout']['darks-dir']['value'], EZVARS['inout']['flats-dir']['value'], EZVARS['inout']['tomo-dir']['value']]: indir.append(os.path.join(lvl0, i)) if flats2 - 3: indir.append(os.path.join(lvl0, EZVARS['inout']['flats2-dir']['value'])) return indir def check_vcrop(cmd, vcrop, y, yheight, ystep): if vcrop: cmd += " --y {} --height {} --y-step {}".format(y, yheight, ystep) return cmd def check_bigtif(cmd, swi): if not swi: cmd += " bytes-per-file=0" return cmd def get_pr_ufo_cmd(nviews, wh): in_proj_dir, out_pattern = fmt_in_out_path(EZVARS['inout']['tmp-dir']['value'], "quatsch", EZVARS['inout']['tomo-dir']['value']) cmds = [] pad_width = next_power_of_two(wh[1] + 50) pad_height = next_power_of_two(wh[0] + 50) pad_x = (pad_width - wh[1]) / 2 pad_y = (pad_height - wh[0]) / 2 cmd = 'ufo-launch read path={} height={} number={}'.format(in_proj_dir, wh[0], nviews) cmd += ' ! pad x={} width={} y={} height={}'.format(pad_x, pad_width, pad_y, pad_height) cmd += ' addressing-mode=clamp_to_edge' cmd += ' ! fft dimensions=2 ! retrieve-phase' cmd += ' energy={} distance={} pixel-size={} regularization-rate={:0.2f}' \ .format(SECTIONS['retrieve-phase']['energy']['value'], SECTIONS['retrieve-phase']['propagation-distance']['value'][0], SECTIONS['retrieve-phase']['pixel-size']['value'], SECTIONS['retrieve-phase']['regularization-rate']['value']) cmd += ' ! ifft dimensions=2 crop-width={} crop-height={}' \ .format(pad_width, pad_height) cmd += ' ! crop x={} width={} y={} height={}'.format(pad_x, wh[1], pad_y, wh[0]) cmd += ' ! opencl kernel=\'absorptivity\' ! opencl kernel=\'fix_nan_and_inf\' !' cmd += ' write filename={}'.format(enquote(out_pattern)) cmds.append(cmd) if not EZVARS['inout']['keep-tmp']['value']: cmds.append('rm -rf {}'.format(in_proj_dir)) return cmds def get_filter1d_sinos_cmd(tmpdir, RR, nviews): sin_in = os.path.join(tmpdir, 'sinos') out_pattern = os.path.join(tmpdir, 'sinos-filt/sin-%04i.tif') pad_height = next_power_of_two(nviews + 500) pad_y = (pad_height - nviews) / 2 cmd = 'ufo-launch read path={}'.format(sin_in) cmd += ' ! pad y={} height={}'.format(pad_y, pad_height) cmd += ' addressing-mode=clamp_to_edge' cmd += ' ! transpose ! fft dimensions=1 ! filter-stripes1d strength={}'.format(RR) cmd += ' ! ifft dimensions=1 ! transpose' cmd += ' ! crop y={} height={}'.format(pad_y, nviews) cmd += ' ! write filename={}'.format(enquote(out_pattern)) return cmd def get_filter2d_sinos_cmd(tmpdir, sig_hor, sig_ver, nviews, w): sin_in = os.path.join(tmpdir, "sinos") out_pattern = os.path.join(tmpdir, "sinos-filt/sin-%04i.tif") pad_height = next_power_of_two(nviews + 500) pad_y = (pad_height - nviews) / 2 pad_width = next_power_of_two(w + 500) pad_x = (pad_width - w) / 2 cmd = "ufo-launch read path={}".format(sin_in) cmd += " ! pad x={} width={} y={} height={}".format(pad_x, pad_width, pad_y, pad_height) cmd += " addressing-mode=mirrored_repeat" cmd += " ! fft dimensions=2 ! filter-stripes horizontal-sigma={} vertical-sigma={}".format( sig_hor, sig_ver ) cmd += " ! ifft dimensions=2 crop-width={} crop-height={}".format(pad_width, pad_height) cmd += " ! crop x={} width={} y={} height={}".format(pad_x, w, pad_y, nviews) cmd += " ! write filename={}".format(enquote(out_pattern)) return cmd def get_pre_cmd( ctset, pre_cmd, tmpdir): indir = make_inpaths(ctset[0], ctset[1]) outdir = make_outpaths(tmpdir, ctset[1]) # add index to the name of the output directory with projections # if enabled preprocessing is always the first step outdir[2] = os.path.join(tmpdir, "proj-step1") # we also must create this directory to format paths correctly if not os.path.exists(outdir[2]): os.makedirs(outdir[2]) cmds = [] for i, fol in enumerate(indir): in_pattern = os.path.join(fol, "*.tif") out_pattern = os.path.join(outdir[i], "frame-%04i.tif") cmds.append("ufo-launch") cmds[i] += " read path={} ! ".format(enquote(in_pattern)) cmds[i] += pre_cmd cmds[i] += " ! write filename={}".format(enquote(out_pattern)) return cmds def get_inp_cmd(ctset, tmpdir, N, nviews): indir = make_inpaths(ctset[0], ctset[1]) cmds = [] ######### CREATE MASK ######### flat1_file = os.path.join(tmpdir, "flat1.tif") mask_file = os.path.join(tmpdir, "mask.tif") # generate mask cmd = 'tofu find-large-spots --images {}'.format(flat1_file) cmd += ' --spot-threshold {} --gauss-sigma {}'.format( SECTIONS['find-large-spots']['spot-threshold']['value'], SECTIONS['find-large-spots']['gauss-sigma']['value']) cmd += ' --output {} --output-bytes-per-file 0'.format(mask_file) cmds.append(cmd) ######### FLAT-CORRECT ######### in_proj_dir, out_pattern = fmt_in_out_path(EZVARS['inout']['tmp-dir']['value'], ctset[0], EZVARS['inout']['tomo-dir']['value']) if EZVARS['flat-correction']['smart-ffc']['value']: cmd = 'bmit_sin --fix-nan' cmd += ' --darks {} --flats {}'.format(indir[0], indir[1]) cmd += ' --projections {}'.format(in_proj_dir) cmd += ' --output {}'.format(os.path.dirname(out_pattern)) cmd += ' --multiprocessing' #cmd += ' --output {}'.format(out_pattern) if ctset[1] == 4: cmd += ' --flats2 {}'.format(indir[3]) # Add options for eigen-pco-repetitions etc. cmd += ' --eigen-pco-repetitions {}'.format(EZVARS['flat-correction']['eigen-pco-reps']['value']) cmd += ' --eigen-pco-downsample {}'.format(EZVARS['flat-correction']['eigen-pco-downsample']['value']) cmd += ' --downsample {}'.format(EZVARS['flat-correction']['downsample']['value']) #if not SECTIONS['retrieve-phase']['enable-phase']['value']: # cmd += ' --absorptivity' ???? # Todo: check if takes neglog? or only computes transmission? # in case of latter add absorptivity and fix nans cmds.append(cmd) elif not EZVARS['flat-correction']['smart-ffc']['value']: cmd = 'tofu flatcorrect --fix-nan-and-inf' cmd += ' --darks {} --flats {}'.format(indir[0], indir[1]) cmd += ' --projections {}'.format(in_proj_dir) cmd += ' --output {}'.format(out_pattern) if ctset[1] == 4: cmd += ' --flats2 {}'.format(indir[3]) if not EZVARS['retrieve-phase']['apply-pr']['value']: cmd += ' --absorptivity --fix-nan-and-inf' if not EZVARS['flat-correction']['dark-scale']['value'] == "": cmd += ' --dark-scale {}'.format(EZVARS['flat-correction']['dark-scale']['value']) if not EZVARS['flat-correction']['flat-scale']['value'] == "": cmd += ' --flat-scale {}'.format(EZVARS['flat-correction']['flat-scale']['value']) cmds.append(cmd) if not EZVARS['inout']['keep-tmp']['value'] and EZVARS['inout']['preprocess']['value']: cmds.append('rm -rf {}'.format(indir[0])) cmds.append('rm -rf {}'.format(indir[1])) cmds.append('rm -rf {}'.format(in_proj_dir)) if len(indir) > 3: cmds.append("rm -rf {}".format(indir[3])) ######### INPAINT ######### in_proj_dir, out_pattern = fmt_in_out_path(EZVARS['inout']['tmp-dir']['value'], ctset[0], EZVARS['inout']['tomo-dir']['value']) cmd = "ufo-launch [read path={} height={} number={}".format(in_proj_dir, N, nviews) cmd += ", read path={}]".format(mask_file) cmd += " ! horizontal-interpolate ! " cmd += "write filename={}".format(enquote(out_pattern)) cmds.append(cmd) if not EZVARS['inout']['keep-tmp']['value']: cmds.append("rm -rf {}".format(in_proj_dir)) return cmds def get_crop_sli(out_pattern): cmd = 'ufo-launch read path={}/*.tif ! '.format(os.path.dirname(out_pattern)) cmd += 'crop x={} width={} y={} height={} ! '. \ format(EZVARS['inout']['output-x']['value'], EZVARS['inout']['output-width']['value'], EZVARS['inout']['output-y']['value'], EZVARS['inout']['output-height']['value']) cmd += 'write filename={}'.format(out_pattern) if EZVARS['inout']['clip_hist']['value']: cmd += ' bits=8 rescale=False' return cmd def fmt_nlmdn_ufo_cmd(inpath: str, outpath: str): """ :param inp: Path to input directory before NLMDN applied :param out: Path to output directory after NLMDN applied :return: """ cmd = 'ufo-launch read path={}'.format(inpath) cmd += ' ! non-local-means patch-radius={}'.format(EZVARS['nlmdn']['patch-radius']['value']) cmd += ' search-radius={}'.format(EZVARS['nlmdn']['search-radius']['value']) cmd += ' h={}'.format(EZVARS['nlmdn']['h']['value']) cmd += ' sigma={}'.format(EZVARS['nlmdn']['sigma']['value']) cmd += ' window={}'.format(EZVARS['nlmdn']['window']['value']) cmd += ' fast={}'.format(EZVARS['nlmdn']['fast']['value']) cmd += ' estimate-sigma={}'.format(EZVARS['nlmdn']['estimate-sigma']['value']) cmd += ' ! write filename={}'.format(enquote(outpath)) if not EZVARS['nlmdn']['bigtiff_output']['value']: cmd += " bytes-per-file=0 tiff-bigtiff=False" if EZVARS['inout']['clip_hist']['value']: cmd += f" bits={SECTIONS['general']['output-bitdepth']['value']} rescale=False" return cmd././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414600.0 ufo-tofu-0.13.0/tofu/ez/util.py0000664000175000017500000004462500000000000016567 0ustar00tomastomas00000000000000""" Created on Apr 20, 2020 @author: gasilos """ import os, glob, tifffile from tofu.ez.params import EZVARS from tofu.config import SECTIONS from tofu.ez.yaml_in_out import read_yaml, write_yaml from tofu.util import get_filenames, get_first_filename, get_image_shape, read_image, restrict_value, tupleize from PyQt5.QtCore import QRegExp from PyQt5.QtGui import QRegExpValidator import argparse def get_dims(pth): # get number of projections and projections dimensions first_proj = get_first_filename(pth) multipage = False try: shape = get_image_shape(first_proj) except: raise ValueError("Failed to determine size and number of projections in {}".format(pth)) if len(shape) == 2: # single page input return len(get_filenames(pth)), [shape[-2], shape[-1]], multipage elif len(shape) == 3: # multipage input nviews = 0 for i in get_filenames(pth): nviews += get_image_shape(i)[0] multipage = True return nviews, [shape[-2], shape[-1]], multipage return -6, [-6, -6] def bad_vert_ROI(multipage, path2proj, y, height): if multipage: with tifffile.TiffFile(get_filenames(path2proj)[0]) as tif: proj = tif.pages[0].asarray().astype(float) else: proj = read_image(get_filenames(path2proj)[0]).astype(float) y_region = slice(y, min(y + height, proj.shape[0]), 1) proj = proj[y_region, :] if proj.shape[0] == 0: return True else: return False def make_copy_of_flat(flatdir, flat_copy_name, dryrun): first_flat_file = get_first_filename(flatdir) try: shape = get_image_shape(first_flat_file) except: raise ValueError("Failed to determine size and number of flats in {}".format(flatdir)) cmd = "" if len(shape) == 2: last_flat_file = get_filenames(flatdir)[-1] cmd = "cp {} {}".format(last_flat_file, flat_copy_name) else: flat = read_image(get_filenames(flatdir)[-1])[-1] if dryrun: cmd = 'echo Will save a copy of flat into "{}"'.format(flat_copy_name) else: tifffile.imwrite(flat_copy_name, flat) # something isn't right in this logic? It used to work but then # stopped to create a copy of flat correctly. Going to point to all flats simply return cmd def clean_tmp_dirs(tmpdir, fdt_names): tmp_pattern = ["proj", "sino", "mask", "flat", "dark", "radi"] tmp_pattern += fdt_names # clean directories in tmpdir if their names match pattern if os.path.exists(tmpdir): for filename in os.listdir(tmpdir): if filename[:4] in tmp_pattern: os.system("rm -rf {}".format(os.path.join(tmpdir, filename))) def make_inpaths(lvl0, flats2): """ Creates a list of paths to flats/darks/tomo directories :param lvl0: Root of directory containing flats/darks/tomo :param flats2: The type of directory: 3 contains flats/darks/tomo 4 contains flats/darks/tomo/flats2 :return: List of abs paths to the directories containing darks/flats/tomo and flats2 (if used) """ indir = [] # If using flats/darks/flats2 in same dir as tomo # or darks/flats were processed and are already in temporary directory if not EZVARS['inout']['shared-flatsdarks']['value'] or \ EZVARS['inout']['shared-df-used']['value']: for i in [EZVARS['inout']['darks-dir']['value'], EZVARS['inout']['flats-dir']['value'], EZVARS['inout']['tomo-dir']['value']]: indir.append(os.path.join(lvl0, i)) if flats2 - 3: indir.append(os.path.join(lvl0, EZVARS['inout']['flats2-dir']['value'])) return indir # If using common flats/darks/flats2 across multiple reconstructions # and that is the first occasion when they are required elif EZVARS['inout']['shared-flatsdarks']['value'] and \ not EZVARS['inout']['shared-df-used']['value']: indir.append(EZVARS['inout']['path2-shared-darks']['value']) indir.append(EZVARS['inout']['path2-shared-flats']['value']) indir.append(os.path.join(lvl0, EZVARS['inout']['tomo-dir']['value'])) if EZVARS['inout']['shared-flats-after']['value']: indir.append(EZVARS['inout']['path2-shared-flats2']['value']) if (EZVARS['COR']['search-method']['value'] != 1) and (EZVARS['COR']['search-method']['value'] != 2): # if axis search is using shared darks/flats, we still have to use them once more for ffc add_value_to_dict_entry(EZVARS['inout']['shared-df-used'], True) return indir def fmt_in_out_path(tmpdir, indir, raw_proj_dir_name, croutdir=True): # suggests input and output path to directory with proj # depending on number of processing steps applied so far li = sorted(glob.glob(os.path.join(tmpdir, "proj-step*"))) proj_dirs = [d for d in li if os.path.isdir(d)] Nsteps = len(proj_dirs) in_proj_dir, out_proj_dir = "qqq", "qqq" if Nsteps == 0: # no projections in temporary directory in_proj_dir = os.path.join(indir, raw_proj_dir_name) out_proj_dir = "proj-step1" elif Nsteps > 0: # there are directories proj-stepX in tmp dir in_proj_dir = proj_dirs[-1] out_proj_dir = "{}{}".format(in_proj_dir[:-1], Nsteps + 1) else: raise ValueError("Something is wrong with in/out filenames") # physically create output directory tmp = os.path.join(tmpdir, out_proj_dir) if croutdir and not os.path.exists(tmp): os.makedirs(tmp) # return names of input directory and output pattern with abs path return in_proj_dir, os.path.join(tmp, "proj-%04i.tif") def enquote(string, escape=False): addition = '\\"' if escape else '"' return addition + string + addition def extract_values_from_dict(dict): """Return a list of values to be saved as a text file""" new_dict = {} for key1 in dict.keys(): new_dict[key1] = {} for key2 in dict[key1].keys(): dict_entry = dict[key1][key2] if 'value' in dict_entry: new_dict[key1][key2] = {} value_type = type(dict_entry['value']) #print(key1, key2, dict_entry) if dict_entry['value'] is None: new_dict[key1][key2]['value'] = None elif value_type is list or value_type is tuple: new_dict[key1][key2]['value'] = str(reverse_tupleize()(dict_entry['value'])) else: new_dict[key1][key2]['value'] = dict_entry['value'] return new_dict def import_values_from_dict(dict, imported_dict): """Import a list of values from an imported dictionary""" for key1 in imported_dict.keys(): for key2 in imported_dict[key1].keys(): add_value_to_dict_entry(dict[key1][key2],imported_dict[key1][key2]['value']) def export_values(filePath): """Export the values of EZVARS and SECTIONS as a YAML file""" combined_dict = {} combined_dict['sections'] = extract_values_from_dict(SECTIONS) combined_dict['ezvars'] = extract_values_from_dict(EZVARS) print("Exporting values to: " + str(filePath)) #print(combined_dict) write_yaml(filePath, combined_dict) print("Finished exporting") def import_values(filePath): """Import EZVARS and SECTIONS from a YAML file""" print("Importing values from: " +str(filePath)) yaml_data = dict(read_yaml(filePath)) import_values_from_dict(EZVARS,yaml_data['ezvars']) import_values_from_dict(SECTIONS,yaml_data['sections']) print("Finished importing") #print(yaml_data) def import_values_from_params(self, params): """ Import parameter values into their corresponding dictionary entries """ print("Entering parameter values into dictionary entries") map_param_to_dict_entries = self.createMapFromParamsToDictEntry() for p in params: dict_entry = map_param_to_dict_entries[str(p)] add_value_to_dict_entry(dict_entry, params[str(p)], False) def export_values(filePath): """Export the values of EZVARS and SECTIONS as a YAML file""" combined_dict = {} combined_dict['sections'] = extract_values_from_dict(SECTIONS) combined_dict['ezvars'] = extract_values_from_dict(EZVARS) print("Exporting values to: " + str(filePath)) #print(combined_dict) write_yaml(filePath, combined_dict) print("Finished exporting") def import_values(filePath): """Import EZVARS and SECTIONS from a YAML file""" print("Importing values from: " +str(filePath)) yaml_data = dict(read_yaml(filePath)) import_values_from_dict(EZVARS,yaml_data['ezvars']) import_values_from_dict(SECTIONS,yaml_data['sections']) print("Finished importing") #print(yaml_data) def save_params(ctsetname, ax, nviews, wh): if not EZVARS['inout']['dryrun']['value'] and not os.path.exists(EZVARS['inout']['output-dir']['value']): os.makedirs(EZVARS['inout']['output-dir']['value']) tmp = os.path.join(EZVARS['inout']['output-dir']['value'], ctsetname) if not EZVARS['inout']['dryrun']['value'] and not os.path.exists(tmp): os.makedirs(tmp) if not EZVARS['inout']['dryrun']['value'] and EZVARS['inout']['save-params']['value']: # Dump the params .yaml file try: yaml_output_filepath = os.path.join(tmp, "parameters.yaml") export_values(yaml_output_filepath) except FileNotFoundError: print("Something went wrong when exporting the .yaml parameters file") # Dump the reco.params output file fname = os.path.join(tmp, 'reco.params') f = open(fname, 'w') f.write('*** General ***\n') f.write('Input directory {}\n'.format(EZVARS['inout']['input-dir']['value'])) if ctsetname == '': ctsetname = '.' f.write('CT set {}\n'.format(ctsetname)) if EZVARS['COR']['search-method']['value'] == 1 or EZVARS['COR']['search-method']['value'] == 2: f.write('Center of rotation {} (auto estimate)\n'.format(ax)) else: f.write('Center of rotation {} (user defined)\n'.format(ax)) f.write('Dimensions of projections {} x {} (height x width)\n'.format(wh[0], wh[1])) f.write('Number of projections {}\n'.format(nviews)) f.write('*** Preprocessing ***\n') tmp = 'None' if EZVARS['inout']['preprocess']['value']: tmp = EZVARS['inout']['preprocess-command']['value'] f.write(' '+tmp+'\n') f.write('*** Image filters ***\n') if EZVARS['filters']['rm_spots']['value']: f.write(' Remove large spots enabled\n') f.write(' threshold {}\n'.format(SECTIONS['find-large-spots']['spot-threshold']['value'])) f.write(' sigma {}\n'.format(SECTIONS['find-large-spots']['gauss-sigma']['value'])) else: f.write(' Remove large spots disabled\n') if EZVARS['retrieve-phase']['apply-pr']['value']: f.write(' Phase retrieval enabled\n') f.write(' energy {} keV\n'.format(SECTIONS['retrieve-phase']['energy']['value'])) f.write(' pixel size {:0.1f} um\n'.format(SECTIONS['retrieve-phase']['pixel-size']['value'] * 1e6)) f.write(' sample-detector distance {} m\n'.format(SECTIONS['retrieve-phase']['propagation-distance']['value'][0])) f.write(' delta/beta ratio {}\n'.format(SECTIONS['retrieve-phase']['regularization-rate']['value'])) else: f.write(' Phase retrieval disabled\n') f.write('*** Ring removal ***\n') if EZVARS['RR']['enable-RR']['value']: if EZVARS['RR']['use-ufo']['value']: tmp = '2D' if EZVARS['RR']['ufo-2d']['value']: tmp = '1D' f.write(' RR with ufo {} stripes filter\n'.format(tmp)) f.write(f' sigma horizontal {EZVARS["RR"]["sx"]["value"]}') f.write(f' sigma vertical {EZVARS["RR"]["sy"]["value"]}') else: if EZVARS['RR']['spy-rm-wide']['value']: tmp = ' RR with ufo sarepy remove wide filter, ' tmp += 'window {}, SNR {}\n'.format( EZVARS['RR']['spy-wide-window']['value'], EZVARS['RR']['spy-wide-SNR']['value']) f.write(tmp) f.write(' ' 'RR with ufo sarepy sorting filter, window {}\n'. format(EZVARS['RR']['spy-narrow-window']['value']) ) else: f.write('RR disabled\n') f.write('*** Region of interest ***\n') if EZVARS['inout']['input_ROI']['value']: f.write('Vertical ROI defined\n') f.write(' first row {}\n'.format(SECTIONS['reading']['y']['value'])) f.write(' height {}\n'.format(SECTIONS['reading']['height']['value'])) f.write(' reconstruct every {}th row\n'.format(SECTIONS['reading']['y-step']['value'])) else: f.write('Vertical ROI: all rows\n') if EZVARS['inout']['output-ROI']['value']: f.write('ROI in slice plane defined\n') f.write(' x {}\n'.format(EZVARS['inout']['output-x']['value'])) f.write(' width {}\n'.format(EZVARS['inout']['output-width']['value'])) f.write(' y {}\n'.format(EZVARS['inout']['output-y']['value'])) f.write(' height {}\n'.format(EZVARS['inout']['output-height']['value'])) else: f.write('ROI in slice plane not defined\n') f.write('*** Reconstructed values ***\n') if EZVARS['inout']['clip_hist']['value']: f.write(' {} bit\n'.format(SECTIONS['general']['output-bitdepth']['value'])) f.write(' Min value in 32-bit histogram {}\n'.format(SECTIONS['general']['output-minimum']['value'])) f.write(' Max value in 32-bit histogram {}\n'.format(SECTIONS['general']['output-maximum']['value'])) else: f.write(' 32bit, histogram untouched\n') f.write('*** Optional reco parameters ***\n') if SECTIONS['general-reconstruction']['volume-angle-z']['value'][0] > 0: f.write(' Rotate volume by: {:0.3f} deg\n'.format(SECTIONS['general-reconstruction']['volume-angle-z']['value'][0])) f.close() ### ALL The following was added by Philmo Gu. I moved it to tofu/ez/utils. . # The important function def add_value_to_dict_entry(dict_entry, value): """Add a value to a dictionary entry. An empty string will insert the ezdefault value""" if 'action' in dict_entry: # no 'type' can be defined in dictionary entries with 'action' key dict_entry['value'] = bool(value) return elif value == '' or value == None: # takes default value if empty string or null if dict_entry['ezdefault'] is None: dict_entry['value'] = dict_entry['ezdefault'] else: dict_entry['value'] = dict_entry['type'](dict_entry['ezdefault']) else: try: dict_entry['value'] = dict_entry['type'](value) except argparse.ArgumentTypeError: # Outside of range of type dict_entry['value'] = dict_entry['type'](value, clamp=True) except ValueError: # int can't convert string with decimal (e.g. "1.0" -> 1) dict_entry['value'] = dict_entry['type'](float(value)) # Few things are helpful but most are not used or not fully implemented def get_ascii_validator(): """Returns a validator that only allows the input of visible ASCII characters""" regexp = "[-A-Za-z0-9_]*" return QRegExpValidator(QRegExp(regexp)) def get_alphabet_lowercase_validator(): """Returns a validator that only allows the input of lowercase ASCII characters""" regexp = "[a-z]*" return QRegExpValidator(QRegExp(regexp)) def get_int_validator(): """Returns a validator that only allows the input of integers""" # Note: QIntValidator allows commas, which is undesirable regexp = "[\-]?[0-9]*" return QRegExpValidator(QRegExp(regexp)) def get_double_validator(): """Returns a validator that only allows the input of floating point number""" # Note: QDoubleValidator allows commas before period, which is undesirable regexp = "[\-]?[0-9]*[.]?[0-9]*" return QRegExpValidator(QRegExp(regexp)) def get_tuple_validator(): """Returns a validator that only allows a tuple of floating point numbers""" regexp = "[-0-9,.]*" return QRegExpValidator(QRegExp(regexp)) def load_values_from_ezdefault(dict): """Add or replace values from ezdefault in a dictionary""" for key1 in dict.keys(): for key2 in dict[key1].keys(): dict_entry = dict[key1][key2] if 'ezdefault' in dict_entry: add_value_to_dict_entry(dict_entry, '') # Add default value def restrict_tupleize(limits, num_items=None, conv=float, dtype=tuple): """Convert a string of numbers separated by commas to tuple with *dtype* and make sure it is within *limits* (included) specified as tuple (min, max). If one of the limits values is None it is ignored.""" def check(value=None, clamp=False): if value is None: return limits results = tupleize(num_items, conv, dtype)(value) for v in results: restrict_value(limits, dtype=conv)(v, clamp) return results return check def reverse_tupleize(num_items=None, conv=float): """Convert a tuple into a comma-separted string of *value*""" def combine_to_string(value): """Combine a tuple of numbers into a comma-separated string""" result = "" if num_items and len(result) != num_items: # A certain number of output is expected raise argparse.ArgumentTypeError('Expected {} items'.format(num_items)) if (len(value) == 0): # No tuple to convert into string return result # Tuple with non-zero lengthh for v in value: result = result + "," + str(conv(v)) result = result[1:] # Remove the erroneous first period return result return combine_to_string ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414563.0 ufo-tofu-0.13.0/tofu/ez/yaml_in_out.py0000664000175000017500000000072300000000000020120 0ustar00tomastomas00000000000000import yaml import logging LOG = logging.getLogger(__name__) def read_yaml(filePath): with open(filePath) as f: data = yaml.load(f, Loader=yaml.FullLoader) LOG.debug("Imported YAML file:") LOG.debug(data) return data def write_yaml(filePath, params): try: file = open(filePath, "w") except FileNotFoundError: LOG.debug("No filename given") else: yaml.dump(params, file) file.close() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414563.0 ufo-tofu-0.13.0/tofu/find_large_spots.py0000664000175000017500000001077500000000000020515 0ustar00tomastomas00000000000000import logging import glob import os from gi.repository import Ufo from tofu.util import ( get_filtering_padding, set_node_props, determine_shape, read_image, setup_read_task, setup_padding ) from tofu.tasks import get_task, get_writer LOG = logging.getLogger(__name__) def find_large_spots_median(args): import numpy as np import skimage.morphology as sm import tifffile from skimage.filters import median from scipy.ndimage import binary_fill_holes if os.path.isfile(args.images): filenames = [args.images] else: filenames = sorted(glob.glob(os.path.join(args.images, '*.*'))) if not filenames: raise RuntimeError("No images found in `{}'".format(args.images)) image = read_image(filenames[0]) if image.ndim == 3: image = np.mean(image, axis=0) mask = np.zeros_like(image, dtype=np.uint8) med = median(image, [np.ones(args.median_width)]) # First, pixels which are too bright are marked mask[image > args.spot_threshold] = 1 # Then the ones which are way brighter than the neighborhood mask[np.abs(image.astype(float) - med) > args.grow_threshold] = 1 mask = binary_fill_holes(mask) mask = sm.dilation(mask, sm.disk(args.dilation_disk_radius)) tifffile.imsave(args.output, mask.astype(np.float32)) def find_large_spots(args): graph = Ufo.TaskGraph() sched = Ufo.FixedScheduler() reader = get_task('read') writer = get_writer(args) if args.gauss_sigma and args.blurred_output: broadcast = Ufo.CopyTask() blurred_writer = get_task('write') if hasattr(blurred_writer.props, 'bytes_per_file'): blurred_writer.props.bytes_per_file = 0 if hasattr(blurred_writer.props, 'tiff_bigtiff'): blurred_writer.props.tiff_bigtiff = False blurred_writer.props.filename = args.blurred_output find = get_task('find-large-spots') set_node_props(find, args) find.props.addressing_mode = args.find_large_spots_padding_mode set_node_props(reader, args) setup_read_task(reader, args.images, args) if args.gauss_sigma: width, height = determine_shape(args, path=args.images) pad = get_task('pad') crop = get_task('crop') if args.vertical_sigma: pad_width = 0 pad_height = get_filtering_padding(height) fft = get_task('fft', dimensions=2) ifft = get_task('ifft', dimensions=2) filter_stripes = get_task( 'filter-stripes', vertical_sigma=args.gauss_sigma, horizontal_sigma=0.0 ) graph.connect_nodes(reader, pad) if args.transpose_input: transpose = get_task('transpose') itranspose = get_task('transpose') graph.connect_nodes(pad, transpose) graph.connect_nodes(transpose, fft) else: graph.connect_nodes(pad, fft) graph.connect_nodes(fft, filter_stripes) graph.connect_nodes(filter_stripes, ifft) if args.transpose_input: graph.connect_nodes(ifft, itranspose) graph.connect_nodes(itranspose, crop) else: graph.connect_nodes(ifft, crop) last = crop else: reader_2 = get_task('read') set_node_props(reader_2, args) setup_read_task(reader_2, args.images, args) opencl = get_task('opencl', kernel='diff', filename='opencl.cl') gauss_size = int(10 * args.gauss_sigma) pad_width = pad_height = gauss_size LOG.debug("Gauss size: %d", gauss_size) blur = get_task('blur', sigma=args.gauss_sigma, size=gauss_size) graph.connect_nodes_full(reader, opencl, 0) graph.connect_nodes(reader_2, pad) graph.connect_nodes(pad, blur) graph.connect_nodes(blur, crop) graph.connect_nodes_full(crop, opencl, 1) last = opencl setup_padding(pad, width, height, args.find_large_spots_padding_mode, crop=crop, pad_width=pad_width, pad_height=pad_height) if args.blurred_output: graph.connect_nodes(last, broadcast) graph.connect_nodes(broadcast, blurred_writer) source = broadcast else: source = last graph.connect_nodes(source, find) else: graph.connect_nodes(reader, find) graph.connect_nodes(find, writer) sched.run(graph) ././@PaxHeader0000000000000000000000000000003400000000000011452 xustar000000000000000028 mtime=1698416097.7697759 ufo-tofu-0.13.0/tofu/flow/0000775000175000017500000000000000000000000015556 5ustar00tomastomas00000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1665760159.0 ufo-tofu-0.13.0/tofu/flow/__init__.py0000664000175000017500000000000000000000000017655 0ustar00tomastomas00000000000000././@PaxHeader0000000000000000000000000000003400000000000011452 xustar000000000000000028 mtime=1698416097.7697759 ufo-tofu-0.13.0/tofu/flow/composites/0000775000175000017500000000000000000000000017743 5ustar00tomastomas00000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1665760159.0 ufo-tofu-0.13.0/tofu/flow/composites/ffc-links.cm0000664000175000017500000002210200000000000022135 0ustar00tomastomas00000000000000{ "name": "CFlatFieldCorrect", "caption": "CFlatFieldCorrect", "models": { "Flat Field Correct": { "model": { "caption": "Flat Field Correct", "properties": { "fix-nan-and-inf": [ true, true ], "absorption-correct": [ true, true ], "sinogram-input": [ false, false ], "dark-scale": [ 1.0, false ], "flat-scale": [ 1.0, false ] } }, "visible": true, "position": { "x": 1253.0, "y": 490.0 }, "name": "flat_field_correct" }, "Read 2": { "model": { "caption": "Read 2", "properties": { "path": [ ".", true ], "start": [ 0, false ], "number": [ 4294967295, true ], "step": [ 1, false ], "y": [ 0, false ], "height": [ 0, false ], "y-step": [ 1, false ], "convert": [ true, false ], "raw-width": [ 0, false ], "raw-height": [ 0, false ], "raw-bitdepth": [ 0, false ], "raw-pre-offset": [ 0, false ], "raw-post-offset": [ 0, false ], "type": [ "unspecified", false ], "retries": [ 0, false ], "retry-timeout": [ 1, false ] } }, "visible": true, "position": { "x": 417.0, "y": 504.0 }, "name": "read" }, "Average": { "model": { "caption": "Average", "properties": { "number": [ 4294967295, true ] } }, "visible": true, "position": { "x": 822.0, "y": 508.0 }, "name": "average" }, "Read 3": { "model": { "caption": "Read 3", "properties": { "path": [ ".", true ], "start": [ 0, false ], "number": [ 4294967295, true ], "step": [ 1, false ], "y": [ 0, false ], "height": [ 0, false ], "y-step": [ 1, false ], "convert": [ true, false ], "raw-width": [ 0, false ], "raw-height": [ 0, false ], "raw-bitdepth": [ 0, false ], "raw-pre-offset": [ 0, false ], "raw-post-offset": [ 0, false ], "type": [ "unspecified", false ], "retries": [ 0, false ], "retry-timeout": [ 1, false ] } }, "visible": true, "position": { "x": 413.0, "y": 735.0 }, "name": "read" }, "Average 2": { "model": { "caption": "Average 2", "properties": { "number": [ 4294967295, true ] } }, "visible": true, "position": { "x": 822.0, "y": 741.0 }, "name": "average" }, "Read": { "model": { "caption": "Read", "properties": { "path": [ ".", true ], "start": [ 0, false ], "number": [ 23212, true ], "step": [ 1, false ], "y": [ 0, false ], "height": [ 0, false ], "y-step": [ 1, false ], "convert": [ true, false ], "raw-width": [ 0, false ], "raw-height": [ 0, false ], "raw-bitdepth": [ 0, false ], "raw-pre-offset": [ 0, false ], "raw-post-offset": [ 0, false ], "type": [ "unspecified", false ], "retries": [ 0, false ], "retry-timeout": [ 1, false ] } }, "visible": true, "position": { "x": 418.0, "y": 245.0 }, "name": "read" } }, "connections": [ [ "Read", 0, "Flat Field Correct", 0 ], [ "Average", 0, "Flat Field Correct", 1 ], [ "Average 2", 0, "Flat Field Correct", 2 ], [ "Read 2", 0, "Average", 0 ], [ "Read 3", 0, "Average 2", 0 ] ], "links": [ [ [ "Read 2", "number" ], [ "Average", "number" ] ], [ [ "Read 3", "number" ], [ "Average 2", "number" ] ] ] } ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1665760159.0 ufo-tofu-0.13.0/tofu/flow/composites/pr.cm0000664000175000017500000001151600000000000020711 0ustar00tomastomas00000000000000{ "name": "CPhaseRetrieve", "caption": "CPhaseRetrieve", "models": { "Fft": { "model": { "caption": "Fft", "properties": { "auto-zeropadding": [ true, true ], "dimensions": [ 2, true ], "size-x": [ 1, true ], "size-y": [ 1, true ], "size-z": [ 1, true ] } }, "visible": true, "position": { "x": 112.0, "y": 245.0 }, "name": "fft" }, "Ifft": { "model": { "caption": "Ifft", "properties": { "dimensions": [ 2, true ], "crop-width": [ -1, true ], "crop-height": [ -1, true ] } }, "visible": true, "position": { "x": 772.0, "y": 250.0 }, "name": "ifft" }, "Retrieve Phase": { "model": { "caption": "Retrieve Phase", "num-inputs": 1, "properties": { "method": [ "tie", true ], "energy": [ 20.0, true ], "distance": [ 0.0, true ], "distance-x": [ 0.0, true ], "distance-y": [ 0.0, true ], "pixel-size": [ 7.500000265281415e-07, true ], "regularization-rate": [ 2.5, true ], "thresholding-rate": [ 0.10000000149011612, true ], "frequency-cutoff": [ 3.4028234663852886e+38, true ], "output-filter": [ false, true ] } }, "visible": true, "position": { "x": 544.0, "y": 515.0 }, "name": "retrieve_phase" }, "Pad": { "model": { "caption": "Pad", "properties": { "width": [ 0, true ], "height": [ 0, true ], "x": [ 0, true ], "y": [ 0, true ], "addressing-mode": [ "clamp", true ] } }, "visible": true, "position": { "x": 0.0, "y": 570.0 }, "name": "pad" } }, "connections": [ [ "Pad", 0, "Fft", 0 ], [ "Fft", 0, "Retrieve Phase", 0 ], [ "Retrieve Phase", 0, "Ifft", 0 ] ], "links": [ [ [ "Fft", "dimensions" ], [ "Ifft", "dimensions" ] ], [ [ "Fft", "size-x" ], [ "Pad", "width" ] ], [ [ "Fft", "size-y" ], [ "Pad", "height" ] ] ] } ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1665760159.0 ufo-tofu-0.13.0/tofu/flow/config.json0000664000175000017500000000677100000000000017731 0ustar00tomastomas00000000000000{ "models": { "average": { "hidden-properties": [ "number" ] }, "flat-field-correct": { "port-captions": { "input": { "0": "radios", "1": "darks", "2": "flats" }, "output": { "0": "" } }, "hidden-properties": [ "sinogram-input", "dark-scale", "flat-scale" ] }, "general-backproject": { "hidden-properties": [ "z", "burst", "source-position-x", "source-position-y", "source-position-z", "detector-position-x", "detector-position-y", "detector-position-z", "detector-angle-x", "detector-angle-y", "detector-angle-z", "axis-angle-x", "axis-angle-y", "axis-angle-z", "volume-angle-x", "volume-angle-y", "volume-angle-z", "compute-type", "result-type", "store-type", "addressing-mode", "gray-map-min", "gray-map-max" ], "range-properties": { "region": [3, true], "x-region": [3, true], "y-region": [3, true], "center-position-x": [null, true], "center-position-z": [null, true], "source-position-x": [null, true], "source-position-y": [null, true], "source-position-z": [null, true], "detector-position-x": [null, true], "detector-position-y": [null, true], "detector-position-z": [null, true], "detector-angle-x": [null, true], "detector-angle-y": [null, true], "detector-angle-z": [null, true], "axis-angle-x": [null, true], "axis-angle-y": [null, true], "axis-angle-z": [null, true], "volume-angle-x": [null, true], "volume-angle-y": [null, true], "volume-angle-z": [null, true] } }, "horizontal-interpolate": { "port-captions": { "input": { "0": "image", "1": "mask" }, "output": { "0": "" } } }, "read": { "hidden-properties": [ "start", "step", "y", "height", "y-step", "convert", "raw-width", "raw-height", "raw-bitdepth", "raw-pre-offset", "raw-post-offset", "type", "retries", "retry-timeout" ] }, "write": { "hidden-properties": [ "counter-start", "counter-step", "bytes-per-file", "append", "bits", "minimum", "maximum", "rescale", "jpeg-quality", "tiff-bigtiff" ] } } } ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1683790582.0 ufo-tofu-0.13.0/tofu/flow/execution.py0000664000175000017500000002025600000000000020140 0ustar00tomastomas00000000000000import gi import logging import networkx as nx gi.require_version('Ufo', '0.0') from gi.repository import Ufo from PyQt5.QtCore import QObject, pyqtSignal from qtpynodeeditor import PortType from threading import Thread from tofu.flow.models import ARRAY_DATA_TYPE, UFO_DATA_TYPE, UfoTaskModel from tofu.flow.util import FlowError LOG = logging.getLogger(__name__) class UfoExecutor(QObject): """Class holding GPU resources and organizing UFO graph execution.""" number_of_inputs_changed = pyqtSignal(int) # Number of inputs has been determined processed_signal = pyqtSignal(int) # Image has been processed execution_started = pyqtSignal() # Graph execution started execution_finished = pyqtSignal() # Graph execution finished exception_occured = pyqtSignal(str) def __init__(self): super().__init__(parent=None) self._resources = Ufo.Resources() self._reset() # If True only log the exception and emit the signal but don't re-raise it in the executing # thread self.swallow_run_exceptions = False def _reset(self): self._aborted = False self._schedulers = [] self.num_generated = 0 def abort(self): LOG.debug('Execution aborted') try: self._aborted = True for scheduler in self._schedulers: scheduler.abort() finally: self.execution_finished.emit() def on_processed(self, ufo_task): self.processed_signal.emit(self.num_generated) self.num_generated += 1 def setup_ufo_graph(self, graph, gpu=None, region=None, signalling_model=None): ufo_graph = Ufo.TaskGraph() ufo_tasks = {} for source, dest, ports in graph.edges.data(): if hasattr(source, 'create_ufo_task') and hasattr(dest, 'create_ufo_task'): if dest not in ufo_tasks: ufo_tasks[dest] = dest.create_ufo_task(region=region) if source not in ufo_tasks: ufo_tasks[source] = source.create_ufo_task(region=region) ufo_graph.connect_nodes_full(ufo_tasks[source], ufo_tasks[dest], ports[PortType.input]) LOG.debug(f'{source.name}->{dest.name}@{ports[PortType.input]}') if source == signalling_model: ufo_tasks[source].connect('generated', self.on_processed) if gpu is not None: for task in ufo_tasks.values(): if task.uses_gpu(): task.set_proc_node(gpu) return ufo_graph def _run_ufo_graph(self, ufo_graph, use_fixed_scheduler): LOG.debug(f'Executing graph, fixed scheduler: {use_fixed_scheduler}') try: scheduler = Ufo.FixedScheduler() if use_fixed_scheduler else Ufo.Scheduler() self._schedulers.append(scheduler) scheduler.set_resources(self._resources) scheduler.run(ufo_graph) LOG.info(f'Execution time: {scheduler.props.time} s') except Exception as e: # Do not continue execution of other batches self._aborted = True LOG.error(e, exc_info=True) self.exception_occured.emit(str(e)) if not self.swallow_run_exceptions: raise e def check_graph(self, graph): """ Check that *graph* starts with an UfoTaskModel and ends with either that or an UfoModel but no UfoTaskModel successor exists (there can be only one UFO path in the graph). """ roots = [n for n in graph.nodes if graph.in_degree(n) == 0] leaves = [n for n in graph.nodes if graph.out_degree(n) == 0] for root in roots: for leave in leaves: for path in nx.simple_paths.all_simple_paths(graph, root, leave): if not isinstance(path[0], UfoTaskModel): raise FlowError('Flow must start with an UFO node') ufo_ended = False for (i, succ) in enumerate(path[1:]): model = path[i] edge_data = graph.get_edge_data(model, succ) if len(edge_data) > 1: # There cannot be multiple edges between nodes raise FlowError('Multiple edges not allowed but detected ' 'between {model} and {succ}') out_index = edge_data[0]['output'] # We don't need to check if input data type is ARRAY_DATA_TYPE because # UFO_DATA_TYPE cannot be connected to ARRAY_DATA_TYPE in the scene if ufo_ended: # From now on only non-UFO tasks are allowed if model.data_type['output'][out_index] != ARRAY_DATA_TYPE: raise FlowError('After a non-UFO node cannot come another UFO node') elif model.data_type['output'][out_index] != UFO_DATA_TYPE: # Output is non-UFO, UFO ends here ufo_ended = True def run(self, graph): self._reset() self.check_graph(graph) gpus = self._resources.get_gpu_nodes() num_inputs = -1 signalling_model = None for model in graph.nodes: if graph.in_degree(model) == 0: if 'number' in model: current = model['number'] if current > num_inputs: num_inputs = current signalling_model = model batches = [[(None, None)]] gpu_splitting_model = None gpu_splitting_models = get_gpu_splitting_models(graph) if len(gpu_splitting_models) > 1: # There cannot be multiple splitting models raise FlowError('Only one gpu splitting model is allowed') elif gpu_splitting_models: gpu_splitting_model = gpu_splitting_models[0] batches = gpu_splitting_model.split_gpu_work(self._resources.get_gpu_nodes()) for model in graph.nodes: # Reset internal model state if hasattr(model, 'reset_batches'): model.reset_batches() LOG.debug(f'{len(batches)} batches: {batches}') if signalling_model: self.number_of_inputs_changed.emit(len(batches) * num_inputs) LOG.debug(f'Number of inputs: {len(batches) * num_inputs}, defined ' f'by {signalling_model}') def execute_batches(): self.execution_started.emit() try: for (i, parallel_batch) in enumerate(batches): LOG.info(f'starting batch {i}: {parallel_batch}') threads = [] for gpu_index, region in parallel_batch: if self._aborted: break gpu = None if gpu_index is None else gpus[gpu_index] ufo_graph = self.setup_ufo_graph(graph, gpu=gpu, region=region, signalling_model=signalling_model) t = Thread(target=self._run_ufo_graph, args=(ufo_graph, len(gpu_splitting_models) > 0)) t.daemon = True threads.append(t) t.start() for t in threads: t.join() if self._aborted: break except Exception as e: LOG.error(e, exc_info=True) self.exception_occured.emit(str(e)) raise e finally: self.execution_finished.emit() gt = Thread(target=execute_batches) gt.daemon = True gt.start() def get_gpu_splitting_models(graph): gpu_splitting_models = [] for model in graph.nodes: if isinstance(model, UfoTaskModel) and model.can_split_gpu_work: gpu_splitting_models.append(model) return gpu_splitting_models ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1683790582.0 ufo-tofu-0.13.0/tofu/flow/filedirdialog.py0000664000175000017500000000131300000000000020724 0ustar00tomastomas00000000000000import os from PyQt5.QtWidgets import QFileDialog class FileDirDialog(QFileDialog): """ A workaround for being able to select both files and directories. Source: https://stackoverflow.com/questions/27520304/qfiledialog-that-accepts-a-single-file-or-a-single-directory """ def __init__(self, parent=None): super().__init__(parent=parent) self.setOption(QFileDialog.DontUseNativeDialog) self.setFileMode(QFileDialog.Directory) self.currentChanged.connect(self._selected) def _selected(self, name): if os.path.isdir(name): self.setFileMode(QFileDialog.Directory) else: self.setFileMode(QFileDialog.ExistingFile) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414563.0 ufo-tofu-0.13.0/tofu/flow/main.py0000664000175000017500000005727300000000000017072 0ustar00tomastomas00000000000000import json import logging import os import pathlib import sys from PyQt5.QtCore import Qt, QObject, QPoint, pyqtSignal from PyQt5.QtWidgets import (QApplication, QFileDialog, QWidget, QVBoxLayout, QMenuBar, QMessageBox, QProgressBar, QMainWindow, QStyle) from qtpynodeeditor import DataModelRegistry, FlowView import xdg.BaseDirectory from tofu.flow.execution import UfoExecutor from tofu.flow.models import (BaseCompositeModel, get_composite_model_classes_from_json, get_composite_model_classes, get_ufo_model_classes, ImageViewerModel, UfoGeneralBackprojectModel, UfoMemoryOutModel, UfoOpenCLModel, UfoReadModel, UfoRetrievePhaseModel, UfoWriteModel) from tofu.flow.scene import UfoScene from tofu.flow.propertylinkswidget import PropertyLinks from tofu.flow.runslider import RunSlider from tofu.flow.util import FlowError LOG = logging.getLogger(__name__) class ApplicationWindow(QMainWindow): def __init__(self, ufo_scene): super().__init__() self.ufo_scene = ufo_scene self.property_links_widget = PropertyLinks(ufo_scene.node_model, ufo_scene.property_links_model, parent=self) self.run_slider = RunSlider(parent=self) self.executor = UfoExecutor() self.console = None self.run_slider_key = (None, None) self.last_dirs = {'scene': None, 'composite': None} self._creating_composite = False self._expanding_composite = False central_widget = QWidget() self.setCentralWidget(central_widget) main_layout = QVBoxLayout(central_widget) self.flow_view = FlowView(self.ufo_scene) self.progress_bar = QProgressBar() self.progress_bar.setMinimum(0) menu_bar = QMenuBar() flow_menu = menu_bar.addMenu('Flow') new_action = flow_menu.addAction("New") new_action.setShortcut('Ctrl+N') new_action.triggered.connect(self.on_new) save_action = flow_menu.addAction("Save") save_action.setShortcut('Ctrl+S') save_action.triggered.connect(self.on_save) save_json_action = flow_menu.addAction("Save json") save_json_action.setShortcut('Ctrl+J') save_json_action.triggered.connect(self.on_save_json) load_action = flow_menu.addAction("Open") load_action.setShortcut('Ctrl+O') load_action.triggered.connect(self.on_open) self.run_action = flow_menu.addAction(self.style().standardIcon(QStyle.SP_MediaPlay), 'Run') self.run_action.setShortcut('Ctrl+R') self.run_action.triggered.connect(self.on_run) abort_action = flow_menu.addAction(self.style().standardIcon(QStyle.SP_MediaStop), 'Abort') abort_action.setShortcut('Ctrl+Shift+X') abort_action.triggered.connect(self.executor.abort) exit_action = flow_menu.addAction('Exit') exit_action.setShortcut('Ctrl+Q') exit_action.triggered.connect(self.close) # Nodes submenu selection_menu = menu_bar.addMenu('Nodes') selection_menu.setToolTipsVisible(True) selection_menu.aboutToShow.connect(self.on_selection_menu_about_to_show) self.skip_action = selection_menu.addAction('Skip Toggle') self.skip_action.setShortcut('S') self.skip_action.triggered.connect(self.ufo_scene.skip_nodes) auto_fill_action = selection_menu.addAction('Auto fill') auto_fill_action.triggered.connect(self.ufo_scene.auto_fill) copy_action = selection_menu.addAction("Duplicate") copy_action.setShortcut('Ctrl+Shift+D') copy_action.triggered.connect(self.ufo_scene.copy_nodes) # Composite create_composite_action = selection_menu.addAction("Create Composite") create_composite_action.setShortcut('Ctrl+Shift+C') create_composite_action.triggered.connect(self.on_create_composite) import_composites_action = selection_menu.addAction("Import Composites") import_composites_action.setToolTip('Import one or more composite nodes ' 'from a file or files') import_composites_action.setShortcut('Ctrl+I') import_composites_action.triggered.connect(self.on_import_composites) self.export_composite_action = selection_menu.addAction("Export Composite") self.export_composite_action.triggered.connect(self.on_export_composite) self.edit_composite_action = selection_menu.addAction("Edit Composite") self.edit_composite_action.triggered.connect(self.on_edit_composite) self.expand_composite_action = selection_menu.addAction("Expand Composite") self.expand_composite_action.setShortcut('Ctrl+Shift+E') self.expand_composite_action.triggered.connect(self.on_expand_composite) view_menu = menu_bar.addMenu('View') reset_view_action = view_menu.addAction("Reset Zoom") reset_view_action.setShortcut('Ctrl+0') reset_view_action.triggered.connect(self.on_reset_view) property_links_action = view_menu.addAction("Link Properties") property_links_action.setShortcut('Ctrl+L') property_links_action.triggered.connect(self.on_property_links_action) console_action = view_menu.addAction("Open Python Console") console_action.setShortcut('Ctrl+Shift+P') console_action.triggered.connect(self.on_console_action) run_slider_action = view_menu.addAction("Run Slider") run_slider_action.setShortcut('Ctrl+Shift+S') run_slider_action.triggered.connect(self.on_run_slider_action) self.fix_run_slider = view_menu.addAction("Fix Run Slider") self.fix_run_slider.setCheckable(True) self.fix_run_slider.setShortcut('Ctrl+Alt+Shift+S') main_layout.addWidget(menu_bar) main_layout.addWidget(self.flow_view) main_layout.addWidget(self.progress_bar) main_layout.setContentsMargins(0, 0, 0, 0) main_layout.setSpacing(0) self.resize(1280, 1000) # Signals self.executor.exception_occured.connect(self.on_exception_occured) self.executor.execution_finished.connect(self.on_execution_finished) self.executor.number_of_inputs_changed.connect(self.on_number_of_inputs_changed) self.executor.processed_signal.connect(self.on_processed) self.ufo_scene.node_deleted.connect(self.on_node_deleted) self.ufo_scene.nodes_duplicated.connect(self.on_nodes_duplicated) self.ufo_scene.item_focus_in.connect(self.on_item_focus_in) self.run_slider.value_changed.connect(self.on_run_slider_value_changed) self.setWindowTitle('tofu flow') def on_save(self): if self.last_dirs['scene']: path = self.last_dirs['scene'] else: path = xdg.BaseDirectory.save_data_path('tofu', 'flows') if not os.path.exists(path): os.makedirs(path) file_name, _ = QFileDialog.getSaveFileName(self, "Select File Name", str(path), "Flow Scene Files (*.flow)") if file_name: self.last_dirs['scene'] = os.path.dirname(file_name) self.ufo_scene.save(file_name) def on_new(self): self.run_slider.reset() self.ufo_scene.clear_scene() self.setWindowTitle('tofu flow') def on_open(self): if self.last_dirs['scene']: path = self.last_dirs['scene'] else: path = xdg.BaseDirectory.save_data_path('tofu', 'flows') if not os.path.exists(path): path = pathlib.Path.home() file_name, _ = QFileDialog.getOpenFileName(self, "Open Flow Scene", str(path), "Flow Scene Files (*.flow)") if file_name: self.last_dirs['scene'] = os.path.dirname(file_name) self.ufo_scene.load(file_name) self.run_slider.reset() self.setWindowTitle(file_name) def on_exception_occured(self, text): msg = QMessageBox(parent=self) msg.setIcon(QMessageBox.Critical) msg.setText(text) msg.setWindowTitle("Error") msg.exec_() def on_number_of_inputs_changed(self, value): self.progress_bar.setMaximum(value) def on_processed(self, value): self.progress_bar.setValue(value + 1) def on_node_deleted(self, node): slider_model, prop_name = self.run_slider_key if slider_model: if (isinstance(node.model, BaseCompositeModel) and node.model.is_model_inside(slider_model) and not (self._expanding_composite or self._creating_composite)): self.run_slider.reset() self.run_slider_key = (None, None) elif node.model == slider_model and not self._creating_composite: self.run_slider.reset() self.run_slider_key = (None, None) def on_nodes_duplicated(self, selected_nodes, new_nodes): min_y = float('inf') y_1 = float('-inf') for node in selected_nodes: height = node.model.embedded_widget().height() y = node.graphics_object.y() if y < min_y: min_y = y if y + height > y_1: y_1 = y + height for node in selected_nodes: dy = node.graphics_object.y() - min_y new_pos = QPoint(int(node.graphics_object.x()), int(dy + y_1 + 100)) new_nodes[node].graphics_object.setPos(new_pos) def on_item_focus_in(self, item, prop_name, caption, model): if not self.fix_run_slider.isChecked() or not self.run_slider.view_item: if self.run_slider.setup(item): self.run_slider_key = (model, prop_name) self.run_slider.setWindowTitle(f'{caption}->{prop_name}') def on_selection_menu_about_to_show(self): composites = False num_selected = len(self.ufo_scene.selected_nodes()) for node in self.ufo_scene.selected_nodes(): if isinstance(node.model, BaseCompositeModel): composites = True break self.edit_composite_action.setEnabled(num_selected == 1 and composites) self.export_composite_action.setEnabled(num_selected == 1 and composites) self.expand_composite_action.setEnabled(composites) self.skip_action.setEnabled(self.ufo_scene.selected_nodes() != []) def on_edit_composite(self): if self.ufo_scene.is_selected_one_composite(): # Check again in case this was invoked by the keyboard shortcut node = self.ufo_scene.selected_nodes()[0] node.model.edit_in_window(self) def on_create_composite(self): self._creating_composite = True try: path = None prop_name = self.run_slider_key[1] if self.run_slider_key[0]: for node in self.ufo_scene.selected_nodes(): if isinstance(node.model, BaseCompositeModel): if node.model.is_model_inside(self.run_slider_key[0]): path = node.model.get_path_from_model(self.run_slider_key[0]) elif node.model == self.run_slider_key[0]: path = [self.run_slider_key[0]] composite_model = self.ufo_scene.create_composite().model if path: str_path = [model.caption for model in path] new_model = composite_model.get_model_from_path(str_path) new_view_item = new_model.get_view_item(prop_name) # Do not make complete setup, that would reset limits, just update the view item self.run_slider.view_item = new_view_item self.run_slider_key = (new_model, prop_name) title = '->'.join([composite_model.caption] + str_path + [prop_name]) self.run_slider.setWindowTitle(title) finally: self._creating_composite = False def on_expand_composite(self): self._expanding_composite = True try: slider_model, prop_name = self.run_slider_key for node in self.ufo_scene.selected_nodes(): if isinstance(node.model, BaseCompositeModel): if slider_model: str_path = None if node.model.is_model_inside(slider_model): str_path = [model.caption for model in node.model.get_path_from_model(slider_model)] new_nodes = self.ufo_scene.expand_composite(node)[0] # Pass the new node to the run slider if it was contained in this composite if slider_model and str_path: if slider_model.caption in new_nodes: # runslider linked to a simple node after expanstion slider_model = new_nodes[slider_model.caption].model self.run_slider_key = (slider_model, prop_name) new_view_item = slider_model.get_view_item(prop_name) # Do not make complete setup, that would reset limits, just update the # view item self.run_slider.view_item = new_view_item self.run_slider.setWindowTitle(f'{slider_model.caption}->{prop_name}') else: # runslider linked to another composite node (nesting) after expanstion for node in new_nodes.values(): if isinstance(node.model, BaseCompositeModel): if node.model.contains_path(str_path[2:]): new_model = node.model.get_model_from_path(str_path[2:]) self.run_slider_key = (new_model, prop_name) new_view_item = new_model.get_view_item(prop_name) # Do not make complete setup, that would reset limits, just # update the view item self.run_slider.view_item = new_view_item title = '->'.join(str_path[1:] + [prop_name]) self.run_slider.setWindowTitle(title) self.run_slider_key = (new_model, prop_name) break finally: self._expanding_composite = False def on_import_composites(self): if self.last_dirs['composite']: path = self.last_dirs['composite'] else: path = xdg.BaseDirectory.save_data_path('tofu', 'flows', 'composites') if not os.path.exists(path): path = pathlib.Path.home() file_names, _ = QFileDialog.getOpenFileNames(self, "Select File Names", str(path), "Composite Model Files (*.cm)") if not file_names: return self.last_dirs['composite'] = os.path.dirname(file_names[0]) overwriting = {} for file_name in file_names: LOG.debug(f'Loading composite from {file_name}') with open(file_name, 'r') as f: state = json.load(f) for model in get_composite_model_classes_from_json(state): if model.name in self.ufo_scene.registry.registered_model_creators(): overwriting[model.name] = os.path.basename(file_name) self.ufo_scene.registry.register_model(model, category='Composite', registry=self.ufo_scene.registry) if overwriting: msg = QMessageBox(parent=self) msg.setIcon(QMessageBox.Warning) msg.setText('Composite nodes with same names detected. Files from which ' 'the nodes have been loaded are listed in details.') msg.setDetailedText('\n'.join([f'Node name "{name}" from file "{file_name}"' for (name, file_name) in overwriting.items()])) msg.setWindowTitle('Warning') msg.exec_() def export_composite(self, node, file_name): state = node.model.save() with open(file_name, 'w') as f: json.dump(state, f, indent=4) def on_export_composite(self): if not self.ufo_scene.is_selected_one_composite(): # Check again in case this was invoked by the keyboard shortcut return if self.last_dirs['composite']: path = self.last_dirs['composite'] else: path = xdg.BaseDirectory.save_data_path('tofu', 'flows', 'composites') if not os.path.exists(path): os.makedirs(path) file_name, _ = QFileDialog.getSaveFileName(self, "Select File Name", str(path), "Composite Model Files (*.cm)") if file_name: self.last_dirs['composite'] = os.path.dirname(file_name) if not file_name.endswith('.cm'): file_name += '.cm' self.export_composite(self.ufo_scene.selected_nodes()[0], file_name) def on_reset_view(self): for view in self.ufo_scene.views(): transform = view.transform() transform.reset() view.setTransform(transform) def on_property_links_action(self): self.property_links_widget.show() # Make sure it goes to the front if it is currently burried under other windows self.property_links_widget.raise_() def on_console_action(self): if self.console: self.console.show() return try: from pyqtconsole.console import PythonConsole from pyqtconsole.highlighter import format self.console = PythonConsole(formats={ 'keyword': format('darkBlue', 'bold') }) self.console.setWindowFlag(Qt.SubWindow, True) self.console.ctrl_d_exits_console(True) self.console.push_local_ns('scene', self.ufo_scene) self.console.resize(640, 480) self.console.show() self.console.eval_queued() except ImportError as e: LOG.error(e, exc_info=True) self.on_exception_occured(str(e)) def on_run_slider_action(self): if not self.run_slider.view_item: msg = QMessageBox(parent=self) msg.setIcon(QMessageBox.Information) msg.setText('Click on an input field in the flow to connect the slider') msg.exec_() else: self.run_slider.show() # Make sure it goes to the front if it is currently burried under other windows self.run_slider.raise_() def on_run_slider_value_changed(self, value): if self.run_action.isEnabled(): self.on_run() def on_run(self): graphs = self.ufo_scene.get_simple_node_graphs() if len(graphs) != 1: raise FlowError('Scene must contain one fully connected graph') if not self.ufo_scene.is_fully_connected(): raise FlowError('Not all node ports are connected') self.executor.run(graphs[0]) self.run_action.setEnabled(False) self.ufo_scene.set_enabled(False) def on_save_json(self): graphs = self.ufo_scene.get_simple_node_graphs() if len(graphs) != 1: raise FlowError('Scene must contain one fully connected graph') if not self.ufo_scene.is_fully_connected(): raise FlowError('Not all node ports are connected') if not self.ufo_scene.are_all_ufo_tasks(graphs=graphs): raise FlowError('Flow contains other than pure UFO nodes (nodes with different ' 'data types, e.g. Memory Out or Image Viewer)') ufo_graph = self.executor.setup_ufo_graph(graphs[0]) if self.last_dirs['scene']: path = self.last_dirs['scene'] else: path = xdg.BaseDirectory.save_data_path('tofu', 'flows') if not os.path.exists(path): os.makedirs(path) file_name, _ = QFileDialog.getSaveFileName(self, "Select File Name", str(path), "json-File (*.json)") if file_name: self.last_dirs['scene'] = os.path.dirname(file_name) if not file_name.endswith('.json'): file_name += '.json' ufo_graph.save_to_json(file_name) def on_execution_finished(self): self.progress_bar.reset() self.run_action.setEnabled(True) self.ufo_scene.set_enabled(True) class GlobalExceptionHandler(QObject): """ Intercept exceptions, log them and inform user if they are UI-related. Emit a signal when the error message should be shown to the user so that e.g. a message can be shown in the main thread. """ exception_occured = pyqtSignal(str) def excepthook(self, exc_type, exc_value, exc_traceback): LOG.error(exc_value, exc_info=(exc_type, exc_value, exc_traceback)) if issubclass(exc_type, FlowError): self.exception_occured.emit(str(exc_value)) def get_filled_registry(): registry = DataModelRegistry() for model in get_ufo_model_classes(): category = 'Processing' if model.num_ports['input'] == 0: category = 'Input' if model.num_ports['output'] == 0: category = 'Output' registry.register_model(model, category=category, scrollable=True) registry.register_model(UfoGeneralBackprojectModel, category='Processing') registry.register_model(UfoOpenCLModel, category='Processing') registry.register_model(UfoRetrievePhaseModel, category='Processing') registry.register_model(UfoMemoryOutModel, category='Data') registry.register_model(ImageViewerModel, category='Output') registry.register_model(UfoWriteModel, category='Output') registry.register_model(UfoReadModel, category='Input') for models in get_composite_model_classes(): for model in models: if model.name not in registry.registered_model_creators(): registry.register_model(model, category='Composite', registry=registry) return registry def main(): app = QApplication(sys.argv) scene = UfoScene(registry=get_filled_registry()) main_window = ApplicationWindow(scene) # Exception interception exception_handler = GlobalExceptionHandler() exception_handler.exception_occured.connect(main_window.on_exception_occured) # Do not use threading.excepthook because it needs at least python 3.8., i.e. all exceptions in # threads have to be handled properly (logged, signal emitted so that a message can be displayed # in the main thread to the user, see tofu.flow.execution for example). sys.excepthook = exception_handler.excepthook main_window.show() sys.exit(app.exec_()) if __name__ == '__main__': main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414563.0 ufo-tofu-0.13.0/tofu/flow/models.py0000664000175000017500000017261100000000000017423 0ustar00tomastomas00000000000000""" All classes needed for :class:`qtpynodeeditor.NodeDataModel` implementation of UFO and composite tasks. """ import gi import glob import json import logging import networkx as nx import numpy as np import pkg_resources import os import re gi.require_version('Ufo', '0.0') from gi.repository import Ufo from PyQt5 import QtCore from PyQt5.QtCore import QObject, Qt, pyqtSignal from PyQt5.QtGui import QDoubleValidator, QValidator from PyQt5.QtWidgets import (QCheckBox, QComboBox, QGroupBox, QInputDialog, QLabel, QLineEdit, QScrollArea, QWidget, QFileDialog, QFormLayout, QVBoxLayout, QMenu) from qtpynodeeditor import (NodeData, NodeDataModel, NodeDataType, FlowScene, FlowView, Port, PortType, opposite_port) from threading import Lock from tofu.flow.util import (CompositeConnection, FlowError, get_config_key, MODEL_ROLE, NODE_ROLE, PROPERTY_ROLE, saved_kwargs) from tofu.flow.filedirdialog import FileDirDialog LOG = logging.getLogger(__name__) UFO_PLUGIN_MANAGER = Ufo.PluginManager() UFO_DATA_TYPE = NodeDataType(id="UfoBuffer", name=None) ARRAY_DATA_TYPE = NodeDataType(id="NumpyArray", name=None) class UfoIntValidator(QValidator): """Combined int and unsigned int validator.""" def __init__(self, minimum, maximum, parent=None): super().__init__(parent=parent) self.minimum = minimum self.maximum = maximum def bottom(self): return self.minimum def top(self): return self.maximum def validate(self, input_str, pos): try: if self.minimum <= int(input_str) <= self.maximum: result = (QValidator.Acceptable, input_str, pos) else: result = (QValidator.Intermediate, input_str, pos) except ValueError: if not input_str or input_str == '-' and self.minimum < 0: result = (QValidator.Intermediate, input_str, pos) else: result = (QValidator.Invalid, input_str, pos) return result class UfoRangeValidator(QValidator): """ Range separated by comma validator. *num_items* specifies how many numbers must be in the string. *is_float* specifies if the numbers are floating point (integer or unsigned integer otherwise). """ def __init__(self, num_items=None, is_float=True, parent=None): super().__init__(parent=parent) self.num_items = num_items self.is_float = is_float def validate(self, input_str, pos): float_regexp = r'[+-]|[+-]?(\d+(\.\d*)?|\.\d*)([eE][+-]?\d*)?' numbers = input_str.split(',') intermediate = False if self.num_items is not None and len(numbers) > self.num_items: # Incorrect number of items return (QValidator.Invalid, input_str, pos) for (i, number) in enumerate(numbers): number = number.lower().strip() if ('e' in number or '.' in number) and not self.is_float: # Integer expected return (QValidator.Invalid, input_str, pos) if self.is_float: try: float(number) except: if (not number or re.fullmatch(float_regexp, number)): # Partial floating point number (e.g. ends with "e") intermediate = True continue else: return (QValidator.Invalid, input_str, pos) else: try: int(number) except: if not number or number == '-': intermediate = True continue else: return (QValidator.Invalid, input_str, pos) if intermediate or (self.num_items is not None and len(numbers) < self.num_items): # Not enough arguments received or some numbers are incomplete return (QValidator.Intermediate, input_str, pos) return (QValidator.Acceptable, input_str, pos) class ViewItem(QObject): property_changed = pyqtSignal(QObject) def __init__(self, widget, default_value=None, tooltip=''): super().__init__(parent=None) self.widget = widget self.focus_info = False if tooltip: self.widget.setToolTip(tooltip) if default_value is not None: self.set(default_value) def on_changed(self, *args): """ Only user interaction must emit signals in the descendants. Signal is emitted only if the user input is valid. """ try: self.get() self.property_changed.emit(self) except: LOG.debug(f'{self}: invalid input') def get(self): ... def set(self, value): ... class CheckBoxViewItem(ViewItem): def __init__(self, checked=False, tooltip=''): widget = QCheckBox() super().__init__(widget, default_value=checked, tooltip=tooltip) widget.clicked.connect(self.on_changed) def get(self): return self.widget.isChecked() def set(self, value): self.widget.setChecked(value) class ComboBoxViewItem(ViewItem): def __init__(self, items, default_value=None, tooltip=''): widget = QComboBox() for item in items: widget.addItem(item) super().__init__(widget, default_value=default_value, tooltip=tooltip) widget.activated.connect(self.on_changed) def get(self): return self.widget.currentText() def set(self, value): self.widget.setCurrentText(value) class FocusInterceptQLineEdit(QLineEdit): focus_in = pyqtSignal(QObject) def focusInEvent(self, event): self.focus_in.emit(self) return super().focusInEvent(event) class QLineEditViewItem(ViewItem): focus_in = pyqtSignal(QObject) def __init__(self, default_value=None, tooltip='', intercept_focus=False): if intercept_focus: widget = FocusInterceptQLineEdit() widget.focus_in.connect(self.on_focus_in) else: widget = QLineEdit() super().__init__(widget, default_value=default_value, tooltip=tooltip) if intercept_focus: self.focus_info = True widget.textEdited.connect(self.on_changed) def on_focus_in(self, widget): self.focus_in.emit(self) def get(self): return self.widget.text() def set(self, value): self.widget.setText(str(value)) class NumberQLineEditViewItem(QLineEditViewItem): def __init__(self, minimum, maximum, default_value=None, tooltip=''): if default_value < minimum or default_value > maximum: raise ValueError(f'default value {default_value} not in limits [{minimum}, {maximum}]') tooltip += ' (range: {} - {})'.format(minimum, maximum) super().__init__(default_value=default_value, tooltip=tooltip, intercept_focus=True) validator = QDoubleValidator(float(minimum), float(maximum), 100) self.widget.setValidator(validator) def get(self): return float(super().get()) class IntQLineEditViewItem(QLineEditViewItem): def __init__(self, minimum, maximum, default_value=None, tooltip=''): if default_value < minimum or default_value > maximum: raise ValueError(f'default value {default_value} not in limits [{minimum}, {maximum}]') tooltip += ' (range: {} - {})'.format(minimum, maximum) super().__init__(default_value=default_value, tooltip=tooltip, intercept_focus=True) validator = UfoIntValidator(minimum, maximum) self.widget.setValidator(validator) def get(self): return int(super().get()) class RangeQLineEditViewItem(QLineEditViewItem): def __init__(self, default_value='', tooltip='', num_items=None, is_float=True): super().__init__(default_value=default_value, tooltip=tooltip, intercept_focus=True) validator = UfoRangeValidator(num_items=num_items, is_float=is_float) self.widget.setValidator(validator) def set(self, values): text = ','.join([str(value) for value in values]) if values else '' self.widget.setText(text) def get(self): text = super().get() if text: values = [float(num) for num in text.split(',')] else: values = [] return values def get_ufo_qline_edit_item(glib_prop, default_value, range_num_items=None, range_is_float=True): if glib_prop.value_type.name == 'GValueArray': item = RangeQLineEditViewItem(tooltip=glib_prop.blurb, default_value=default_value, num_items=range_num_items, is_float=range_is_float) elif glib_prop.value_type.name in ['gdouble', 'gfloat']: item = NumberQLineEditViewItem(glib_prop.minimum, glib_prop.maximum, default_value=default_value, tooltip=glib_prop.blurb) elif hasattr(glib_prop, 'minimum') and hasattr(glib_prop, 'maximum'): item = IntQLineEditViewItem(glib_prop.minimum, glib_prop.maximum, default_value=default_value, tooltip=glib_prop.blurb) else: item = QLineEditViewItem(default_value=str(default_value), tooltip=glib_prop.blurb) return item class PropertyViewRecord: """Attribute-access to a view's item.""" def __init__(self, view_item, label, visible): self.view_item = view_item self.label = label self.visible = visible def __str__(self): return repr(self) def __repr__(self): fmt = 'PropertyViewRecord(widget={}, visible={})' return fmt.format(self.view_item.widget, self.visible) class MultiPropertyViewRecord: """Attribute-access to a multiple property view's item.""" def __init__(self, model, widget, visible): self.model = model self.widget = widget self.visible = visible def __str__(self): return repr(self) def __repr__(self): fmt = 'MultiPropertyViewRecord(model={}, widget={}, visible={})' return fmt.format(self.model, self.widget, self.visible) class PropertyView(QWidget): property_changed = pyqtSignal(str, object) item_focus_in = pyqtSignal(ViewItem, str) def __init__(self, properties=None, parent=None, scrollable=True): super().__init__(parent=parent) form_layout = QFormLayout() form_layout.setVerticalSpacing(0) self._properties = {} if properties: for (name, (item, active)) in properties.items(): if name in self._properties: raise ValueError("Item '{}' already exists".format(name)) # Set the parent properly, so that set_property_visible won't try to show the item # widget and the label in their own windows before the view is shown item.widget.setParent(self) label = QLabel(name, parent=self) form_layout.addRow(label, item.widget) self._properties[name] = PropertyViewRecord(item, label, active) self.set_property_visible(name, active) item.property_changed.connect(self.on_property_changed) if item.focus_info: item.focus_in.connect(self.on_item_focus_in) if scrollable: widget = QWidget() widget.setLayout(form_layout) scroll = QScrollArea() scroll.setWidget(widget) scroll.setWidgetResizable(True) main_layout = QVBoxLayout() main_layout.addWidget(scroll) self.setLayout(main_layout) else: self.setLayout(form_layout) @property def property_names(self): return self._properties.keys() def get_property(self, name): return self._properties[name].view_item.get() def set_property(self, name, value): return self._properties[name].view_item.set(value) def get_record(self, name): return self._properties[name] def on_property_changed(self, item): # Get item's name for (name, record) in self._properties.items(): if item == record.view_item: break self.property_changed.emit(name, item.get()) def on_item_focus_in(self, view_item): for (name, it) in self._properties.items(): if it.view_item.widget == view_item.widget: self.item_focus_in.emit(view_item, name) break def is_property_visible(self, name): return self._properties[name].visible def set_property_visible(self, name, visible): self._properties[name].view_item.widget.setVisible(visible) self._properties[name].label.setVisible(visible) self._properties[name].visible = visible def restore_properties(self, values): for prop in self._properties: if prop not in values: LOG.debug(f'Property {prop} not stored, using default') continue value, visible = values[prop] self.set_property(prop, value) self.set_property_visible(prop, visible) def export_properties(self): values = {} for prop in self._properties: values[prop] = [self.get_property(prop), self.is_property_visible(prop)] return values def contextMenuEvent(self, event): contextMenu = QMenu(self) actions = {} for name in list(self._properties.keys()): action = contextMenu.addAction(name) action.setCheckable(True) action.setChecked(self._properties[name].visible) actions[action] = name contextMenu.addSeparator() show_all_action = contextMenu.addAction('Show All') hide_all_action = contextMenu.addAction('Hide All') action = contextMenu.exec_(self.mapToGlobal(event.pos())) if action: if action in actions: name = actions[action] checked = action.isChecked() self.set_property_visible(name, checked) elif action == show_all_action: for name in self._properties.keys(): self.set_property_visible(name, True) elif action == hide_all_action: for name in self._properties.keys(): self.set_property_visible(name, False) class MultiPropertyView(QWidget): def __init__(self, groups, parent=None): super().__init__(parent=parent) self._group_box_layout = QVBoxLayout() main_layout = QVBoxLayout() widget = QWidget() widget.setLayout(self._group_box_layout) scroll = QScrollArea() scroll.setWidget(widget) scroll.setWidgetResizable(True) self.setLayout(main_layout) main_layout.addWidget(scroll) self._groups = {} for (model, visible) in groups.items(): if isinstance(model, PropertyModel): model_widget = QGroupBox(model.caption) layout = QVBoxLayout() model_widget.setLayout(layout) layout.addWidget(model.embedded_widget()) else: model_widget = QLabel(model.caption, parent=self) record = MultiPropertyViewRecord(model, model_widget, visible) self._groups[model.caption] = record self._group_box_layout.addWidget(model_widget) self.set_group_visible(model.caption, visible) def __getitem__(self, key): return self._groups[key].model def __contains__(self, key): return key in self._groups def __iter__(self): return iter(self._groups) def export_groups(self): values = {} for name in self._groups: state = self._groups[name].model.save() values[name] = {'model': state, 'visible': self._groups[name].visible} return values def restore_groups(self, values): for name in values: self[name].restore(values[name]['model']) self.set_group_visible(name, values[name]['visible']) def set_group_visible(self, name, visible): self._groups[name].widget.setVisible(visible) self._groups[name].visible = visible def is_group_visible(self, name): return self._groups[name].visible def contextMenuEvent(self, event): contextMenu = QMenu(self) actions = {} for name in list(self._groups.keys()): action = contextMenu.addAction(name) action.setCheckable(True) action.setChecked(self._groups[name].visible) actions[action] = name contextMenu.addSeparator() show_all_action = contextMenu.addAction('Show All') hide_all_action = contextMenu.addAction('Hide All') action = contextMenu.exec_(self.mapToGlobal(event.pos())) if action: if action in actions: name = actions[action] checked = action.isChecked() self.set_group_visible(name, checked) elif action == show_all_action: for name in self._groups.keys(): self.set_group_visible(name, True) elif action == hide_all_action: for name in self._groups.keys(): self.set_group_visible(name, False) class UfoModel(NodeDataModel): """The root parent of all other models in tofu flow.""" data_type = UFO_DATA_TYPE item_focus_in = pyqtSignal(QObject, str, str, NodeDataModel) def __init__(self, style=None, parent=None): super().__init__(style=style, parent=parent) # This is the caption model wants to have when it's instantiated, however, it might # get a different caption from the scene because the captions must be unique within self.base_caption = self.caption self.skip = False def restore(self, state, restore_caption=False): if restore_caption: self.caption = state.get('caption', self.caption) def save(self): return {'caption': self.caption} def double_clicked(self, parent): ... def __repr__(self): return f'UfoModel({self.caption})' def __str__(self): return repr(self) class PropertyModel(UfoModel): property_changed = pyqtSignal(UfoModel, str, object) def __init__(self, style=None, parent=None, scrollable=True): """*properties* is a dictionary of name: ViewItem items.""" super().__init__(style=style, parent=parent) properties = self.make_properties() if properties: self.properties = list(properties.keys()) self._view = PropertyView(properties=properties, scrollable=scrollable) self._view.property_changed.connect(self.on_property_changed) self._view.item_focus_in.connect(self.on_item_focus_in) else: self.properties = [] self._view = None def __getitem__(self, key): return self._view.get_property(key) def __setitem__(self, key, value): return self._view.set_property(key, value) def __contains__(self, key): return key in self.properties def __iter__(self): return iter(self.properties) def get_view_item(self, name): return self._view.get_record(name).view_item def on_property_changed(self, name, value): self.property_changed.emit(self, name, value) def on_item_focus_in(self, item, name): self.item_focus_in.emit(item, name, self.caption, self) def make_properties(self): """*properties* is a dictionary of name: ViewItem items.""" return {} def copy_properties(self): properties = self.make_properties() for (name, (item, active)) in properties.items(): item.set(self[name]) properties[name][-1] = self._view.is_property_visible(name) return properties def auto_fill(self): """Automatically fill properties (e.g. number of files, etc.)""" ... def resizable(self): return True def embedded_widget(self) -> QWidget: return self._view if self._view else None def restore(self, state, restore_caption=True): self._view.restore_properties(state['properties']) super().restore(state, restore_caption=restore_caption) def save(self): state = super().save() state['properties'] = self._view.export_properties() return state class UfoTaskModel(PropertyModel): caption_visible = True def __init__(self, task_name, style=None, parent=None, scrollable=True): self._task_name = task_name self.caption = ' '.join([item[0].upper() + item[1:] for item in self.name.split('_')]) self.needs_fixed_scheduler = False self.can_split_gpu_work = False super().__init__(style=style, parent=parent, scrollable=scrollable) def make_properties(self): hidden_properties = get_config_key('models', self._task_name, 'hidden-properties') range_properties = get_config_key('models', self._task_name, 'range-properties', default={}) properties = {} ufo_task = UFO_PLUGIN_MANAGER.get_task(self._task_name) for prop in ufo_task.list_properties(): if prop.name == 'num-processed': continue default_value = getattr(ufo_task.props, prop.name) if prop.value_type.name == 'gboolean': item = CheckBoxViewItem(checked=default_value, tooltip=prop.blurb) elif hasattr(prop, 'enum_class'): items = [name.value_nick for name in default_value.__enum_values__.values()] item = ComboBoxViewItem(items, default_value=default_value.value_nick, tooltip=prop.blurb) else: range_num_items, range_is_float = range_properties.get(prop.name, (None, True)) item = get_ufo_qline_edit_item(prop, default_value=default_value, range_num_items=range_num_items, range_is_float=range_is_float) visible = True if hidden_properties and prop.name in hidden_properties: visible = False properties[prop.name] = [item, visible] return properties def create_ufo_task(self, region=None): if self.expects_multiple_inputs and region is None: raise UfoModelError(f'{self.caption} expects multiple inputs ' 'but there is no node with such capability in the flow') ufo_task = UFO_PLUGIN_MANAGER.get_task(self._task_name) self._setup_ufo_task(ufo_task, region=region) return ufo_task def _setup_ufo_task(self, ufo_task, region=None): for prop in self: setattr(ufo_task.props, prop, self[prop]) def reset_batches(self): """ In case the model can process batches and has internal state depending on them, this is where it can be re-set. """ pass @property def uses_gpu(self): return UFO_PLUGIN_MANAGER.get_task(self._task_name).uses_gpu() @property def expects_multiple_inputs(self): return False def get_ufo_model_class(ufo_task_name): # Use this to determine inputs and outputs but create a new object in the constructor in order # to enable multiple instances having different parameter values _ufo_task = UFO_PLUGIN_MANAGER.get_task(ufo_task_name) ufo_task_num_inputs = _ufo_task.get_num_inputs() ufo_task_num_outputs = int(_ufo_task.get_mode() & Ufo.TaskMode.SINK == 0) class UfoAutoModel(UfoTaskModel): name = ufo_task_name.replace('-', '_') def __init__(self, style=None, parent=None, scrollable=True): self.num_ports = {PortType.input: ufo_task_num_inputs, PortType.output: ufo_task_num_outputs} self.data_type = {} self.port_caption = {} self.port_caption_visible = {} for port_type in (PortType.input, PortType.output): self.data_type[port_type] = {} self.port_caption[port_type] = {} self.port_caption_visible[port_type] = {} for i in range(self.num_ports[port_type]): port_captions = get_config_key('models', ufo_task_name, 'port-captions') if port_captions: port_caption = port_captions[port_type][str(i)] port_caption_visible = True if port_caption else False else: port_caption = '' port_caption_visible = False self.data_type[port_type][i] = UFO_DATA_TYPE self.port_caption[port_type][i] = port_caption self.port_caption_visible[port_type][i] = port_caption_visible self.ufo_task = None super().__init__(ufo_task_name, style=style, parent=parent, scrollable=scrollable) return UfoAutoModel class BaseCompositeModel(UfoModel): # Move functionality which can go here from CompositeModel here data_type = UFO_DATA_TYPE def __init__(self, models, connections, links=None, registry=None, style=None, parent=None): if registry is None: # This has to be keyword argument because of the qtpynodeeditor's node creation # mechanism, but the argument is actually required raise AttributeError('registry must be provided') super().__init__(style=style, parent=parent) # Nodes in the edit pop-up window self.window_parent = None self._property_links_model = None self._links = [] if links is None else links self._slave_property_links = [] self._window_nodes = {} self._other_scene = None self._other_view = None self.num_ports = {PortType.input: 0, PortType.output: 0} self.data_type = {PortType.input: {}, PortType.output: {}} self.port_caption = {PortType.input: {}, PortType.output: {}} self.port_caption_visible = {PortType.input: {}, PortType.output: {}} groups = {} self._registry = registry self._models = {} # Internal connections self._connections = connections # Composite port to subnode port mapping self._inside_ports = {} # Subnode port to composite port mapping self._outside_ports = {} for (name, state, visible, position) in models: # Don't use the deafault registry creation because embedded PropertyModel must have # scrollable set to False cls, orig_kwargs = registry.registered_model_creators()[name] # Don't mess with the original dictionary kwargs = {orig_key: orig_value for (orig_key, orig_value) in orig_kwargs.items()} if issubclass(cls, PropertyModel): kwargs['scrollable'] = False if 'num-inputs' in state: kwargs['num_inputs'] = state['num-inputs'] model = cls(**kwargs) model.restore(state) self._models[model] = position groups[model] = visible model.item_focus_in.connect(self.on_item_focus_in) for port_type in ['input', 'output']: for index in range(model.num_ports[port_type]): side = (model.caption, port_type, index) if not any([conn.contains(*side) for conn in connections]): i = self.num_ports[port_type] self.data_type[port_type][i] = UFO_DATA_TYPE port_caption = model.caption if model.port_caption[port_type][index]: port_caption += ':' + model.port_caption[port_type][index] self.port_caption[port_type][i] = port_caption self.port_caption_visible[port_type][i] = True self._inside_ports[(port_type, i)] = (model, port_type, index) self._outside_ports[side] = (port_type, i) self.num_ports[port_type] += 1 self._view = MultiPropertyView(groups) def __getitem__(self, key): return self._view[key] def __contains__(self, key): return key in self._view def __iter__(self): return iter(self._view) def __repr__(self): return f'Composite(caption={self.caption}, models={sorted(list(iter(self._view)))})' def __str__(self): return repr(self) def get_outside_port(self, unique_name, port_type, port_index): return self._outside_ports[(unique_name, port_type, port_index)] def get_model_and_port_index(self, port_type, port_index): model, spt, index = self._inside_ports[(port_type, port_index)] return (model, index) def embedded_widget(self) -> QWidget: return self._view if self._view else None def resizable(self): return True def on_item_focus_in(self, item, name, caption, model): self.item_focus_in.emit(item, name, self.caption + '->' + caption, model) @property def is_editing(self): """Is wubwindow open.""" return self._window_nodes != {} @property def property_links_model(self): return self._property_links_model @property_links_model.setter def property_links_model(self, plm): self._property_links_model = plm for model in self._models: if isinstance(model, BaseCompositeModel): model.property_links_model = plm def contains_path(self, path): """Is there a caption *path* inside this model.""" model = self for caption in path: if caption in model: model = model[caption] else: return False return True def get_model_from_path(self, path): """*path* is caption path (str).""" model = self for caption in path: model = model[caption] return model def is_model_inside(self, model): """Return True if *model* is inside at any level.""" paths = self.get_leaf_paths() for path in paths: for item in path: if item == model: return True return False def get_path_from_model(self, model): """*model* must be inside this composite model.""" paths = self.get_leaf_paths() for path in paths: for (i, item) in enumerate(path): if item == model: return path[:i + 1] raise KeyError(f'{model} not inside') def get_descendant_graph(self, in_subwindow=False): """ Get all descendant models recursively in case there are composite models inside this model. If *in_subwindow* is True, return models shown to the user in the subwindow, otherwise the ones created at class instantiation. For composites inside this one, if *in_subwindow* is True return the subwindow models, but if it's not being edited instead raising an exception, return the internal models. """ if in_subwindow and not self.is_editing: raise ValueError('in_subwindow True but no subwindow open') graph = nx.DiGraph() def descend(parent): if in_subwindow and parent.is_editing: models = [node.model for node in parent._window_nodes.values()] else: models = [parent[key] for key in parent] for model in models: graph.add_edge(parent, model) if isinstance(model, BaseCompositeModel): descend(model) descend(self) return graph def get_leaf_paths(self, in_subwindow=False): graph = self.get_descendant_graph(in_subwindow=in_subwindow) leaves = [node for node in graph.nodes if graph.out_degree(node) == 0] paths = [] for leaf in leaves: paths.append(list(nx.simple_paths.all_simple_paths(graph, self, leaf))[0]) return paths def restore(self, state, restore_caption=True): self._connections = [CompositeConnection(*args) for args in state['connections']] self._view.restore_groups(state['models']) super().restore(state, restore_caption=restore_caption) def restore_links(self, node): if self.property_links_model: row = self.property_links_model.rowCount() for items in self._links: # A row can be restored only if no property from the state is in the link model # yet row_ok = True for str_path in items: prop_name = str_path[-1] model = self.get_model_from_path(str_path[:-1]) if self.property_links_model.find_items([node, model, prop_name], [NODE_ROLE, MODEL_ROLE, PROPERTY_ROLE]): LOG.info(f'{str_path[-2]}->{prop_name} already in property links') row_ok = False break if row_ok: for (i, str_path) in enumerate(items): model = self.get_model_from_path(str_path[:-1]) self.property_links_model.add_item(node, model, str_path[-1], row, i) row += 1 def save(self): state = {'name': self.name, 'caption': self.caption} state['models'] = self._view.export_groups() for (model, position) in self._models.items(): state['models'][model.caption]['position'] = position # This is necessary for creating models from saved files state['models'][model.caption]['name'] = model.name state['connections'] = [conn.save() for conn in self._connections] if self.property_links_model: state['links'] = [] paths = self.get_leaf_paths() models = [path[-1] for path in paths] items = self.property_links_model.get_model_links(models) for row in items.values(): # First item in the row is this model, skip it state['links'].append([str_path[1:] for str_path in row]) return state def on_connection_created(self, connection): self._other_scene.connection_deleted.disconnect(self.on_connection_deleted) self._other_scene.delete_connection(connection) self._other_scene.connection_deleted.connect(self.on_connection_deleted) def on_connection_deleted(self, connection): self._other_scene.connection_created.disconnect(self.on_connection_created) self._other_scene.restore_connection(connection.__getstate__()) self._other_scene.connection_created.connect(self.on_connection_created) def double_clicked(self, parent): self.edit_in_window(parent=parent) def on_other_scene_double_clicked(self, node): node.model.double_clicked(self._other_view) def expand_into_graph(self, graph): """Expand to submodels in a *graph*, which is a networkx.DiGraph instance.""" name_to_model = {} for model in self._models: LOG.debug(f'Adding node {model.name}') graph.add_node(model) name_to_model[model.caption] = model for conn in self._connections: source = name_to_model[conn.from_unique_name] dest = name_to_model[conn.to_unique_name] LOG.debug(f'Adding edge {source.name}@{conn.from_port_index} -> ' f'{dest.name}@{conn.to_port_index}') graph.add_edge(source, dest, input=conn.to_port_index, output=conn.from_port_index) def _expand_into_scene(self, scene, original_nodes=None, restore_captions=False): # unique name to node instance mapping name_to_node = {} for model in self._models: if original_nodes and model.caption in original_nodes: node = scene.restore_node(original_nodes[model.caption]) else: with saved_kwargs(scene.registry, model.__getstate__()): if restore_captions: node = scene.create_node(model.__class__) else: # This is the main scene, links restoration takes place in expand_into_scene # for all nodes including composites node = scene.create_node(model.__class__, restore_links=False) if isinstance(model, PropertyModel) or isinstance(model, BaseCompositeModel): node.model.restore(model.save(), restore_caption=restore_captions) if isinstance(node.model, BaseCompositeModel): node.model.property_links_model = self.property_links_model else: node.model.restore(model.save()) name_to_node[model.caption] = node if self._models[model] is not None: node.position = (self._models[model]['x'], self._models[model]['y']) for conn in self._connections: f_node = name_to_node[conn.from_unique_name] t_node = name_to_node[conn.to_unique_name] f_port = f_node[PortType.output][conn.from_port_index] t_port = t_node[PortType.input][conn.to_port_index] scene.create_connection(f_port, t_port, check_cycles=False) return name_to_node def add_slave_links(self): self._slave_property_links = [] if not self.property_links_model: return for node in self._window_nodes.values(): if isinstance(node.model, BaseCompositeModel): paths = node.model.get_leaf_paths(in_subwindow=node.model._window_nodes != {}) else: paths = [[node.model]] # Propagate all signals from leaves to the original models for path in paths: str_path = [m.caption for m in path] new_model = path[-1] orig_model = self.get_model_from_path(str_path) # Create a link from this node's model instances to the original root # models in the link model (there can be other composites along the way to # the root root_model = self.property_links_model.get_root_model(orig_model) if root_model: prop_names = self.property_links_model.get_model_properties(root_model) for prop_name in prop_names: if (new_model, prop_name) not in self._slave_property_links: # In order to remove slaves when the subwindow is closed, register # the slaves with respect to the most nested composite registering_model = path[-2] if len(path) > 1 else self if registering_model.is_editing: registering_model._slave_property_links.append((new_model, prop_name)) registering_model.property_links_model.add_silent(new_model, prop_name, root_model, prop_name) if registering_model.window_parent: # If the registering model has a parent, register also the # models in it's internal model view new_model = registering_model[path[-1].caption] registering_model = registering_model.window_parent registering_model._slave_property_links.append((new_model, prop_name)) registering_model.property_links_model.add_silent(new_model, prop_name, root_model, prop_name) def edit_in_window(self, parent=None): self._other_scene = FlowScene(registry=self._registry) self._other_scene.node_double_clicked.connect(self.on_other_scene_double_clicked) self._window_nodes = self._expand_into_scene(self._other_scene, restore_captions=True) # Store references to parent composites for node in self._window_nodes.values(): if isinstance(node.model, BaseCompositeModel): node.model.window_parent = self # Property links have to be registered with respect to the top composite model because # it's property model's property model is registered in property links window_parent = self while window_parent.window_parent: window_parent = window_parent.window_parent window_parent.add_slave_links() # Disable manipulation because the number of ports is fixed, so we can't e.g. internally # connect two nodes and delete the newly occupied port from the composite node self._other_scene.allow_node_creation = False self._other_scene.allow_node_deletion = False # There is no allow_connection_creation/deletion, so take care of it here self._other_scene.connection_created.connect(self.on_connection_created) self._other_scene.connection_deleted.connect(self.on_connection_deleted) self._other_view = FlowView(self._other_scene, parent=parent) self._other_view.setWindowFlag(Qt.Window, True) self._other_view.closeEvent = self.view_close_event self._other_view.setWindowTitle(self.name) self._other_view.resize(900, 600) self._other_view.show() def view_close_event(self, event): for node in self._window_nodes.values(): # Clse all composite children recursively first if isinstance(node.model, BaseCompositeModel) and node.model.is_editing: node.model._other_view.close() node.model.window_parent = None for (unique_name, node) in self._window_nodes.items(): self._view[unique_name].restore(node.model.save()) if self.property_links_model: for (model, prop_name) in self._slave_property_links: self.property_links_model.remove_silent(model, prop_name) self._slave_property_links = [] self._window_nodes = {} self._other_scene = None self._other_view = None def expand_into_scene(self, scene, composite_node, original_nodes=None): """ Expand this node into *scene* and replace *composite_node*'s connections with connections going straight into its subnodes. Also create connections internal to this node and update property links. *original_nodes* is a dictionary in form {caption: node_state} which will be used for positioning of the replacing nodes (scene.restore_node instead of scene.create_node will be called). """ assert self.property_links_model is not None # Connections to external nodes connections = [] # name_to_node is in format caption: new node dictionary # Internal connections are handled in _expand_into_scene name_to_node = self._expand_into_scene(scene, original_nodes=original_nodes, restore_captions=False) for port_type in [PortType.input, PortType.output]: for index, port in composite_node[port_type].items(): if port.connections: connection = port.connections[0] outside_port = connection.valid_ports[opposite_port(port_type)] internal_model, pt, pi = self._inside_ports[(port_type, index)] connections.append((outside_port, name_to_node[internal_model.caption][pt][pi])) # Update property links for (subcaption, subnode) in name_to_node.items(): if isinstance(subnode.model, BaseCompositeModel): # Get all leaf PropertyModel instances paths = subnode.model.get_leaf_paths() else: paths = [[subnode.model]] # In case selected node is composite, replace all leaf node links for path in paths: str_path = [model.caption for model in path] # Captions might have changed if subnode captions were equal to other captions # in the scene and the composite node which is being replaced contains still the # old ones old_str_path = [subcaption] + str_path[1:] old_model = composite_node.model.get_model_from_path(old_str_path) self.property_links_model.replace_item(subnode, path[-1], old_model) subnode.graphics_object.setSelected(True) scene.remove_node(composite_node) # Create outside connections only after the composite node has been deleted to prevent # creating multiple connections per input port in the outside nodes for outside, inside in connections: scene.create_connection(outside, inside, check_cycles=False) return name_to_node, connections def get_composite_model_class(composite_name, models, connections, links=None): if not composite_name: raise UfoModelError('composite name must be specified') class CompositeModel(BaseCompositeModel): name = composite_name data_type = UFO_DATA_TYPE def __init__(self, style=None, parent=None, registry=None): super().__init__(models, connections, links=links, registry=registry, style=style, parent=parent) model = CompositeModel model.caption_visible = True model.caption = composite_name return model class UfoGeneralBackprojectModel(UfoTaskModel): name = 'general_backproject' num_ports = {PortType.input: 1, PortType.output: 1} data_type = UFO_DATA_TYPE def __init__(self, style=None, parent=None, scrollable=True): super().__init__('general-backproject', style=style, parent=parent, scrollable=scrollable) self.needs_fixed_scheduler = True self.can_split_gpu_work = True def make_properties(self): properties = super().make_properties() slice_memory_coeff = NumberQLineEditViewItem(0.01, 1., default_value=0.8, tooltip='Portion of used GPU memory') properties['slice-memory-coeff'] = [slice_memory_coeff, False] return properties def split_gpu_work(self, gpus): from tofu.genreco import make_runs, DTYPE_CL_SIZE def check_region(region): if not len(np.arange(*self[region])): raise UfoModelError(f'Invalid {region} {self[region]}') # Check if ranges are OK check_region('region') check_region('x-region') check_region('y-region') gpu_indices = range(len(gpus)) bpp = DTYPE_CL_SIZE[self['store-type']] runs = make_runs(gpus, gpu_indices, self['x-region'], self['y-region'], self['region'], bpp, slice_memory_coeff=self['slice-memory-coeff']) return runs def _setup_ufo_task(self, ufo_task, region=None): separate = ['region', 'slice-memory-coeff'] task_props = [prop for prop in self if prop not in separate] for prop in task_props: setattr(ufo_task.props, prop, self[prop]) # Set region separately in case there are multiple inputs current_region = self['region'] if region is None else region setattr(ufo_task.props, 'region', current_region) class UfoVaryingInputModel(UfoTaskModel): """Base class for models which can have varying number if inputs.""" def __init__(self, task_name, style=None, parent=None, scrollable=True, num_inputs=None, dialog_title='Number of inputs', dialog_label='Number of inputs:'): if not num_inputs: num_inputs, ok = QInputDialog.getInt(parent, dialog_title, dialog_label, value=1, min=1, max=10, step=1) if not ok: raise UfoModelError('Number of inputs must be specified') self.num_ports = {PortType.input: num_inputs, PortType.output: 1} self.data_type = {PortType.output: {0: UFO_DATA_TYPE}} self.port_caption = {PortType.output: {0: ''}} self.port_caption_visible = {PortType.output: {0: False}} self.data_type[PortType.input] = {} self.port_caption[PortType.input] = {} self.port_caption_visible[PortType.input] = {} for i in range(num_inputs): self.data_type[PortType.input][i] = UFO_DATA_TYPE self.port_caption[PortType.input][i] = '' self.port_caption_visible[PortType.input][i] = False super().__init__(task_name, style=style, parent=parent, scrollable=scrollable) def save(self): state = super().save() state['num-inputs'] = self.num_ports['input'] return state class UfoOpenCLModel(UfoVaryingInputModel): name = 'opencl' def __init__(self, style=None, parent=None, scrollable=True, num_inputs=None): super().__init__('opencl', style=style, parent=parent, scrollable=scrollable, num_inputs=num_inputs) def _setup_ufo_task(self, ufo_task, region=None): for prop in self: if prop in ['filename', 'source']: # opencl task really needs NULL value = self[prop] if self[prop] else None else: value = self[prop] setattr(ufo_task.props, prop, value) class UfoReadModel(UfoTaskModel): name = 'read' num_ports = {PortType.input: 0, PortType.output: 1} data_type = UFO_DATA_TYPE def __init__(self, style=None, parent=None, scrollable=True): super().__init__('read', style=style, parent=parent, scrollable=scrollable) def auto_fill(self): import glob import imageio if os.path.isdir(self['path']): paths = sorted(glob.glob(os.path.join(self['path'], '*'))) else: paths = [self['path']] num_images = 0 for path in paths: try: num_images += len(imageio.get_reader(path)) except: LOG.error(f"Error reading '{path}'") if not num_images: raise UfoModelError(f"No images found in {self['path']}") self['number'] = num_images def double_clicked(self, parent): current_path = self['path'] if not os.path.isdir(current_path): current_path = os.path.dirname(current_path) if not current_path: current_path = QtCore.QDir.homePath() dialog = FileDirDialog() if dialog.exec_(): self['path'] = dialog.selectedFiles()[0] def _setup_ufo_task(self, ufo_task, region=None): for prop in self: if prop != 'raw-bitdepth' or self['raw-bitdepth']: setattr(ufo_task.props, prop, self[prop]) class UfoRetrievePhaseModel(UfoVaryingInputModel): name = 'retrieve_phase' def __init__(self, style=None, parent=None, scrollable=True, num_inputs=None): super().__init__('retrieve-phase', style=style, parent=parent, scrollable=scrollable, dialog_title='Multi-distance Setup', dialog_label='Number of distances:', num_inputs=num_inputs) def make_properties(self): properties = super().make_properties() # Override distance property based on how many inputs we expect tooltip = properties['distance'][0].widget.toolTip() item = RangeQLineEditViewItem(tooltip=tooltip, default_value=[], num_items=self.num_ports['input'], is_float=True) properties['distance'] = [item, True] if self.num_ports['input'] > 1: properties['method'][0].set('ctf_multidistance') properties['method'][0].widget.setEnabled(False) properties['distance-x'][0].widget.setEnabled(False) properties['distance-y'][0].widget.setEnabled(False) return properties class UfoWriteModel(UfoTaskModel): name = 'write' num_ports = {PortType.input: 1, PortType.output: 0} data_type = UFO_DATA_TYPE def __init__(self, style=None, parent=None, scrollable=True): super().__init__('write', style=style, parent=parent, scrollable=scrollable) def double_clicked(self, parent): current_path = os.path.dirname(self['filename']) if not current_path: current_path = QtCore.QDir.homePath() file_name, _ = QFileDialog.getSaveFileName(None, "Select File Name", current_path) if file_name: self['filename'] = file_name @property def expects_multiple_inputs(self): return '{region}' in self['filename'] def _setup_ufo_task(self, ufo_task, region=None): if region is not None and not self.expects_multiple_inputs: raise UfoModelError('Write got region without enabling multiple inputs. ' 'Add {region} somewhere in the "filename" field to enable it.') super()._setup_ufo_task(ufo_task, region=region) filename = self['filename'] if region is not None and self.expects_multiple_inputs: filename = filename.format(region=region[0]) setattr(ufo_task.props, 'filename', filename) class _Batch(QObject): finished = pyqtSignal(int) def __init__(self, ufo_task, shape, batch_id): super().__init__(parent=None) self.batch_id = batch_id self.data = np.empty(shape, dtype=np.float32) ptr = self.data.__array_interface__['data'][0] ufo_task.props.pointer = ptr ufo_task.props.max_size = self.data.nbytes ufo_task.connect('processed', self._on_processed) self.num_processed = 0 def _on_processed(self, ufo_task): self.num_processed += 1 if self.num_processed == self.data.shape[0]: self.finished.emit(self.batch_id) class UfoMemoryOutModel(UfoTaskModel): name = 'memory_out' num_ports = {PortType.input: 1, PortType.output: 1} data_type = {PortType.input: {0: UFO_DATA_TYPE}, PortType.output: {0: ARRAY_DATA_TYPE}} port_caption = {PortType.input: {0: ''}, PortType.output: {0: ''}} port_caption_visible = {PortType.input: {0: False}, PortType.output: {0: False}} def __init__(self, style=None, parent=None, scrollable=True): self._lock = Lock() self.reset_batches() super().__init__('memory-out', style=style, parent=parent, scrollable=scrollable) @property def expects_multiple_inputs(self): return self['number'] == '{region}' def make_properties(self): width_item = IntQLineEditViewItem(0, 1000000, default_value=0, tooltip='Input width') height_item = IntQLineEditViewItem(0, 1000000, default_value=0, tooltip='Input height') depth_item = IntQLineEditViewItem(0, 1000000, default_value=1, tooltip='Input depth (for 2D images should be 1)') number_item = QLineEditViewItem(default_value=1, tooltip='Number of inputs') properties = {'width': [width_item, True], 'height': [height_item, True], 'depth': [depth_item, True], 'number': [number_item, True]} return properties def consume_batch(self, batch_id): def consume(current_batch): LOG.debug(f'{self.caption}: consuming {current_batch.batch_id} (caller {batch_id})') self._current_data = current_batch.data self.data_updated.emit(0) # Free memory up self._batches[self._expecting_id] = None with self._lock: if self._expecting_id == batch_id: consume(self._batches[self._expecting_id]) self._expecting_id += 1 while self._expecting_id in self._waiting_list: consume(self._batches[self._expecting_id]) del self._waiting_list[self._waiting_list.index(self._expecting_id)] self._expecting_id += 1 else: LOG.debug(f'{self.caption}: putting {batch_id} on waiting list') self._waiting_list.append(batch_id) def out_data(self, port: int) -> NodeData: LOG.debug(f'{self.caption}: out_data shape:' f'{None if self._current_data is None else self._current_data.shape}') return self._current_data def reset_batches(self): self._batches = [] self._waiting_list = [] self._expecting_id = 0 self._current_data = None def _setup_ufo_task(self, ufo_task, region=None): if region is not None and not self.expects_multiple_inputs: raise UfoModelError('Memory Out got region without enabling multiple inputs. ' 'Type {region} in the "number" field to enable it.') number = int(self['number']) if region is None else len(np.arange(*region)) shape = (number, self['height'], self['width']) with self._lock: batch = _Batch(ufo_task, shape, len(self._batches)) self._batches.append(batch) batch.finished.connect(self.consume_batch) class ImageViewerModel(UfoModel): name = 'image_viewer' caption = 'Image Viewer' num_ports = {PortType.input: 1, PortType.output: 0, } data_type = ARRAY_DATA_TYPE def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._node_data = None from tofu.flow.viewer import ImageViewer self._widget = ImageViewer() self._reset = True def embedded_widget(self): return self._widget def resizable(self): return True def double_clicked(self, parent): try: if self._widget.images is not None and not self._widget.popup_visible: import pyqtgraph self._widget.popup() except ImportError: LOG.debug('pyqtgraph not installed, not popping up') def set_in_data(self, data: NodeData, port: Port): if data is not None: if self._reset: self._widget.images = data self._reset = False else: self._widget.append(data) def reset_batches(self): self._reset = True def cleanup(self): self._widget.cleanup() def get_ufo_model_classes(names=None): all_names = set(UFO_PLUGIN_MANAGER.get_all_task_names()) # stamp causes a gobject unref warning blacklist = set(['general-backproject', 'memory-in', 'memory-out', 'opencl', 'read', 'retrieve-phase', 'stamp', 'write']) all_names = list(all_names - blacklist) return (get_ufo_model_class(name) for name in names or all_names) def get_composite_model_classes_from_json(state): """ Get composite model classes from their json representation. This is recursive in case a user creates a composite inside the scene, then adds nodes and creates another composite with the first one inside and doesn't export explicitly the first one. The order of returned classes is bottom -> up, i.e. first the classes which have striclty non-composite submodels are returned and the top level class is last. """ classes = [] def go_down(current): connections = [CompositeConnection(*args) for args in current['connections']] submodels = [] for (key, model) in current['models'].items(): if 'models' in model['model'] and 'connections' in model['model']: go_down(current['models'][key]['model']) # models are tuples (name, state, visible, position) submodels.append((model['name'], model['model'], model['visible'], model['position'])) classes.append(get_composite_model_class(current['name'], submodels, connections, links=current.get('links', None))) go_down(state) return classes def get_composite_model_classes(): import xdg.BaseDirectory composite_lists = [] paths = [pkg_resources.resource_filename(__name__, 'composites'), xdg.BaseDirectory.save_data_path('tofu', 'flows', 'composites')] for path in paths: file_names = sorted(glob.glob(os.path.join(path, '*.cm'))) for file_name in file_names: LOG.debug(f'Loading composite from {file_name}') try: with open(file_name, 'r') as f: state = json.load(f) composite_lists.append(get_composite_model_classes_from_json(state)) except Exception as e: LOG.error(e, exc_info=True) return composite_lists class UfoModelError(FlowError): pass ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1683790582.0 ufo-tofu-0.13.0/tofu/flow/propertylinksmodels.py0000664000175000017500000003711300000000000022266 0ustar00tomastomas00000000000000import logging from PyQt5.QtCore import QDataStream, pyqtSignal from PyQt5.QtGui import QStandardItemModel, QStandardItem from tofu.flow.models import PropertyModel, BaseCompositeModel from tofu.flow.util import MODEL_ROLE, NODE_ROLE, PROPERTY_ROLE LOG = logging.getLogger(__name__) def _decode_mime_data(data): byte_array = data.data('application/x-sourcetreemodelindex') ds = QDataStream(byte_array) row = ds.readInt32() column = ds.readInt32() internal_id = ds.readUInt64() return (row, column, internal_id) def _data_from_tree_index(index): """ Traverse parents up to the root and get the root node, model and it's property from *index*, which must be a property record (leaf in the tree). """ prop_name = index.data() index = index.parent() model = index.data(role=MODEL_ROLE) while index.data(role=NODE_ROLE) is None and index.isValid(): index = index.parent() node = index.data(role=NODE_ROLE) return (node, model, prop_name) def _get_string_path(node, model, prop_name): if isinstance(node.model, BaseCompositeModel): path = node.model.get_path_from_model(model) else: path = [model] str_path = [model.caption for model in path] str_path.append(prop_name) return str_path class NodeTreeModel(QStandardItemModel): """Tree model representing nodes in the scene.""" def add_node(self, node): item = self._add_model(node.model) if item: item.setData(node, role=NODE_ROLE) def remove_node(self, node): for j in range(self.rowCount()): item = self.item(j, 0) if item and item.data(role=NODE_ROLE) == node: self.removeRow(j) break def clear(self): """In PyQt5, clear doesn't emit the rowsAboutToBeRemoved signal and this does effectively the same. """ self.removeRows(0, self.rowCount()) self.removeColumns(0, self.columnCount()) self.rowCount(), self.columnCount() def set_nodes(self, nodes): self.clear() for node in nodes: self.add_node(node) def _add_model(self, flow_model, parent=None): if not parent: parent = self.invisibleRootItem() item = None if (isinstance(flow_model, PropertyModel) or isinstance(flow_model, BaseCompositeModel)): item = QStandardItem(flow_model.caption) item.setData(flow_model, role=MODEL_ROLE) item.setEditable(False) if isinstance(flow_model, PropertyModel): for prop in sorted(flow_model): prop_item = QStandardItem(prop) prop_item.setEditable(False) item.appendRow(prop_item) else: for submodel_name in sorted(flow_model): self._add_model(flow_model[submodel_name], parent=item) if item: parent.appendRow(item) return item class PropertyLinksModel(QStandardItemModel): """Links model representing property links between nodes in the scene.""" restored = pyqtSignal() def __init__(self, node_model): super().__init__() self._silent = {} self._slaves = {} self._node_model = node_model self._node_model.rowsAboutToBeRemoved.connect(self.on_node_rows_about_to_be_removed) def __contains__(self, key): for column in range(self.columnCount()): if self.findItems(key, column=column): return True return False def clear(self): for j in range(self.rowCount()): for i in range(self.columnCount()): self.remove_item(self.indexFromItem(self.item(j, i))) super().clear() def find_items(self, data_list, roles): result = [] for j in range(self.rowCount()): for i in range(self.columnCount()): item = self.item(j, i) if item: success = True for (data, role) in zip(data_list, roles): if item.data(role=role) != data: success = False break if success: result.append(item) return result def get_model_links(self, models): """ Get links between *models*. Return dict {row index: [str_path, ...]}, where *str_path* is the path from the topmost model (in case of composites along the way) to the property name. """ items = {} for model in models: for item in self.find_items([model], [MODEL_ROLE]): str_path = item.text().split('->') if item.row() not in items: items[item.row()] = [str_path] else: items[item.row()].append(str_path) return items def get_root_model(self, model): root_model = None items = self.find_items([model], [MODEL_ROLE]) if items: root_model = items[0].data(role=MODEL_ROLE) else: for (silent_model, prop_name) in self._silent: if silent_model == model: root_model = self._silent[(silent_model, prop_name)][0] return root_model def get_model_properties(self, model): items = self.find_items([model], [MODEL_ROLE]) return [item.data(role=PROPERTY_ROLE) for item in items] def add_item(self, node, model, prop_name, row, column, insert=False): """ Add item where *node* is the root node (can be composite), *model* is the leaf model (there can be composites above if the leaf is nested) and *prop_name* is the property name. *row* and *column* determine the table cell to which to add the item or replace an old item with the new one. If *insert* is True, insert a new row at *row*. """ str_path = '->'.join(_get_string_path(node, model, prop_name)) if str_path in self: raise ValueError(f'{str_path} already inside') item = QStandardItem(str_path) item.setData(model, role=MODEL_ROLE) item.setData(prop_name, role=PROPERTY_ROLE) item.setData(node, role=NODE_ROLE) item.setEditable(False) if row == -1: row = self.rowCount() if column == -1: # +1 to find an empty cell even if the row is full for i in range(self.columnCount() + 1): if self.item(row, i) is None: column = i break LOG.debug(f'Add item {node.model.caption}({item.data(role=MODEL_ROLE)}):' f'{item.data(role=PROPERTY_ROLE)} at ({row}, {column})') if insert: self.insertRow(row, item) else: self.setItem(row, column, item) # In case the composite is being edit in a subwindow, connect the slave nodes from the # subsecene if isinstance(node.model, BaseCompositeModel): node.model.add_slave_links() model.property_changed.connect(self.on_property_changed) def remove_item(self, index): flow_model = index.data(role=MODEL_ROLE) if not flow_model: # Empty cell return property_name = index.data(role=PROPERTY_ROLE) flow_model.property_changed.disconnect(self.on_property_changed) self.setItem(index.row(), index.column(), None) # Remove all associated slaves root_key = (flow_model, property_name) if root_key in self._slaves: for slave_key in tuple(self._slaves[root_key]): self.remove_silent(*slave_key) def add_silent(self, model, property_name, root, root_property_name): key = (model, property_name) if key in self._silent: return model.property_changed.connect(self.on_property_changed) root_key = (root, root_property_name) if not self.find_items(root_key, (MODEL_ROLE, PROPERTY_ROLE)): raise ValueError(f'{model} not in property links') self._silent[key] = root_key if root_key not in self._slaves: self._slaves[root_key] = [key] else: self._slaves[root_key].append(key) LOG.debug(f'Slave {root}->{root_property_name} -> {model}->{property_name} added') def remove_silent(self, model, property_name): key = (model, property_name) if key not in self._silent: # Already removed, e.g. by deleting an item by del key while some composite windows were # still opened return model.property_changed.disconnect(self.on_property_changed) root_key = self._silent[key] index = self._slaves[root_key].index(key) del self._slaves[root_key][index] if not self._slaves[root_key]: del self._slaves[root_key] del self._silent[key] LOG.debug(f'Slave {model}->{property_name} removed') def replace_item(self, node, new_model, old_model): for j in range(self.rowCount()): for i in range(self.columnCount()): item = self.item(j, i) if item and item.data(role=MODEL_ROLE) == old_model: # Don't break, replace all properties of *old_model* prop_name = item.data(role=PROPERTY_ROLE) slaves = tuple(self._slaves.get((old_model, prop_name), [])) self.remove_item(self.indexFromItem(item)) self.add_item(node, new_model, prop_name, j, i) for (slave_model, slave_property_name) in slaves: self.add_silent(slave_model, slave_property_name, new_model, prop_name) def on_node_rows_about_to_be_removed(self, parent, first, last): for k in range(first, last + 1): node = self._node_model.item(k, 0).data(role=NODE_ROLE) for j in range(self.rowCount()): for i in range(self.columnCount()): item = self.item(j, i) if item and item.data(role=NODE_ROLE) == node: self.remove_item(self.indexFromItem(item)) self.compact() def canDropMimeData(self, data, action, row, column, parent): can_drop = False if data.hasFormat('application/x-sourcetreemodelindex'): src_row, src_column, src_internal_id = _decode_mime_data(data) src_model_index = self._node_model.createIndex(src_row, src_column, src_internal_id) # src_model_index is the property, it's parent is the model node, flow_model, property_name = _data_from_tree_index(src_model_index) str_path = '->'.join(_get_string_path(node, flow_model, property_name)) can_drop = str_path not in self if parent.isValid(): # Parent itself can be an empty cell, so use the first column which is for sure # occupied since the parent is valid (row exists and we are not between rows) first_item = self.item(parent.row(), 0) parent_model = first_item.data(role=MODEL_ROLE) parent_property_name = first_item.data(role=PROPERTY_ROLE) if not type(flow_model[property_name]) is type(parent_model[parent_property_name]): # Data can be dropped only if the types of properties match can_drop = False return can_drop def dropMimeData(self, data, action, row, column, parent): src_row, src_column, src_internal_id = _decode_mime_data(data) src_model_index = self._node_model.createIndex(src_row, src_column, src_internal_id) node, flow_model, property_name = _data_from_tree_index(src_model_index) if parent.isValid(): row = parent.row() insert = False else: insert = True # drops never replace items and column=-1 means "find an empty cell" self.add_item(node, flow_model, property_name, row, -1, insert=insert) return True def save(self): state = [] for j in range(self.rowCount()): row_state = [] for i in range(self.columnCount()): item = self.item(j, i) if not item: continue node = item.data(role=NODE_ROLE) model = item.data(role=MODEL_ROLE) prop_name = item.data(role=PROPERTY_ROLE) str_path = _get_string_path(node, model, prop_name) row_state.append([node.id, str_path]) state.append(row_state) return state def restore(self, state, nodes): self.clear() for (j, row) in enumerate(state): for (i, (node_id, path)) in enumerate(row): node = nodes[node_id] # Last path entry is the property name if isinstance(node.model, BaseCompositeModel): flow_model = node.model.get_model_from_path(path[1:-1]) else: flow_model = node.model self.add_item(node, flow_model, path[-1], j, i) self.restored.emit() def compact(self): # Shift rows to the left for j in range(self.rowCount()): filled = [] for i in range(self.columnCount()): if self.item(j, i): filled.append(self.takeItem(j, i)) for (i, item) in enumerate(filled): self.setItem(j, i, item) # Check empty rows for j in range(self.rowCount())[::-1]: is_empty = True for i in range(self.columnCount()): if self.item(j, i): is_empty = False if is_empty: self.removeRow(j) # Check empty columns for i in range(self.columnCount())[::-1]: is_empty = True for j in range(self.rowCount()): if self.item(j, i): is_empty = False if is_empty: self.removeColumn(i) def on_property_changed(self, sig_model, sig_property_name, value): LOG.debug(f'on_property_changed: {sig_model}, {sig_model.caption}, ' f'{sig_property_name}, {value}') sig_key = (sig_model, sig_property_name) if sig_key in self._silent: # pyqtSignal came from a composite subwindow, get root model from the silent slave root_key = self._silent[sig_key] root_key[0][root_key[1]] = value else: root_key = (sig_model, sig_property_name) row = -1 for j in range(self.rowCount()): for i in range(self.columnCount()): item = self.item(j, i) if (item and item.data(role=MODEL_ROLE) == root_key[0] and item.data(role=PROPERTY_ROLE) == root_key[1]): row = j break if row != -1: break for i in range(self.columnCount()): item = self.item(row, i) if item: model = item.data(role=MODEL_ROLE) property_name = item.data(role=PROPERTY_ROLE) if root_key != (model, property_name): model[property_name] = value # Notify all slaves key = (model, property_name) if key in self._slaves: for (slave_model, slave_property_name) in self._slaves[key]: if (slave_model, slave_property_name) != (sig_model, sig_property_name): slave_model[slave_property_name] = value ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1683790582.0 ufo-tofu-0.13.0/tofu/flow/propertylinkswidget.py0000664000175000017500000000702300000000000022263 0ustar00tomastomas00000000000000from PyQt5.QtCore import QMimeData, Qt, QDataStream, QByteArray, QIODevice, QModelIndex from PyQt5.QtGui import QDrag from PyQt5.QtWidgets import QAbstractItemView, QLabel, QTableView, QTreeView, QVBoxLayout, QWidget def _encode_mime_data(index: QModelIndex): """Encode item in *index* into :class:`QMimeData`.""" mime_data = QMimeData() data = QByteArray() stream = QDataStream(data, QIODevice.WriteOnly) try: stream.writeInt32(index.row()) stream.writeInt32(index.column()) stream.writeUInt64(index.internalId()) finally: stream.device().close() mime_data.setData("application/x-sourcetreemodelindex", data) return mime_data class PropertyLinksView(QTableView): """Table view for displaying node property links.""" def keyPressEvent(self, event): if event.key() == Qt.Key_Delete: model = self.model() for index in self.selectedIndexes(): model.remove_item(index) model.compact() class NodesView(QTreeView): """Tree view displaying nodes in the scene.""" def get_drag_index(self): selected = self.selectedIndexes() if not selected: return index = selected[0] if index.child(0, 0).row() != -1: return return index def mouseMoveEvent(self, event): """All that a mouse *event* can do is start a drag and drop operation.""" index = self.get_drag_index() if not index: return drag = QDrag(self) mime_data = _encode_mime_data(index) drag.setMimeData(mime_data) drag.exec_(Qt.CopyAction) return True class PropertyLinks(QWidget): """Widget displaying nodes in the scene and their property links in one window.""" def __init__(self, node_model, table_model, parent=None): super().__init__(parent=parent, flags=Qt.Window) self.setWindowTitle('Property Links') self.resize(600, 800) self._treeview = NodesView() self._treeview.setHeaderHidden(True) self._treeview.setAlternatingRowColors(True) self._treeview.setDragEnabled(True) self._treeview.setAcceptDrops(False) self._treeview.setModel(node_model) node_model.itemChanged.connect(self.on_node_model_changed) self._table_view = PropertyLinksView() self._table_view.setDragDropOverwriteMode(False) self._table_view.setDragDropMode(QAbstractItemView.DropOnly) table_model.itemChanged.connect(self.on_table_model_changed) table_model.rowsInserted.connect(self.on_table_model_rows_inserted) table_model.restored.connect(self.on_table_model_restored) self._table_view.setModel(table_model) main_layout = QVBoxLayout() main_layout.addWidget(self._treeview) main_layout.addWidget(QLabel('Drag properties from above to the area below')) main_layout.addWidget(self._table_view) self.setLayout(main_layout) def show(self): self._table_view.resizeColumnsToContents() self._treeview.sortByColumn(0, Qt.AscendingOrder) super().show() def on_table_model_changed(self, item): self._table_view.resizeColumnToContents(item.column()) def on_table_model_rows_inserted(self, index, start, stop): self._table_view.resizeColumnToContents(0) def on_table_model_restored(self): self._table_view.resizeColumnsToContents() def on_node_model_changed(self, item): self._treeview.sortByColumn(0, Qt.AscendingOrder) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1683790582.0 ufo-tofu-0.13.0/tofu/flow/runslider.py0000664000175000017500000001605300000000000020144 0ustar00tomastomas00000000000000from functools import partial from PyQt5.QtCore import Qt, pyqtSignal, QTimer from PyQt5 import QtGui from PyQt5.QtWidgets import QGridLayout, QLineEdit, QWidget, QSlider from tofu.flow.models import IntQLineEditViewItem, RangeQLineEditViewItem, UfoIntValidator from tofu.flow.util import FlowError class RunSlider(QWidget): value_changed = pyqtSignal(float) def __init__(self, parent=None): super().__init__(parent=parent, flags=Qt.Window) self.setWindowFlag(Qt.WindowStaysOnTopHint) self.setMaximumHeight(20) self.setMinimumWidth(600) self.min_edit = QLineEdit() self.min_edit.setToolTip('Minimum') self.min_edit.setMaximumWidth(80) self.min_edit.editingFinished.connect(self.on_min_edit_editing_finished) self.current_edit = QLineEdit() self.current_edit.setToolTip('Current value') self.current_edit.editingFinished.connect(self.on_current_edit_editing_finished) self.max_edit = QLineEdit() self.max_edit.setToolTip('Maximum') self.max_edit.setMaximumWidth(80) self.max_edit.editingFinished.connect(self.on_max_edit_editing_finished) self.slider = QSlider(orientation=Qt.Horizontal) self.slider.setMinimum(0) self.slider.setMaximum(100) self.slider.valueChanged.connect(self.on_slider_value_changed) main_layout = QGridLayout() main_layout.addWidget(self.current_edit, 0, 0, 1, 3, Qt.AlignHCenter) main_layout.addWidget(self.min_edit, 1, 0) main_layout.addWidget(self.slider, 1, 1) main_layout.addWidget(self.max_edit, 1, 2) self.setLayout(main_layout) self.view_item = None self.real_minimum = 0 self.real_maximum = 100 self.real_span = 100 self.type = None self._last_value = None self.setEnabled(False) def _update_range(self, current=None): self.real_span = self.real_maximum - self.real_minimum if current is not None: self.slider.blockSignals(True) self.slider.setValue(int(round((current - self.real_minimum) / self.real_span * 100))) self.slider.blockSignals(False) def get_real_value(self): # First convert possible exponents to float (in case UFO has huge defaults set) return self.type(float(self.current_edit.text())) def set_widget_value(self): value = self.get_real_value() self._last_value = value if isinstance(self.view_item, RangeQLineEditViewItem): value = [value] self.view_item.set(value) # Notify linked widgets self.view_item.property_changed.emit(self.view_item) def set_current_validator(self): if self.type == int: validator = UfoIntValidator(self.real_minimum, self.real_maximum) else: validator = QtGui.QDoubleValidator(self.real_minimum, self.real_maximum, 1000) self.current_edit.setValidator(validator) def setup(self, view_item): if self.view_item == view_item: return False current = view_item.get() if isinstance(view_item, RangeQLineEditViewItem): if len(current) > 1: return False self.type = float current = current[0] d_current = 0.1 * abs(current) if current else 100 self.real_minimum = current - d_current self.real_maximum = current + d_current else: self.type = int if isinstance(view_item, IntQLineEditViewItem) else float self.real_minimum = view_item.widget.validator().bottom() self.real_maximum = view_item.widget.validator().top() self.view_item = view_item self._update_range(current=current) _set_number(self.min_edit, self.real_minimum) _set_number(self.max_edit, self.real_maximum) _set_number(self.current_edit, current) self._last_value = current self.setEnabled(True) self.set_current_validator() return True def reset(self): self.real_minimum = 0 self.real_maximum = 100 self.real_span = 100 self._last_value = None self.type = None self.min_edit.setText('') self.max_edit.setText('') self.current_edit.setText('') self.setWindowTitle('') self.view_item = None self.setEnabled(False) def on_slider_value_changed(self, value): def delayed_update(init_value): current_value = self.slider.value() if init_value == current_value: self.set_widget_value() self.value_changed.emit(real_value) if self.view_item: real_value = self.slider.value() / 100 * self.real_span + self.real_minimum self.current_edit.setText('{:g}'.format(self.type(real_value))) func = partial(delayed_update, value) QTimer.singleShot(100, func) def on_current_edit_editing_finished(self): if not self.view_item: return try: value = self.type(self.current_edit.text()) except ValueError: raise RunSliderError('Not a number') if value == self._last_value: # Nothing new, do not emit value_changed signal in case the app is closing return self.slider.blockSignals(True) self.slider.setValue(int(round((value - self.real_minimum) / self.real_span * 100))) self.slider.blockSignals(False) self.set_widget_value() self.value_changed.emit(value) def on_min_edit_editing_finished(self): if not self.view_item: return try: value = self.type(self.min_edit.text()) except ValueError: raise RunSliderError('Not a number') if value >= self.real_maximum: raise RunSliderError('Minimum must be smaller than maximum') current = self.get_real_value() self.real_minimum = value if current < self.real_minimum: current = self.real_minimum self.current_edit.setText('{:g}'.format(current)) self.set_widget_value() self.value_changed.emit(current) self._update_range(current=current) self.set_current_validator() def on_max_edit_editing_finished(self): if not self.view_item: return try: value = self.type(self.max_edit.text()) except ValueError: raise RunSliderError('Not a number') if value <= self.real_minimum: raise RunSliderError('Maximum must be greater than minimum') current = self.get_real_value() self.real_maximum = value if current > self.real_maximum: current = self.real_maximum self.current_edit.setText('{:g}'.format(current)) self.set_widget_value() self.value_changed.emit(current) self._update_range(current=current) self.set_current_validator() def _set_number(edit, number): edit.setText('{:g}'.format(number)) class RunSliderError(FlowError): pass ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414563.0 ufo-tofu-0.13.0/tofu/flow/scene.py0000664000175000017500000004276500000000000017243 0ustar00tomastomas00000000000000import logging import numpy as np import networkx as nx from PyQt5.QtCore import pyqtSignal, QObject from PyQt5.QtWidgets import QInputDialog from qtpynodeeditor import FlowScene, NodeDataModel, PortType, opposite_port from tofu.flow.models import (BaseCompositeModel, ImageViewerModel, PropertyModel, UFO_DATA_TYPE, get_composite_model_class, get_composite_model_classes_from_json) from tofu.flow.util import CompositeConnection, FlowError, saved_kwargs from tofu.flow.propertylinksmodels import PropertyLinksModel, NodeTreeModel LOG = logging.getLogger(__name__) class UfoScene(FlowScene): nodes_duplicated = pyqtSignal(list, dict) # view item, its name and model name item_focus_in = pyqtSignal(QObject, str, str, NodeDataModel) def __init__(self, registry=None, style=None, parent=None, allow_node_creation=True, allow_node_deletion=True): super().__init__(registry=registry, style=style, parent=parent, allow_node_creation=allow_node_creation, allow_node_deletion=allow_node_deletion) self._composite_nodes = {} self._selected_nodes_on_disabled = [] self.node_model = NodeTreeModel() self.node_model.setColumnCount(1) self.property_links_model = PropertyLinksModel(self.node_model) self.style_collection.node.opacity = 1 self.style_collection.connection.use_data_defined_colors = True self.node_double_clicked.connect(self.on_node_double_clicked) def __getstate__(self): state = super().__getstate__() state['property-links'] = self.property_links_model.save() return state def __setstate__(self, doc): for node in doc['nodes']: model = node['model'] if 'models' in model and 'connections' in model: # First register the composite model models = get_composite_model_classes_from_json(model) for model in models: self.registry.register_model(model, category='Composite', registry=self.registry) # Restore the scene super().__setstate__(doc) # and the property link models and widgets if 'property-links' in doc: self.node_model.set_nodes(self.nodes.values()) self.property_links_model.restore(doc['property-links'], self.nodes) def create_node(self, data_model, restore_links=True): """Overrides :class:`FlowScene` in order to create a node with *data_model* with a unique caption. """ LOG.debug(f'Create node with model {data_model}') node = super().create_node(data_model) self._setup_new_node(node) if restore_links and isinstance(node.model, BaseCompositeModel): node.model.restore_links(node) return node def restore_node(self, node_json): LOG.debug(f"Restore node with model {node_json['model']['name']}") with saved_kwargs(self.registry, node_json['model']): node = super().restore_node(node_json) self._setup_new_node(node) return node def on_item_focus_in(self, view_item, prop_name, caption, model): self.item_focus_in.emit(view_item, prop_name, caption, model) def _setup_new_node(self, node): self._set_unique_caption(node) self.node_model.add_node(node) if isinstance(node.model, BaseCompositeModel): node.model.property_links_model = self.property_links_model node.model.item_focus_in.connect(self.on_item_focus_in) def _set_unique_caption(self, new_node): caption = new_node.model.caption captions = [node.model.caption for node in self.nodes.values() if node != new_node] if caption in captions: fmt = new_node.model.base_caption + ' {}' i = 2 while fmt.format(i) in captions: i += 1 caption = fmt.format(i) new_node.model.caption = caption def remove_node(self, node): if hasattr(node.model, 'cleanup'): node.model.cleanup() if (isinstance(node.model, BaseCompositeModel) and node.model.name in self._composite_nodes): del self._composite_nodes[node.model.name] self.node_model.remove_node(node) super().remove_node(node) def is_selected_one_composite(self): result = False nodes = self.selected_nodes() if len(nodes) == 1: result = isinstance(nodes[0].model, BaseCompositeModel) return result def skip_nodes(self): selected_nodes = self.selected_nodes() # First check if the selected nodes may be skipped for node in selected_nodes: if (node.model.num_ports[PortType.input] != 1 or node.model.num_ports[PortType.output] != 1): raise FlowError('Only nodes with one input and one output can be skipped') ports = list(node.state.ports) if ports[0].data_type != UFO_DATA_TYPE or ports[1].data_type != UFO_DATA_TYPE: raise FlowError('Only tasks with UFO input and output can be skipped') # And only if all is fine, then skip them for node in selected_nodes: node.model.skip = not node.model.skip opacity = 0.5 if node.model.skip else 1 node.state.input_connections[0].graphics_object.setOpacity(opacity) node.state.output_connections[0].graphics_object.setOpacity(opacity) node.graphics_object.setOpacity(opacity) def auto_fill(self): for node in self.nodes.values(): if isinstance(node.model, BaseCompositeModel): paths = node.model.get_leaf_paths() else: paths = [[node.model]] for path in paths: model = path[-1] if isinstance(model, PropertyModel): model.auto_fill() def copy_nodes(self): new_nodes = {} selected_nodes = self.selected_nodes() # Create nodes for node in selected_nodes: new_node = self.create_node(node.model) new_nodes[node] = new_node values = node.model.save() new_node.model.restore(values, restore_caption=False) # Create connections for node, new_node in new_nodes.items(): for connection in self.connections: port = connection.ports[0] in_index = port.index out_index = connection.ports[1].index if port.node == node: other_node = connection.ports[1].node if other_node in new_nodes: # Other node has been also selected self.create_connection_by_index(new_node, in_index, new_nodes[other_node], out_index, None) self.nodes_duplicated.emit(selected_nodes, new_nodes) def create_composite(self): composite_name, ok = QInputDialog.getText(None, 'Create Composite Node', 'Name:') if not ok: return if composite_name in self.registry.registered_model_creators(): raise FlowError(f'Composite node with name "{composite_name}" has already ' 'been registered') self._composite_nodes[composite_name] = {} connection_replacements = [] models = [] connections = [] selected_nodes = self.selected_nodes() for node in selected_nodes: unique_name = node.model.caption models.append((node.model.name, node.model.save(), True, node.__getstate__()['position'])) self._composite_nodes[composite_name][unique_name] = node.__getstate__() # Connections assigned_ports = [] x = [] y = [] for node in selected_nodes: x.append(node.position.x()) y.append(node.position.y()) for port_type in ['input', 'output']: for index, port in node[port_type].items(): if port.connections: # We allow only one connection conn = port.connections[0] other_port = conn.ports[0] if conn.ports[1] == port else conn.ports[1] other = conn.get_node(opposite_port(port_type)) if (other in selected_nodes and port not in assigned_ports and other_port not in assigned_ports): # Connection reaches to a node outside selection if port_type == PortType.input: to_node_name = node.model.caption to_node_index = index from_node_name = other.model.caption from_node_index = other_port.index else: to_node_name = other.model.caption to_node_index = other_port.index from_node_name = node.model.caption from_node_index = index conn = CompositeConnection(from_node_name, from_node_index, to_node_name, to_node_index) connections.append(conn) assigned_ports.append(port) if other not in selected_nodes: inside = (node.model.caption, port_type, index) connection_replacements.append((other_port, inside)) # Get links which will be internal to the newly created model node_models = [] for selected_node in self.selected_nodes(): if isinstance(selected_node.model, BaseCompositeModel): paths = selected_node.model.get_leaf_paths() else: paths = [[selected_node.model]] node_models += [path[-1] for path in paths] internal_links = list(self.property_links_model.get_model_links(node_models).values()) composite = get_composite_model_class(composite_name, models, connections, links=internal_links) self.registry.register_model(composite, category='Composite', registry=self.registry) node = self.create_node(composite, restore_links=False) for selected_node in selected_nodes: if isinstance(selected_node.model, BaseCompositeModel): # Get all leaf PropertyModel instances paths = selected_node.model.get_leaf_paths() else: paths = [[selected_node.model]] # In case selected node is composite, replace all leaf node links for path in paths: new_model = node.model.get_model_from_path([model.caption for model in path]) self.property_links_model.replace_item(node, new_model, path[-1]) self.remove_node(selected_node) for outside_port, inside in connection_replacements: port_type, index = node.model.get_outside_port(*inside) self.create_connection(outside_port, node[port_type][index], check_cycles=False) # Put the new composite node to the average of x and y position of the selected nodes node.position = (np.mean(x), np.mean(y)) node.graphics_object.setSelected(True) return node def on_node_double_clicked(self, node): views = self.views() if views: node.model.double_clicked(views[0]) def expand_composite(self, node): name = node.model.name original_nodes = self._composite_nodes.get(name, None) return node.model.expand_into_scene(self, node, original_nodes=original_nodes) def is_fully_connected(self): """Are all the ports in all nodes connected?""" def are_ports_connected(node, port_type): for port in node[port_type].values(): if not port.connections: return False return True for node in self.nodes.values(): if not are_ports_connected(node, 'input'): return False if not are_ports_connected(node, 'output'): return False return True def are_all_ufo_tasks(self, graphs=None): """If all inputs and outputs of all models in all *graphs* have `UfoBuffer` data type, return True. If *graphs* are not specified, they are created from the scene. """ if graphs is None: graphs = self.get_simple_node_graphs() for graph in graphs: for model in graph.nodes: for port_type in ['input', 'output']: for data_type in model.data_type[port_type].values(): if data_type.id != 'UfoBuffer': return False return True def get_simple_node_graphs(self): """ Get a graph from the scene without composite nodes which can be directly used byt the execution. """ def get_composite(graph): """Get first found composite model.""" for model in graph.nodes: if isinstance(model, BaseCompositeModel): return model def replace_edge(graph, composite, edges, port_type): """Replace interface edges (going in or out from the composite model).""" for edge in edges: ports = graph.edges[edge] other = edge[0] if port_type == PortType.input else edge[1] model, index = composite.get_model_and_port_index(port_type, ports[port_type]) if model not in graph: graph.add_node(model) if port_type == PortType.input: source = other dest = model input_port = index output_port = ports[PortType.output] else: source = model dest = other input_port = ports[PortType.input] output_port = index LOG.debug(f'Adding edge {source.name}@{output_port} -> {dest.name}@{input_port}') graph.add_edge(source, dest, input=input_port, output=output_port) def replace_composite(graph, composite): composite.expand_into_graph(graph) edges = graph.in_edges(composite, keys=True) replace_edge(graph, composite, edges, PortType.input) edges = graph.out_edges(composite, keys=True) replace_edge(graph, composite, edges, PortType.output) graph.remove_node(composite) # Initial graph with composite nodes. We need a multigraph because composite nodes may have # many outputs which can lead to a same destination node. graph = nx.MultiDiGraph() for node in self.nodes.values(): if not node.model.skip: graph.add_node(node.model) for conn in self.connections: p_dest, p_source = conn.ports if p_dest.node.model.skip: LOG.debug(f'Skiping connection {p_source.node.model.name} -> ' f'{p_dest.node.model.name}') continue while p_source.node.model.skip: LOG.debug(f'Skiping connection {p_source.node.model.name} -> ' f'{p_dest.node.model.name}') previous_conn = p_source.node.state.input_connections[0] previous_node = previous_conn.output_node p_source = list(previous_node.state.output_ports)[0] graph.add_edge(p_source.node.model, p_dest.node.model, input=p_dest.index, output=p_source.index) # Expand composite nodes until there are only simple ones left model = get_composite(graph) while model: LOG.debug(f'Replacing composite {model.name}') replace_composite(graph, model) model = get_composite(graph) components = nx.weakly_connected_components(graph) return [nx.subgraph(graph, component) for component in components] def set_enabled(self, enabled): selected_nodes = self.selected_nodes() self.allow_node_creation = enabled self.allow_node_deletion = enabled for node in self.nodes.values(): if not isinstance(node.model, ImageViewerModel): node.graphics_object.setEnabled(enabled) if enabled: if node in self._selected_nodes_on_disabled: node.graphics_object.setSelected(True) else: if node in selected_nodes: self._selected_nodes_on_disabled.append(node) for conn in self.connections: conn._graphics_object.setEnabled(enabled) if enabled: self._selected_nodes_on_disabled = [] ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1683790582.0 ufo-tofu-0.13.0/tofu/flow/util.py0000664000175000017500000000443200000000000017110 0ustar00tomastomas00000000000000import contextlib import json import pkg_resources from PyQt5.QtCore import Qt from qtpynodeeditor import PortType MODEL_ROLE = Qt.UserRole + 1 PROPERTY_ROLE = MODEL_ROLE + 1 NODE_ROLE = PROPERTY_ROLE + 1 with open(pkg_resources.resource_filename(__name__, 'config.json')) as f: ENTRIES = json.load(f) def get_config_key(*keys, default=None): current = ENTRIES.get(keys[0], default) if current != default and len(keys) > 1: for key in keys[1:]: current = current.get(key, default) if current == default: break return current @contextlib.contextmanager def saved_kwargs(registry, state): """ Tell the registry to use the number of saved inputs for model creation but only for one model creation, i.e. reset the context afterward. """ if 'num-inputs' in state: kwargs = registry.registered_model_creators()[state['name']][1] kwargs['num_inputs'] = state['num-inputs'] try: yield finally: if 'num-inputs' in state: del kwargs['num_inputs'] class CompositeConnection: def __init__(self, from_unique_name, from_port_index, to_unique_name, to_port_index): if from_unique_name == to_unique_name: raise ValueError('from_unique_name and to_unique_name must be different') self.from_unique_name = from_unique_name self.from_port_index = from_port_index self.to_unique_name = to_unique_name self.to_port_index = to_port_index def contains(self, unique_name, port_type, port_index): is_from = is_to = False if port_type == PortType.output: is_from = (unique_name == self.from_unique_name and port_index == self.from_port_index) else: is_to = (unique_name == self.to_unique_name and port_index == self.to_port_index) return is_from or is_to def save(self): return [self.from_unique_name, self.from_port_index, self.to_unique_name, self.to_port_index] def __str__(self): return repr(self) def __repr__(self): fmt = 'Connection({}@{} -> {}@{})' return fmt.format(self.from_unique_name, self.from_port_index, self.to_unique_name, self.to_port_index) class FlowError(Exception): pass ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1683790582.0 ufo-tofu-0.13.0/tofu/flow/viewer.py0000664000175000017500000004510000000000000017431 0ustar00tomastomas00000000000000import logging import numpy as np import os from PyQt5 import QtGui from PyQt5.QtCore import Qt from PyQt5.QtWidgets import QFileDialog, QGridLayout, QLabel, QLineEdit, QMenu, QWidget, QSlider from tofu.flow.util import FlowError LOG = logging.getLogger(__name__) class ScreenImage: """On-screen image representation.""" def __init__(self, image=None): self._black_point = None self._white_point = None self.minimum = None self.maximum = None self.image = image @property def image(self): return self._image @image.setter def image(self, image): """ Keep the minimum, maximum, black and white points as they are so that images don't flicker when going through a sequence. """ self._image = image if self._image is not None: self._image = image.astype(np.float32) if self.minimum is None: self.minimum = np.nanmin(self._image) if self.maximum is None: self.maximum = np.nanmax(self._image) if self.black_point is None: self.black_point = self.minimum if self.white_point is None: self.white_point = self.maximum @property def white_point(self): return self._white_point @white_point.setter def white_point(self, value): if self.black_point is not None and value < self.black_point: raise ImageViewingError('White point cannot be smaller than black point') self._white_point = value @property def black_point(self): return self._black_point @black_point.setter def black_point(self, value): if self.white_point is not None and value > self.white_point: raise ImageViewingError('Black point cannot be greater than white point') self._black_point = value def reset(self): """Reset black and white points.""" if self._image is not None: self.minimum = np.nanmin(self._image) self.maximum = np.nanmax(self._image) self._black_point = self.minimum self._white_point = self.maximum def auto_levels(self, percentile=0.1): """ Compute cumulative histogram normalized to [0, 100] and truncate gray values which fall below *percentile* or above 100 - *percentile*. """ hist, bins = np.histogram(self._image, bins=256) cumsum = np.cumsum(hist) / float(np.sum(hist)) * 100 valid = bins[np.where((cumsum > percentile) & (cumsum < 100 - percentile))] if len(valid): self.black_point = valid[0] self.white_point = valid[-1] else: self.black_point = self.white_point = self._image[0, 0] def set_black_point_normalized(self, value): """Set black point according to *value*, where value is from interval [0, 255].""" native = self.convert_normalized_value_to_native(value) if native > self.white_point: raise ImageViewingError('Black point cannot be greater than white point') self.black_point = native def set_white_point_normalized(self, value): """Set white point according to *value*, where value is from interval [0, 255].""" native = self.convert_normalized_value_to_native(value) if native < self.black_point: raise ImageViewingError('White point cannot be smaller than white point') self.white_point = native def convert_normalized_value_to_native(self, value): """Convert *value* from interval [0, 255] to the gray value in the image.""" if value < 0 or value > 255: raise ImageViewingError('Normalized value must be in interval [0, 255]') span = self.maximum - self.minimum return value / 255 * span + self.minimum def convert_native_value_to_normalized(self, value): """Convert gray value in the image to a normalized value in interval [0, 255].""" if value < self.minimum or value > self.maximum: raise ImageViewingError(f'Value must be in interval [{self.minimum}, {self.maximum}]') span = self.maximum - self.minimum return (value - self.minimum) / span * 255 if span > 0 else 0 def get_pixmap(self, downsampling=1): """Get :class:`QPixmap` for display.""" if self.black_point is None or self.white_point is None: raise ImageViewingError('Image has not been set') image = self.image[::downsampling, ::downsampling] - self.black_point if self.white_point - self.black_point > 0: image = np.clip(image * 255 / (self.white_point - self.black_point), 0, 255) image = image.astype(np.uint8) qim = QtGui.QImage(image, image.shape[1], image.shape[0], image[0].nbytes, QtGui.QImage.Format.Format_Grayscale8) return QtGui.QPixmap.fromImage(qim) class ImageLabel(QLabel): """QLabel holding the image data.""" def __init__(self, screen_image=None, parent=None): super().__init__(parent=parent) self.screen_image = screen_image def updateImage(self): if self.screen_image and self.screen_image.image is not None: hd = self.screen_image.image.shape[1] // self.width() vd = self.screen_image.image.shape[0] // self.height() downsampling = max(min(hd, vd), 1) pixmap = self.screen_image.get_pixmap(downsampling=downsampling) self.setPixmap(pixmap.scaled(self.width(), self.height(), Qt.KeepAspectRatio)) def resizeEvent(self, event): self.updateImage() class ImageViewer(QWidget): edit_height = 16 edit_width = 100 def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._images = None self._last_save_dir = '.' # Pyqtgraph popped up window self._pg_window = None self.screen_image = ScreenImage() self.new_image_auto_levels = True self.label = ImageLabel(self.screen_image) self.label.setAlignment(Qt.AlignVCenter | Qt.AlignCenter) self.slider_edit = QLineEdit() self.slider_edit.setFixedSize(self.edit_width, self.edit_height) self.slider_edit.returnPressed.connect(self.on_slider_edit_return_pressed) self.slider = QSlider(Qt.Horizontal) validator = QtGui.QIntValidator(0, self.slider.maximum()) self.slider_edit.setValidator(validator) self.slider.valueChanged.connect(self.on_slider_value_changed) self.min_slider = QSlider(Qt.Horizontal) self.max_slider = QSlider(Qt.Horizontal) self.min_slider_edit = QLineEdit() self.min_slider_edit.setFixedSize(self.edit_width, self.edit_height) self.max_slider_edit = QLineEdit() self.max_slider_edit.setFixedSize(self.edit_width, self.edit_height) self.min_slider.setMinimum(0) self.max_slider.setMinimum(0) self.min_slider.setMaximum(255) self.max_slider.setMaximum(255) self.max_slider.setValue(255) self.min_slider.valueChanged.connect(self.on_min_slider_value_changed) self.max_slider.valueChanged.connect(self.on_max_slider_value_changed) self.min_slider_edit.returnPressed.connect(self.on_min_slider_edit_return_pressed) self.max_slider_edit.returnPressed.connect(self.on_max_slider_edit_return_pressed) # Tooltips self.slider.setToolTip('Image index in sequence') self.slider_edit.setToolTip(self.slider.toolTip()) self.min_slider.setToolTip('Black point') self.min_slider_edit.setToolTip(self.min_slider.toolTip()) self.max_slider.setToolTip('White point') self.max_slider_edit.setToolTip(self.min_slider.toolTip()) mainLayout = QGridLayout() mainLayout.addWidget(self.label, 0, 0, 1, 2) mainLayout.addWidget(self.slider_edit, 1, 0) mainLayout.addWidget(self.slider, 1, 1) mainLayout.addWidget(self.min_slider_edit, 2, 0) mainLayout.addWidget(self.min_slider, 2, 1) mainLayout.addWidget(self.max_slider_edit, 3, 0) mainLayout.addWidget(self.max_slider, 3, 1) self.setLayout(mainLayout) def contextMenuEvent(self, event): contextMenu = QMenu(self) reset_action = contextMenu.addAction('Reset') auto_levels_action = contextMenu.addAction('Auto Levels') new_image_auto_levels = contextMenu.addAction('Auto Levels on New Image') new_image_auto_levels.setCheckable(True) new_image_auto_levels.setChecked(self.new_image_auto_levels) pop_action = None save_action = None try: import pyqtgraph if self._images is not None and not self.popup_visible: pop_action = contextMenu.addAction('Pop Up') except: LOG.debug('pyqtgraph not installed, pop up option disabled') try: import imageio if self._images is not None: save_action = contextMenu.addAction('Save') except: LOG.debug('imageio not installed, save option disabled') action = contextMenu.exec_(self.mapToGlobal(event.pos())) if not action: return if action == save_action: file_name, _ = QFileDialog.getSaveFileName(None, "Select File Name", self._last_save_dir, "Images (*.tif *.png *.jpg)") if file_name: if not os.path.splitext(file_name)[1]: file_name += '.tif' self._last_save_dir = os.path.dirname(file_name) if self._images.shape[0] == 1: imageio.imsave(file_name, self._images[0]) else: if os.path.splitext(file_name)[1] != '.tif': raise ImageViewingError('3D data can be stored only in tif format') # bigtiff size from tifffile imageio.volsave(file_name, self._images, bigtiff=self._images.nbytes > 2 ** 32 - 2 ** 25) elif action == reset_action: self.reset_clim() elif action == auto_levels_action: self.reset_clim(auto=True) elif action == new_image_auto_levels: self.new_image_auto_levels = action.isChecked() elif action == pop_action: self.popup() @property def images(self): return self._images @images.setter def images(self, images): was_none = self._images is None self._images = images if self._images is None: self.screen_image.image = None self.set_enabled_adjustments(False) return self.set_enabled_adjustments(True) if self._images.ndim == 2: self._images = self._images[np.newaxis, :, :] if self._images.shape[0] == 1: self.slider.hide() self.slider_edit.hide() else: self.slider.setMaximum(len(self._images) - 1) self.slider.show() self.slider_edit.show() self.slider_edit.setText('0') self.slider.blockSignals(True) self.slider.setValue(0) self.slider.blockSignals(False) if self._pg_window is not None: self._update_pg_window_images() self._update_pg_window_index() self.screen_image.image = self._images[0] if was_none or self.new_image_auto_levels: self.reset_clim(auto=True) else: self.label.updateImage() validator = self.min_slider_edit.validator() if validator is None: validator = QtGui.QDoubleValidator(self.screen_image.minimum, self.screen_image.maximum, 100) self.min_slider_edit.setValidator(validator) self.max_slider_edit.setValidator(validator) else: validator.setRange(self.screen_image.minimum, self.screen_image.maximum, 100) self.slider_edit.validator().setTop(self.slider.maximum()) if self.label.width() < 256 or self.label.height() < 256: self.label.resize(256, 256) def append(self, images): if self.images is None: self.images = images else: if images.ndim == 2: images = images[np.newaxis, :, :] if images.shape[1:] != self.images.shape[1:]: raise ImageViewingError('Appended images have different shape ' f'{images.shape[1:]} than the displayed ones ' f'{self.images.shape[1:]}') self.images = np.concatenate((self.images, images)) def set_enabled_adjustments(self, enabled): self.slider.setEnabled(enabled) self.slider_edit.setEnabled(enabled) self.min_slider.setEnabled(enabled) self.min_slider_edit.setEnabled(enabled) self.max_slider.setEnabled(enabled) self.max_slider_edit.setEnabled(enabled) def reset_clim(self, auto=False): self.screen_image.reset() if auto: self.screen_image.auto_levels() self.min_slider_edit.setText('{:g}'.format(self.screen_image.black_point)) self.max_slider_edit.setText('{:g}'.format(self.screen_image.white_point)) self.set_slider_value(self.min_slider, self.screen_image.black_point) self.set_slider_value(self.max_slider, self.screen_image.white_point) self.label.updateImage() self._update_pg_window_lut() @property def popup_visible(self): return self._pg_window and self._pg_window.isVisible() def popup(self): import pyqtgraph pyqtgraph.setConfigOptions(antialias=True, imageAxisOrder='row-major') if self._pg_window is not None: if not self._pg_window.isVisible(): self._pg_window.show() return def on_pg_window_time_changed(index, time): self._set_index(index) self.slider.blockSignals(True) self.slider_edit.setText(str(index)) self.slider.setValue(index) self.slider.blockSignals(False) def on_pg_window_levels_changed(hist_item): minimum, maximum = hist_item.getLevels() if (self.screen_image.minimum <= minimum <= self.screen_image.maximum and self.screen_image.minimum <= maximum <= self.screen_image.maximum): self.min_slider_edit.setText('{:g}'.format(minimum)) self.set_slider_value(self.min_slider, minimum) self.max_slider_edit.setText('{:g}'.format(maximum)) self.set_slider_value(self.max_slider, maximum) self.screen_image.black_point = minimum self.screen_image.white_point = maximum self.label.updateImage() def pg_mouse_moved(ev): if self._pg_window.imageItem.sceneBoundingRect().contains(ev): pos = self._pg_window.imageItem.mapFromScene(ev) x = int(pos.x() + 0.5) y = int(pos.y() + 0.5) self._pg_window.view.setTitle('x={}, y={}, I={:g}'.format(x, y, self._pg_window.imageItem.image[y, x])) else: self._pg_window.view.setTitle('') self._pg_window = pyqtgraph.ImageView(view=pyqtgraph.PlotItem()) self._pg_window.imageItem.scene().sigMouseMoved.connect(pg_mouse_moved) self._pg_window.setWindowFlag(Qt.SubWindow, True) self._update_pg_window_images() self._update_pg_window_index() self._update_pg_window_lut() self._pg_window.show() self._pg_window.sigTimeChanged.connect(on_pg_window_time_changed) self._pg_window.ui.histogram.item.sigLevelsChanged.connect(on_pg_window_levels_changed) def cleanup(self): if self._pg_window: self._pg_window.close() self._pg_window = None def _set_index(self, index): self.screen_image.image = self.images[index] self.label.updateImage() def _update_pg_window_images(self): if self.images.shape[0] == 1: im_to_set = self.images[0] else: im_to_set = self.images self._pg_window.setImage(im_to_set, autoLevels=False) def _update_pg_window_index(self): if self._images.shape[0] > 1 and self._pg_window is not None: self._pg_window.blockSignals(True) self._pg_window.setCurrentIndex(self.slider.value()) self._pg_window.blockSignals(False) def _update_pg_window_lut(self): if self._pg_window is not None: self._pg_window.ui.histogram.item.blockSignals(True) self._pg_window.setLevels(self.screen_image.black_point, self.screen_image.white_point) self._pg_window.ui.histogram.item.blockSignals(False) def on_slider_value_changed(self, value): self._set_index(value) self.slider_edit.setText(str(value)) self._update_pg_window_index() def on_slider_edit_return_pressed(self): self.slider.setValue(int(self.slider_edit.text())) def on_min_slider_edit_return_pressed(self): value = float(self.min_slider_edit.text()) if value < self.screen_image.white_point: self.screen_image.black_point = value self.set_slider_value(self.min_slider, value) self.label.updateImage() self._update_pg_window_lut() def on_max_slider_edit_return_pressed(self): value = float(self.max_slider_edit.text()) if value > self.screen_image.black_point: self.screen_image.white_point = value self.set_slider_value(self.max_slider, value) self.label.updateImage() self._update_pg_window_lut() def on_min_slider_value_changed(self, value): self.screen_image.set_black_point_normalized(value) self.min_slider_edit.setText('{:g}'.format(self.screen_image.black_point)) self.label.updateImage() self._update_pg_window_lut() def on_max_slider_value_changed(self, value): self.screen_image.set_white_point_normalized(value) self.max_slider_edit.setText('{:g}'.format(self.screen_image.white_point)) self.label.updateImage() self._update_pg_window_lut() def set_slider_value(self, slider, value): slider.blockSignals(True) slider.setValue(int(self.screen_image.convert_native_value_to_normalized(value))) slider.blockSignals(False) class ImageViewingError(FlowError): pass ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414587.0 ufo-tofu-0.13.0/tofu/genreco.py0000664000175000017500000007747300000000000016625 0ustar00tomastomas00000000000000"""General projection-based reconstruction for tomographic/laminographic cone/parallel beam data sets. """ import copy import itertools import logging import os import time import numpy as np from multiprocessing.pool import ThreadPool from threading import Event, Thread from gi.repository import Ufo from .preprocess import create_preprocessing_pipeline from .util import (fbp_filtering_in_phase_retrieval, get_filtering_padding, get_reconstructed_cube_shape, get_reconstruction_regions, get_filenames, determine_shape, get_scarray_value, Vector) from .tasks import get_task, get_writer LOG = logging.getLogger(__name__) DTYPE_CL_SIZE = {'float': 4, 'double': 8, 'half': 2, 'uchar': 1, 'ushort': 2, 'uint': 4} def genreco(args): st = time.time() if is_output_single_file(args): try: import ufo.numpy except ImportError: LOG.error('You must install ufo python support (in ufo-core/python) to be able to write single-file output') return if (args.energy is not None and args.propagation_distance is not None and not (args.projection_margin or args.disable_projection_crop)): LOG.warning('Phase retrieval without --projection-margin specification or ' '--disable-projection-crop may cause convolution artifacts') _fill_missing_args(args) _convert_angles_to_rad(args) set_projection_filter_scale(args) x_region, y_region, z_region = get_reconstruction_regions(args, store=True, dtype=float) vol_shape = get_reconstructed_cube_shape(x_region, y_region, z_region) bpp = DTYPE_CL_SIZE[args.store_type] num_voxels = vol_shape[0] * vol_shape[1] * vol_shape[2] vol_nbytes = num_voxels * bpp resources = [Ufo.Resources()] gpus = np.array(resources[0].get_gpu_nodes()) gpu_indices = np.array(args.gpus or list(range(len(gpus)))) if min(gpu_indices) < 0 or max(gpu_indices) > len(gpus) - 1: raise ValueError('--gpus contains invalid indices') gpus = gpus[gpu_indices] duration = 0 for i, gpu in enumerate(gpus): print('Max mem for {}: {:.2f} GB'.format(i, gpu.get_info(0) / 2. ** 30)) runs = make_runs(gpus, gpu_indices, x_region, y_region, z_region, bpp, slices_per_device=args.slices_per_device, slice_memory_coeff=args.slice_memory_coeff, data_splitting_policy=args.data_splitting_policy, num_gpu_threads=args.num_gpu_threads) for i in range(len(runs[0]) - 1): resources.append(Ufo.Resources()) LOG.info('Number of passes: %d', len(runs)) LOG.debug('GPUs and regions:') for regions in runs: LOG.debug('%s', str(regions)) for i, regions in enumerate(runs): duration += _run(resources, args, x_region, y_region, regions, i, vol_nbytes) num_gupdates = num_voxels * args.number * 1e-9 total_duration = time.time() - st LOG.debug('UFO duration: %.2f s', duration) LOG.debug('Total duration: %.2f s', total_duration) LOG.debug('UFO performance: %.2f GUPS', num_gupdates / duration) LOG.debug('Total performance: %.2f GUPS', num_gupdates / total_duration) def make_runs(gpus, gpu_indices, x_region, y_region, z_region, bpp, slices_per_device=None, slice_memory_coeff=0.8, data_splitting_policy='one', num_gpu_threads=1): gpu_indices = np.array(gpu_indices) def _add_region(runs, gpu_index, current, to_process, z_start, z_step): current_per_thread = current // num_gpu_threads for i in range(num_gpu_threads): if i + 1 == num_gpu_threads: current_per_thread += current % num_gpu_threads z_end = z_start + current_per_thread * z_step runs[-1].append((gpu_indices[gpu_index], [z_start, z_end, z_step])) z_start = z_end return z_start, z_end, to_process - current z_start, z_stop, z_step = z_region y_start, y_stop, y_step = y_region x_start, x_stop, x_step = x_region slice_width, slice_height, num_slices = get_reconstructed_cube_shape(x_region, y_region, z_region) if slices_per_device: slices_per_device = [slices_per_device for i in range(len(gpus))] else: slices_per_device = get_num_slices_per_gpu(gpus, slice_width, slice_height, bpp, slice_memory_coeff=slice_memory_coeff) max_slices_per_pass = sum(slices_per_device) if not max_slices_per_pass: raise RuntimeError('None of the available devices has enough memory to store any slices') num_full_passes = num_slices // max_slices_per_pass LOG.debug('Number of slices: %d', num_slices) LOG.debug('Slices per device %s', slices_per_device) LOG.debug('Maximum slices on all GPUs per pass: %d', max_slices_per_pass) LOG.debug('Number of passes with full workload: %d', num_slices // max_slices_per_pass) sorted_indices = np.argsort(slices_per_device)[-np.count_nonzero(slices_per_device):] runs = [] z_start = z_region[0] to_process = num_slices # Create passes where all GPUs are fully loaded for j in range(num_full_passes): runs.append([]) for i in sorted_indices: z_start, z_end, to_process = _add_region(runs, i, slices_per_device[i], to_process, z_start, z_step) if to_process: if data_splitting_policy == 'one': # Fill the last pass by maximizing the workload per GPU runs.append([]) for i in sorted_indices[::-1]: if not to_process: break current = min(slices_per_device[i], to_process) z_start, z_end, to_process = _add_region(runs, i, current, to_process, z_start, z_step) else: # Fill the last pass by maximizing the number of GPUs which will work num_gpus = len(sorted_indices) runs.append([]) for j, i in enumerate(sorted_indices): # Current GPU will either process the maximum number of slices it can. If the number # of slices per GPU based on even division between them cannot saturate the GPU, use # this number. This way the work will be split evenly between the GPUs. current = max(min(slices_per_device[i], (to_process - 1) // (num_gpus - j) + 1), 1) z_start, z_end, to_process = _add_region(runs, i, current, to_process, z_start, z_step) if not to_process: break return runs def get_num_slices_per_gpu(gpus, width, height, bpp, slice_memory_coeff=0.8): num_slices = [] slice_size = width * height * bpp for i, gpu in enumerate(gpus): max_mem = gpu.get_info(Ufo.GpuNodeInfo.GLOBAL_MEM_SIZE) num_slices.append(int(np.floor(max_mem * slice_memory_coeff / slice_size))) return num_slices def _run(resources, args, x_region, y_region, regions, run_number, vol_nbytes): """Execute one pass on all possible GPUs with slice ranges given by *regions*. Use separate thread per GPU and optimize the read projection regions. """ executors = [] writer = None last = None if is_output_single_file(args): import tifffile bigtiff = vol_nbytes > 2 ** 32 - 2 ** 25 LOG.debug('Writing BigTiff: %s', bigtiff) dirname = os.path.dirname(args.output) if dirname and not os.path.exists(dirname): os.makedirs(dirname) writer = tifffile.TiffWriter(args.output, append=run_number != 0, bigtiff=bigtiff) for index in range(len(regions)): gpu_index, region = regions[index] region_index = run_number * len(resources) + index executors.append( Executor( resources[index], args, region, x_region, y_region, gpu_index, region_index, writer=writer ) ) if last: # Chain up waiting events of subsequent executors executors[-1].wait_event = last.finished last = executors[-1] def start_one(index): return executors[index].process() st = time.time() try: with ThreadPool(processes=len(regions)) as pool: try: pool.map(start_one, list(range(len(regions)))) except KeyboardInterrupt: LOG.info('Processing interrupted') for executor in executors: executor.abort() finally: if writer: writer.close() LOG.debug('Writer closed') return time.time() - st def setup_graph(args, graph, x_region, y_region, region, source=None, gpu=None, do_output=True, index=0, make_reader=True): backproject = get_task('general-backproject', processing_node=gpu) if do_output: if args.dry_run: sink = get_task('null', processing_node=gpu, download=True) else: sink = get_writer(args) sink.props.filename = '{}-{:>03}-%04i.tif'.format(args.output, index) width = args.width height = args.height if args.transpose_input: tmp = width width = height height = tmp if args.projection_filter != 'none' and args.projection_crop_after == 'backprojection': # Take projection padding into account if fbp_filtering_in_phase_retrieval(args): padding = args.retrieval_padded_width - width padding_from = 'phase retrieval' else: padding = get_filtering_padding(width) padding_from = 'default backproject' args.center_position_x = [pos + padding / 2 for pos in args.center_position_x] if args.z_parameter == 'center-position-x': region = [region[0] + padding / 2, region[1] + padding / 2, region[2]] LOG.debug('center-position-x after padding: %g (from %s)', args.center_position_x[0], padding_from) backproject.props.parameter = args.z_parameter if args.burst: backproject.props.burst = args.burst backproject.props.z = args.z backproject.props.region = region backproject.props.x_region = x_region backproject.props.y_region = y_region backproject.props.center_position_x = args.center_position_x backproject.props.center_position_z = args.center_position_z backproject.props.source_position_x = args.source_position_x backproject.props.source_position_y = args.source_position_y backproject.props.source_position_z = args.source_position_z backproject.props.detector_position_x = args.detector_position_x backproject.props.detector_position_y = args.detector_position_y backproject.props.detector_position_z = args.detector_position_z backproject.props.detector_angle_x = args.detector_angle_x backproject.props.detector_angle_y = args.detector_angle_y backproject.props.detector_angle_z = args.detector_angle_z backproject.props.axis_angle_x = args.axis_angle_x backproject.props.axis_angle_y = args.axis_angle_y backproject.props.axis_angle_z = args.axis_angle_z backproject.props.volume_angle_x = args.volume_angle_x backproject.props.volume_angle_y = args.volume_angle_y backproject.props.volume_angle_z = args.volume_angle_z backproject.props.num_projections = args.number backproject.props.compute_type = args.compute_type backproject.props.result_type = args.result_type backproject.props.store_type = args.store_type backproject.props.overall_angle = args.overall_angle backproject.props.addressing_mode = args.genreco_padding_mode backproject.props.gray_map_min = args.slice_gray_map[0] backproject.props.gray_map_max = args.slice_gray_map[1] source = create_preprocessing_pipeline(args, graph, source=source, processing_node=gpu, cone_beam_weight=not args.disable_cone_beam_weight, make_reader=make_reader) if source: graph.connect_nodes(source, backproject) else: source = backproject if do_output: graph.connect_nodes(backproject, sink) last = sink else: last = backproject return (source, last) def is_output_single_file(args): filename = args.output.lower() return not args.dry_run and (filename.endswith('.tif') or filename.endswith('.tiff')) def set_projection_filter_scale(args): is_parallel = np.all(np.isinf(args.source_position_y)) magnification = (args.source_position_y[0] - args.detector_position_y[0]) / \ args.source_position_y[0] args.projection_filter_scale = 1. if is_parallel: if np.any(np.array(args.axis_angle_x)): LOG.debug('Adjusting filter for parallel beam laminography') args.projection_filter_scale = 0.5 * np.cos(args.axis_angle_x[0]) else: args.projection_filter_scale = 0.5 args.projection_filter_scale /= magnification ** 2 if np.all(np.array(args.axis_angle_x) == 0): LOG.debug('Adjusting filter for cone beam tomography') args.projection_filter_scale /= magnification def _fill_missing_args(args): (width, height) = determine_shape(args, args.projections, store=False) if args.transpose_input: tmp = width width = height height = tmp args.center_position_x = (args.center_position_x or [width / 2.]) args.center_position_z = (args.center_position_z or [height / 2.]) if not args.overall_angle: args.overall_angle = 360. LOG.info('Overall angle not specified, using 360 deg') if not args.number: if len(args.axis_angle_z) > 1: LOG.debug("--number not specified, using length of --axis-angle-z: %d", len(args.axis_angle_z)) args.number = len(args.axis_angle_z) else: num_files = len(get_filenames(args.projections)) if not num_files: raise RuntimeError("No files found in `{}'".format(args.projections)) LOG.debug("--number not specified, using number of files matching " "--projections pattern: %d", num_files) args.number = num_files if args.dry_run: if not args.number: raise ValueError('--number must be specified by --dry-run') determine_shape(args, args.projections, store=True) LOG.info('Dummy data W x H x N: {} x {} x {}'.format(args.width, args.height, args.number)) return args def _convert_angles_to_rad(args): names = ['detector_angle', 'axis_angle', 'volume_angle'] coords = ['x', 'y', 'z'] angular_z_params = [x[0].replace('_', '-') + '-' + x[1] for x in itertools.product(names, coords)] args.overall_angle = np.deg2rad(args.overall_angle) if args.z_parameter in angular_z_params: LOG.debug('Converting z parameter values to radians') args.region = _convert_list_to_rad(args.region) for name in names: for coord in coords: full_name = name + '_' + coord values = getattr(args, full_name) setattr(args, full_name, _convert_list_to_rad(values)) def _convert_list_to_rad(values): return np.deg2rad(np.array(values)).tolist() def _are_values_equal(values): return np.all(np.array(values) == values[0]) class Executor(object): """Reconstructs one region. :param writer: if not None, we'll be writing to a file shared with other executors and need to use *wait_event* to make sure we write our region when the previous executors are finished. """ def __init__(self, resources, args, region, x_region, y_region, gpu_index, region_index, writer=None): self.resources = resources self.args = args self.region = region self.gpu_index = gpu_index self.x_region = x_region self.y_region = y_region self.region_index = region_index self.writer = writer self.output = Ufo.OutputTask() if self.writer else None self.scheduler = None self.wait_event = None self.finished = Event() self.abort_requested = False def process(self): self.scheduler = Ufo.FixedScheduler() if hasattr(self.scheduler.props, 'enable_tracing'): LOG.debug("Use tracing: {}".format(self.args.enable_tracing)) self.scheduler.props.enable_tracing = self.args.enable_tracing self.scheduler.set_resources(self.resources) graph = Ufo.TaskGraph() gpu = self.scheduler.get_resources().get_gpu_nodes()[self.gpu_index] geometry = CTGeometry(self.args) if (len(self.args.center_position_z) == 1 and np.modf(self.args.center_position_z[0])[0] == 0 and geometry.is_simple_parallel_tomo): LOG.info('Simple tomography with integer z center, changing to center_position_z + 0.5 ' 'to avoid interpolation') geometry.args.center_position_z = (geometry.args.center_position_z[0] + 0.5,) if not self.args.disable_projection_crop: if not self.args.dry_run and (self.args.y or self.args.height or self.args.transpose_input): LOG.debug('--y or --height or --transpose-input specified, ' 'not optimizing projection region') else: geometry.optimize_args(region=self.region) opt_args = geometry.args if self.args.dry_run: source = get_task('dummy-data', number=self.args.number, width=self.args.width, height=self.args.height) else: source = None last = setup_graph(opt_args, graph, self.x_region, self.y_region, self.region, source=source, gpu=gpu, index=self.region_index, make_reader=True, do_output=self.writer is None)[-1] if self.writer: graph.connect_nodes(last, self.output) LOG.debug('Device: %d, region: %s', self.gpu_index, self.region) thread = Thread(target=self.scheduler.run, args=(graph,)) thread.setDaemon(True) thread.start() if self.writer: self.consume() thread.join() return self.scheduler.props.time def consume(self): import ufo.numpy if self.wait_event: LOG.debug('Executor of region %s waiting for writing', self.region) self.wait_event.wait() for i in np.arange(*self.region): if self.abort_requested: LOG.debug('Abort requested in writing of region %s', self.region) return buf = self.output.get_output_buffer() self.writer.save(ufo.numpy.asarray(buf)) self.output.release_output_buffer(buf) self.finished.set() LOG.debug('Executor of region %s finished writing', self.region) def abort(self): self.abort_requested = True if self.scheduler: self.scheduler.abort() class CTGeometry(object): def __init__(self, args): self.args = copy.deepcopy(args) determine_shape(self.args, self.args.projections, store=True) get_reconstruction_regions(self.args, store=True, dtype=float) self.args.center_position_x = (self.args.center_position_x or [self.args.width / 2.]) self.args.center_position_z = (self.args.center_position_z or [self.args.height / 2.]) @property def is_parallel(self): return np.all(np.isinf(self.args.source_position_y)) @property def is_detector_rotated(self): return (np.any(self.args.detector_angle_x) or np.any(self.args.detector_angle_y) or np.any(self.args.detector_angle_z)) @property def is_axis_rotated(self): return (np.any(self.args.axis_angle_x) or np.any(self.args.axis_angle_y) or np.any(self.args.axis_angle_z)) @property def is_volume_rotated(self): return (np.any(self.args.volume_angle_x) or np.any(self.args.volume_angle_y) or np.any(self.args.volume_angle_z)) @property def is_center_position_x_constant(self): return _are_values_equal(self.args.center_position_x) @property def is_center_position_z_constant(self): return _are_values_equal(self.args.center_position_z) @property def is_center_constant(self): return self.is_center_position_x_constant and self.is_center_position_z_constant @property def is_simple_parallel_tomo(self): return (not (self.is_axis_rotated or self.is_detector_rotated or self.is_volume_rotated) and self.is_parallel and self.is_center_constant) def optimize_args(self, region=None): xmin, ymin, xmax, ymax = self.compute_height(region=region) center_position_z = np.array(self.args.center_position_z) - ymin self.args.center_position_z = center_position_z.tolist() self.args.y = int(ymin) self.args.height = int(ymax - ymin) LOG.debug('Optimization for region: %s', region or self.args.region) LOG.debug('Optimized X: %d - %d, Z: %d - %d', xmin, xmax, ymin, ymax) LOG.debug('Optimized Z: %d', self.args.y) LOG.debug('Optimized height: %d', self.args.height) LOG.debug('Optimized center_position_z: %g - %g', self.args.center_position_z[0], self.args.center_position_z[-1]) def compute_height(self, region=None): extrema = [] if not region: region = self.args.region if self.is_simple_parallel_tomo: # Simple parallel beam tomography, thus compute only the horizontal crop at rotations # which are multiples of 45 degrees LOG.debug('Computing optimal projection region from 4 angles') projs_per_45 = self.args.number / self.args.overall_angle * np.pi / 4 stop = 4 if self.args.overall_angle <= np.pi else 8 indices = projs_per_45 * np.arange(1, stop, 2) indices = np.round(indices).astype(int).tolist() else: LOG.debug('Computing optimal projection region from all angles') indices = list(range(self.args.number)) for i in indices: extrema_0 = self._compute_one_parameter(region[0], i) extrema_1 = self._compute_one_parameter(region[1], i) extrema.append(extrema_0) extrema.append(extrema_1) minima = np.min(extrema, axis=0) maxima = np.max(extrema, axis=0) if maxima[-1] == minima[2]: # Don't let height be 0 maxima[-1] += 1 result = tuple(minima[::2]) + tuple(maxima[1::2]) return result def _compute_one_parameter(self, param_value, index): source_position = np.array([get_scarray_value(self.args.source_position_x, index), get_scarray_value(self.args.source_position_y, index), get_scarray_value(self.args.source_position_z, index)]) axis = Vector(x_angle=get_scarray_value(self.args.axis_angle_x, index), y_angle=get_scarray_value(self.args.axis_angle_y, index), z_angle=get_scarray_value(self.args.axis_angle_z, index), position=[get_scarray_value(self.args.center_position_x, index), 0, get_scarray_value(self.args.center_position_z, index)]) detector = Vector(x_angle=get_scarray_value(self.args.detector_angle_x, index), y_angle=get_scarray_value(self.args.detector_angle_y, index), z_angle=get_scarray_value(self.args.detector_angle_z, index), position=[get_scarray_value(self.args.detector_position_x, index), get_scarray_value(self.args.detector_position_y, index), get_scarray_value(self.args.detector_position_z, index)]) volume_angle = Vector(x_angle=get_scarray_value(self.args.volume_angle_x, index), y_angle=get_scarray_value(self.args.volume_angle_y, index), z_angle=get_scarray_value(self.args.volume_angle_z, index)) z = self.args.z if self.args.z_parameter == 'z': z = param_value elif self.args.z_parameter == 'axis-angle-x': axis.x_angle = param_value elif self.args.z_parameter == 'axis-angle-y': axis.y_angle = param_value elif self.args.z_parameter == 'axis-angle-z': axis.z_angle = param_value elif self.args.z_parameter == 'volume-angle-x': volume_angle.x_angle = param_value elif self.args.z_parameter == 'volume-angle-y': volume_angle.y_angle = param_value elif self.args.z_parameter == 'volume-angle-z': volume_angle.z_angle = param_value elif self.args.z_parameter == 'detector-angle-x': detector.x_angle = param_value elif self.args.z_parameter == 'detector-angle-y': detector.y_angle = param_value elif self.args.z_parameter == 'detector-angle-z': detector.z_angle = param_value elif self.args.z_parameter == 'detector-position-x': detector.position[0] = param_value elif self.args.z_parameter == 'detector-position-y': detector.position[1] = param_value elif self.args.z_parameter == 'detector-position-z': detector.position[2] = param_value elif self.args.z_parameter == 'source-position-x': source_position[0] = param_value elif self.args.z_parameter == 'source-position-y': source_position[1] = param_value elif self.args.z_parameter == 'source-position-z': source_position[2] = param_value elif self.args.z_parameter == 'center-position-x': axis.position[0] = param_value elif self.args.z_parameter == 'center-position-z': axis.position[2] = param_value else: raise RuntimeError("Unknown z parameter '{}'".format(self.args.z_parameter)) points = get_extrema(self.args.x_region, self.args.y_region, z) if self.args.z_parameter != 'z': points_upper = get_extrema(self.args.x_region, self.args.y_region, z + 1) points = np.hstack((points, points_upper)) tomo_angle = float(index) / self.args.number * self.args.overall_angle xe, ye = compute_detector_pixels(points, source_position, axis, volume_angle, detector, tomo_angle) return compute_detector_region(xe, ye, (self.args.height, self.args.width), overhead=self.args.projection_margin) def project(points, source, detector_normal, detector_offset): """Project *points* onto a detector.""" x, y, z = points source_extended = np.tile(source[:, np.newaxis], [1, points.shape[1]]) detector_normal_extended = np.tile(detector_normal[:, np.newaxis], [1, points.shape[1]]) denom = np.sum((points - source_extended) * detector_normal_extended, axis=0) if np.isinf(source[1]): # Parallel beam if np.any(detector_normal != np.array([0., -1, 0])): # Detector is not perpendicular, compute translation along the beam direction, # otherwise don't compute anything because voxels are mapped directly # to detector coordinates points[1, :] = - (detector_offset + detector_normal[0] * points[0, :] + detector_normal[2] * points[2, :]) / detector_normal[1] projected = points else: # Cone beam u = -(detector_offset + np.dot(source, detector_normal)) / denom u = np.tile(u, [3, 1]) projected = source_extended + (points - source_extended) * u return projected def compute_detector_pixels(points, source_position, axis, volume_rotation, detector, tomo_angle): """*points* are a list of points along x-direcion, thus the array has height 3. *source_position* is a 3-vector, *axis*, *volume_rotation* and *detector* are util.Vector instances. """ # Rotate the axis detector_normal = np.array((0, -1, 0), dtype=float) detector_normal = rotate_z(detector.z_angle, detector_normal) detector_normal = rotate_y(detector.y_angle, detector_normal) detector_normal = rotate_x(detector.x_angle, detector_normal) # Compute d from ax + by + cz + d = 0 detector_offset = -np.dot(detector.position, detector_normal) if np.isinf(source_position[1]): # Parallel beam voxels = points else: # Apply magnification voxels = -points * source_position[1] / (detector.position[1] - source_position[1]) # Rotate the volume voxels = rotate_z(volume_rotation.z_angle, voxels) voxels = rotate_y(volume_rotation.y_angle, voxels) voxels = rotate_x(volume_rotation.x_angle, voxels) # Rotate around the axis voxels = rotate_z(tomo_angle, voxels) # Rotate the volume voxels = rotate_z(axis.z_angle, voxels) voxels = rotate_y(axis.y_angle, voxels) voxels = rotate_x(axis.x_angle, voxels) # Get the projected pixel projected = project(voxels, source_position, detector_normal, detector_offset) if np.any(detector_normal != np.array([0., -1, 0])): # Detector is not perpendicular projected -= np.array([detector.position]).T # Reverse rotation => reverse order of transformation matrices and negative angles projected = rotate_x(-detector.x_angle, projected) projected = rotate_y(-detector.y_angle, projected) projected = rotate_z(-detector.z_angle, projected) x = projected[0, :] + axis.position[0] - 0.5 y = projected[2, :] + axis.position[2] - 0.5 return x, y def compute_detector_region(x, y, shape, overhead=2): """*overhead* specifies how much margin is taken into account around the computed area.""" def _compute_outlier(extremum_func, values): if extremum_func == min: round_func = np.floor sgn = -1 else: round_func = np.ceil sgn = +1 return int(round_func(extremum_func(values)) + sgn * overhead) x_min = min(shape[1], max(0, _compute_outlier(min, x))) y_min = min(shape[0], max(0, _compute_outlier(min, y))) x_max = max(0, min(shape[1], _compute_outlier(max, x))) y_max = max(0, min(shape[0], _compute_outlier(max, y))) return (x_min, x_max, y_min, y_max) def get_extrema(x_region, y_region, z): def get_extrema(region): return (region[0], region[1]) product = itertools.product(get_extrema(x_region), get_extrema(y_region), [z]) return np.array(list(product), dtype=float).T.copy() def rotate_x(angle, point): cos = np.cos(angle) sin = np.sin(angle) matrix = np.identity(3) matrix[1, 1] = cos matrix[1, 2] = -sin matrix[2, 1] = sin matrix[2, 2] = cos return np.dot(matrix, point) def rotate_y(angle, point): cos = np.cos(angle) sin = np.sin(angle) matrix = np.identity(3) matrix[0, 0] = cos matrix[0, 2] = sin matrix[2, 0] = -sin matrix[2, 2] = cos return np.dot(matrix, point) def rotate_z(angle, point): cos = np.cos(angle) sin = np.sin(angle) matrix = np.identity(3) matrix[0, 0] = cos matrix[0, 1] = -sin matrix[1, 0] = sin matrix[1, 1] = cos return np.dot(matrix, point) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414563.0 ufo-tofu-0.13.0/tofu/gui.py0000664000175000017500000005535500000000000015762 0ustar00tomastomas00000000000000import sys import os import logging import numpy as np import tifffile import pkg_resources from argparse import ArgumentParser from contextlib import contextmanager from tofu import reco, config, util, __version__ try: import tofu.vis.qt from PyQt5 import QtGui, QtCore, uic, QtWidgets except ImportError: raise ImportError("Cannot import modules for GUI, please install PyQt4 and pyqtgraph") LOG = logging.getLogger(__name__) def set_last_dir(path, line_edit, last_dir): if os.path.exists(str(path)): line_edit.clear() line_edit.setText(path) last_dir = str(line_edit.text()) return last_dir def get_filtered_filenames(path, exts=['.tif', '.edf']): result = [] try: for ext in exts: result += [os.path.join(path, f) for f in os.listdir(path) if f.endswith(ext)] except OSError: return [] return sorted(result) @contextmanager def spinning_cursor(): QtWidgets.QApplication.setOverrideCursor(QtWidgets.QCursor(QtCore.Qt.WaitCursor)) yield QtWidgets.QApplication.restoreOverrideCursor() class CallableHandler(logging.Handler): def __init__(self, func): logging.Handler.__init__(self) self.func = func def emit(self, record): self.func(self.format(record)) class ApplicationWindow(QtWidgets.QMainWindow): def __init__(self, app, params): QtWidgets.QMainWindow.__init__(self) self.params = params self.app = app ui_file = pkg_resources.resource_filename(__name__, 'gui.ui') self.ui = uic.loadUi(ui_file, self) self.ui.show() self.ui.setAttribute(QtCore.Qt.WA_DeleteOnClose) self.ui.tab_widget.setCurrentIndex(0) self.ui.slice_dock.setVisible(False) self.ui.volume_dock.setVisible(False) self.ui.axis_view_widget.setVisible(False) self.slice_viewer = None self.volume_viewer = None self.overlap_viewer = tofu.vis.qt.OverlapViewer() self.get_values_from_params() try: import pyqtgraph.opengl as gl except ImportError: LOG.info("OpenGL not available, volume viewer disabled") self.ui.show_volume_button.setEnabled(False) log_handler = CallableHandler(self.on_log_record) log_handler.setLevel(logging.DEBUG) log_handler.setFormatter(logging.Formatter('%(name)s: %(message)s')) root_logger = logging.getLogger('') root_logger.setLevel(logging.DEBUG) root_logger.handlers = [log_handler] self.ui.input_path_button.setToolTip('Path to projections or sinograms') self.ui.proj_button.setToolTip('Denote if path contains projections') self.ui.y_step.setToolTip(self.get_help('reading', 'y-step')) self.ui.method_box.setToolTip(self.get_help('tomographic-reconstruction', 'method')) self.ui.axis_spin.setToolTip(self.get_help('tomographic-reconstruction', 'axis')) self.ui.angle_step.setToolTip(self.get_help('reconstruction', 'angle')) self.ui.angle_offset.setToolTip(self.get_help('tomographic-reconstruction', 'offset')) self.ui.oversampling.setToolTip(self.get_help('dfi', 'oversampling')) self.ui.iterations_sart.setToolTip(self.get_help('ir', 'num-iterations')) self.ui.relaxation.setToolTip(self.get_help('sart', 'relaxation-factor')) self.ui.output_path_button.setToolTip(self.get_help('general', 'output')) self.ui.ffc_box.setToolTip(self.get_help('gui', 'ffc-correction')) self.ui.interpolate_button.setToolTip('Interpolate between two sets of flat fields') self.ui.darks_path_button.setToolTip(self.get_help('flat-correction', 'darks')) self.ui.flats_path_button.setToolTip(self.get_help('flat-correction', 'flats')) self.ui.flats2_path_button.setToolTip(self.get_help('flat-correction', 'flats2')) self.ui.path_button_0.setToolTip(self.get_help('gui', 'deg0')) self.ui.path_button_180.setToolTip(self.get_help('gui', 'deg180')) self.ui.input_path_button.clicked.connect(self.on_input_path_clicked) self.ui.sino_button.clicked.connect(self.on_sino_button_clicked) self.ui.proj_button.clicked.connect(self.on_proj_button_clicked) self.ui.region_box.clicked.connect(self.on_region_box_clicked) self.ui.method_box.currentIndexChanged.connect(self.change_method) self.ui.axis_spin.valueChanged.connect(self.change_axis_spin) self.ui.angle_step.valueChanged.connect(self.change_angle_step) self.ui.output_path_button.clicked.connect(self.on_output_path_clicked) self.ui.ffc_box.clicked.connect(self.on_ffc_box_clicked) self.ui.interpolate_button.clicked.connect(self.on_interpolate_button_clicked) self.ui.darks_path_button.clicked.connect(self.on_darks_path_clicked) self.ui.flats_path_button.clicked.connect(self.on_flats_path_clicked) self.ui.flats2_path_button.clicked.connect(self.on_flats2_path_clicked) self.ui.ffc_options.currentIndexChanged.connect(self.change_ffc_options) self.ui.reco_button.clicked.connect(self.on_reconstruct) self.ui.path_button_0.clicked.connect(self.on_path_0_clicked) self.ui.path_button_180.clicked.connect(self.on_path_180_clicked) self.ui.show_slices_button.clicked.connect(self.on_show_slices_clicked) self.ui.show_volume_button.clicked.connect(self.on_show_volume_clicked) self.ui.run_button.clicked.connect(self.on_compute_center) self.ui.save_action.triggered.connect(self.on_save_as) self.ui.clear_action.triggered.connect(self.on_clear) self.ui.clear_output_dir_action.triggered.connect(self.on_clear_output_dir_clicked) self.ui.open_action.triggered.connect(self.on_open_from) self.ui.close_action.triggered.connect(self.close) self.ui.about_action.triggered.connect(self.on_about) self.ui.extrema_checkbox.clicked.connect(self.on_remove_extrema_clicked) self.ui.overlap_opt.currentIndexChanged.connect(self.on_overlap_opt_changed) self.ui.input_path_line.textChanged.connect(self.on_input_path_changed) self.ui.y_step.valueChanged.connect(lambda value: self.change_value('y_step', value)) self.ui.angle_offset.valueChanged.connect(lambda value: self.change_value('offset', value)) self.ui.oversampling.valueChanged.connect(lambda value: self.change_value('oversampling', value)) self.ui.iterations_sart.valueChanged.connect(lambda value: self.change_value('num_iterations', value)) self.ui.relaxation.valueChanged.connect(lambda value: self.change_value('relaxation_factor', value)) self.ui.output_path_line.textChanged.connect(lambda value: self.change_value('output', str(self.ui.output_path_line.text()))) self.ui.darks_path_line.textChanged.connect(lambda value: self.change_value('darks', str(self.ui.darks_path_line.text()))) self.ui.flats_path_line.textChanged.connect(lambda value: self.change_value('flats', str(self.ui.flats_path_line.text()))) self.ui.flats2_path_line.textChanged.connect(lambda value: self.change_value('flats2', str(self.ui.flats2_path_line.text()))) self.ui.fix_naninf_box.clicked.connect(lambda value: self.change_value('fix_nan_and_inf', self.ui.fix_naninf_box.isChecked())) self.ui.absorptivity_box.clicked.connect(lambda value: self.change_value('absorptivity', self.ui.absorptivity_box.isChecked())) self.ui.path_line_0.textChanged.connect(lambda value: self.change_value('deg0', str(self.ui.path_line_0.text()))) self.ui.path_line_180.textChanged.connect(lambda value: self.change_value('deg180', str(self.ui.path_line_180.text()))) self.ui.overlap_layout.addWidget(self.overlap_viewer) self.overlap_viewer.slider.valueChanged.connect(self.on_axis_slider_changed) def on_log_record(self, record): self.ui.text_browser.append(record) def get_values_from_params(self): self.ui.input_path_line.setText(self.params.sinograms or self.params.projections or '.') self.ui.output_path_line.setText(self.params.output or '') self.ui.darks_path_line.setText(self.params.darks or '') self.ui.flats_path_line.setText(self.params.flats or '') self.ui.flats2_path_line.setText(self.params.flats2 or '') self.ui.path_line_0.setText(self.params.deg0) self.ui.path_line_180.setText(self.params.deg180) self.ui.y_step.setValue(self.params.y_step if self.params.y_step else 1) self.ui.axis_spin.setValue(self.params.axis if self.params.axis else 0.0) self.ui.angle_step.setValue(self.params.angle if self.params.angle else 0.0) self.ui.angle_offset.setValue(self.params.offset if self.params.offset else 0.0) self.ui.oversampling.setValue(self.params.oversampling if self.params.oversampling else 0) self.ui.iterations_sart.setValue(self.params.num_iterations if self.params.num_iterations else 0) self.ui.relaxation.setValue(self.params.relaxation_factor if self.params.relaxation_factor else 0.0) if self.params.projections is not None: self.ui.proj_button.setChecked(True) self.ui.sino_button.setChecked(False) self.on_proj_button_clicked() else: self.ui.proj_button.setChecked(False) self.ui.sino_button.setChecked(True) self.on_sino_button_clicked() if self.params.method == "fbp": self.ui.method_box.setCurrentIndex(0) elif self.params.method == "dfi": self.ui.method_box.setCurrentIndex(1) elif self.params.method == "sart": self.ui.method_box.setCurrentIndex(2) self.change_method() if self.params.y_step > 1 and self.sino_button.isChecked(): self.ui.region_box.setChecked(True) else: self.ui.region_box.setChecked(False) self.ui.on_region_box_clicked() ffc_enabled = bool(self.params.flats) and bool(self.params.darks) and self.proj_button.isChecked() self.ui.ffc_box.setChecked(ffc_enabled) self.ui.preprocessing_container.setVisible(ffc_enabled) self.ui.interpolate_button.setChecked(bool(self.params.flats2) and ffc_enabled) self.ui.fix_naninf_box.setChecked(self.params.fix_nan_and_inf) self.ui.absorptivity_box.setChecked(self.params.absorptivity) if self.params.reduction_mode.lower() == "average": self.ui.ffc_options.setCurrentIndex(0) else: self.ui.ffc_options.setCurrentIndex(1) def change_method(self): self.params.method = str(self.ui.method_box.currentText()).lower() is_dfi = self.params.method == 'dfi' is_sart = self.params.method == 'sart' for w in (self.ui.oversampling_label, self.ui.oversampling): w.setVisible(is_dfi) for w in (self.ui.relaxation, self.ui.relaxation_label, self.ui.iterations_sart, self.ui.iterations_sart_label): w.setVisible(is_sart) def get_help(self, section, name): help = config.SECTIONS[section][name]['help'] return help def change_value(self, name, value): setattr(self.params, name, value) def on_sino_button_clicked(self): self.on_input_path_changed() self.ui.ffc_box.setEnabled(False) self.ui.preprocessing_container.setVisible(False) def on_proj_button_clicked(self): self.on_input_path_changed() self.ui.ffc_box.setEnabled(True) self.ui.preprocessing_container.setVisible(self.ffc_box.isChecked()) self.ui.region_box.setEnabled(False) self.ui.region_box.setChecked(False) self.on_region_box_clicked() def on_region_box_clicked(self): self.ui.y_step.setEnabled(self.ui.region_box.isChecked()) if self.ui.region_box.isChecked(): self.params.y_step = self.ui.y_step.value() else: self.params.y_step = 1 def on_input_path_changed(self): if self.ui.sino_button.isChecked(): self.params.sinograms = str(self.ui.input_path_line.text()) self.params.projections = None else: self.params.sinograms = None self.params.projections = str(self.ui.input_path_line.text()) def on_input_path_clicked(self, checked): directory = self.params.projections or self.params.sinograms path = self.get_path(directory, self.params.last_dir) self.params.last_dir = set_last_dir(path, self.ui.input_path_line, self.params.last_dir) def change_axis_spin(self): if self.ui.axis_spin.value() == 0: self.params.axis = None else: self.params.axis = self.ui.axis_spin.value() def change_angle_step(self): if self.ui.angle_step.value() == 0: self.params.angle = None else: self.params.angle = self.ui.angle_step.value() def on_output_path_clicked(self, checked): path = self.get_path(self.params.output, self.params.last_dir) self.params.last_dir = set_last_dir(path, self.ui.output_path_line, self.params.last_dir) def on_clear_output_dir_clicked(self): with spinning_cursor(): output_absfiles = get_filtered_filenames(str(self.ui.output_path_line.text())) for f in output_absfiles: os.remove(f) def on_ffc_box_clicked(self): checked = self.ui.ffc_box.isChecked() self.ui.preprocessing_container.setVisible(checked) self.params.ffc_correction = checked def on_interpolate_button_clicked(self): checked = self.ui.interpolate_button.isChecked() self.ui.flats2_path_line.setEnabled(checked) self.ui.flats2_path_button.setEnabled(checked) def change_ffc_options(self): self.params.reduction_mode = str(self.ui.ffc_options.currentText()).lower() def on_darks_path_clicked(self, checked): path = self.get_path(self.params.darks, self.params.last_dir) self.params.last_dir = set_last_dir(path, self.ui.darks_path_line, self.params.last_dir) def on_flats_path_clicked(self, checked): path = self.get_path(self.params.flats, self.params.last_dir) self.params.last_dir = set_last_dir(path, self.ui.flats_path_line, self.params.last_dir) def on_flats2_path_clicked(self, checked): path = self.get_path(self.params.flats2, self.params.last_dir) self.params.last_dir = set_last_dir(path, self.ui.flats2_path_line, self.params.last_dir) def get_path(self, directory, last_dir): return QtWidgets.QFileDialog.getExistingDirectory(self, '.', last_dir or directory) def get_filename(self, directory, last_dir): # Thanks to Lisa D. for pointing out that a tuple is returned in PyQT5 filename, _ = QtWidgets.QFileDialog.getOpenFileName(self, '.', last_dir or directory) return filename def on_path_0_clicked(self, checked): path = self.get_filename(self.params.deg0, self.params.last_dir) self.params.last_dir = set_last_dir(path, self.ui.path_line_0, self.params.last_dir) def on_path_180_clicked(self, checked): path = self.get_filename(self.params.deg180, self.params.last_dir) self.params.last_dir = set_last_dir(path, self.ui.path_line_180, self.params.last_dir) def on_open_from(self): config_file, _ = QtWidgets.QFileDialog.getOpenFileName(self, 'Open ...', self.params.last_dir) parser = ArgumentParser() params = config.Params(sections=config.TOMO_PARAMS + ('gui',)) parser = params.add_arguments(parser) self.params = parser.parse_known_args(config.config_to_list(config_name=config_file))[0] self.get_values_from_params() def on_about(self): message = "GUI is part of ufo-reconstruct {}.".format(__version__) QtWidgets.QMessageBox.about(self, "About ufo-reconstruct", message) def on_save_as(self): if os.path.exists(self.params.last_dir): config_file = str(self.params.last_dir + "/reco.conf") else: config_file = str(os.getenv('HOME') + "reco.conf") save_config = QtWidgets.QFileDialog.getSaveFileName(self, 'Save as ...', config_file) if save_config: sections = config.TOMO_PARAMS + ('gui',) config.write(save_config, args=self.params, sections=sections) def on_clear(self): self.ui.axis_view_widget.setVisible(False) self.ui.input_path_line.setText('.') self.ui.output_path_line.setText('.') self.ui.darks_path_line.setText('.') self.ui.flats_path_line.setText('.') self.ui.flats2_path_line.setText('.') self.ui.path_line_0.setText('.') self.ui.path_line_180.setText('.') self.ui.fix_naninf_box.setChecked(True) self.ui.absorptivity_box.setChecked(True) self.ui.sino_button.setChecked(True) self.ui.proj_button.setChecked(False) self.ui.region_box.setChecked(False) self.ui.ffc_box.setChecked(False) self.ui.interpolate_button.setChecked(False) self.ui.y_step.setValue(1) self.ui.axis_spin.setValue(0) self.ui.angle_step.setValue(0) self.ui.angle_offset.setValue(0) self.ui.oversampling.setValue(0) self.ui.ffc_options.setCurrentIndex(0) self.ui.text_browser.clear() self.ui.method_box.setCurrentIndex(0) self.params.enable_cropping = False self.params.reduction_mode = "average" self.params.fix_nan_and_inf = True self.params.absorptivity = True self.params.show_2d = False self.params.show_3d = False self.params.angle = None self.params.axis = None self.on_region_box_clicked() self.on_ffc_box_clicked() self.on_interpolate_button_clicked() def on_reconstruct(self): with spinning_cursor(): self.ui.centralWidget.setEnabled(False) self.repaint() self.app.processEvents() input_images = get_filtered_filenames(str(self.ui.input_path_line.text())) if not input_images: self.gui_warn("No data found in {}".format(str(self.ui.input_path_line.text()))) self.ui.centralWidget.setEnabled(True) return shape = util.get_image_shape(input_images[0]) self.params.width = shape[-1] self.params.height = shape[-2] self.params.ffc_correction = self.params.ffc_correction and self.ui.proj_button.isChecked() if not (self.params.output.endswith('.tif') or self.params.output.endswith('.tiff')): self.params.output = os.path.join(self.params.output, 'slice-%05i.tif') if self.params.y_step > 1: self.params.angle *= self.params.y_step if self.params.ffc_correction: flats_files = get_filtered_filenames(str(self.ui.flats_path_line.text())) self.params.num_flats = len(flats_files) else: self.params.num_flats = 0 self.params.darks = None self.params.flats = None self.params.flats2 = self.ui.flats2_path_line.text() if self.ui.interpolate_button.isChecked() else '' self.params.oversampling = self.ui.oversampling.value() if self.params.method == 'dfi' else None if self.params.method == 'sart': self.params.max_iterations = self.ui.iterations_sart.value() self.params.relaxation_factor = self.ui.relaxation.value() if self.params.angle is None: self.gui_warn("Missing argument for Angle step (rad)") else: try: reco.tomo(self.params) except Exception as e: self.gui_warn(str(e)) self.ui.centralWidget.setEnabled(True) self.params.angle = self.ui.angle_step.value() def on_show_slices_clicked(self): path = str(self.ui.output_path_line.text()) filenames = get_filtered_filenames(path) if not self.slice_viewer: self.slice_viewer = tofu.vis.qt.ImageViewer(filenames) self.slice_dock.setWidget(self.slice_viewer) self.ui.slice_dock.setVisible(True) else: self.slice_viewer.load_files(filenames) def on_show_volume_clicked(self): if not self.volume_viewer: step = int(self.ui.reduction_box.currentText()) self.volume_viewer = tofu.vis.qt.VolumeViewer(parent=self, step=step) self.volume_dock.setWidget(self.volume_viewer) self.ui.volume_dock.setVisible(True) path = str(self.ui.output_path_line.text()) filenames = get_filtered_filenames(path) self.volume_viewer.load_files(filenames) def on_compute_center(self): first_name = str(self.ui.path_line_0.text()) second_name = str(self.ui.path_line_180.text()) with tifffile.TiffFile(first_name) as tif: first = tif.pages[0].asarray().astype(float) with tifffile.TiffFile(second_name) as tif: second = tif.pages[-1].asarray().astype(float) if self.params.ffc_correction: # FIXME: we should of course use the pipelines we have ... flat_files = get_filtered_filenames(str(self.ui.flats_path_line.text())) dark_files = get_filtered_filenames(str(self.ui.darks_path_line.text())) flats = np.array([tifffile.TiffFile(x).asarray().astype(float) for x in flat_files]) darks = np.array([tifffile.TiffFile(x).asarray().astype(float) for x in dark_files]) dark = np.mean(darks, axis=0) flat = np.mean(flats, axis=0) - dark first = (first - dark) / flat second = (second - dark) / flat self.axis = reco.compute_rotation_axis(first, second) self.height, self.width = first.shape w2 = self.width / 2.0 position = w2 + (w2 - self.axis) * 2.0 self.overlap_viewer.set_images(first, second) self.overlap_viewer.set_position(position) self.ui.img_size.setText('width = {} | height = {}'.format(self.width, self.height)) def on_remove_extrema_clicked(self, val): self.ui.overlap_viewer.remove_extrema = val def on_overlap_opt_changed(self, index): self.ui.overlap_viewer.subtract = index == 0 self.ui.overlap_viewer.update_image() def on_axis_slider_changed(self): val = self.overlap_viewer.slider.value() w2 = self.width / 2.0 self.axis = w2 + (w2 - val) / 2 self.ui.axis_num.setText('{} px'.format(self.axis)) self.ui.axis_spin.setValue(self.axis) def gui_warn(self, message): QtWidgets.QMessageBox.warning(self, "Warning", message) def main(params): app = QtWidgets.QApplication(sys.argv) ApplicationWindow(app, params) sys.exit(app.exec_()) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1665760160.0 ufo-tofu-0.13.0/tofu/gui.ui0000664000175000017500000013777600000000000015757 0ustar00tomastomas00000000000000 mainWindow 0 0 1018 1081 0 0 Tomoviewer true 0 0 541 761 0 0 530 0 PreferDefault true Qt::TabFocus Qt::LeftToRight 1 true 0 0 0 0 Reconstruction Input 0 0 Projections true 0 0 Sinograms true true false 0 0 1 500 0 0 false false Region (y-step): Qt::Horizontal 40 20 0 0 Do flat-field correction 0 0 Path: 0 0 0 0 Browse … 0 0 Flat-field correction 0 0 Average Median Options Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter Use absorptivity Remove NaN and Inf 0 0 Interpolate Qt::Horizontal 40 20 Method: Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter true Darks: Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter true true Browse … true 0 Flats: Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter Last flats: Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter true true Browse … Browse … Reconstruction 6 0 0 Angle step (rad): Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter 0 0 50 false Method: Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter 0 0 50 false FBP DFI SART 0 0 10 Qt::Horizontal 40 20 0 0 Angle offset (rad): Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter 0 0 Reconstruct 0 0 Axis (pixel): Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter 0 0 8192.000000000000000 0 0 10 0 0 Max iterations: Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter 0 0 0 Relaxation factor: 0.000000000000000 Qt::Horizontal 40 20 true 99 0 true 0 0 Oversampling: Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter Output 0 0 Path: 0 0 0 0 Browse ... Reduction: 1 1 2 4 8 Qt::Horizontal 40 20 Show Volume Show Slices 0 0 Log 0 0 QFrame::Sunken 0 0 0 Center of rotation 0 0 Input Options: Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter Method: Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter 0 0 Browse ... 0 0 180° projection: Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter 0 0 Browse ... 0 0 In case of multi-page input, last image in the file is used Remove extrema 0 0 In case of multi-page input, first image in the file is used 0 0 0° projection: Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter 0 0 Subtraction overlap Addition overlap 0 0 Run Output 0 0 Center: Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter 0 0 Size: Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter 0 0 0 0 75 true 0 0 0 0 0 0 0 0 0 1018 22 0 0 0 0 File 0 0 Edit Help QDockWidget::DockWidgetFloatable|QDockWidget::DockWidgetMovable 2 0 0 QDockWidget::DockWidgetFloatable|QDockWidget::DockWidgetMovable 2 Open Save as... Open ... Qt::WindowShortcut Save as ... Quit Ctrl+Q Clear Remove old slices Remove old slices in output directory About ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414563.0 ufo-tofu-0.13.0/tofu/inpaint.py0000664000175000017500000002357100000000000016633 0ustar00tomastomas00000000000000import gi import logging import numpy as np gi.require_version('Ufo', '0.0') from gi.repository import Ufo from tofu.tasks import get_memory_in, get_task, get_writer from tofu.util import ( determine_shape, make_subargs, run_scheduler, set_node_props, setup_read_task, setup_padding, ) LOG = logging.getLogger(__name__) SELECT_SRC = """ kernel void select_simple (global float *image, global float *mask, global float *output) { const size_t idx = get_global_id (1) * get_global_size (0) + get_global_id (0); output[idx] = mask[idx] > 0.0f ? 0.0f : image[idx]; } kernel void select_guidance (global float *image, global float *mask, global float *guidance, global float *output) { const size_t idx = get_global_id (1) * get_global_size (0) + get_global_id (0); output[idx] = mask[idx] > 0.0f ? guidance[idx] : image[idx]; } """ ADD_CONSTANT_SRC = """ kernel void add_constant (global float *image, global float *value, global float *output) { const size_t idx = get_global_id (1) * get_global_size (0) + get_global_id (0); output[idx] = image[idx] + value[0]; } """ def _make_discrete_inverse_laplace(width, height): """Make discrete Laplace deconvolution kernel special for this use case, where we do not care about the (0, 0) frequency becuase the kernel is going to be applied on Laplace-filtered data, which has zero mean. """ f = np.fft.fftfreq(width) g = np.fft.fftfreq(height) f, g = np.meshgrid(f, g) # From discrete Laplace and time shift: F[f''(x, y)] = -4 F[f(x, y)] # + F[f(x + 1, y)] + F[f(x - 1, y)] + F[f(x, y + 1)] + F[f(x, y - 1)] = the result below when we # use the time shift property of the Fourier transform. kernel = 2 * (np.cos(2 * np.pi * f) + np.cos(2 * np.pi * g) - 2) # Make this invertible by simply setting the (0, 0) frequency to 1 instead of making sure that # after the inversion it is 0. We can afford this becuase we know the input to filtering will be # Laplace-filtered -> zero mean -> (0, 0) frequency = 0. kernel[0, 0] = 1 return (1 / kernel).astype(np.float32) def prepare_border_smoothing(padded_width, padded_height): """ The use case here is mainly the removal of the cross at (0, 0) in the power spectrum by masking out the borders of the image, i.e. the gradients are forced to go to zeros at the borders and thus removing the sharp transitions when we consider the periodicity assumed by the DFT. *padded_width* and *padded_height* are the width and height of the FFT-padding, not the original image shape. One should use `mirrored_repeat' padding mode on the input images to get the FFT-padded image, compute the mask here and use it for inpainting. """ mask = np.ones((padded_width, padded_height), dtype=np.float32) mask[1:-1, 1:-1] = 0 mem_in_task = get_memory_in(mask) return mem_in_task def _get_gradient_task(finite_difference_type, direction): return get_task( "gradient", finite_difference_type=finite_difference_type, direction=direction, addressing_mode="repeat", ) def create_inpaint_pipeline( args, graph, processing_node=None ): """ Create tasks needed for inpainting and connect them. The pipeline has three inputs and one output, which is the inpainted image. Based on :cite:`MOREL2012342`. """ determine_shape(args, path=args.projections, store=True, do_raise=True) if not args.inpaint_padded_width: args.inpaint_padded_width = args.width if not args.inpaint_padded_height: args.inpaint_padded_height = args.height do_pad = args.inpaint_padded_width != args.width and args.inpaint_padded_height != args.height use_guidance = not args.harmonize_borders and args.guidance_image LOG.debug("inpaint padding on: %s", do_pad) LOG.debug("inpaint using guidance image: %s", use_guidance) copy_projections = Ufo.CopyTask() copy_mask = Ufo.CopyTask() copy_guidance = Ufo.CopyTask() if use_guidance else None if do_pad: # Padding pad_projections = get_task("pad") pad_mask = get_task("pad") pad_guidance = get_task("pad") for pad_task in (pad_projections, pad_mask, pad_guidance): setup_padding( pad_task, args.width, args.height, args.inpaint_padding_mode, pad_width=args.inpaint_padded_width - args.width, pad_height=args.inpaint_padded_height - args.height, centered=False ) graph.connect_nodes(pad_projections, copy_projections) graph.connect_nodes(pad_mask, copy_mask) if use_guidance: graph.connect_nodes(pad_guidance, copy_guidance) else: pad_guidance = None inputs = (pad_projections, pad_mask, pad_guidance) else: inputs = (copy_projections, copy_mask, copy_guidance) # First gradient is forward and the second backward -> we get exactly the discrete Laplace after # the two passes. gx = _get_gradient_task("forward", "horizontal") gy = _get_gradient_task("forward", "vertical") ggx = _get_gradient_task("backward", "horizontal") ggy = _get_gradient_task("backward", "vertical") fft_task = get_task("fft", dimensions=2) ifft_task = get_task("ifft", dimensions=2) add_ggx_ggy = get_task("opencl", kernel="add", dimensions=2) mul_task = get_task("opencl", kernel="multiply", halve_width=False, dimensions=2) select_kernel = "select_guidance" if use_guidance else "select_simple" select_gx = get_task("opencl", source=SELECT_SRC, kernel=select_kernel, dimensions=2) select_gy = get_task("opencl", source=SELECT_SRC, kernel=select_kernel, dimensions=2) # We are computing discrete gradients -> Laplace must also be discrete lap_kernel = _make_discrete_inverse_laplace( args.inpaint_padded_width, args.inpaint_padded_height ) # Multiply interleaved complex array -> a * z = a * Re[z] + j * a * Im[z] mem_in_task = get_memory_in(lap_kernel + 1j * lap_kernel) if args.preserve_mean: mean_task = get_task("measure", axis=-1, metric="mean") add_constant = get_task( "opencl", source=ADD_CONSTANT_SRC, kernel="add_constant", dimensions=2 ) # First derivative graph.connect_nodes(copy_projections, gx) graph.connect_nodes(copy_projections, gy) # Select guidance or zeros where mask >= 0 graph.connect_nodes_full(gx, select_gx, 0) graph.connect_nodes_full(gy, select_gy, 0) graph.connect_nodes_full(copy_mask, select_gx, 1) graph.connect_nodes_full(copy_mask, select_gy, 1) if use_guidance: guidance_gx = _get_gradient_task("forward", "horizontal") guidance_gy = _get_gradient_task("forward", "vertical") graph.connect_nodes(copy_guidance, guidance_gx) graph.connect_nodes(copy_guidance, guidance_gy) graph.connect_nodes_full(guidance_gx, select_gx, 2) graph.connect_nodes_full(guidance_gy, select_gy, 2) # Second derivative graph.connect_nodes(select_gx, ggx) graph.connect_nodes(select_gy, ggy) # Sum -> Laplacian graph.connect_nodes_full(ggx, add_ggx_ggy, 0) graph.connect_nodes_full(ggy, add_ggx_ggy, 1) # Deconvolve with Laplacian graph.connect_nodes(add_ggx_ggy, fft_task) graph.connect_nodes_full(fft_task, mul_task, 0) graph.connect_nodes_full(mem_in_task, mul_task, 1) graph.connect_nodes(mul_task, ifft_task) if args.preserve_mean: # Get the mean back to the one of the input image graph.connect_nodes(copy_projections, mean_task) graph.connect_nodes_full(ifft_task, add_constant, 0) graph.connect_nodes_full(mean_task, add_constant, 1) last = add_constant else: last = ifft_task outputs = (last,) return (inputs, outputs) def run(args): """Usage with tofu: create readers, the pipeline and run it.""" if args.harmonize_borders: if args.mask_image: LOG.warning( "--mask-image has no effect when --harmonize-borders is specified" ) if args.guidance_image: LOG.warning( "--guidance-image has no effect when --harmonize-borders is specified" ) if args.inpaint_padding_mode != "mirrored_repeat": LOG.warning( "Padding mode should be `mirrored_repeat' for smooth transitions between " "true image borders and padded borders" ) elif not args.mask_image: raise ValueError("One of --mask-image or --harmonize-borders must be specified") # Reading reader = get_task("read") roi_args = make_subargs(args, ['y', 'height', 'y_step']) set_node_props(reader, args) setup_read_task(reader, args.projections, args) out_task = get_writer(args) graph = Ufo.TaskGraph() ((input_projections, input_mask, input_guidance), (last,)) = create_inpaint_pipeline( args, graph, ) if args.harmonize_borders: mask_reader = prepare_border_smoothing( args.inpaint_padded_width, args.inpaint_padded_height ) else: mask_reader = get_task("read") set_node_props(mask_reader, roi_args) setup_read_task(mask_reader, args.mask_image, args) graph.connect_nodes(reader, input_projections) graph.connect_nodes(mask_reader, input_mask) if not args.harmonize_borders and args.guidance_image: guidance_reader = get_task("read") set_node_props(guidance_reader, roi_args) setup_read_task(guidance_reader, args.guidance_image, args) graph.connect_nodes(guidance_reader, input_guidance) graph.connect_nodes(last, out_task) # CopyTask works only with FixedScheduler sched = Ufo.FixedScheduler() run_scheduler(sched, graph) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1683790582.0 ufo-tofu-0.13.0/tofu/lamino.py0000664000175000017500000002124500000000000016444 0ustar00tomastomas00000000000000"""Laminographic reconstruction.""" import logging import numpy as np from multiprocessing import Queue, Process from tofu.preprocess import create_preprocessing_pipeline from tofu.util import (get_filtering_padding, determine_shape, get_filenames, get_reconstruction_regions, get_reconstructed_cube_shape) from tofu.tasks import get_task, get_writer LOG = logging.getLogger(__name__) def lamino(params): """Laminographic reconstruction utilizing all GPUs.""" LOG.info('Z parameter: {}'.format(params.z_parameter)) prepare_angular_arguments(params) params.projection_filter_scale = np.sin(np.deg2rad(params.lamino_angle)) # For now we need to make a workaround for the memory leak, which means we need to execute # the passes in separate processes to clean up the low level code. For that we also need to # call the region-splitting in a separate function. # TODO: Simplify after the memory leak fix! queue = Queue() proc = Process(target=_create_runs, args=(params, queue,)) proc.start() proc.join() x_region, y_region, regions, num_gpus = queue.get() for i in range(0, len(regions), num_gpus): z_subregion = regions[i:min(i + num_gpus, len(regions))] LOG.info('Computing slices {}..{}'.format(z_subregion[0][0], z_subregion[-1][1])) proc = Process(target=_run, args=(params, x_region, y_region, z_subregion, i // num_gpus)) proc.start() proc.join() def prepare_angular_arguments(params): if not params.overall_angle: params.overall_angle = 360. LOG.info('Overall angle not specified, using 360 deg') if not params.angle: if params.dry_run: if not params.number: raise ValueError('--number must be specified by --dry-run') num_files = params.number else: num_files = len(get_filenames(params.projections)) if not num_files: raise RuntimeError("No files found in `{}'".format(params.projections)) params.angle = params.overall_angle / num_files * params.step LOG.info('Angle not specified, calculating from ' + '{} projections and step {}: {} deg'.format(num_files, params.step, params.angle)) determine_shape(params, params.projections, store=True) if not params.number: params.number = int(np.round(np.abs(params.overall_angle / params.angle))) if params.dry_run: LOG.info('Dummy data W x H x N: {} x {} x {}'.format(params.width, params.height, params.number)) def _create_runs(params, queue): """Workaround function to get the number of gpus and compute regions. gi.repository must always be called in a separate process, otherwise the resources return None gpus. """ #TODO: remove the whole function after memory leak fix! from gi.repository import Ufo scheduler = Ufo.FixedScheduler() gpus = scheduler.get_resources().get_gpu_nodes() num_gpus = len(gpus) x_region, y_region, regions = _split_regions(params, gpus) LOG.info('Using {} GPUs in {} passes'.format(min(len(regions), num_gpus), len(regions))) queue.put((x_region, y_region, regions, num_gpus)) def _run(params, x_region, y_region, regions, index): """Execute one pass on all possible GPUs with slice ranges given by *regions*.""" from gi.repository import Ufo pm = Ufo.PluginManager() graph = Ufo.TaskGraph() scheduler = Ufo.FixedScheduler() gpus = scheduler.get_resources().get_gpu_nodes() num_gpus = len(gpus) broadcast = Ufo.CopyTask() source = _setup_source(params, pm, graph) graph.connect_nodes(source, broadcast) for i, region in enumerate(regions): subindex = index * num_gpus + i _setup_graph(pm, graph, subindex, x_region, y_region, region, params, broadcast, gpu=gpus[i]) scheduler.run(graph) duration = scheduler.props.time LOG.info('Execution time: {} s'.format(duration)) return duration def _setup_source(params, pm, graph): from tofu.preprocess import create_flat_correct_pipeline from tofu.util import set_node_props, setup_read_task if params.dry_run: source = pm.get_task('dummy-data') source.props.number = params.number source.props.width = params.width source.props.height = params.height elif params.darks and params.flats: source = create_flat_correct_pipeline(params, graph) else: source = pm.get_task('read') set_node_props(source, params) setup_read_task(source, params.projections, params) return source def _setup_graph(pm, graph, index, x_region, y_region, region, params, source, gpu=None): backproject = get_task('lamino-backproject', processing_node=gpu) slicer = get_task('slice', processing_node=gpu) writer = get_writer(params) if not params.dry_run: writer.props.filename = '{}-{:>03}-%04i.tif'.format(params.output, index) # parameters backproject.props.num_projections = params.number backproject.props.overall_angle = np.deg2rad(params.overall_angle) backproject.props.lamino_angle = np.deg2rad(params.lamino_angle) backproject.props.roll_angle = np.deg2rad(params.roll_angle) backproject.props.x_region = x_region backproject.props.y_region = y_region backproject.props.z = params.z backproject.props.addressing_mode = params.lamino_padding_mode backproject.props.parameter = params.z_parameter if params.projection_crop_after == 'backprojection': padding = get_filtering_padding(params.width) else: padding = 0 if params.z_parameter in ['lamino-angle', 'roll-angle']: region = [np.deg2rad(reg) for reg in region] if params.z_parameter == 'x-center': # Take projection padding into account region = [region[0] + padding / 2, region[1] + padding / 2, region[2]] backproject.props.region = region backproject.props.center = (params.axis[0] + padding / 2, params.axis[1]) LOG.debug('x center after padding: %g', backproject.props.center[0]) graph.connect_nodes(backproject, slicer) graph.connect_nodes(slicer, writer) if params.only_bp: first = backproject graph.connect_nodes(source, backproject) else: first = create_preprocessing_pipeline(params, graph, source=source, processing_node=gpu) graph.connect_nodes(first, backproject) return first def _split_regions(params, gpus): """Split processing between *gpus* by specifying the number of slices processed per GPU.""" x_region, y_region, z_region = get_reconstruction_regions(params) z_start, z_stop, z_step = z_region y_start, y_stop, y_step = y_region x_start, x_stop, x_step = x_region slice_width, slice_height, num_slices = get_reconstructed_cube_shape(x_region, y_region, z_region) if params.slices_per_device: num_slices_per_gpu = params.slices_per_device else: num_slices_per_gpu = _compute_num_slices(gpus, slice_width, slice_height) if num_slices_per_gpu > num_slices: num_slices_per_gpu = num_slices LOG.info('Using {} slices per GPU'.format(num_slices_per_gpu)) z_starts = np.arange(z_start, z_stop, z_step * num_slices_per_gpu) regions = [] for start in z_starts: regions.append((start, min(z_stop, start + z_step * num_slices_per_gpu), z_step)) return x_region, y_region, regions def _compute_num_slices(gpus, width, height): """Determine number of slices which can be calculated per-device based on *gpus*, slice *width* and *height*. """ from gi.repository import Ufo # Make sure the double buffering works with room for intermediate steps # TODO: compute this precisely safety_coeff = 3. # Use the weakest one, if heterogenous systems emerge, measure the performance and # reconsider memories = [gpu.get_info(Ufo.GpuNodeInfo.GLOBAL_MEM_SIZE) for gpu in gpus] i = np.argmin(memories) max_allocatable = gpus[i].get_info(Ufo.GpuNodeInfo.MAX_MEM_ALLOC_SIZE) if max_allocatable * safety_coeff <= memories[i]: # Don't waste resources max_memory = max_allocatable else: max_memory = memories[i] / safety_coeff if max_memory > 2 ** 32: # Current NVIDIA implementation allows only 4 GB max_memory = 2 ** 32 max_memory /= safety_coeff num_slices = int(np.floor(max_memory / (width * height * 4))) LOG.info('GPU memory used per GPU: {:.2f} GB'.format(max_memory / 2. ** 30)) return num_slices ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414563.0 ufo-tofu-0.13.0/tofu/preprocess.py0000664000175000017500000004353700000000000017362 0ustar00tomastomas00000000000000"""Flat field correction.""" import sys import logging from gi.repository import Ufo from tofu.util import (fbp_filtering_in_phase_retrieval, get_filenames, set_node_props, make_subargs, determine_shape, setup_read_task, setup_padding, next_power_of_two, run_scheduler) from tofu.tasks import get_task, get_writer LOG = logging.getLogger(__name__) def create_flat_correct_pipeline(args, graph, processing_node=None): """ Create flat field correction pipeline. All the settings are provided in *args*. *graph* is used for making the connections. Returns the flat field correction task which can be used for further pipelining. """ pm = Ufo.PluginManager() if args.projections is None or args.flats is None or args.darks is None: raise RuntimeError("You must specify --projections, --flats and --darks.") reader = get_task('read') dark_reader = get_task('read') flat_before_reader = get_task('read') ffc = get_task('flat-field-correct', processing_node=processing_node, dark_scale=args.dark_scale, flat_scale=args.flat_scale, absorption_correct=args.absorptivity, fix_nan_and_inf=args.fix_nan_and_inf) mode = args.reduction_mode.lower() roi_args = make_subargs(args, ['y', 'height', 'y_step']) set_node_props(reader, args) set_node_props(dark_reader, roi_args) set_node_props(flat_before_reader, roi_args) for r, path in ((reader, args.projections), (dark_reader, args.darks), (flat_before_reader, args.flats)): setup_read_task(r, path, args) LOG.debug("Doing flat field correction using reduction mode `{}'".format(mode)) if args.flats2: flat_after_reader = get_task('read') setup_read_task(flat_after_reader, args.flats2, args) set_node_props(flat_after_reader, roi_args) num_files = len(get_filenames(args.projections)) can_read = len(list(range(args.start, num_files, args.step))) number = args.number if args.number else num_files num_read = min(can_read, number) flat_interpolate = get_task('interpolate', processing_node=processing_node, number=num_read) if args.resize: LOG.debug("Resize input data by factor of {}".format(args.resize)) proj_bin = get_task('bin', processing_node=processing_node, size=args.resize) dark_bin = get_task('bin', processing_node=processing_node, size=args.resize) flat_bin = get_task('bin', processing_node=processing_node, size=args.resize) graph.connect_nodes(reader, proj_bin) graph.connect_nodes(dark_reader, dark_bin) graph.connect_nodes(flat_before_reader, flat_bin) reader, dark_reader, flat_before_reader = proj_bin, dark_bin, flat_bin if args.flats2: flat_bin = get_task('bin', processing_node=processing_node, size=args.resize) graph.connect_nodes(flat_after_reader, flat_bin) flat_after_reader = flat_bin if mode == 'median': dark_stack = get_task('stack', processing_node=processing_node, number=len(get_filenames(args.darks))) dark_reduced = get_task('flatten', processing_node=processing_node, mode='median') flat_before_stack = get_task('stack', processing_node=processing_node, number=len(get_filenames(args.flats))) flat_before_reduced = get_task('flatten', processing_node=processing_node, mode='median') graph.connect_nodes(dark_reader, dark_stack) graph.connect_nodes(dark_stack, dark_reduced) graph.connect_nodes(flat_before_reader, flat_before_stack) graph.connect_nodes(flat_before_stack, flat_before_reduced) if args.flats2: flat_after_stack = get_task('stack', processing_node=processing_node, number=len(get_filenames(args.flats2))) flat_after_reduced = get_task('flatten', processing_node=processing_node, mode='median') graph.connect_nodes(flat_after_reader, flat_after_stack) graph.connect_nodes(flat_after_stack, flat_after_reduced) elif mode == 'average': dark_reduced = get_task('average', processing_node=processing_node) flat_before_reduced = get_task('average', processing_node=processing_node) graph.connect_nodes(dark_reader, dark_reduced) graph.connect_nodes(flat_before_reader, flat_before_reduced) if args.flats2: flat_after_reduced = get_task('average', processing_node=processing_node) graph.connect_nodes(flat_after_reader, flat_after_reduced) else: raise ValueError('Invalid reduction mode') graph.connect_nodes_full(reader, ffc, 0) graph.connect_nodes_full(dark_reduced, ffc, 1) if args.flats2: graph.connect_nodes_full(flat_before_reduced, flat_interpolate, 0) graph.connect_nodes_full(flat_after_reduced, flat_interpolate, 1) graph.connect_nodes_full(flat_interpolate, ffc, 2) else: graph.connect_nodes_full(flat_before_reduced, ffc, 2) return ffc def create_phase_retrieval_pipeline(args, graph, processing_node=None): LOG.debug('Creating phase retrieval pipeline') pm = Ufo.PluginManager() # Retrieve phase phase_retrieve = get_task('retrieve-phase', processing_node=processing_node) pad_phase_retrieve = get_task('pad', processing_node=processing_node) crop_phase_retrieve = get_task('crop', processing_node=processing_node) fft_phase_retrieve = get_task('fft', processing_node=processing_node) ifft_phase_retrieve = get_task('ifft', processing_node=processing_node) calculate = get_task('calculate', processing_node=processing_node) width = args.width height = args.height default_padded_width = next_power_of_two(width + 64) default_padded_height = next_power_of_two(height + 64) if not args.retrieval_padded_width: args.retrieval_padded_width = default_padded_width if not args.retrieval_padded_height: args.retrieval_padded_height = default_padded_height fmt = 'Phase retrieval padding: {}x{} -> {}x{}' LOG.debug(fmt.format(width, height, args.retrieval_padded_width, args.retrieval_padded_height)) x = (args.retrieval_padded_width - width) // 2 y = (args.retrieval_padded_height - height) // 2 pad_phase_retrieve.props.x = x pad_phase_retrieve.props.y = y pad_phase_retrieve.props.width = args.retrieval_padded_width pad_phase_retrieve.props.height = args.retrieval_padded_height pad_phase_retrieve.props.addressing_mode = args.retrieval_padding_mode crop_phase_retrieve.props.y = y crop_phase_retrieve.props.height = height if ( args.projection_crop_after == 'filter' or not fbp_filtering_in_phase_retrieval(args) ): crop_phase_retrieve.props.x = x crop_phase_retrieve.props.width = width phase_retrieve.props.method = args.retrieval_method phase_retrieve.props.energy = args.energy if len(args.propagation_distance) == 1: phase_retrieve.props.distance = [args.propagation_distance[0]] else: phase_retrieve.props.distance_x = args.propagation_distance[0] phase_retrieve.props.distance_y = args.propagation_distance[1] phase_retrieve.props.pixel_size = args.pixel_size phase_retrieve.props.regularization_rate = args.regularization_rate phase_retrieve.props.thresholding_rate = args.thresholding_rate phase_retrieve.props.frequency_cutoff = args.frequency_cutoff fft_phase_retrieve.props.dimensions = 2 ifft_phase_retrieve.props.dimensions = 2 if args.delta is not None: import numpy as np lam = 6.62606896e-34 * 299792458 / (args.energy * 1.60217733e-16) thickness_conversion = -lam / (2 * np.pi * args.delta) else: thickness_conversion = 1 if fbp_filtering_in_phase_retrieval(args): LOG.debug('Fusing phase retrieval and FBP filtering') fltr = get_task('filter', processing_node=processing_node) fltr.props.filter = args.projection_filter fltr.props.scale = args.projection_filter_scale fltr.props.cutoff = args.projection_filter_cutoff graph.connect_nodes(phase_retrieve, fltr) graph.connect_nodes(fltr, ifft_phase_retrieve) else: graph.connect_nodes(phase_retrieve, ifft_phase_retrieve) if args.retrieval_method == 'tie' and args.tie_approximate_logarithm: # a = 2 / 10^R, b = -10^R / 2, c = thickness_conversion # t = Taylor series point, T = t * (1 - ln(t)) # a * b = -1 (from above) # ----------------------------------------------------- # We will use the Taylor expansion to the 1st order of the logarithm which TIE needs: # -ln (a * retrieved) * b * c ~ b * c / t * [T - a * retrieved] # T - a * retrieved = T - a * F^-1{F(I) * kernel} = T - F^-1{aF(I) * kernel} # for u, v = 0: aF(I) * kernel = F(I)(0, 0) because kernel(0, 0) = 1 / a, thus: # T - F^-1{aF(I) * kernel} = -F^-1{aF(I - T) * kernel}, because for u, v = 0 we have: # -aF(I - T) * kernel = - F(I - T) = T - F(I)(0, 0) (and the rest of the frequencies is # unaffected by "T". # further: -aF(I - T) * kernel = F[a(T - I)] * kernel # bring in b, c and t and we have F[abc/t(T - I)] * kernel, with ab=-1 we end up with: # F[c/t(I - T)] * kernel, so the approximation of the logarithm does not need any change in # the TIE kernel itself, we may just transform the input image and use the rest of the # pipeline as usual. if args.delta is None: # Do not multiply by one expression = "v - 1" else: expression = "{} * (v - {})".format( # c / t thickness_conversion / args.tie_approximate_point, # t * (1 - ln(t)) args.tie_approximate_point * (1 - np.log(args.tie_approximate_point)) ) calculate.props.expression = expression LOG.debug("Phase contrast conversion expression (log approximation): `%s'", expression) graph.connect_nodes(calculate, pad_phase_retrieve) first = calculate last = crop_phase_retrieve else: if args.retrieval_method == 'tie': expression = '(isinf (v) || isnan (v) || (v <= 0)) ? 0.0f : ' if args.tie_approximate_logarithm: # ln(x) ~ x - 1 at a=1 expression += '(1.0f - {} * v) * {}' else: expression += '-log ({} * v) * {}' # first term: 2 for 0.5 factor in ufo-filters and alpha = 10^-R, so divide by 10^R # second term: The following converts the TIE result to the actual phase, which when # multiplied by the thickness_conversion gives the projected thickness thickness_conversion *= -10 ** args.regularization_rate / 2 expression = expression.format(2 / 10 ** args.regularization_rate, thickness_conversion) else: expression = '(isinf (v) || isnan (v)) ? 0.0f : v * {}'.format(thickness_conversion) LOG.debug("Phase contrast conversion expression: `%s'", expression) calculate.props.expression = expression graph.connect_nodes(crop_phase_retrieve, calculate) first = pad_phase_retrieve last = calculate graph.connect_nodes(pad_phase_retrieve, fft_phase_retrieve) graph.connect_nodes(fft_phase_retrieve, phase_retrieve) graph.connect_nodes(ifft_phase_retrieve, crop_phase_retrieve) return (first, last) def run_flat_correct(args): graph = Ufo.TaskGraph() sched = Ufo.Scheduler() pm = Ufo.PluginManager() out_task = get_writer(args) flat_task = create_flat_correct_pipeline(args, graph) graph.connect_nodes(flat_task, out_task) run_scheduler(sched, graph) def create_sinogram_pipeline(args, graph): """Create sinogram generating pipeline based on arguments from *args*.""" pm = Ufo.PluginManager() sinos = pm.get_task('transpose-projections') if args.number: region = (args.start, args.start + args.number, args.step) num_projections = len(list(range(*region))) else: num_projections = len(get_filenames(args.projections)) sinos.props.number = num_projections if args.darks and args.flats: start = create_flat_correct_pipeline(args, graph) else: start = get_task('read') start.props.path = args.projections set_node_props(start, args) graph.connect_nodes(start, sinos) return sinos def run_sinogram_generation(args): """Make the sinograms with arguments provided by *args*.""" if not args.height: args.height = determine_shape(args, args.projections)[1] - args.y step = args.y_step * args.pass_size if args.pass_size else args.height starts = list(range(args.y, args.y + args.height, step)) + [args.y + args.height] def generate_partial(append=False): graph = Ufo.TaskGraph() sched = Ufo.Scheduler() args.output_append = append writer = get_writer(args) sinos = create_sinogram_pipeline(args, graph) graph.connect_nodes(sinos, writer) return run_scheduler(sched, graph) for i in range(len(starts) - 1): args.y = starts[i] args.height = starts[i + 1] - starts[i] if not generate_partial(append=i != 0): # We were interrupted break def create_projection_filtering_pipeline(args, graph, processing_node=None): pm = Ufo.PluginManager() pad = get_task('pad', processing_node=processing_node) fft = get_task('fft', processing_node=processing_node) ifft = get_task('ifft', processing_node=processing_node) fltr = get_task('filter', processing_node=processing_node) if args.projection_crop_after == 'filter': crop = get_task('crop', processing_node=processing_node) else: crop = None padding_width = setup_padding(pad, args.width, args.height, args.projection_padding_mode, crop=crop)[0] fft.props.dimensions = 1 ifft.props.dimensions = 1 fltr.props.filter = args.projection_filter fltr.props.scale = args.projection_filter_scale fltr.props.cutoff = args.projection_filter_cutoff graph.connect_nodes(pad, fft) graph.connect_nodes(fft, fltr) graph.connect_nodes(fltr, ifft) if crop: graph.connect_nodes(ifft, crop) last = crop else: last = ifft return (pad, last) def create_preprocessing_pipeline(args, graph, source=None, processing_node=None, cone_beam_weight=True, make_reader=True): """If *make_reader* is True, create a read task if *source* is None and no dark and flat fields are given. """ import numpy as np if not (args.width and args.height): width, height = determine_shape(args, args.projections) if not width: raise RuntimeError("Could not determine width from the input") if not args.width: args.width = width if not args.height: args.height = height - args.y LOG.debug('Image width x height: %d x %d', args.width, args.height) current = None if source: current = source elif args.darks and args.flats: current = create_flat_correct_pipeline(args, graph, processing_node=processing_node) else: if make_reader: current = get_task('read') set_node_props(current, args) if not args.projections: raise RuntimeError('--projections not set') setup_read_task(current, args.projections, args) if args.absorptivity: absorptivity = get_task('calculate', processing_node=processing_node) absorptivity.props.expression = 'v <= 0 ? 0.0f : -log(v)' if current: graph.connect_nodes(current, absorptivity) current = absorptivity if args.transpose_input: transpose = get_task('transpose') if current: graph.connect_nodes(current, transpose) current = transpose tmp = args.width args.width = args.height args.height = tmp if cone_beam_weight and not np.all(np.isinf(args.source_position_y)): # Cone beam projection weight LOG.debug('Enabling cone beam weighting') weight = get_task('cone-beam-projection-weight', processing_node=processing_node) weight.props.source_distance = (-np.array(args.source_position_y)).tolist() weight.props.detector_distance = args.detector_position_y weight.props.center_position_x = args.center_position_x or [args.width / 2. + (args.width % 2) * 0.5] weight.props.center_position_z = args.center_position_z or [args.height / 2. + (args.height % 2) * 0.5] weight.props.axis_angle_x = args.axis_angle_x if current: graph.connect_nodes(current, weight) current = weight if args.energy is not None and args.propagation_distance is not None: pr_first, pr_last = create_phase_retrieval_pipeline(args, graph, processing_node=processing_node) if current: graph.connect_nodes(current, pr_first) current = pr_last if args.projection_filter != 'none' and not fbp_filtering_in_phase_retrieval(args): pf_first, pf_last = create_projection_filtering_pipeline(args, graph, processing_node=processing_node) if current: graph.connect_nodes(current, pf_first) current = pf_last return current def run_preprocessing(args): graph = Ufo.TaskGraph() sched = Ufo.Scheduler() pm = Ufo.PluginManager() out_task = get_writer(args) current = create_preprocessing_pipeline(args, graph) graph.connect_nodes(current, out_task) run_scheduler(sched, graph) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414563.0 ufo-tofu-0.13.0/tofu/reco.py0000664000175000017500000002701100000000000016112 0ustar00tomastomas00000000000000import os import logging import glob import tempfile import sys import numpy as np from gi.repository import Ufo from tofu.preprocess import create_flat_correct_pipeline from tofu.util import (set_node_props, setup_read_task, get_filenames, read_image, determine_shape, setup_padding, run_scheduler) from tofu.tasks import get_task, get_writer LOG = logging.getLogger(__name__) pm = Ufo.PluginManager() def get_dummy_reader(params): if params.width is None and params.height is None: raise RuntimeError("You have to specify --width and --height when generating data.") width, height = params.width, params.height reader = get_task('dummy-data', width=width, height=height, number=params.number or 1) return reader, width, height def get_file_reader(params): reader = pm.get_task('read') set_node_props(reader, params) return reader def get_projection_reader(params): reader = get_file_reader(params) setup_read_task(reader, params.projections, params) width, height = determine_shape(params, params.projections) return reader, width, height def get_sinogram_reader(params): reader = get_file_reader(params) setup_read_task(reader, params.sinograms, params) width, height = determine_shape(params, path=params.sinograms) return reader, width, height def tomo(params): # Create reader and writer if params.projections and params.sinograms: raise RuntimeError("Cannot specify both --projections and --sinograms.") if params.projections is None and params.sinograms is None: reader, width, height = get_dummy_reader(params) else: if params.projections: reader, width, height = get_projection_reader(params) else: reader, width, height = get_sinogram_reader(params) axis = params.axis or width / 2.0 if params.projections and params.resize: width /= params.resize height /= params.resize axis /= params.resize LOG.debug("Input dimensions: {}x{} pixels".format(width, height)) writer = get_writer(params) # Setup graph depending on the chosen method and input data g = Ufo.TaskGraph() if params.projections is not None: if params.number: count = len(list(range(params.start, params.start + params.number, params.step))) else: count = len(get_filenames(params.projections)) LOG.debug("Number of projections: {}".format(count)) sino_output = get_task('transpose-projections', number=count) if params.darks and params.flats: g.connect_nodes(create_flat_correct_pipeline(params, g), sino_output) else: g.connect_nodes(reader, sino_output) if height: # Sinogram height is the one needed for further padding height = count else: sino_output = reader if params.method == 'fbp': fft = get_task('fft', dimensions=1) ifft = get_task('ifft', dimensions=1) fltr = get_task('filter', filter=params.projection_filter, cutoff=params.projection_filter_cutoff) bp = get_task('backproject', axis_pos=axis) last_node = bp if params.angle: bp.props.angle_step = params.angle if params.offset: bp.props.angle_offset = params.offset if width and height: # Pad the image with its extent to prevent reconstuction ring pad = get_task('pad') crop = get_task('crop') if params.projection_crop_after == 'filter': crop_after_filter = crop else: crop_after_filter = None padding_width = setup_padding(pad, width, height, params.projection_padding_mode, crop=crop_after_filter)[0] LOG.debug("Padding input to: {}x{} pixels".format(pad.props.width, pad.props.height)) g.connect_nodes(sino_output, pad) g.connect_nodes(pad, fft) g.connect_nodes(fft, fltr) g.connect_nodes(fltr, ifft) if crop_after_filter: g.connect_nodes(ifft, crop) g.connect_nodes(crop, bp) else: bp.props.axis_pos = axis + padding_width / 2 crop.props.x = padding_width // 2 crop.props.y = padding_width // 2 crop.props.width = width crop.props.height = width g.connect_nodes(ifft, bp) g.connect_nodes(bp, crop) last_node = crop else: if params.crop_width: ifft.props.crop_width = int(params.crop_width) LOG.debug("Cropping to {} pixels".format(ifft.props.crop_width)) g.connect_nodes(sino_output, fft) g.connect_nodes(fft, fltr) g.connect_nodes(fltr, ifft) g.connect_nodes(ifft, bp) g.connect_nodes(last_node, writer) if params.method in ('sart', 'sirt', 'sbtv', 'asdpocs'): projector = pm.get_task_from_package('ir', 'parallel-projector') projector.set_properties(model='joseph', is_forward=False) projector.set_properties(axis_position=axis) projector.set_properties(step=params.angle if params.angle else np.pi / 180.0) method = pm.get_task_from_package('ir', params.method) method.set_properties(projector=projector, num_iterations=params.num_iterations) if params.method in ('sart', 'sirt'): method.set_properties(relaxation_factor=params.relaxation_factor) if params.method == 'asdpocs': minimizer = pm.get_task_from_package('ir', 'sirt') method.set_properties(df_minimizer=minimizer) if params.method == 'sbtv': # FIXME: the lambda keyword is preventing from the following # assignment ... # method.props.lambda = params.lambda method.set_properties(mu=params.mu) g.connect_nodes(sino_output, method) g.connect_nodes(method, writer) if params.method == 'dfi': oversampling = params.oversampling or 1 pad = get_task('zeropad', center_of_rotation=axis, oversampling=oversampling) fft = get_task('fft', dimensions=1, auto_zeropadding=0) dfi = get_task('dfi-sinc') ifft = get_task('ifft', dimensions=2) swap_forward = get_task('swap-quadrants') swap_backward = get_task('swap-quadrants') if params.angle: dfi.props.angle_step = params.angle g.connect_nodes(sino_output, pad) g.connect_nodes(pad, fft) g.connect_nodes(fft, dfi) g.connect_nodes(dfi, swap_forward) g.connect_nodes(swap_forward, ifft) g.connect_nodes(ifft, swap_backward) if width: crop = get_task('crop') crop.set_properties(from_center=True, width=width, height=width) g.connect_nodes(swap_backward, crop) g.connect_nodes(crop, writer) else: g.connect_nodes(swap_backward, writer) scheduler = Ufo.Scheduler() if hasattr(scheduler.props, 'enable_tracing'): LOG.debug("Use tracing: {}".format(params.enable_tracing)) scheduler.props.enable_tracing = params.enable_tracing if not run_scheduler(scheduler, g): return duration = scheduler.props.time LOG.info("Execution time: {} s".format(duration)) return duration def estimate_center(params): if params.estimate_method == 'reconstruction': axis = estimate_center_by_reconstruction(params) else: axis = estimate_center_by_correlation(params) return axis def estimate_center_by_reconstruction(params): if params.projections is not None: raise RuntimeError("Cannot estimate axis from projections") sinos = sorted(glob.glob(os.path.join(params.sinograms, '*.tif'))) if not sinos: raise RuntimeError("No sinograms found in {}".format(params.sinograms)) # Use a sinogram that probably has some interesting data filename = sinos[len(sinos) // 2] sinogram = read_image(filename) initial_width = sinogram.shape[1] m0 = np.mean(np.sum(sinogram, axis=1)) center = initial_width / 2.0 width = initial_width / 2.0 new_center = center tmp_dir = tempfile.mkdtemp() tmp_output = os.path.join(tmp_dir, 'slice-0.tif') params.sinograms = filename params.output = os.path.join(tmp_dir, 'slice-%i.tif') def heaviside(A): return (A >= 0.0) * 1.0 def get_score(guess, m0): # Run reconstruction with new guess params.axis = guess tomo(params) # Analyse reconstructed slice result = read_image(tmp_output) Q_IA = float(np.sum(np.abs(result)) / m0) Q_IN = float(-np.sum(result * heaviside(-result)) / m0) LOG.info("Q_IA={}, Q_IN={}".format(Q_IA, Q_IN)) return Q_IA def best_center(center, width): trials = [center + (width / 4.0) * x for x in range(-2, 3)] scores = [(guess, get_score(guess, m0)) for guess in trials] LOG.info(scores) best = sorted(scores, cmp=lambda x, y: cmp(x[1], y[1])) return best[0][0] for i in range(params.num_iterations): LOG.info("Estimate iteration: {}".format(i)) new_center = best_center(new_center, width) LOG.info("Currently best center: {}".format(new_center)) width /= 2.0 try: os.remove(tmp_output) os.removedirs(tmp_dir) except OSError: LOG.info("Could not remove {} or {}".format(tmp_output, tmp_dir)) return new_center def estimate_center_by_correlation(params): """Use correlation to estimate center of rotation for tomography.""" def flat_correct(flat, radio): nonzero = np.where(radio != 0) result = np.zeros_like(radio) result[nonzero] = flat[nonzero] / radio[nonzero] # log(1) = 0 result[result <= 0] = 1 return np.log(result) first = read_image(get_filenames(params.projections)[0]).astype(float) last_index = params.start + params.number if params.number else -1 last = read_image(get_filenames(params.projections)[last_index]).astype(float) if params.darks and params.flats: dark = read_image(get_filenames(params.darks)[0]).astype(float) flat = read_image(get_filenames(params.flats)[0]) - dark first = flat_correct(flat, first - dark) last = flat_correct(flat, last - dark) height = params.height if params.height else -1 y_region = slice(params.y, min(params.y + height, first.shape[0]), params.y_step) first = first[y_region, :] last = last[y_region, :] return compute_rotation_axis(first, last) def compute_rotation_axis(first_projection, last_projection): """ Compute the tomographic rotation axis based on cross-correlation technique. *first_projection* is the projection at 0 deg, *last_projection* is the projection at 180 deg. """ from scipy.signal import fftconvolve width = first_projection.shape[1] first_projection = first_projection - first_projection.mean() last_projection = last_projection - last_projection.mean() # The rotation by 180 deg flips the image horizontally, in order # to do cross-correlation by convolution we must also flip it # vertically, so the image is transposed and we can apply convolution # which will act as cross-correlation convolved = fftconvolve(first_projection, last_projection[::-1, :], mode='same') center = np.unravel_index(convolved.argmax(), convolved.shape)[1] return (width / 2.0 + center) / 2 ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414563.0 ufo-tofu-0.13.0/tofu/tasks.py0000664000175000017500000000630500000000000016312 0ustar00tomastomas00000000000000import logging from gi.repository import Ufo LOG = logging.getLogger(__name__) PLUGIN_MANAGER = Ufo.PluginManager() def get_task(name, processing_node=None, **kwargs): task = PLUGIN_MANAGER.get_task(name) task.set_properties(**kwargs) if processing_node and task.uses_gpu(): LOG.debug("Assigning task '%s' to node %d", name, processing_node.get_index()) task.set_proc_node(processing_node) return task def get_writer(params): if 'dry_run' in params and params.dry_run: LOG.debug("Discarding data output") return get_task('null', download=True) outname = params.output LOG.debug("Writing output to {}".format(outname)) writer = get_task('write', filename=outname, rescale=params.output_rescale) writer.props.append = params.output_append if params.output_bitdepth != 32: writer.props.bits = params.output_bitdepth if params.output_minimum is not None: writer.props.minimum = params.output_minimum if params.output_maximum is not None: writer.props.maximum = params.output_maximum if params.output_minimum is not None or params.output_maximum is not None: LOG.info('--output-minimum or --output-maximum specified, turning --output-rescale on') writer.props.rescale = True if hasattr (writer.props, 'bytes_per_file'): writer.props.bytes_per_file = params.output_bytes_per_file if hasattr(writer.props, 'tiff_bigtiff'): writer.props.tiff_bigtiff = params.output_bytes_per_file > 2 ** 32 - 2 ** 25 return writer def get_memory_in(array): import numpy as np if array.ndim != 2: raise ValueError("Only 2D images are supported") if array.dtype != np.float32 and array.dtype != np.complex64: raise ValueError("Only images with float32 or complex64 data type are supported") is_complex = array.dtype == np.complex64 in_task = get_task('memory-in') in_task.props.complex_layout = is_complex in_task.props.pointer = array.__array_interface__['data'][0] in_task.props.width = 2 * array.shape[1] if is_complex else array.shape[1] in_task.props.height = array.shape[0] in_task.props.number = 1 in_task.props.bitdepth = 32 # We need to extend the survival of *array* beyond this function to the point when the graph is # executed, otherwise it will be destroyed and UFO will try to get data from freed memory. Thus, # attach it to the task which actually needs it, because when that one is garbage collected then # the array may be as well. in_task.np_array = array return in_task def get_memory_out(width, height): import numpy as np array = np.empty((height, width), dtype=np.float32) out_task = get_task('memory-out') out_task.props.pointer = array.__array_interface__['data'][0] out_task.props.max_size = array.nbytes # We need to extend the survival of *array* beyond this function to the point when the graph is # executed, otherwise it will be destroyed and UFO will try to get data from freed memory. Thus, # attach it to the task which actually needs it, because when that one is garbage collected then # the array may be as well. out_task.np_array = array return out_task ././@PaxHeader0000000000000000000000000000003300000000000011451 xustar000000000000000027 mtime=1698416097.773776 ufo-tofu-0.13.0/tofu/tests/0000775000175000017500000000000000000000000015751 5ustar00tomastomas00000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1665760162.0 ufo-tofu-0.13.0/tofu/tests/__init__.py0000664000175000017500000000000000000000000020050 0ustar00tomastomas00000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1683790582.0 ufo-tofu-0.13.0/tofu/tests/conftest.py0000664000175000017500000000304400000000000020151 0ustar00tomastomas00000000000000import pytest from PyQt5.QtWidgets import QInputDialog from tofu.flow.main import get_filled_registry from tofu.flow.scene import UfoScene from tofu.flow.propertylinksmodels import PropertyLinksModel, NodeTreeModel @pytest.fixture(scope='function') def nodes(monkeypatch): reg = get_filled_registry() scene = UfoScene(reg) nodes = {} # Composite node for name in ['read', 'pad']: model_cls = reg.create(name) node = scene.create_node(model_cls) node.graphics_object.setSelected(True) monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('cpm', True)) nodes['cpm'] = scene.create_composite() nodes['cpm'].graphics_object.setSelected(False) # Simple nodes for i in range(5): name = f'read_{i}' if i else 'read' model_cls = reg.create('read') nodes[name] = scene.create_node(model_cls) model_cls = reg.create('image_viewer') nodes['image_viewer'] = scene.create_node(model_cls) model_cls = reg.create('average') nodes['average'] = scene.create_node(model_cls) return nodes @pytest.fixture(scope='function') def scene(): reg = get_filled_registry() return UfoScene(reg) @pytest.fixture(scope='function') def scene_with_composite(nodes): return UfoScene(nodes['cpm'].model._registry) @pytest.fixture(scope='function') def node_model(): model = NodeTreeModel() model.setColumnCount(1) return model @pytest.fixture(scope='function') def link_model(node_model): model = PropertyLinksModel(node_model) return model ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1665760162.0 ufo-tofu-0.13.0/tofu/tests/flow_util.py0000664000175000017500000000166600000000000020340 0ustar00tomastomas00000000000000def populate_link_model(link_model, nodes): read = nodes['read'] read_2 = nodes['read_2'] composite = nodes['cpm'] records = [[read, read.model, 'number'], [read_2, read_2.model, 'height'], [composite, composite.model['Read'], 'y']] for (i, (node, model, prop)) in enumerate(records): link_model.add_item(node, model, prop, 0, i) return records def get_index_from_treemodel(node_model, row, prop_name): item = node_model.item(row, 0) i = 0 prop_item = item.child(i) while prop_item.text() != prop_name: i += 1 prop_item = item.child(i) return node_model.indexFromItem(prop_item) def add_nodes_to_scene(scene, model_names=None): if not model_names: model_names = ['read'] nodes = [] for name in model_names: model_cls = scene.registry.create(name) nodes.append(scene.create_node(model_cls)) return nodes ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1665760161.0 ufo-tofu-0.13.0/tofu/tests/test_flow_execution.py0000664000175000017500000001143700000000000022422 0ustar00tomastomas00000000000000import pytest from tofu.flow.execution import get_gpu_splitting_models, UfoExecutor from tofu.flow.main import get_filled_registry from tofu.flow.scene import UfoScene @pytest.fixture(scope='function') def scene(): reg = get_filled_registry() scene = UfoScene(reg) for name in ['dummy_data', 'pad', 'null']: # Set nodes as scene attributes for convenience setattr(scene, name, scene.create_node(reg.create(name))) scene.create_connection(scene.dummy_data['output'][0], scene.pad['input'][0]) scene.create_connection(scene.pad['output'][0], scene.null['input'][0]) return scene @pytest.fixture(scope='function') def executor(): return UfoExecutor() class TestUfoExecutor: def test_init(self, executor): ... def test_reset(self, executor): assert not executor._aborted assert executor._schedulers == [] assert executor.num_generated == 0 def test_abort(self, executor): self.called = False def slot(): self.called = True executor.execution_finished.connect(slot) executor.abort() assert self.called def test_on_processed(self, executor): self.num_generated = 0 def slot(): self.num_generated += 1 executor.processed_signal.connect(slot) executor.on_processed(None) executor.on_processed(None) assert self.num_generated == executor.num_generated == 2 def test_setup_ufo_graph(self, qtbot, scene, executor): graph = scene.get_simple_node_graphs()[0] gpus = executor._resources.get_gpu_nodes() assert gpus executor.setup_ufo_graph(graph, gpu=gpus[0], region=None, signalling_model=scene.dummy_data.model) def test_run_ufo_graph(self, qtbot, scene, executor): graph = scene.get_simple_node_graphs()[0] gpus = executor._resources.get_gpu_nodes() assert gpus ufo_graph = executor.setup_ufo_graph(graph, gpu=gpus[0], region=None, signalling_model=scene.dummy_data.model) # Run with default scheduler executor._run_ufo_graph(ufo_graph, False) # Run with fixed scheduler executor._run_ufo_graph(ufo_graph, True) # def test_check_graph(self, qtbot, scene, executor): # # TODO: implement this when memory-in is implemented and there is something to test def test_run(self, qtbot, scene, executor): def on_num_inputs_changed(number): self.num_inputs = number def on_processed(number): self.num_processed = number def on_execution_started(): self.started = True def on_execution_finished(): self.finished = True def on_exception_occured(): self.exception = True scene.dummy_data.model['number'] = 10 graph = scene.get_simple_node_graphs()[0] self.num_inputs = 0 self.num_processed = 0 self.started = False self.finished = False self.exception = None executor.number_of_inputs_changed.connect(on_num_inputs_changed) executor.processed_signal.connect(on_processed) executor.execution_started.connect(on_execution_started) executor.execution_finished.connect(on_execution_finished) executor.exception_occured.connect(on_exception_occured) with qtbot.waitSignal(signal=executor.execution_finished, timeout=100000): executor.run(graph) assert self.num_inputs == scene.dummy_data.model['number'] assert self.num_processed == scene.dummy_data.model['number'] assert self.started assert self.finished assert self.exception is None scene.remove_node(scene.dummy_data) # Create a reader and point it to a nonexistent path so that it raises an exception and # check that this exception has been processed byt the executor setattr(scene, 'read', scene.create_node(scene.registry.create('read'))) scene.create_connection(scene.read['output'][0], scene.pad['input'][0]) # Make sure the path is nonsense scene.read.model['path'] = '/dfasf/fsdafsdaf/asd/asf' scene.read.model['number'] = 10 graph = scene.get_simple_node_graphs()[0] executor.swallow_run_exceptions = True with qtbot.waitSignal(signal=executor.execution_finished): executor.run(graph) assert self.exception def test_get_gpu_splitting_models(qtbot, scene, executor): graph = scene.get_simple_node_graphs()[0] assert len(get_gpu_splitting_models(graph)) == 0 scene.clear_scene() scene.create_node(scene.registry.create('general_backproject')) graph = scene.get_simple_node_graphs()[0] assert len(get_gpu_splitting_models(graph)) == 1 ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414563.0 ufo-tofu-0.13.0/tofu/tests/test_flow_main.py0000664000175000017500000005116000000000000021340 0ustar00tomastomas00000000000000import glob import os import pathlib import pkg_resources import pytest import sys import xdg.BaseDirectory from PyQt5.QtWidgets import QFileDialog, QInputDialog, QMessageBox from tofu.flow.execution import UfoExecutor from tofu.flow.main import ApplicationWindow, get_filled_registry, GlobalExceptionHandler from tofu.flow.scene import UfoScene from tofu.flow.util import FlowError from tofu.tests.flow_util import add_nodes_to_scene @pytest.fixture(scope='function') def app_window(qtbot, scene): window = ApplicationWindow(scene) qtbot.addWidget(window) return window class TestApplicationWindow: def test_init(self, qtbot, app_window): assert app_window.ufo_scene assert app_window.executor def test_on_save(self, monkeypatch, app_window): def getSaveFileNameDefault(inst, header, path, fltr): return (os.path.join(path, 'flow.flow'), True) def getSaveFileName(inst, header, path, fltr): return (os.path.join('foo', 'bar', 'flow.flow'), True) # Don't actually write to disk monkeypatch.setattr(UfoScene, "save", lambda *args: None) # Default directory monkeypatch.setattr(QFileDialog, "getSaveFileName", getSaveFileNameDefault) app_window.on_save() directory = xdg.BaseDirectory.save_data_path('tofu', 'flows') assert os.path.exists(directory) assert app_window.last_dirs['scene'] == directory # When user picks a different directory it must be remembered monkeypatch.setattr(QFileDialog, "getSaveFileName", getSaveFileName) app_window.on_save() assert app_window.last_dirs['scene'] == os.path.join('foo', 'bar') # And used the next time monkeypatch.setattr(QFileDialog, "getSaveFileName", getSaveFileNameDefault) app_window.on_save() assert app_window.last_dirs['scene'] == os.path.join('foo', 'bar') def test_on_open(self, monkeypatch, app_window): def getOpenFileNameDefault(inst, header, path, fltr): return (os.path.join(path, 'flow.flow'), True) def getOpenFileName(inst, header, path, fltr): return (os.path.join('foo', 'bar', 'flow.flow'), True) # Don't actually read from disk monkeypatch.setattr(UfoScene, "load", lambda *args: None) # Default directory monkeypatch.setattr(QFileDialog, "getOpenFileName", getOpenFileNameDefault) app_window.on_open() directory = xdg.BaseDirectory.save_data_path('tofu', 'flows') if not os.path.exists(directory): directory = pathlib.Path.home() assert app_window.last_dirs['scene'] == directory # When user picks a different directory it must be remembered monkeypatch.setattr(QFileDialog, "getOpenFileName", getOpenFileName) app_window.on_open() assert app_window.last_dirs['scene'] == os.path.join('foo', 'bar') # And used the next time monkeypatch.setattr(QFileDialog, "getOpenFileName", getOpenFileNameDefault) app_window.on_open() assert app_window.last_dirs['scene'] == os.path.join('foo', 'bar') def test_on_exception_occured(self, qtbot, monkeypatch, app_window): def exec_(inst): self.message_shown = True self.message_shown = False monkeypatch.setattr(QMessageBox, "exec_", exec_) app_window.on_exception_occured('foo') assert self.message_shown def test_on_number_of_inputs_changed(self, qtbot, app_window): app_window.on_number_of_inputs_changed(123) assert app_window.progress_bar.maximum() == 123 def test_on_processed(self, qtbot, app_window): app_window.on_number_of_inputs_changed(100) app_window.on_processed(10) assert app_window.progress_bar.value() == 11 def test_on_nodes_duplicated(self, qtbot, app_window): node = add_nodes_to_scene(app_window.ufo_scene)[0] node.graphics_object.setSelected(True) app_window.ufo_scene.copy_nodes() nodes = list(app_window.ufo_scene.nodes.values()) assert nodes[0].graphics_object.pos().y() != nodes[1].graphics_object.pos().y() def test_on_selection_menu_about_to_show(self, qtbot, monkeypatch, app_window): # Nothing selected app_window.on_selection_menu_about_to_show() assert not app_window.edit_composite_action.isEnabled() assert not app_window.expand_composite_action.isEnabled() assert not app_window.export_composite_action.isEnabled() # Only non-composite nodes nodes = add_nodes_to_scene(app_window.ufo_scene, model_names=['read', 'average', 'null']) app_window.on_selection_menu_about_to_show() assert not app_window.edit_composite_action.isEnabled() assert not app_window.expand_composite_action.isEnabled() assert not app_window.export_composite_action.isEnabled() # One composite monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('cpm', True)) for i in range(2): nodes[i].graphics_object.setSelected(True) app_window.ufo_scene.create_composite() app_window.on_selection_menu_about_to_show() assert app_window.edit_composite_action.isEnabled() assert app_window.expand_composite_action.isEnabled() assert app_window.export_composite_action.isEnabled() # More composites monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('cpm_2', True)) app_window.ufo_scene.clearSelection() nodes[-1].graphics_object.setSelected(True) app_window.ufo_scene.create_composite() for node in app_window.ufo_scene.nodes.values(): node.graphics_object.setSelected(True) app_window.on_selection_menu_about_to_show() assert not app_window.edit_composite_action.isEnabled() assert app_window.expand_composite_action.isEnabled() assert not app_window.export_composite_action.isEnabled() def test_skip_action(self, qtbot, app_window): # No nodes selected, menu item must be disabled app_window.on_selection_menu_about_to_show() assert not app_window.skip_action.isEnabled() # Add some nodes, conect them and disable one nodes = add_nodes_to_scene(app_window.ufo_scene, model_names=['read', 'average', 'null']) app_window.ufo_scene.create_connection(nodes[0]['output'][0], nodes[1]['input'][0]) app_window.ufo_scene.create_connection(nodes[1]['output'][0], nodes[2]['input'][0]) average = nodes[1] average.graphics_object.setSelected(True) app_window.on_selection_menu_about_to_show() # Nodes selected, menu item must be enabled assert app_window.skip_action.isEnabled() def test_on_edit_composite(self, qtbot, scene_with_composite, app_window): app_window.ufo_scene = scene_with_composite node = add_nodes_to_scene(app_window.ufo_scene, model_names=['cpm'])[0] node.graphics_object.setSelected(True) app_window.on_edit_composite() qtbot.addWidget(node.model._other_view) assert node.model.is_editing def test_on_create_composite(self, qtbot, monkeypatch, scene_with_composite, app_window): monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('cpm', True)) nodes = add_nodes_to_scene(app_window.ufo_scene, model_names=['read', 'pad']) # Link a model to the slider model = nodes[0].model view_item = model._view._properties['number'].view_item app_window.on_item_focus_in(view_item, 'number', 'Read', model) # Create a composite for node in app_window.ufo_scene.nodes.values(): node.graphics_object.setSelected(True) app_window.on_create_composite() composite = list(app_window.ufo_scene.nodes.values())[0].model slider_model, prop_name = app_window.run_slider_key assert slider_model == composite.get_model_from_path(['Read']) assert prop_name == 'number' def test_on_item_focus_in(self, qtbot, app_window, scene_with_composite): read, pad = add_nodes_to_scene(app_window.ufo_scene, model_names=['read', 'pad']) # Simple node model = read.model view_item = model._view._properties['number'].view_item app_window.on_item_focus_in(view_item, 'number', model.caption, model) slider_model, prop_name = app_window.run_slider_key assert slider_model == model assert prop_name == 'number' app_window.fix_run_slider.setChecked(False) model = pad.model view_item = model._view._properties['y'].view_item app_window.on_item_focus_in(view_item, 'y', model.caption, model) slider_model, prop_name = app_window.run_slider_key assert slider_model == model assert prop_name == 'y' # Focus gets another widget, but the run slider must be linked to the one focused before the # fix option is checked app_window.fix_run_slider.setChecked(True) model = read.model view_item = model._view._properties['number'].view_item app_window.on_item_focus_in(view_item, 'number', model.caption, model) slider_model, prop_name = app_window.run_slider_key assert slider_model == pad.model assert prop_name == 'y' def test_on_node_deleted(self, qtbot, monkeypatch, app_window, scene_with_composite): app_window.ufo_scene = scene_with_composite cpm, cpm_2, read = add_nodes_to_scene(app_window.ufo_scene, model_names=['cpm', 'cpm', 'read']) # Simple node model = read.model view_item = model._view._properties['number'].view_item app_window.on_item_focus_in(view_item, 'number', model.caption, model) # remove in the scene doesn't seem to emit the signal, so use the window app_window.on_node_deleted(read) slider_model, prop_name = app_window.run_slider_key assert slider_model is None assert prop_name is None # Composite node model = cpm.model.get_model_from_path(['Read']) view_item = model._view._properties['number'].view_item app_window.on_item_focus_in(view_item, 'number', 'cpm->Read', model) # remove in the scene doesn't seem to emit the signal, so use the window app_window.on_node_deleted(cpm) slider_model, prop_name = app_window.run_slider_key assert slider_model is None assert prop_name is None # Nested composite node cpm_2.graphics_object.setSelected(True) monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('parent', True)) app_window.on_create_composite() node = app_window.ufo_scene.selected_nodes()[0] model = node.model.get_model_from_path(['cpm 2', 'Read']) view_item = model._view._properties['number'].view_item app_window.on_item_focus_in(view_item, 'number', 'parent->cpm 2->Read', model) # remove in the scene doesn't seem to emit the signal, so use the window app_window.on_node_deleted(node) slider_model, prop_name = app_window.run_slider_key assert slider_model is None assert prop_name is None def test_on_expand_composite(self, qtbot, scene_with_composite, app_window): app_window.ufo_scene = scene_with_composite nodes = add_nodes_to_scene(app_window.ufo_scene, model_names=['cpm', 'cpm']) for node in nodes: node.graphics_object.setSelected(True) app_window.on_expand_composite() captions = {node.model.caption for node in app_window.ufo_scene.nodes.values()} assert captions == {'Read 2', 'Pad 2', 'Read', 'Pad'} # Run slider # Create yet another composite and select a reader inside node = add_nodes_to_scene(app_window.ufo_scene, model_names=['cpm'])[0] model = node.model.get_model_from_path(['Read']) view_item = model._view._properties['number'].view_item app_window.on_item_focus_in(view_item, 'number', 'cpm->Read', model) node.graphics_object.setSelected(True) app_window.on_expand_composite() # After expansion, the reader's index will be 3 slider_model, prop_name = app_window.run_slider_key assert slider_model.caption == 'Read 3' assert prop_name == 'number' def test_on_import_composites(self, qtbot, monkeypatch, app_window): tests_directory = pkg_resources.resource_filename(__name__, 'composites') def getOpenFileNamesDefault(inst, header, path, fltr): # Let's pretend there are files file_names = [os.path.join(path, 'foo.cm')] return (file_names, True) def getOpenFileNames(inst, header, path, fltr): file_names = sorted(glob.glob(os.path.join(tests_directory, '*.cm'))) return (file_names, True) def exec_(inst): self.message_shown = True monkeypatch.setattr(QMessageBox, "exec_", exec_) # Nothing opened, nothing happens monkeypatch.setattr(QFileDialog, "getOpenFileNames", lambda *args: ([], True)) app_window.on_import_composites() # Default directory monkeypatch.setattr(QFileDialog, "getOpenFileNames", getOpenFileNamesDefault) directory = xdg.BaseDirectory.save_data_path('tofu', 'flows', 'composites') if not os.path.exists(directory): directory = pathlib.Path.home() try: app_window.on_import_composites() except FileNotFoundError: # We don't care if there are files, just the last_dirs setting is important pass assert app_window.last_dirs['composite'] == directory # It's possible to open more than one at a time monkeypatch.setattr(QFileDialog, "getOpenFileNames", getOpenFileNames) app_window.on_import_composites() assert 'cmp' in app_window.ufo_scene.registry.registered_model_creators() assert 'cmp_2' in app_window.ufo_scene.registry.registered_model_creators() # When user picks a different directory it must be remembered assert app_window.last_dirs['composite'] == tests_directory # And used the next time self.message_shown = False app_window.on_import_composites() assert app_window.last_dirs['composite'] == tests_directory # Message about overwriting models must be shown assert self.message_shown def test_on_export_composite(self, qtbot, monkeypatch, scene_with_composite, app_window): tests_directory = pkg_resources.resource_filename(__name__, 'composites') def getSaveFileNameDefault(inst, header, path, fltr): return (os.path.join(path, self.file_name), True) def getSaveFileName(inst, header, path, fltr): return (os.path.join(tests_directory, self.file_name), True) def export_composite(inst, node, file_name): self.final_file_name = file_name # Nothing selected, must silently pass app_window.on_export_composite() # Make a composite node app_window.ufo_scene = scene_with_composite node = add_nodes_to_scene(app_window.ufo_scene, model_names=['cpm'])[0] node.graphics_object.setSelected(True) monkeypatch.setattr(ApplicationWindow, "export_composite", export_composite) # Default directory monkeypatch.setattr(QFileDialog, "getSaveFileName", getSaveFileNameDefault) self.file_name = 'composite' directory = xdg.BaseDirectory.save_data_path('tofu', 'flows', 'composites') app_window.on_export_composite() assert self.final_file_name.endswith('.cm') and not self.final_file_name.endswith('.cm.cm') assert os.path.exists(directory) assert app_window.last_dirs['composite'] == directory # When user picks a different directory it must be remembered monkeypatch.setattr(QFileDialog, "getSaveFileName", getSaveFileName) app_window.on_export_composite() assert self.final_file_name.endswith('.cm') and not self.final_file_name.endswith('.cm.cm') assert app_window.last_dirs['composite'] == tests_directory # And used the next time monkeypatch.setattr(QFileDialog, "getSaveFileName", getSaveFileNameDefault) app_window.on_export_composite() assert app_window.last_dirs['composite'] == tests_directory # .cm must not be added if it's present in the file name self.file_name = 'composite.cm' app_window.on_export_composite() assert self.final_file_name.endswith('.cm') and not self.final_file_name.endswith('.cm.cm') def test_on_reset_view(self, qtbot, app_window): app_window.flow_view.scale_up() app_window.on_reset_view() assert app_window.flow_view.transform().m11() == pytest.approx(1) assert app_window.flow_view.transform().m22() == pytest.approx(1) def test_on_property_links_action(self, qtbot, app_window): qtbot.addWidget(app_window.property_links_widget) app_window.property_links_widget.show() assert app_window.property_links_widget.isVisible() def test_on_run(self, qtbot, monkeypatch, app_window): def executor_run(inst, graph): self.ran = True monkeypatch.setattr(UfoExecutor, "run", executor_run) nodes = add_nodes_to_scene(app_window.ufo_scene, model_names=['read', 'read', 'flat_field_correct', 'null']) i_0, i_1, ffc, null = nodes # No connections -> many graphs with pytest.raises(FlowError): app_window.on_run() assert app_window.run_action.isEnabled() app_window.ufo_scene.create_connection(i_0['output'][0], ffc['input'][0]) app_window.ufo_scene.create_connection(i_1['output'][0], ffc['input'][1]) app_window.ufo_scene.create_connection(ffc['output'][0], null['input'][0]) # One ffc input is not connected with pytest.raises(FlowError): app_window.on_run() assert app_window.run_action.isEnabled() # All connections present -> must run i_2 = add_nodes_to_scene(app_window.ufo_scene, model_names=['read'])[0] app_window.ufo_scene.create_connection(i_2['output'][0], ffc['input'][2]) self.ran = False app_window.on_run() assert self.ran assert not app_window.run_action.isEnabled() def test_on_save_json(self, qtbot, monkeypatch, app_window): import gi gi.require_version('Ufo', '0.0') from gi.repository import Ufo # Don't pop up file dialog def getSaveFileName(inst, header, path, fltr): return (os.path.join(path, 'flow.json'), True) monkeypatch.setattr(QFileDialog, "getSaveFileName", getSaveFileName) # Don't actually write to disk monkeypatch.setattr(Ufo.TaskGraph, "save_to_json", lambda *args: None) # Empty scene with pytest.raises(FlowError): app_window.on_save_json() # Wrong data types app_window.ufo_scene.clear_scene() read, mem_out, viewer = add_nodes_to_scene(app_window.ufo_scene, model_names=['read', 'memory_out', 'image_viewer']) app_window.ufo_scene.create_connection(read['output'][0], mem_out['input'][0]) app_window.ufo_scene.create_connection(mem_out['output'][0], viewer['input'][0]) with pytest.raises(FlowError): app_window.on_save_json() # Not connected app_window.ufo_scene.clear_scene() read, null = add_nodes_to_scene(app_window.ufo_scene, model_names=['read', 'null']) with pytest.raises(FlowError): app_window.on_save_json() # This must pass app_window.ufo_scene.create_connection(read['output'][0], null['input'][0]) app_window.on_save_json() def test_on_execution_finished(self, qtbot, app_window): app_window.run_action.setEnabled(False) app_window.progress_bar.setMaximum(100) app_window.progress_bar.setValue(50) app_window.on_execution_finished() assert app_window.progress_bar.value() == -1 assert app_window.run_action.isEnabled() def test_global_exception_handler(qtbot): handler = GlobalExceptionHandler() def slot(text): handler.called_signal = True handler.exception_occured.connect(slot) handler.called_signal = False try: raise FlowError('foo') except: # Call the hook explicitly, sys.excinfo = ... doesn't seem to have effect handler.excepthook(*sys.exc_info()) assert handler.called_signal def test_get_filled_registry(): registry = get_filled_registry() assert 'read' in registry.registered_model_creators() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414563.0 ufo-tofu-0.13.0/tofu/tests/test_flow_models.py0000664000175000017500000016432200000000000021704 0ustar00tomastomas00000000000000import pytest import numpy as np from PyQt5.QtCore import Qt from PyQt5.QtGui import QValidator from PyQt5.QtWidgets import QFileDialog, QInputDialog, QLineEdit from tofu.flow.main import get_filled_registry from tofu.flow.models import (CheckBoxViewItem, ComboBoxViewItem, get_composite_model_class, get_composite_model_classes, get_composite_model_classes_from_json, get_ufo_model_class, get_ufo_model_classes, ImageViewerModel, IntQLineEditViewItem, MultiPropertyView, NumberQLineEditViewItem, PropertyModel, PropertyView, QLineEditViewItem, RangeQLineEditViewItem, UfoGeneralBackprojectModel, UfoIntValidator, UfoMemoryOutModel, UfoModelError, UfoRangeValidator, UfoReadModel, UfoRetrievePhaseModel, UfoModel, UfoTaskModel, UfoVaryingInputModel, UfoWriteModel, ViewItem) from tofu.flow.scene import UfoScene from tofu.flow.util import CompositeConnection, MODEL_ROLE, PROPERTY_ROLE from tofu.tests.flow_util import populate_link_model def check_property_changed_emit(qtbot, view_item, expected, gui_func, gui_args, gui_kwargs=None, show=False): if gui_kwargs is None: gui_kwargs = {} def on_changed(vit): vit.change_called = True view_item.change_called = False view_item.property_changed.connect(on_changed) qtbot.addWidget(view_item.widget) if show: # without show the mouse click for QCheckBox doesn't happen, bug in pytest-qt? view_item.widget.show() # Store old value for later check of programmatic change old_value = view_item.get() # Simulate user interaction gui_func(*gui_args, **gui_kwargs) # Value must have been set assert view_item.get() == expected # Signal must have been emitted assert view_item.change_called view_item.change_called = False view_item.set(old_value) # Signal must be emitted only on user interacion, not programmatic access assert not view_item.change_called def make_properties(): return { 'int': [IntQLineEditViewItem(0, 100, default_value=10), True], 'float': [NumberQLineEditViewItem(0, 100, default_value=0), True], 'string': [QLineEditViewItem(default_value='foo'), True], 'range': [RangeQLineEditViewItem(default_value=[1, 2, 3], num_items=3, is_float=True), True], 'choices': [ComboBoxViewItem(['a', 'b', 'c']), True], 'check': [CheckBoxViewItem(checked=True), True] } class DummyPropertyModel(PropertyModel): def make_properties(self): return make_properties() @pytest.fixture(scope='function') def property_view(): return PropertyView(properties=make_properties(), scrollable=False) @pytest.fixture(scope='function') def multi_property_view(nodes): groups = {nodes['cpm'].model: True, nodes['read'].model: False} return MultiPropertyView(groups=groups) def make_composite_model_class(nodes, name='foobar'): # We want to connect cpm:Pad to average, thus we need to get the outside port of the cpm # composite which corresponds to the pad model pad_index = nodes['cpm'].model.get_outside_port('Pad', 'output', 0)[1] connections = [CompositeConnection('cpm', pad_index, 'Average', 0)] state = [('cpm', nodes['cpm'].model.save(), True, None), ('average', nodes['average'].model.save(), True, None)] return get_composite_model_class(name, state, connections) def create_scene(qtbot, registry): scene = UfoScene(registry=registry) if scene.views(): for view in scene.views(): qtbot.addWidget(view) return scene def make_composite_node_in_scene(qtbot, nodes): model_cls = make_composite_model_class(nodes) registry = get_filled_registry() # Register both composites so that we can create them registry.register_model(nodes['cpm'].model.__class__, category='Composite', registry=registry) registry.register_model(model_cls, category='Composite', registry=registry) scene = create_scene(qtbot, registry) node = scene.create_node(model_cls) return (scene, node) @pytest.fixture(scope='function') def composite_model(nodes): # Make sure 'cpm', which is inside this composite model, has been registered registry = nodes['cpm'].model._registry model_cls = make_composite_model_class(nodes) registry.register_model(model_cls, category='Composite', registry=registry) return model_cls(registry=registry) @pytest.fixture(scope='function') def general_backproject(qtbot): model = UfoGeneralBackprojectModel() qtbot.addWidget(model.embedded_widget()) return model @pytest.fixture(scope='function') def read_model(qtbot): model = UfoReadModel() qtbot.addWidget(model.embedded_widget()) return model @pytest.fixture(scope='function') def write_model(qtbot): model = UfoWriteModel() qtbot.addWidget(model.embedded_widget()) return model @pytest.fixture(scope='function') def memory_out_model(qtbot): model = UfoMemoryOutModel() model['width'] = 100 model['height'] = 100 qtbot.addWidget(model.embedded_widget()) return model @pytest.fixture(scope='function') def image_viewer_model(qtbot): model = ImageViewerModel() qtbot.addWidget(model.embedded_widget()) return model def test_ufo_int_validator(): validator = UfoIntValidator(-10, 10) def check(input_str, expected): assert validator.validate(input_str, -1)[0] == expected check('0', QValidator.Acceptable) check('1', QValidator.Acceptable) check('-1', QValidator.Acceptable) check('101', QValidator.Intermediate) check('-101', QValidator.Intermediate) check('-', QValidator.Intermediate) check('1.', QValidator.Invalid) check('1.0', QValidator.Invalid) check('asdf', QValidator.Invalid) validator = UfoIntValidator(3, 10) check('1', QValidator.Intermediate) def test_ufo_range_validator(): def check(validator, input_str, expected): assert validator.validate(input_str, len(input_str))[0] == expected # Integer validator = UfoRangeValidator(num_items=3, is_float=False) check(validator, ',,', QValidator.Intermediate) check(validator, ' ,,', QValidator.Intermediate) check(validator, '1,1,', QValidator.Intermediate) check(validator, ',1,', QValidator.Intermediate) check(validator, '1,-2,3', QValidator.Acceptable) check(validator, '1,1.0,1', QValidator.Invalid) check(validator, '-1,s,-1', QValidator.Invalid) check(validator, '1,1,1,1', QValidator.Invalid) check(validator, '1,1,1,', QValidator.Invalid) # Float validator = UfoRangeValidator(num_items=3, is_float=True) check(validator, ',,', QValidator.Intermediate) check(validator, ' ,,', QValidator.Intermediate) check(validator, '.,,', QValidator.Intermediate) check(validator, '.e,,', QValidator.Intermediate) check(validator, '.e-,,', QValidator.Intermediate) check(validator, '.e+,,', QValidator.Intermediate) check(validator, '1.0e,,', QValidator.Intermediate) check(validator, '1.0e+,,', QValidator.Intermediate) check(validator, '1.0e-,,', QValidator.Intermediate) check(validator, '1e,,', QValidator.Intermediate) check(validator, '1e+,,', QValidator.Intermediate) check(validator, '1e-,,', QValidator.Intermediate) check(validator, '.1e,,', QValidator.Intermediate) check(validator, '.1e+,,', QValidator.Intermediate) check(validator, '.1e-,,', QValidator.Intermediate) check(validator, '1,1,1', QValidator.Acceptable) check(validator, '-1,1,1', QValidator.Acceptable) check(validator, '1.,1.,1', QValidator.Acceptable) check(validator, '-1.,1.,1', QValidator.Acceptable) check(validator, '1.0e1,1.0,1', QValidator.Acceptable) check(validator, '1.0e+1,1.0,1', QValidator.Acceptable) check(validator, '1.0e-1,1.0,1', QValidator.Acceptable) check(validator, '.1,1.0,1', QValidator.Acceptable) check(validator, '.1e-1,1.0,1', QValidator.Acceptable) check(validator, '.1e+1,1.0,1', QValidator.Acceptable) check(validator, '.1e1,1.0,1', QValidator.Acceptable) check(validator, 'e,,', QValidator.Invalid) check(validator, 'e.,,', QValidator.Invalid) check(validator, '+e,,', QValidator.Invalid) check(validator, '-e,,', QValidator.Invalid) check(validator, '+e.,,', QValidator.Invalid) check(validator, '-e.,,', QValidator.Invalid) check(validator, '1+,,', QValidator.Invalid) check(validator, '1-,,', QValidator.Invalid) check(validator, 'gfd,1,3', QValidator.Invalid) def test_view_item_init(qtbot): def get(inst): return inst.widget.text() def set(inst, value): inst.widget.setText(value) def on_changed(vit): vit.change_called = True ViewItem.get = get ViewItem.set = set edit = QLineEdit() qtbot.addWidget(edit) vit = ViewItem(edit, default_value='foo', tooltip='tooltip') edit.textEdited.connect(vit.on_changed) assert vit.widget.toolTip() == 'tooltip' assert vit.widget.text() == 'foo' check_property_changed_emit(qtbot, vit, 'fooa', qtbot.keyClick, (edit, 'a')) def test_check_box_view_item(qtbot): assert CheckBoxViewItem(checked=True).get() vit = CheckBoxViewItem(checked=False, tooltip='tooltip') assert vit.widget.toolTip() == 'tooltip' assert not vit.get() check_property_changed_emit(qtbot, vit, True, qtbot.mouseClick, (vit.widget, Qt.LeftButton), show=True) def test_combo_box_view_item(qtbot): items = ['a', 'b', 'c'] vit = ComboBoxViewItem(items, default_value='b', tooltip='tooltip') assert vit.widget.toolTip() == 'tooltip' assert vit.get() == 'b' check_property_changed_emit(qtbot, vit, 'c', qtbot.keyClick, (vit.widget, 'c')) def test_qline_edit_view_item(qtbot): vit = QLineEditViewItem(default_value='foo', tooltip='tooltip') assert vit.widget.toolTip() == 'tooltip' assert vit.get() == 'foo' check_property_changed_emit(qtbot, vit, 'fooc', qtbot.keyClick, (vit.widget, 'c')) def test_number_qline_edit_view_item(qtbot): with pytest.raises(ValueError): NumberQLineEditViewItem(-100, 100, default_value=1000) with pytest.raises(ValueError): NumberQLineEditViewItem(-100, 100, default_value=-1000) vit = NumberQLineEditViewItem(-100., 100., default_value=0., tooltip='tooltip') assert vit.widget.toolTip().startswith('tooltip') assert vit.get() == 0 # is 0.0, after key click "1" will be 0.01 check_property_changed_emit(qtbot, vit, 0.01, qtbot.keyClick, (vit.widget, '1')) def test_int_qline_edit_view_item(qtbot): with pytest.raises(ValueError): IntQLineEditViewItem(-100, 100, default_value=1000) with pytest.raises(ValueError): IntQLineEditViewItem(-100, 100, default_value=-1000) vit = IntQLineEditViewItem(-100, 100, default_value=0, tooltip='tooltip') assert vit.widget.toolTip().startswith('tooltip') assert vit.get() == 0 # is 0, after key click "1" will be 01, thus 1 check_property_changed_emit(qtbot, vit, 1, qtbot.keyClick, (vit.widget, '1')) def test_range_edit_view_item(qtbot): vit = RangeQLineEditViewItem(default_value=[1.0, 2.0, 3.0], tooltip='tooltip') assert vit.widget.toolTip().startswith('tooltip') assert vit.get() == [1.0, 2.0, 3.0] # Last is 3.0, after key click "1" will be 3.01 check_property_changed_emit(qtbot, vit, [1.0, 2.0, 3.01], qtbot.keyClick, (vit.widget, '1')) class TestPropertyView: def test_init(self, qtbot, property_view): assert len(property_view.property_names) > 0 # Defaults must pass PropertyView() def test_get_property(self, qtbot, property_view): assert property_view.get_property('int') == 10 def test_set_property(self, qtbot, property_view): property_view.set_property('int', 50) assert property_view.get_property('int') == 50 def test_on_property_changed(self, qtbot, property_view): widget = property_view._properties['int'].view_item.widget qtbot.addWidget(widget) qtbot.keyClick(widget, '0') assert property_view.get_property('int') == 100 def test_is_property_visible(self, qtbot, property_view): assert property_view.is_property_visible('int') def test_set_property_visible(self, qtbot, property_view): visible = not property_view.is_property_visible('int') property_view.set_property_visible('int', visible) assert property_view.is_property_visible('int') == visible def test_restore_properties(self, qtbot, property_view): props = property_view.export_properties() property_view.set_property('int', props['int'][0] + 1) property_view.restore_properties(props) assert property_view.get_property('int') == props['int'][0] def test_export_properties(self, qtbot, property_view): props = property_view.export_properties() assert 'int' in props assert props['int'][0] == property_view.get_property('int') assert props['int'][1] == property_view.is_property_visible('int') class TestMultiPropertyView: def test_init(self, qtbot, multi_property_view): assert len(list(iter(multi_property_view))) == 2 def test_getitem(self, qtbot, multi_property_view, nodes): assert multi_property_view['cpm'] == nodes['cpm'].model def test_contains(self, qtbot, multi_property_view): assert 'cpm' in multi_property_view assert 'foo' not in multi_property_view def test_iter(self, qtbot, multi_property_view): assert set(list(iter(multi_property_view))) == set(['cpm', 'Read']) def test_export_groups(self, qtbot, multi_property_view): state = multi_property_view.export_groups() multi_property_view.set_group_visible('Read', False) assert state['Read']['model']['caption'] == 'Read' assert not state['Read']['visible'] def test_restore_groups(self, qtbot, multi_property_view, nodes): multi_property_view['Read']['number'] = 100 state = multi_property_view.export_groups() multi_property_view['Read']['number'] = 1000 multi_property_view.restore_groups(state) assert multi_property_view['Read']['number'] def test_set_group_visible(self, qtbot, multi_property_view): visible = not multi_property_view.is_group_visible('cpm') multi_property_view.set_group_visible('cpm', visible) assert multi_property_view.is_group_visible('cpm') == visible def test_is_group_visible(self, qtbot, multi_property_view): assert multi_property_view.is_group_visible('cpm') assert not multi_property_view.is_group_visible('Read') class TestUfoModel: def test_init(self): model = UfoModel() assert model.caption == model.base_caption def test_restore(self): model = UfoModel() state = {'caption': 'foo'} old_caption = model.caption model.restore(state, restore_caption=False) assert model.caption == old_caption model.restore(state, restore_caption=True) assert model.caption == 'foo' # 'caption' not in state, the old one must be preserved model = UfoModel() old_caption = model.caption model.restore({}, restore_caption=True) assert model.caption == old_caption def save(self): model = UfoModel() model.caption = 'foo' assert model.save()['caption'] == 'foo' class TestPropertyModel: def test_init(self, qtbot): PropertyModel() model = DummyPropertyModel() # make_properties must be called assert set(model.properties) == set(make_properties().keys()) def test_getitem(self, qtbot): model = DummyPropertyModel() model['int'] with pytest.raises(KeyError): model['foo'] def test_setitem(self, qtbot): model = DummyPropertyModel() model['int'] = 132 assert model['int'] == 132 def test_contains(self, qtbot): model = DummyPropertyModel() assert 'int' in model assert 'foo' not in model def test_iter(self, qtbot): model = DummyPropertyModel() assert set(iter(model)) == set(make_properties().keys()) def test_on_property_changed(self, qtbot): def callback(model, name, value): self.called_name = name self.called_value = value model = DummyPropertyModel() model.property_changed.connect(callback) widget = model._view._properties['int'].view_item.widget qtbot.addWidget(widget) qtbot.keyClick(widget, '0') assert self.called_value == model['int'] assert self.called_name == 'int' def test_make_properties(self, qtbot): props = DummyPropertyModel().make_properties() assert props.keys() == make_properties().keys() assert PropertyModel().make_properties() == {} def test_copy_properties(self, qtbot): model = DummyPropertyModel() model['int'] = 123 visible = not model._view.is_property_visible('int') model._view.set_property_visible('int', visible) properties = model.copy_properties() # It has to be a deep copy, so changing the model properties cannot affect the copy model['int'] = 12 model._view.set_property_visible('int', not visible) assert properties['int'][0].get() == 123 assert properties['int'][1] == visible def test_embedded_widget(self, qtbot): assert PropertyModel().embedded_widget() is None assert isinstance(DummyPropertyModel().embedded_widget(), PropertyView) def test_restore(self, qtbot): model = DummyPropertyModel() state = model.save() old_value = model['int'] old_caption = model.caption visible = not model._view.is_property_visible('int') model['int'] = old_value + 1 model._view.set_property_visible('int', visible) model.caption = 'Foo' model.restore(state, restore_caption=False) assert model['int'] == old_value assert model._view.is_property_visible('int') == (not visible) assert model.caption == 'Foo' model.restore(state, restore_caption=True) assert model.caption == old_caption def test_save(self, qtbot): model = DummyPropertyModel() old_value = model['int'] visible = not model._view.is_property_visible('int') model['int'] = old_value + 1 model._view.set_property_visible('int', visible) model.caption = 'Foo' state = model.save() assert state['properties']['int'][0] == old_value + 1 assert state['properties']['int'][1] == visible assert state['caption'] == 'Foo' class TestUfoTaskModel: def test_init(self, qtbot): model = UfoTaskModel('flat-field-correct') assert model.properties # A task doesn't need any special treatment by default assert not model.expects_multiple_inputs assert not model.can_split_gpu_work assert not model.needs_fixed_scheduler def test_make_properties(self, qtbot): model = UfoTaskModel('flat-field-correct') # Config takes effect assert not model._view.is_property_visible('dark-scale') def test_create_ufo_task(self, qtbot): model = UfoTaskModel('flat-field-correct') model['dark-scale'] = 12.3 task = model.create_ufo_task() assert task.props.dark_scale == pytest.approx(12.3) def test_uses_gpu(self, qtbot): model = UfoTaskModel('flat-field-correct') assert model.uses_gpu model = UfoTaskModel('read') assert not model.uses_gpu def test_get_ufo_model_class(qtbot): # flat correction is a fairly complicated task to test task_name = 'flat-field-correct' model_cls = get_ufo_model_class(task_name) # Model class attributes assert model_cls.name == 'flat_field_correct' model = model_cls() # Model instance attributes assert model.num_ports['input'] == 3 assert model.num_ports['output'] == 1 assert model.port_caption['input'][0] == 'radios' assert model.port_caption['input'][1] == 'darks' assert model.port_caption['input'][2] == 'flats' assert model.port_caption['output'][0] == '' class TestBaseCompositeModel: def test_init(self, qtbot, monkeypatch, composite_model, scene): # cpm has 1 input and 2 outputs (read and pad are not connected) and average has 1 input and # 1 output, but cpm is connected with average, which reduces both port types by 1 assert composite_model.num_ports['input'] == 1 assert composite_model.num_ports['output'] == 2 for port_type in ['input', 'output']: for i in range(composite_model.num_ports[port_type]): submodel, j = composite_model.get_model_and_port_index(port_type, i) subcaption = submodel.port_caption[port_type][j] if subcaption: subcaption = ':' + subcaption assert (composite_model.port_caption[port_type][i] == submodel.caption + subcaption) assert composite_model._view # num-inputs must take effect monkeypatch.setattr(QInputDialog, "getInt", lambda *args, **kwargs: (2, True)) monkeypatch.setattr(QInputDialog, "getText", lambda *args, **kwargs: ('with-pr', True)) node = scene.create_node(scene.registry.create('retrieve_phase')) node.graphics_object.setSelected(True) node = scene.create_composite() assert node.model.get_model_from_path(['Retrieve Phase']).num_ports['input'] == 2 # and it must not affect default registry creators kwargs = scene.registry.registered_model_creators()['retrieve_phase'][1] assert 'num_inputs' not in kwargs def test_getitem(self, qtbot, composite_model, nodes): assert composite_model['cpm'] assert composite_model['Average'] with pytest.raises(KeyError): composite_model['foo'] def test_contains(self, qtbot, composite_model, nodes): assert 'cpm' in composite_model assert 'foo' not in composite_model def test_iter(self, qtbot, composite_model): assert set(list(iter(composite_model))) == set(['cpm', 'Average']) def test_get_descendant_graph(self, qtbot, monkeypatch, composite_model, nodes): graph = composite_model.get_descendant_graph() cpm = composite_model['cpm'] assert (composite_model, cpm) in graph.edges assert (composite_model, composite_model['Average']) in graph.edges assert (cpm, cpm['Read']) in graph.edges assert (cpm, cpm['Pad']) in graph.edges with pytest.raises(ValueError): composite_model.get_descendant_graph(in_subwindow=True) # Subwindow editing composite_model.edit_in_window() qtbot.addWidget(composite_model._other_view) graph = composite_model.get_descendant_graph(in_subwindow=True) cpm = composite_model._window_nodes['cpm'].model average = composite_model._window_nodes['Average'].model assert (composite_model, cpm) in graph.edges assert (composite_model, average) in graph.edges assert (cpm, cpm['Read']) in graph.edges assert (cpm, cpm['Pad']) in graph.edges composite_model._other_view.close() # Create outer composite with foobar inside, get_descendant_graph with in_subwindow=True # when outer is being edited and foobar not must return outer subwindow models and foobar's # internal models scene = create_scene(qtbot, composite_model._registry) inner = scene.create_node(composite_model.__class__) monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('outer', True)) inner.graphics_object.setSelected(True) outer = scene.create_composite().model outer.edit_in_window() qtbot.addWidget(outer._other_view) graph = outer.get_descendant_graph(in_subwindow=True) inner = outer._window_nodes['foobar'].model assert (outer, inner) in graph.edges assert (inner, inner['Average']) in graph.edges assert (inner['cpm'], inner['cpm']['Read']) in graph.edges outer._other_view.close() def test_contains_path(self, qtbot, composite_model, nodes): assert composite_model.contains_path(['Average']) assert composite_model.contains_path(['cpm']) assert composite_model.contains_path(['cpm', 'Read']) assert not composite_model.contains_path(['cpm', 'Read 2']) assert not composite_model.contains_path(['foo']) def test_get_model_from_path(self, qtbot, composite_model, nodes): assert composite_model.get_model_from_path(['cpm', 'Read']) with pytest.raises(KeyError): composite_model.get_model_from_path(['foo']) def test_is_model_inside(self, qtbot, composite_model, nodes): model = composite_model.get_model_from_path(['cpm']) assert composite_model.is_model_inside(model) model = composite_model.get_model_from_path(['cpm', 'Read']) assert composite_model.is_model_inside(model) assert not composite_model.is_model_inside(nodes['read_2'].model) def test_get_path_from_model(self, qtbot, composite_model, nodes): cpm = composite_model['cpm'] path = composite_model.get_path_from_model(cpm) assert path == [composite_model, cpm] path = composite_model.get_path_from_model(cpm['Read']) assert path == [composite_model, cpm, cpm['Read']] model = composite_model['cpm']['Read'] path = composite_model.get_path_from_model(model) assert path == [composite_model, cpm, model] with pytest.raises(KeyError): composite_model.get_path_from_model(nodes['read_2'].model) def test_leaf_paths(self, qtbot, composite_model, nodes): leaves = composite_model.get_leaf_paths(in_subwindow=False) cpm = composite_model['cpm'] assert len(leaves) == 3 assert [composite_model, cpm, cpm['Read']] in leaves assert [composite_model, cpm, cpm['Pad']] in leaves assert [composite_model, composite_model['Average']] in leaves def test_set_property_links_model(self, qtbot, link_model, composite_model): composite_model.property_links_model = link_model assert composite_model.property_links_model == link_model # The property links model must be set also for children assert composite_model['cpm'].property_links_model == link_model def test_get_outside_port(self, qtbot, composite_model): # There is one input corresponding to cpm's pad model cpm = composite_model.get_model_from_path(['cpm']) pad_index = cpm.get_outside_port('Pad', 'input', 0)[1] composite_model.get_outside_port('cpm', 'input', pad_index) # and two outputs: cpm's read and average read_index = cpm.get_outside_port('Read', 'output', 0)[1] composite_model.get_outside_port('cpm', 'output', read_index) composite_model.get_outside_port('Average', 'output', 0) def test_get_model_and_port_index(self, qtbot, composite_model): model, index = composite_model.get_model_and_port_index('input', 0) cpm = composite_model.get_model_from_path(['cpm']) # There is only one input: Pad. Get it's internal cpm's index and compare with what the # outer composite object gives. pad_index = cpm.get_outside_port('Pad', 'input', 0)[1] assert model == cpm assert index == pad_index # There are two output ports, one from cpm's read model and one from average average = composite_model.get_model_from_path(['Average']) # Get read index from the cpm inside the composite_model and not from the cpm in the 'nodes' # fixsture because those are not the same instance and the read output index might be # different in those two instances because the ports are dictionaries read_index = cpm.get_outside_port('Read', 'output', 0)[1] outputs = [composite_model.get_model_and_port_index('output', 0)] outputs.append(composite_model.get_model_and_port_index('output', 1)) assert (cpm, read_index) in outputs assert (average, 0) in outputs def test_embedded_widget(self, qtbot, composite_model): assert isinstance(composite_model.embedded_widget(), MultiPropertyView) def test_restore(self, qtbot, composite_model): state = composite_model.save() old_value = composite_model['cpm']['Pad']['width'] old_caption = composite_model.caption visible = not composite_model._view.is_group_visible('cpm') composite_model['cpm']['Pad']['width'] = old_value + 1 composite_model._view.set_group_visible('cpm', visible) composite_model.caption = 'Foo' composite_model.restore(state, restore_caption=False) assert composite_model['cpm']['Pad']['width'] == old_value assert composite_model._view.is_group_visible('cpm') == (not visible) assert composite_model.caption == 'Foo' composite_model.restore(state, restore_caption=True) assert composite_model.caption == old_caption conn = composite_model._connections[0] assert [[conn.from_unique_name, conn.from_port_index, conn.to_unique_name, conn.to_port_index]] == state['connections'] def test_restore_links(self, qtbot, nodes): def check_links(node, link_model): assert link_model.rowCount() == 1 assert link_model.columnCount() == 3 assert link_model.find_items((node.model['cpm']['Read'], 'number'), (MODEL_ROLE, PROPERTY_ROLE)) assert link_model.find_items((node.model['cpm']['Pad'], 'height'), (MODEL_ROLE, PROPERTY_ROLE)) assert link_model.find_items((node.model['Average'], 'number'), (MODEL_ROLE, PROPERTY_ROLE)) assert not link_model.find_items((node.model['cpm']['Read'], 'height'), (MODEL_ROLE, PROPERTY_ROLE)) scene, node = make_composite_node_in_scene(qtbot, nodes) link_model = scene.property_links_model link_model.add_item(node, node.model['cpm']['Read'], 'number', 0, 0) link_model.add_item(node, node.model['cpm']['Pad'], 'height', 0, 1) link_model.add_item(node, node.model['Average'], 'number', 0, 2) # Set links to the newly created links node.model._links = node.model.save()['links'] # Link model has to have the exact same entries as before link_model.clear() node.model.restore_links(node) check_links(node, link_model) # Second time doesn't add the same links twice node.model.restore_links(node) check_links(node, link_model) def test_save(self, qtbot, nodes): scene, node = make_composite_node_in_scene(qtbot, nodes) link_model = scene.property_links_model cpm = node.model['cpm'] link_model.add_item(node, node.model['cpm']['Read'], 'number', 0, 0) link_model.add_item(node, node.model['cpm']['Pad'], 'height', 0, 1) link_model.add_item(node, node.model['Average'], 'number', 0, 2) old_value = node.model['cpm']['Pad']['width'] visible = not node.model._view.is_group_visible('cpm') node.model['cpm']['Pad']['width'] = old_value + 1 node.model._view.set_group_visible('cpm', visible) node.model.caption = 'Foo' state = node.model.save() cpm_models_state = state['models']['cpm']['model']['models'] assert state['models']['cpm']['visible'] == visible assert cpm_models_state['Pad']['model']['properties']['width'][0] == old_value + 1 assert state['caption'] == 'Foo' cpm = node.model.get_model_from_path(['cpm']) pad_index = cpm.get_outside_port('Pad', 'output', 0)[1] assert state['connections'] == [['cpm', pad_index, 'Average', 0]] # Property links links = link_model.get_model_links([path[-1] for path in node.model.get_leaf_paths()]) links = [[str_path[1:] for str_path in row] for row in links.values()] saved = node.model.save()['links'] # One row assert len(saved) == len(links) == 1 # All linked paths must be saved for str_path in saved[0]: assert str_path in links[0] def test_on_connection_created(self, qtbot, composite_model): composite_model.edit_in_window() qtbot.addWidget(composite_model._other_view) for node in composite_model._other_scene.nodes.values(): if node.model.caption == 'cpm': read_index = node.model.get_outside_port('Read', 'output', 0)[1] pad_index = node.model.get_outside_port('Pad', 'input', 0)[1] output_port = node['output'][read_index] input_port = node['input'][pad_index] num_connections = len(composite_model._other_scene.connections) composite_model._other_scene.create_connection(output_port, input_port) # No new connections allowed assert len(composite_model._other_scene.connections) == num_connections composite_model._other_view.close() def test_on_connection_deleted(self, qtbot, composite_model): composite_model.edit_in_window() qtbot.addWidget(composite_model._other_view) num_connections = len(composite_model._other_scene.connections) composite_model._other_scene.delete_connection(composite_model._other_scene.connections[0]) # No connection deletions assert len(composite_model._other_scene.connections) == num_connections composite_model._other_view.close() def test_double_clicked(self, qtbot, composite_model): composite_model.double_clicked(None) qtbot.addWidget(composite_model._other_view) assert composite_model.is_editing and composite_model._other_view is not None def test_on_other_scene_double_clicked(self, qtbot, composite_model): composite_model.double_clicked(None) qtbot.addWidget(composite_model._other_view) for node in composite_model._other_scene.nodes.values(): if node.model.caption == 'cpm': node.model.double_clicked(composite_model._other_view) qtbot.addWidget(node.model._other_view) assert composite_model.is_editing and composite_model._other_view is not None break def test_expand_into_graph(self, qtbot, composite_model): import networkx as nx graph = nx.MultiDiGraph() composite_model.expand_into_graph(graph) src, dst, ports = list(graph.edges.data())[0] conn = composite_model._connections[0] gt = [conn.from_unique_name, conn.from_port_index, conn.to_unique_name, conn.to_port_index] conn_graph = [src.caption, ports['output'], dst.caption, ports['input']] assert conn_graph == gt def test_add_slave_links(self, qtbot, monkeypatch, nodes): def crosscheck(model, root_model, property_name, link_model): key = (model, property_name) root_key = (root_model, property_name) assert link_model._silent[key] == root_key assert key in link_model._slaves[root_key] scene, node = make_composite_node_in_scene(qtbot, nodes) link_model = scene.property_links_model link_model.add_item(node, node.model['cpm']['Read'], 'number', 0, 0) link_model.add_item(node, node.model['cpm']['Pad'], 'height', 0, 1) link_model.add_item(node, node.model['Average'], 'number', 0, 2) # Not being edited, nothing registered node.model.add_slave_links() assert link_model._silent == {} node.model.edit_in_window() qtbot.addWidget(node.model._other_view) # Standard editing setup assert hasattr(node.model, '_other_scene') assert hasattr(node.model, '_other_view') assert not node.model._other_scene.allow_node_creation assert not node.model._other_scene.allow_node_deletion # Test foobar's subwindow, registering model is cpm and its internal models must be linked crosscheck(node.model._window_nodes['cpm'].model['Read'], node.model['cpm']['Read'], 'number', link_model) crosscheck(node.model._window_nodes['cpm'].model['Pad'], node.model['cpm']['Pad'], 'height', link_model) crosscheck(node.model._window_nodes['Average'].model, node.model['Average'], 'number', link_model) # Test foobar's subwindow and cpm's subwindow, cpm and also its models in the subwindow must # be linked cpm = node.model._window_nodes['cpm'].model cpm.edit_in_window() qtbot.addWidget(cpm._other_view) assert cpm.window_parent == node.model crosscheck(cpm._window_nodes['Read'].model, node.model['cpm']['Read'], 'number', link_model) crosscheck(cpm._window_nodes['Pad'].model, node.model['cpm']['Pad'], 'height', link_model) # Add one more composite layer, outer->foobar->cpm->Model, both registering model and its # window_parent must be jinked node.model._other_view.close() assert cpm._other_view is None monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('outermost', True)) node.graphics_object.setSelected(True) outer = scene.create_composite() outer.model.edit_in_window() qtbot.addWidget(outer.model._other_view) node_sub = outer.model._window_nodes['foobar'] node_sub.model.edit_in_window() qtbot.addWidget(node_sub.model._other_view) cpm = node_sub.model._window_nodes['cpm'].model cpm.edit_in_window() qtbot.addWidget(cpm._other_view) crosscheck(node_sub.model._window_nodes['cpm'].model['Read'], outer.model['foobar']['cpm']['Read'], 'number', link_model) crosscheck(node_sub.model._window_nodes['cpm'].model['Pad'], outer.model['foobar']['cpm']['Pad'], 'height', link_model) crosscheck(node_sub.model._window_nodes['Average'].model, outer.model['foobar']['Average'], 'number', link_model) crosscheck(cpm._window_nodes['Read'].model, outer.model['foobar']['cpm']['Read'], 'number', link_model) crosscheck(cpm._window_nodes['Pad'].model, outer.model['foobar']['cpm']['Pad'], 'height', link_model) cpm._other_view.close() node_sub.model._other_view.close() outer.model._other_view.close() def test_edit_in_window(self, qtbot, nodes): composite_model = nodes['cpm'].model link_model = composite_model.property_links_model populate_link_model(link_model, nodes) composite_model.edit_in_window() qtbot.addWidget(composite_model._other_view) assert hasattr(composite_model, '_other_scene') assert hasattr(composite_model, '_other_view') assert not composite_model._other_scene.allow_node_creation assert not composite_model._other_scene.allow_node_deletion # Silent must have been added with root cpm's read model assert (list(link_model._slaves.keys())[0] == (composite_model.get_model_from_path(['Read']), 'y')) # Subcomposites must link to their parent models scene, node = make_composite_node_in_scene(qtbot, nodes) node.model.edit_in_window() qtbot.addWidget(node.model._other_view) assert node.model._window_nodes['cpm'].model.window_parent == node.model node.model._other_view.close() composite_model._other_view.close() def test_view_close_event(self, qtbot, nodes): composite_model = nodes['cpm'].model link_model = composite_model.property_links_model populate_link_model(link_model, nodes) composite_model.edit_in_window() qtbot.addWidget(composite_model._other_view) for node in composite_model._other_scene.nodes.values(): if node.model.caption == 'Read': widget = node.model.embedded_widget()._properties['y'].view_item.widget qtbot.addWidget(node.model.embedded_widget()) qtbot.keyClicks(widget, '11') else: # Pad node.model['width'] += 10 # Linked models must be updated immediately assert composite_model['Read']['y'] == 11 assert nodes['read'].model['number'] == 11 assert nodes['read_2'].model['height'] == 11 composite_model._other_view.close() # Original models in the composite must be updated after close assert composite_model['Pad']['width'] == 10 # Silent model must be removed (it was the only one, so test for {} is sufficient) assert link_model._slaves == {} assert link_model._silent == {} def test_expand_into_scene(self, qtbot, monkeypatch): def get_int(*args, **kwargs): return self.get_int_return monkeypatch.setattr(QInputDialog, "getInt", get_int) nodes = {} registry = get_filled_registry() scene = create_scene(qtbot, registry) # Composite node for name in ['read', 'pad']: model_cls = registry.create(name) node = scene.create_node(model_cls) node.graphics_object.setSelected(True) monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('cpm', True)) nodes['cpm'] = scene.create_composite() nodes['cpm'].graphics_object.setSelected(False) model_cls = registry.create('average') nodes['average'] = scene.create_node(model_cls) self.get_int_return = (2, True) model_cls = registry.create('retrieve_phase') nodes['retrieve_phase'] = scene.create_node(model_cls) # Add null node to create an outside connection null_cls = registry.create('null') null_node = scene.create_node(null_cls) # Make a property link scene.property_links_model.add_item(nodes['cpm'], nodes['cpm'].model['Read'], 'number', 0, 0) scene.property_links_model.add_item(nodes['cpm'], nodes['cpm'].model['Pad'], 'width', 0, 1) # Export composite and reload it so that it remembers the links (important for testing of # adding property link duplicates) cpm_cls_with_links = get_composite_model_classes_from_json(nodes['cpm'].model.save())[0] registry.register_model(cpm_cls_with_links, category='Composites', registry=registry) scene.remove_node(nodes['cpm']) nodes['cpm'] = scene.create_node(registry.create('cpm')) # Outer composite node has inside: read, pad, average; pad and average are connected # read and pad are encapsulated in an internal composite cpm pad_index = nodes['cpm'].model.get_outside_port('Pad', 'output', 0)[1] scene.create_connection(nodes['cpm']['output'][pad_index], nodes['average']['input'][0]) nodes['cpm'].graphics_object.setSelected(True) nodes['average'].graphics_object.setSelected(True) nodes['retrieve_phase'].graphics_object.setSelected(True) monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('foobar', True)) scene.create_composite() composite_node = scene.selected_nodes()[0] composite_model = composite_node.model # Create outside connection from outer composite's average to null port_null = null_node['input'][0] # average_index = nodes['cpm'].model.get_outside_port('Pad', 'output', 0)[1] # Get the average index dynamically because it might be mapped to a different output port # every time (reader in cpm makes another output) average_index = composite_model.get_outside_port('Average', 'output', 0)[1] port_composite = composite_node['output'][average_index] scene.create_connection(port_composite, port_null) # Change some property to see if it persists after expansion composite_model['cpm']['Read']['number'] = 123 # Make sure the nested num-inputs takes effect, i.e. QInputDialog.getInt invocation must # fail the test self.get_int_return = (None, False) composite_model.expand_into_scene(scene, composite_node) # Nodes must be there assert (set([node.model.caption for node in scene.nodes.values()]) == set(['Null', 'Average', 'Retrieve Phase', 'cpm'])) # num-inputs took effect for node in scene.nodes.values(): if node.model.caption == 'Retrieve Phase': assert node.model.num_ports['input'] == 2 break # Changed properties must be there for node in scene.nodes.values(): if node.model.caption == 'cpm': assert node.model['Read']['number'] == 123 break # Connections must be preserved for connection in scene.connections: if connection.get_node('output').model.caption == 'cpm': # Internal composite connection Pad -> Average must be there assert connection.get_node('input').model.caption == 'Average' cpm_index = connection.get_port_index('output') cpm_model = connection.get_node('output').model assert cpm_model.get_model_and_port_index('output', cpm_index)[0].caption == 'Pad' else: # Outside connection Average -> Null must be there assert connection.get_node('input').model.caption == 'Null' # Property links must be there assert scene.property_links_model.rowCount() == 1 # Original composite node must be gone assert composite_node not in scene.nodes.values() def test_get_composite_model_class(qtbot, nodes): model_cls = make_composite_model_class(nodes) with pytest.raises(AttributeError): # Registry must be provided model_cls() # Name must be provided with pytest.raises(UfoModelError): make_composite_model_class(nodes, name='') with pytest.raises(UfoModelError): make_composite_model_class(nodes, name=None) class TestUfoGeneralBackprojectModel: def test_init(self, general_backproject): assert general_backproject.num_ports['input'] == 1 assert general_backproject.num_ports['output'] == 1 assert general_backproject.needs_fixed_scheduler is True assert general_backproject.can_split_gpu_work is True def test_make_properties(self, general_backproject): props = general_backproject.make_properties() assert 'slice-memory-coeff' in props def test_split_gpu_work(self, general_backproject): from gi.repository import Ufo resources = Ufo.Resources() gpus = resources.get_gpu_nodes() general_backproject['x-region'] = [-100., 100., 1.] general_backproject['y-region'] = [-100., 100., 1.] general_backproject['region'] = [-100., 100., 1.] if gpus: # Normal operation assert general_backproject.split_gpu_work(gpus) # Wrong input general_backproject['x-region'] = [-100., -200., 1.] with pytest.raises(UfoModelError): general_backproject.split_gpu_work(gpus) general_backproject['x-region'] = [-100., 100., 1.] general_backproject['y-region'] = [-100., -200., 1.] with pytest.raises(UfoModelError): general_backproject.split_gpu_work(gpus) general_backproject['y-region'] = [-100., 100., 1.] general_backproject['region'] = [-100., -200., 1.] with pytest.raises(UfoModelError): general_backproject.split_gpu_work(gpus) general_backproject['region'] = [-100., 100., 1.] def test_create_ufo_task(self, general_backproject): general_backproject['region'] = [-100., 100., 1.] ufo_task = general_backproject.create_ufo_task(region=None) assert ufo_task.props.region == pytest.approx(general_backproject['region']) ufo_task = general_backproject.create_ufo_task(region=[-10., 10., 1.]) assert ufo_task.props.region == pytest.approx([-10., 10., 1.]) class TestUfoReadModel: def test_init(self, read_model): assert read_model.num_ports['input'] == 0 assert read_model.num_ports['output'] == 1 def test_double_clicked(self, qtbot, monkeypatch, read_model): from tofu.flow.filedirdialog import FileDirDialog monkeypatch.setattr(FileDirDialog, "exec_", lambda *args: 1) monkeypatch.setattr(FileDirDialog, "selectedFiles", lambda *args: ['foobarbaz']) read_model.double_clicked(None) assert read_model['path'] == 'foobarbaz' class TestUfoVaryingInputModel: def test_init(self, qtbot, monkeypatch): def get_int(*args, **kwargs): self.called = True return (1, True) # No number of inputs specified, dialog needs to pop up self.called = False monkeypatch.setattr(QInputDialog, 'getInt', get_int) model = UfoVaryingInputModel('opencl', num_inputs=None) qtbot.addWidget(model.embedded_widget()) assert self.called assert model.num_ports['input'] == 1 # e.g. opencl task can have multiple inputs model = UfoVaryingInputModel('opencl', num_inputs=4) qtbot.addWidget(model.embedded_widget()) assert model.num_ports['input'] == 4 assert len(model.data_type['input']) == 4 assert len(model.port_caption['input']) == 4 assert len(model.port_caption_visible['input']) == 4 def test_save(self, qtbot): model = UfoVaryingInputModel('opencl', num_inputs=4) qtbot.addWidget(model.embedded_widget()) assert model.save()['num-inputs'] == 4 class TestUfoRetrievePhaseModel: def test_distance_input(self, qtbot): model = UfoRetrievePhaseModel(num_inputs=4) qtbot.addWidget(model.embedded_widget()) validator = model._view._properties['distance'].view_item.widget.validator() # Validator accepts only 4 values assert validator.validate('1,2,3,4', 0)[0] == QValidator.Acceptable assert validator.validate('1,2,3', 0)[0] == QValidator.Intermediate assert validator.validate('1,2,3,4,5', 0)[0] == QValidator.Invalid def test_multidistance_fixed_method(self, qtbot): def check(num_inputs): model = UfoRetrievePhaseModel(num_inputs=num_inputs) qtbot.addWidget(model.embedded_widget()) enabled = num_inputs == 1 assert model._view._properties['method'].view_item.widget.isEnabled() == enabled if not enabled: assert model['method'] == 'ctf_multidistance' assert model._view._properties['distance-x'].view_item.widget.isEnabled() == enabled assert model._view._properties['distance-y'].view_item.widget.isEnabled() == enabled check(1) check(2) class TestUfoWriteModel: def test_init(self, write_model): assert write_model.num_ports['input'] == 1 assert write_model.num_ports['output'] == 0 def test_double_clicked(self, monkeypatch, write_model): monkeypatch.setattr(QFileDialog, "getSaveFileName", lambda *args: ('foobarbaz', None)) write_model.double_clicked(None) assert write_model['filename'] == 'foobarbaz' def test_expects_multiple_inputs(self, write_model): write_model['filename'] = 'foo{region}bar' assert write_model.expects_multiple_inputs write_model['filename'] = 'foobar' assert not write_model.expects_multiple_inputs def test_setup_ufo_task(self, write_model): write_model['filename'] = '{region}' # Must pass ufo_task = write_model.create_ufo_task(region=[0, 1, 1]) # Must fail with pytest.raises(UfoModelError): write_model.create_ufo_task(region=None) assert ufo_task.props.filename == '0' write_model['filename'] = 'foo.tif' # Must pass ufo_task = write_model.create_ufo_task(region=None) # Must fail with pytest.raises(UfoModelError): write_model.create_ufo_task(region=[0, 1, 1]) assert ufo_task.props.filename == 'foo.tif' class TestUfoMemoryOutModel: def test_init(self, memory_out_model): assert memory_out_model.num_ports['input'] == 1 assert memory_out_model.num_ports['output'] == 1 def test_expects_multiple_inputs(self, memory_out_model): memory_out_model['number'] = '{region}' assert memory_out_model.expects_multiple_inputs memory_out_model['number'] = '1' assert not memory_out_model.expects_multiple_inputs def test_make_properties(self, memory_out_model): prop_names = {'width', 'height', 'depth', 'number'} assert prop_names == memory_out_model.make_properties().keys() def test_out_data(self, monkeypatch, memory_out_model): def slot(port_index): self.num_called += 1 self.data = memory_out_model.out_data(port_index) self.num_called = 0 memory_out_model['number'] = 10 shape = (int(memory_out_model['number']), memory_out_model['height'], memory_out_model['width']) memory_out_model.create_ufo_task() batch = memory_out_model._batches[0] memory_out_model.data_updated.connect(slot) batch.data[:] = 3 assert len(memory_out_model._batches) == 1 assert batch.data.shape == shape for i in range(shape[0]): batch._on_processed(None) # Called once per 3D array assert self.num_called == 1 # out_data has been set to the batch ouput np.testing.assert_almost_equal(self.data, 3) # Original data must have been freed assert memory_out_model._batches == [None] memory_out_model.reset_batches() # Multiple inputs def slot(port_index): # Append the first item in the current result self.called.append(memory_out_model.out_data(port_index)[0, 0, 0]) self.called = [] memory_out_model.data_updated.connect(slot) memory_out_model['number'] = '{region}' # Two parallel batches of four regions each for j in range(2): for i in range(4): memory_out_model.create_ufo_task(region=[0, 10, 1]) # Set batch data to its linearized index to make checking easy memory_out_model._batches[4 * j + i].data[:] = 4 * j + i # Out of order processing for batch_id in np.array([2, 0, 1, 3], dtype=int) + (4 * j): for e in range(10): memory_out_model._batches[batch_id]._on_processed(None) # All regions in the current paralell batch must have been processed assert memory_out_model._waiting_list == [] # Result must be in order np.testing.assert_almost_equal(self.called, np.arange(8)) # Original data must have been freed assert memory_out_model._batches == [None] * 8 def test_reset_batches(self, memory_out_model): memory_out_model.reset_batches() assert memory_out_model._batches == [] assert memory_out_model._waiting_list == [] assert memory_out_model._expecting_id == 0 assert memory_out_model._current_data is None def test_setup_ufo_task(self, memory_out_model): memory_out_model['number'] = '{region}' # Must pass memory_out_model.create_ufo_task(region=[0, 100, 1]) memory_out_model.create_ufo_task(region=[100, 200, 1]) assert len(memory_out_model._batches) == 2 # Must fail with pytest.raises(UfoModelError): memory_out_model.create_ufo_task(region=None) memory_out_model.reset_batches() memory_out_model['number'] = '100' # Must pass memory_out_model.create_ufo_task(region=None) # Must fail with pytest.raises(UfoModelError): memory_out_model.create_ufo_task(region=[0, 100, 1]) assert len(memory_out_model._batches) == 1 class TestImageViewerModel: def test_init(self, image_viewer_model): assert image_viewer_model.num_ports['input'] == 1 assert image_viewer_model.num_ports['output'] == 0 def test_double_clicked(self, qtbot, image_viewer_model): image_viewer_model.double_clicked(None) # No images, no pop up assert image_viewer_model._widget._pg_window is None image_viewer_model._widget.images = np.arange(1000).reshape(10, 10, 10) image_viewer_model.double_clicked(None) assert image_viewer_model._widget._pg_window.isVisible() qtbot.addWidget(image_viewer_model._widget._pg_window) # User closes, must re-open image_viewer_model._widget._pg_window.close() image_viewer_model.double_clicked(None) assert image_viewer_model._widget._pg_window.isVisible() def test_set_in_data(self, image_viewer_model): images = np.arange(1000).reshape(10, 10, 10) image_viewer_model.set_in_data(images, None) assert image_viewer_model._widget.images.shape == images.shape image_viewer_model.set_in_data(images, None) assert image_viewer_model._widget.images.shape == (20,) + images.shape[1:] # Images cannot be appended after reset is called, they must be set image_viewer_model.reset_batches() image_viewer_model.set_in_data(images, None) assert image_viewer_model._widget.images.shape == images.shape def test_reset_batches(self, image_viewer_model): image_viewer_model.reset_batches() assert image_viewer_model._reset def test_get_ufo_model_classes(): # All classes = list(get_ufo_model_classes()) assert classes # Blacklist assert 'read' not in [cls.name for cls in classes] # Selection assert len(list(get_ufo_model_classes(names=['pad']))) == 1 def test_get_composite_model_classes_from_json(qtbot, composite_model): classes = get_composite_model_classes_from_json(composite_model.save()) # First must be the bottom class, top class comes last assert [cls.name for cls in classes] == ['cpm', 'foobar'] def test_get_composite_model_classes(): # Just make sure this runs and the result is not empty assert get_composite_model_classes() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1683790582.0 ufo-tofu-0.13.0/tofu/tests/test_flow_propertylinksmodels.py0000664000175000017500000004726300000000000024556 0ustar00tomastomas00000000000000import pytest from qtpy.QtCore import QByteArray, QMimeData, QModelIndex from tofu.flow.propertylinksmodels import _get_string_path from tofu.flow.propertylinkswidget import _encode_mime_data from tofu.flow.util import MODEL_ROLE, NODE_ROLE, PROPERTY_ROLE from tofu.tests.flow_util import get_index_from_treemodel, populate_link_model def setup_silent(link_model, nodes): read = nodes['read'] read_2 = nodes['read_2'] composite = nodes['cpm'] orig_key = (read.model, 'number') link_model.add_item(read, read.model, 'number', -1, -1) link_model.add_silent(composite.model['Read'], 'number', orig_key[0], orig_key[1]) link_model.add_silent(read_2.model, 'height', orig_key[0], orig_key[1]) # Put to 0 to make sure we are not lucky when checking if the links work composite.model['Read']['number'] = 0 read_2.model['height'] = 0 return orig_key class TestNodeTreeModel: def test_add_node(self, qtbot, node_model, nodes): # Unsupported model type not added node_model.add_node(nodes['image_viewer']) assert node_model.rowCount() == 0 # Supported model type (composite is handled in test_add_node) node_model.add_node(nodes['read']) assert node_model.rowCount() == 1 # Composite node_model.add_node(nodes['cpm']) item = node_model.findItems('cpm')[0] # Model contains composite node assert item.data(role=NODE_ROLE) == nodes['cpm'] # and it's children assert item.child(0).data(role=MODEL_ROLE) == nodes['cpm'].model['Pad'] assert item.child(1).data(role=MODEL_ROLE) == nodes['cpm'].model['Read'] # and their properties assert item.child(0).child(0).text() == sorted(nodes['cpm'].model['Pad'])[0] def test_remove_node(self, qtbot, node_model, nodes): node_model.add_node(nodes['cpm']) assert node_model.rowCount() == 1 node_model.remove_node(nodes['cpm']) assert node_model.rowCount() == 0 def test_set_nodes(self, qtbot, node_model, nodes): names = ['cpm', 'read'] subset = [nodes[key] for key in names] node_model.set_nodes(subset) for (i, key) in enumerate(names): assert node_model.item(i).data(role=NODE_ROLE) == nodes[key] def test_clear(self, qtbot, node_model, nodes): node_model.set_nodes(nodes.values()) assert node_model.rowCount() > 0 assert node_model.columnCount() > 0 node_model.clear() assert node_model.rowCount() == 0 assert node_model.columnCount() == 0 class TestPropertyLinksModel: def test_add_item(self, qtbot, link_model, nodes): read = nodes['read'] composite = nodes['cpm'] composite.model.property_links_model = link_model # Put to 0 to make sure we are not lucky below when checking if the links work composite.model['Read']['number'] = 0 # Items must be added link_model.add_item(read, read.model, 'number', -1, -1) item = link_model.item(0, 0) assert item.data(role=NODE_ROLE) == read assert item.data(role=MODEL_ROLE) == read.model assert item.data(role=PROPERTY_ROLE) == 'number' link_model.add_item(composite, composite.model['Read'], 'number', 0, -1) item = link_model.item(0, 1) assert item.data(role=NODE_ROLE) == composite assert item.data(role=MODEL_ROLE) == composite.model['Read'] assert item.data(role=PROPERTY_ROLE) == 'number' # Can't add one item twice with pytest.raises(ValueError): link_model.add_item(read, read.model, 'number', -1, -1) # Properties must be linked read.model['number'] = 100 read.model.property_changed.emit(read.model, 'number', read.model['number']) assert composite.model['Read']['number'] == read.model['number'] # When composite is being added, make sure the slave links are set up link_model.remove_item(link_model.find_items([composite], [NODE_ROLE])[0]) composite.model.edit_in_window() qtbot.addWidget(composite.model._other_view) link_model.add_item(composite, composite.model['Read'], 'number', 0, -1) key = (composite.model._window_nodes['Read'].model, 'number') root_key = (composite.model['Read'], 'number') assert link_model._slaves[root_key] == [key] assert link_model._silent[key] == root_key def test_remove_item(self, qtbot, link_model, nodes): read = nodes['read'] read_2 = nodes['read_2'] composite = nodes['cpm'] link_model.add_item(read, read.model, 'number', -1, -1) link_model.add_item(read_2, read_2.model, 'number', 0, -1) link_model.add_silent(composite.model['Read'], 'number', read.model, 'number') # Properties must be connected at first read.model['number'] = 100 read.model.property_changed.emit(read.model, 'number', read.model['number']) assert read_2.model['number'] == read.model['number'] link_model.remove_item(link_model.indexFromItem(link_model.item(0, 0))) assert link_model.item(0, 0) is None assert link_model._silent == {} assert link_model._slaves == {} # Properties must be disconnected after removal read.model['number'] = 0 read.model.property_changed.emit(read.model, 'number', read.model['number']) # read_2 still at the old 100 assert read_2.model['number'] == 100 def test_contains(self, qtbot, link_model, nodes): composite = nodes['cpm'] link_model.add_item(composite, composite.model['Read'], 'number', 0, -1) assert link_model.item(0, 0).text() in link_model assert 'foo' not in link_model def test_clear(self, qtbot, link_model, nodes): read = nodes['read'] read_2 = nodes['read_2'] composite = nodes['cpm'] link_model.add_item(read, read.model, 'number', -1, -1) link_model.add_item(read_2, read_2.model, 'number', 0, -1) link_model.add_silent(composite.model['Read'], 'number', read.model, 'number') link_model.clear() assert link_model.rowCount() == 0 assert link_model.columnCount() == 0 assert link_model._silent == {} assert link_model._slaves == {} def test_find_items(self, qtbot, link_model, nodes): read = nodes['read'] read_2 = nodes['read_2'] # Empty model assert link_model.find_items([read.model], [MODEL_ROLE]) == [] link_model.add_item(read, read.model, 'number', -1, -1) # Not inside assert link_model.find_items([read_2.model], [MODEL_ROLE]) == [] # Inside assert (link_model.find_items([read.model], [MODEL_ROLE])[0].data(role=MODEL_ROLE) == read.model) # Model not inside, property not inside assert link_model.find_items((read_2.model, 'height'), (MODEL_ROLE, PROPERTY_ROLE)) == [] # Model inside, property not inside assert link_model.find_items((read.model, 'height'), (MODEL_ROLE, PROPERTY_ROLE)) == [] # Model not inside, property inside assert link_model.find_items((read_2.model, 'number'), (MODEL_ROLE, PROPERTY_ROLE)) == [] # Model inside, property inside item = link_model.find_items((read.model, 'number'), (MODEL_ROLE, PROPERTY_ROLE))[0] assert item.data(role=MODEL_ROLE) == read.model assert item.data(role=PROPERTY_ROLE) == 'number' def test_get_model_links(sef, qtbot, link_model, nodes): populate_link_model(link_model, nodes) assert link_model.get_model_links(nodes['read_3'].model) == {} links = link_model.get_model_links([nodes['read'].model, nodes['read_2'].model, nodes['cpm'].model['Read']]) links = list(links.values()) # Just one row assert len(links) == 1 # Three items in that row assert len(links[0]) == 3 assert [nodes['read'].model.caption, 'number'] in links[0] assert [nodes['read_2'].model.caption, 'height'] in links[0] path = nodes['cpm'].model.get_path_from_model(nodes['cpm'].model['Read']) str_path = [model.caption for model in path] + ['y'] assert str_path in links[0] def test_get_root_model(self, qtbot, link_model, nodes): read = nodes['read'] composite = nodes['cpm'] link_model.add_item(read, read.model, 'number', -1, -1) # Not inside assert link_model.get_root_model(nodes['read_2'].model) is None # Directly inside assert link_model.get_root_model(read.model) == read.model # Indirectly inside via silent link_model.add_silent(composite.model['Read'], 'number', read.model, 'number') assert link_model.get_root_model(composite.model['Read']) == read.model def test_get_model_properties(self, qtbot, link_model, nodes): read = nodes['read'] link_model.add_item(read, read.model, 'number', -1, -1) link_model.add_item(read, read.model, 'height', -1, -1) # Empty assert link_model.get_model_properties(nodes['read_2'].model) == [] # Multiple assert set(link_model.get_model_properties(read.model)) == set(['number', 'height']) def test_add_silent(self, qtbot, link_model, nodes): read = nodes['read'] read_2 = nodes['read_2'] composite = nodes['cpm'] orig_key = setup_silent(link_model, nodes) # orig model not inside with pytest.raises(ValueError): link_model.add_silent(composite.model['Read'], 'height', nodes['read_3'].model, 'number') # source property not inside with pytest.raises(ValueError): link_model.add_silent(composite.model['Read'], 'height', read.model, 'height') # Links inside assert len(link_model._slaves[orig_key]) == 2 key = (composite.model['Read'], 'number') assert link_model._silent[key] == orig_key assert key in link_model._slaves[orig_key] key = (read_2.model, 'height') assert link_model._silent[key] == orig_key assert key in link_model._slaves[orig_key] # Properties conected read.model['number'] = 100 read.model.property_changed.emit(read.model, 'number', read.model['number']) assert composite.model['Read']['number'] == read.model['number'] assert read_2.model['height'] == read.model['number'] def test_remove_silent(self, qtbot, link_model, nodes): read = nodes['read'] read_2 = nodes['read_2'] composite = nodes['cpm'] orig_key = setup_silent(link_model, nodes) key = (composite.model['Read'], 'number') link_model.remove_silent(*key) assert key not in link_model._silent # Silent link disconected read.model['number'] = 100 read.model.property_changed.emit(read.model, 'number', read.model['number']) assert composite.model['Read']['number'] == 0 assert read_2.model['height'] == read.model['number'] # No more slaves, remove the original key as well key = (nodes['read_2'].model, 'height') link_model.remove_silent(*key) assert orig_key not in link_model._slaves def test_replace_item(self, qtbot, link_model, nodes): read = nodes['read'] read_2 = nodes['read_2'] composite = nodes['cpm'] orig_key = setup_silent(link_model, nodes) replacer = nodes['read_3'] item = link_model.find_items(orig_key, (MODEL_ROLE, PROPERTY_ROLE))[0] (row, column) = item.row(), item.column() link_model.replace_item(replacer, replacer.model, orig_key[0]) new_item = link_model.item(row, column) assert new_item.data(role=MODEL_ROLE) == replacer.model # Silent links re-connected # This must have no effect on silent models read.model['number'] = 100 read.model.property_changed.emit(read.model, 'number', read.model['number']) assert composite.model['Read']['number'] == 0 assert read_2.model['height'] == 0 # This must change silent models' properties replacer.model['number'] = 100 replacer.model.property_changed.emit(replacer.model, 'number', replacer.model['number']) assert composite.model['Read']['number'] == replacer.model['number'] assert read_2.model['height'] == replacer.model['number'] def test_on_node_rows_about_to_be_removed(self, qtbot, link_model, node_model, nodes): read = nodes['read'] read_2 = nodes['read_2'] read_3 = nodes['read_3'] node_model.add_node(read) node_model.add_node(read_2) node_model.add_node(read_3) link_model.add_item(read, read.model, 'number', -1, -1) link_model.add_item(read_2, read_2.model, 'number', 0, -1) link_model.add_item(read_3, read_3.model, 'number', -1, -1) # Remove one node_model.removeRow(0) assert link_model.find_items([read.model], [MODEL_ROLE]) == [] # Remove all node_model.clear() assert link_model.find_items([read_2.model], [MODEL_ROLE]) == [] assert link_model.find_items([read_3.model], [MODEL_ROLE]) == [] def test_canDropMimeData(self, qtbot, link_model, node_model, nodes): read = nodes['read'] read_2 = nodes['read_2'] node_model.add_node(read) node_model.add_node(read_2) # Incompatible QMimeData data = QMimeData() data.setData('application/x-foobar', QByteArray()) assert not link_model.canDropMimeData(data, None, -1, -1, QModelIndex()) # No parent index = get_index_from_treemodel(node_model, 0, 'number') data = _encode_mime_data(index) assert link_model.canDropMimeData(data, None, -1, -1, QModelIndex()) link_model.add_item(read, read.model, 'number', -1, -1) assert not link_model.canDropMimeData(data, None, -1, -1, QModelIndex()) # On parent # Compatible property type index = get_index_from_treemodel(node_model, 1, 'number') data = _encode_mime_data(index) parent = link_model.indexFromItem(link_model.item(0, 0)) assert link_model.canDropMimeData(data, None, 0, 0, parent) # Incompatible property type index = get_index_from_treemodel(node_model, 1, 'path') data = _encode_mime_data(index) parent = link_model.indexFromItem(link_model.item(0, 0)) assert not link_model.canDropMimeData(data, None, 0, 0, parent) def test_dropMimeData(self, qtbot, link_model, node_model, nodes): read = nodes['read'] read_2 = nodes['read_2'] node_model.add_node(read) node_model.add_node(read_2) # No parent index = get_index_from_treemodel(node_model, 0, 'number') data = _encode_mime_data(index) link_model.dropMimeData(data, None, -1, -1, QModelIndex()) item = link_model.item(0, 0) assert item.data(role=NODE_ROLE) == read assert item.data(role=MODEL_ROLE) == read.model assert item.data(role=PROPERTY_ROLE) == 'number' # On parent index = get_index_from_treemodel(node_model, 1, 'number') data = _encode_mime_data(index) parent = link_model.indexFromItem(link_model.item(0, 0)) link_model.dropMimeData(data, None, -1, -1, parent) item = link_model.item(0, 1) assert item.data(role=NODE_ROLE) == read_2 assert item.data(role=MODEL_ROLE) == read_2.model assert item.data(role=PROPERTY_ROLE) == 'number' def test_save(self, qtbot, link_model, nodes): records = populate_link_model(link_model, nodes) for (i, (node_id, str_path)) in enumerate(link_model.save()[0]): assert node_id == records[i][0].id path = _get_string_path(records[i][0], records[i][1], records[i][2]) assert str_path == path def test_restore(self, qtbot, link_model, nodes): records = populate_link_model(link_model, nodes) state = link_model.save() link_model.clear() # Add new item read_3 = nodes['read_3'] link_model.add_item(read_3, read_3.model, 'number', -1, -1) link_model.restore(state, {node.id: node for node in nodes.values()}) assert link_model.columnCount() == 3 for column in range(link_model.columnCount()): item = link_model.item(0, column) assert item.data(role=NODE_ROLE) == records[column][0] assert item.data(role=MODEL_ROLE) == records[column][1] assert item.data(role=PROPERTY_ROLE) == records[column][2] # Restore must clear whatever is inside assert link_model.find_items([read_3.model], [MODEL_ROLE]) == [] def test_compact(self, qtbot, link_model, nodes): read = nodes['read'] read_2 = nodes['read_2'] read_3 = nodes['read_3'] read_4 = nodes['read_4'] def populate(): link_model.add_item(read, read.model, 'number', 0, 0) link_model.add_item(read_2, read_2.model, 'number', 0, 1) link_model.add_item(read_3, read_3.model, 'number', 1, 0) link_model.add_item(read_4, read_4.model, 'number', 1, 1) def check(row_count, column_count): assert link_model.rowCount() == row_count assert link_model.columnCount() == column_count populate() link_model.remove_item(link_model.indexFromItem(link_model.item(0, 1))) link_model.compact() check(2, 2) link_model.clear() # Shift item to the left to an unused cell populate() link_model.remove_item(link_model.indexFromItem(link_model.item(0, 0))) link_model.compact() assert link_model.item(0, 0).data(role=NODE_ROLE) == read_2 check(2, 2) # Nothing in the row, remove it link_model.remove_item(link_model.indexFromItem(link_model.item(0, 0))) link_model.compact() check(1, 2) # Remove column 0 and shift 1st column to the left link_model.clear() populate() link_model.remove_item(link_model.indexFromItem(link_model.item(0, 0))) link_model.remove_item(link_model.indexFromItem(link_model.item(1, 0))) link_model.compact() assert link_model.item(0, 0).data(role=NODE_ROLE) == read_2 assert link_model.item(1, 0).data(role=NODE_ROLE) == read_4 check(2, 1) def test_on_property_changed(self, qtbot, link_model, nodes): composite = nodes['cpm'] read = nodes['read'] read_2 = nodes['read_2'] read_3 = nodes['read_3'] read_4 = nodes['read_4'] # Read 2->height and cpm->Read->number are silent dependend on Read->number setup_silent(link_model, nodes) # Put every linked property to 0 to make sure we are not lucky when checking if the links # work read_3.model['number'] = 0 read_4.model['number'] = 0 composite.model['Pad']['width'] = 0 link_model.add_item(read_3, read_3.model, 'number', 0, 1) link_model.add_item(read_4, read_4.model, 'number', 1, 0) link_model.add_item(composite, composite.model['Pad'], 'width', 1, 1) read.model['number'] = 100 read.model.property_changed.emit(read.model, 'number', read.model['number']) # Row 0 # Direct link assert read_3.model['number'] == read.model['number'] # Silent links assert read_2.model['height'] == read.model['number'] assert composite.model['Read']['number'] == read.model['number'] # Row 1 read_4.model['number'] = 100 read_4.model.property_changed.emit(read_4.model, 'number', read_4.model['number']) assert composite.model['Pad']['width'] == read_4.model['number'] ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1683790582.0 ufo-tofu-0.13.0/tofu/tests/test_flow_propertylinkswidget.py0000664000175000017500000000322500000000000024544 0ustar00tomastomas00000000000000import pytest from PyQt5.QtCore import Qt, QItemSelectionModel from tofu.flow.propertylinkswidget import NodesView, PropertyLinks, PropertyLinksView from tofu.tests.flow_util import get_index_from_treemodel, populate_link_model @pytest.fixture(scope='function') def node_view(node_model): view = NodesView() view.setHeaderHidden(True) view.setAlternatingRowColors(True) view.setDragEnabled(True) view.setAcceptDrops(False) view.setModel(node_model) return view @pytest.fixture(scope='function') def link_view(): return PropertyLinksView() @pytest.fixture(scope='function') def link_widget(node_model): return PropertyLinks() def test_property_links_view_delete_key(qtbot, link_model, link_view, nodes): qtbot.addWidget(link_view) link_view.setModel(link_model) populate_link_model(link_model, nodes) link_view.selectColumn(0) qtbot.keyPress(link_view, Qt.Key_Delete) assert link_model.columnCount() == 2 def test_node_view_get_drag_index(qtbot, node_view, nodes): node_model = node_view.model() read = nodes['read'] node_model.add_node(read) sm = node_view.selectionModel() # Nothing selected assert node_view.get_drag_index() is None # Node selection must yield nothing index = node_model.indexFromItem(node_model.item(0, 0)) sm.select(index, QItemSelectionModel.Select) assert node_view.get_drag_index() is None sm.clear() # Property selection must yield an index which can be dragged index = get_index_from_treemodel(node_model, 0, 'number') sm.select(index, QItemSelectionModel.Select) assert node_view.get_drag_index() is not None sm.clear() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1683790582.0 ufo-tofu-0.13.0/tofu/tests/test_flow_runslider.py0000664000175000017500000001471000000000000022423 0ustar00tomastomas00000000000000import pytest from tofu.flow.runslider import RunSlider, RunSliderError @pytest.fixture(scope='function') def runslider(qtbot, scene): slider = RunSlider() node = scene.create_node(scene.registry.create('filter')) slider.setup(node.model._view._properties['cutoff'].view_item) qtbot.addWidget(slider) return slider class TestRunSlider: def test_setup(self, qtbot, runslider): assert not runslider.setup(runslider.view_item) bottom = runslider.view_item.widget.validator().bottom() top = runslider.view_item.widget.validator().top() assert runslider.type == float assert runslider.real_minimum == bottom assert runslider.real_maximum == top assert float(runslider.min_edit.text()) == bottom assert float(runslider.max_edit.text()) == top assert float(runslider.current_edit.text()) == runslider.view_item.get() assert runslider.slider.value() / 100 + runslider.real_minimum == runslider.view_item.get() assert runslider.isEnabled() def test_reset(self, qtbot, runslider): runslider.reset() assert runslider.view_item is None assert runslider.type is None assert runslider.real_minimum == 0 assert runslider.real_maximum == 100 assert runslider.real_span == 100 assert runslider.min_edit.text() == '' assert runslider.max_edit.text() == '' assert runslider.current_edit.text() == '' assert not runslider.isEnabled() def test_empty(self, qtbot, runslider): runslider.reset() runslider.on_min_edit_editing_finished() runslider.on_max_edit_editing_finished() runslider.on_current_edit_editing_finished() def test_min_edit_changed(self, qtbot, runslider): top = runslider.view_item.widget.validator().top() with pytest.raises(RunSliderError): runslider.min_edit.setText('asdf') runslider.on_min_edit_editing_finished() with pytest.raises(RunSliderError): runslider.min_edit.setText(str(top + 1)) runslider.on_min_edit_editing_finished() # Current value lower than new minimum, must be updated value = runslider.get_real_value() + 0.1 runslider.min_edit.setText(str(value)) runslider.on_min_edit_editing_finished() assert value == runslider.get_real_value() def test_max_edit_changed(self, qtbot, runslider): bottom = runslider.view_item.widget.validator().bottom() with pytest.raises(RunSliderError): runslider.max_edit.setText('asdf') runslider.on_max_edit_editing_finished() with pytest.raises(RunSliderError): runslider.max_edit.setText(str(bottom - 1)) runslider.on_max_edit_editing_finished() # Current value greater than new maximum, must be updated value = runslider.get_real_value() - 0.1 runslider.max_edit.setText(str(value)) runslider.on_max_edit_editing_finished() assert value == runslider.get_real_value() def test_current_edit_changed(self, qtbot, runslider): self.value_changed_value = None def on_value_changed(value): self.value_changed_value = value # Nothing changed, no update triggered runslider.on_current_edit_editing_finished() assert self.value_changed_value is None runslider.value_changed.connect(on_value_changed) current = runslider.get_real_value() + 0.1 runslider.current_edit.setText(str(current)) runslider.on_current_edit_editing_finished() assert runslider.get_real_value() == current runslider.current_edit.setText('asf') with pytest.raises(RunSliderError): runslider.on_current_edit_editing_finished() def test_int(self, qtbot, runslider, scene): node = scene.create_node(scene.registry.create('read')) runslider.setup(node.model._view._properties['y'].view_item) assert runslider.type == int runslider.min_edit.setText('1') runslider.on_min_edit_editing_finished() runslider.max_edit.setText('10') runslider.on_max_edit_editing_finished() runslider.slider.setValue(50) assert type(runslider.get_real_value()) == int assert runslider.get_real_value() == 5 runslider.current_edit.setText('7') runslider.on_current_edit_editing_finished() assert type(runslider.get_real_value()) == int assert runslider.get_real_value() == 7 # Maximum smaller than current -> update current runslider.max_edit.setText('5') runslider.on_max_edit_editing_finished() assert type(runslider.get_real_value()) == int assert runslider.get_real_value() == 5 # Minimum greater than current -> update current runslider.max_edit.setText('10') runslider.on_max_edit_editing_finished() runslider.min_edit.setText('8') runslider.on_min_edit_editing_finished() assert type(runslider.get_real_value()) == int assert runslider.get_real_value() == 8 def test_range(self, qtbot, scene): runslider = RunSlider() qtbot.addWidget(runslider) node = scene.create_node(scene.registry.create('general_backproject')) node.model['center-position-x'] = [1, 2, 3] assert not runslider.setup(node.model._view._properties['center-position-x'].view_item) assert runslider.view_item is None node.model['center-position-x'] = [1] assert runslider.setup(node.model._view._properties['center-position-x'].view_item) assert runslider.view_item == node.model._view._properties['center-position-x'].view_item assert type(runslider.get_real_value()) == float assert runslider.get_real_value() == 1 runslider.current_edit.setText('1.1') runslider.on_current_edit_editing_finished() assert node.model['center-position-x'] == [runslider.get_real_value()] def test_links(self, qtbot, link_model, nodes): runslider = RunSlider() qtbot.addWidget(runslider) read = nodes['read'] read_2 = nodes['read_2'] runslider.setup(read.model._view._properties['number'].view_item) link_model.add_item(read, read.model, 'number', -1, -1) link_model.add_item(read_2, read_2.model, 'number', 0, -1) runslider.current_edit.setText('123') runslider.on_current_edit_editing_finished() assert read_2.model['number'] == 123 ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414563.0 ufo-tofu-0.13.0/tofu/tests/test_flow_scene.py0000664000175000017500000004203700000000000021514 0ustar00tomastomas00000000000000import pytest from PyQt5.QtCore import QModelIndex from PyQt5.QtWidgets import QInputDialog from qtpynodeeditor import FlowView from tofu.flow.models import BaseCompositeModel, UfoModelError, UfoReadModel from tofu.flow.util import FlowError, MODEL_ROLE, NODE_ROLE, PROPERTY_ROLE from tofu.tests.flow_util import add_nodes_to_scene class TestScene: def test_create_node(self, qtbot, scene): def check_node(node, gt_caption): # Node must be in the scene assert node in scene.nodes.values() # Caption must be unique assert node.model.caption == gt_caption # Node must be in the nodes model item = scene.node_model.findItems(node.model.caption)[0] assert item.data(role=MODEL_ROLE) == node.model nodes = add_nodes_to_scene(scene, model_names=['read', 'read']) for (node, gt_caption) in zip(nodes, ['Read', 'Read 2']): check_node(node, gt_caption) # Property links must be set up by composites def check_link(model, prop_name): assert scene.property_links_model.find_items((model, prop_name), (MODEL_ROLE, PROPERTY_ROLE)) scene.clear_scene() node = add_nodes_to_scene(scene, model_names=['CFlatFieldCorrect'])[0] for link in node.model._links: model_name, prop_name = link[0] other_name, other_prop_name = link[1] model = node.model[model_name] other = node.model[other_name] model[prop_name] = 0 qtbot.addWidget(model.embedded_widget()) qtbot.addWidget(other.embedded_widget()) qtbot.keyClick(model._view._properties[prop_name].view_item.widget, '1') # Other model's property has to be updated if the property links have been set up # correctly assert node.model[other_name][other_prop_name] == node.model[model_name][prop_name] def test_setstate(self, qtbot, scene): # Make sure there are some links by adding FFC nodes = add_nodes_to_scene(scene, model_names=['CFlatFieldCorrect', 'average']) # Create a connection scene.create_connection(nodes[0]['output'][0], nodes[1]['input'][0]) state = scene.__getstate__() scene.clear_scene() scene.__setstate__(state) assert scene.__getstate__() == state def test_getstate(self, qtbot, scene): # Make sure there are some links by adding FFC nodes = add_nodes_to_scene(scene, model_names=['CFlatFieldCorrect', 'average']) # Create a connection scene.create_connection(nodes[0]['output'][0], nodes[1]['input'][0]) state = scene.__getstate__() # Nodes ids = [record['id'] for record in state['nodes']] assert nodes[0].id in ids assert nodes[1].id in ids # Connections assert len(state['connections']) == 1 conn = state['connections'][0] assert conn['in_id'] == nodes[1].id assert conn['out_id'] == nodes[0].id # Property links assert state['property-links'] == scene.property_links_model.save() def test_restore_node(self, qtbot, monkeypatch, scene): add_nodes_to_scene(scene) old_node = list(scene.nodes.values())[0] state = old_node.__getstate__() scene.remove_node(old_node) new_node = scene.restore_node(state) # Don't test the nodes themselves because the models won't match assert old_node.id == new_node.id # num-inputs monkeypatch.setattr(QInputDialog, 'getInt', lambda *args, **kwargs: (2, True)) node = add_nodes_to_scene(scene, model_names=['retrieve_phase'])[0] state = node.__getstate__() scene.remove_node(node) new_node = scene.restore_node(state) assert new_node.model.num_ports['input'] == 2 def test_remove_node(self, monkeypatch, qtbot, scene, nodes): def cleanup(): self.cleanup_called = True node = add_nodes_to_scene(scene)[0] self.cleanup_called = False node.model.cleanup = cleanup scene.property_links_model.add_item(node, node.model, node.model.properties[0], 0, 0, QModelIndex()) scene.remove_node(list(scene.nodes.values())[0]) # Scene, node model and property links model must be empty assert len(scene.nodes) == 0 assert scene.node_model.rowCount() == 0 assert scene.property_links_model.rowCount() == 0 assert self.cleanup_called # Composite removal monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('cpm', True)) nodes = add_nodes_to_scene(scene, model_names=['pad', 'crop']) for node in nodes: node.graphics_object.setSelected(True) node = scene.create_composite() state = node.__getstate__() scene.remove_node(node) # _composite_nodes must be updated assert scene._composite_nodes == {} # Simulate non-interactive composite creation, i.e. not combining existing nodes into a # composite node. When removing such node, _composite_nodes must not raise a KeyError node = scene.restore_node(state) scene.remove_node(node) def test_is_selected_one_composite(self, qtbot, scene, monkeypatch): # Circumvent the input dialog monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('cpm', True)) nodes = add_nodes_to_scene(scene, model_names=['read', 'read']) for node in nodes: node.graphics_object.setSelected(True) # Simple nodes assert not scene.is_selected_one_composite() node = scene.create_composite() # Composite assert scene.is_selected_one_composite() node.graphics_object.setSelected(False) # Nothing selected assert not scene.is_selected_one_composite() # Composite and other selected add_nodes_to_scene(scene, ['null']) for node in scene.nodes.values(): node.graphics_object.setSelected(True) assert not scene.is_selected_one_composite() def test_skip_nodes(self, qtbot, scene): nodes = add_nodes_to_scene(scene, model_names=['read', 'pad', 'crop', 'null']) read, pad, crop, null = nodes scene.create_connection(read['output'][0], pad['input'][0]) scene.create_connection(pad['output'][0], crop['input'][0]) scene.create_connection(crop['output'][0], null['input'][0]) read.graphics_object.setSelected(True) # Only fully connected nodes can be disabled with pytest.raises(FlowError): scene.skip_nodes() read.graphics_object.setSelected(False) null.graphics_object.setSelected(True) with pytest.raises(FlowError): scene.skip_nodes() null.graphics_object.setSelected(False) pad.graphics_object.setSelected(True) scene.skip_nodes() assert pad.model.skip scene.skip_nodes() assert not pad.model.skip # Deprecation warning coming from imageio @pytest.mark.filterwarnings('ignore::DeprecationWarning') def test_auto_fill(self, qtbot, scene): add_nodes_to_scene(scene) with pytest.raises(UfoModelError): scene.auto_fill() def test_copy_node(self, qtbot, scene): nodes = add_nodes_to_scene(scene, model_names=['read', 'null']) scene.create_connection(nodes[0]['output'][0], nodes[1]['input'][0]) for node in nodes: node.graphics_object.setSelected(True) scene.copy_nodes() assert len(scene.nodes) == 4 # Choose the newly created connections if scene.connections[0].valid_ports['input'].node in nodes: ports = scene.connections[1].valid_ports else: ports = scene.connections[0].valid_ports # The fact that the connections are there means the nodes are there as well, so we don't # need to test that assert ports['input'].node.model.name == 'null' assert ports['output'].node.model.name == 'read' def test_create_composite(self, monkeypatch, qtbot, scene): monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('cpm', True)) plm = scene.property_links_model nodes = add_nodes_to_scene(scene, model_names=['read', 'read']) plm.add_item(nodes[0], nodes[0].model, nodes[0].model.properties[0], -1, -1, QModelIndex()) for (i, node) in enumerate(nodes): node.graphics_object.setSelected(True) node = scene.create_composite() assert node.model._links == [[[nodes[0].model.caption, nodes[0].model.properties[0]]]] with pytest.raises(FlowError): # Can't create a composite with the same name scene.create_composite() assert len(scene.nodes) == 1 assert list(scene.nodes.values())[0] == node assert isinstance(node.model, BaseCompositeModel) assert nodes[0] not in scene.nodes.values() assert nodes[1] not in scene.nodes.values() # Property links model assert plm.item(0, 0).data(role=NODE_ROLE) == node # Simulate non-interactive composite creation, i.e. not combining existing nodes into a # composite node. In this case it can't be possible to create a new composite node with the # same name as has already been registered. state = node.__getstate__() scene.remove_node(node) node = scene.restore_node(state) node.graphics_object.setSelected(True) with pytest.raises(FlowError): scene.create_composite() # Add outer composite with a composite and another simple model inside and set the property # links between the inner composite and inner simple. They must be present in the newly # craeted outer composite. average = add_nodes_to_scene(scene, model_names=['average'])[0] average.graphics_object.setSelected(True) monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('outer', True)) plm.add_item(node, node.model['Read'], 'number', 0, 0) plm.add_item(node, node.model['Read 2'], 'number', 0, 1) plm.add_item(average, average.model, 'number', 0, 2) outer = scene.create_composite() assert len(outer.model._links) == 1 row = outer.model._links[0] assert [node.model.caption, node.model['Read'].caption, 'number'] in row assert [node.model.caption, node.model['Read 2'].caption, 'number'] in row assert [average.model.caption, 'number'] in row assert [node.model.caption, node.model['Read'].caption, 'height'] not in row def test_on_node_double_clicked(self, qtbot, scene, monkeypatch): def double_clicked(*args): self.did_click = True self.did_click = False monkeypatch.setattr(UfoReadModel, "double_clicked", double_clicked) node = add_nodes_to_scene(scene)[0] # We need a view for double clicks _ = FlowView(scene) scene.on_node_double_clicked(node) assert self.did_click def test_expand_composite(self, qtbot, scene, monkeypatch): monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('cpm', True)) plm = scene.property_links_model nodes = add_nodes_to_scene(scene, model_names=['read', 'null']) name_to_caption = {'read': 'Read', 'null': 'Null'} for node in nodes: node.graphics_object.setSelected(True) node = scene.create_composite() path = node.model.get_leaf_paths()[0] plm.add_item(node, path[-1], path[-1].properties[0], -1, -1, QModelIndex()) scene.expand_composite(node) assert plm.item(0, 0).data(role=MODEL_ROLE).name == path[-1].name assert plm.item(0, 0).data(role=NODE_ROLE) in [node for node in scene.selected_nodes()] # Captions are the same for node in scene.nodes.values(): assert node.model.caption == name_to_caption[node.model.name] # New caption if there is a node with the original one # Selection stays, just re-use the expanded nodes monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('cpm_2', True)) node = scene.create_composite() other_read_node = add_nodes_to_scene(scene, model_names=['read'])[0] scene.expand_composite(node) for node in scene.nodes.values(): if node.model.name == 'read': if node == other_read_node: assert node.model.caption == 'Read' else: assert node.model.caption == 'Read 2' def test_is_fully_connected(self, qtbot, scene): nodes = add_nodes_to_scene(scene, model_names=['read', 'pad', 'crop', 'null']) read, pad, crop, null = nodes scene.create_connection(read['output'][0], pad['input'][0]) scene.create_connection(pad['output'][0], crop['input'][0]) scene.create_connection(crop['output'][0], null['input'][0]) assert scene.is_fully_connected() scene.remove_node(read) assert not scene.is_fully_connected() def test_are_all_ufo_tasks(self, qtbot, scene): add_nodes_to_scene(scene, model_names=['read', 'pad', 'crop', 'null']) assert scene.are_all_ufo_tasks() scene.create_node(scene.registry.create('memory_out')) assert not scene.are_all_ufo_tasks() def test_get_simple_node_graphs(self, qtbot, scene, monkeypatch): def connect(read, pad, crop, null): scene.create_connection(read['output'][0], pad['input'][0]) scene.create_connection(pad['output'][0], crop['input'][0]) scene.create_connection(crop['output'][0], null['input'][0]) monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('cpm', True)) nodes = add_nodes_to_scene(scene, model_names=2 * ['read', 'pad', 'crop', 'null']) read, pad, crop, null = nodes[:4] read_2, pad_2, crop_2, null_2 = nodes[4:] connect(read, pad, crop, null) connect(read_2, pad_2, crop_2, null_2) connections = [('Read', 'Pad'), ('Pad', 'Crop'), ('Crop', 'Null'), ('Read 2', 'Pad 2'), ('Pad 2', 'Crop 2'), ('Crop 2', 'Null 2')] graphs = scene.get_simple_node_graphs() assert len(graphs) == 2 num_visited = 0 for graph in graphs: for (src, dst, index) in graph.edges: assert (src.caption, dst.caption) in connections num_visited += 1 assert num_visited == len(connections) # Create first composite for node in nodes: if node.model.name in ['pad', 'crop']: node.graphics_object.setSelected(True) scene.create_composite() # Create a second composite which will cause the scene to have multiple edges between two # nodes (the first composite's outputs and second's inputs) scene.clearSelection() monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('cpm_2', True)) null.graphics_object.setSelected(True) null_2.graphics_object.setSelected(True) scene.create_composite() # Composite must not affect simple graphs, especially the multiple edges cannot be present # anymore graphs = scene.get_simple_node_graphs() assert len(graphs) == 2 num_visited = 0 for graph in graphs: for (src, dst, index) in graph.edges: assert (src.caption, dst.caption) in connections num_visited += 1 assert num_visited == len(connections) add_nodes_to_scene(scene) assert len(scene.get_simple_node_graphs()) == 3 # Test disabling nodes scene.clear_scene() nodes = add_nodes_to_scene(scene, model_names=['read', 'pad', 'crop', 'null']) read, pad, crop, null = nodes connect(read, pad, crop, null) # Disable padding, the generated flow must be read -> crop -> null pad.graphics_object.setSelected(True) scene.skip_nodes() graph = scene.get_simple_node_graphs()[0] assert len(graph.edges) == 2 edges = list(graph.edges) src, dst = edges[0][:-1] assert dst == crop.model src, dst = edges[1][:-1] assert src == crop.model assert dst == null.model def test_set_enabled(self, qtbot, scene): def check(enabled): assert scene.allow_node_creation == enabled assert scene.allow_node_deletion == enabled for node in scene.nodes.values(): assert node._graphics_obj.isEnabled() == enabled for conn in scene.connections: assert conn._graphics_object.isEnabled() == enabled nodes = add_nodes_to_scene(scene, model_names=['CFlatFieldCorrect', 'average']) nodes[0].graphics_object.setSelected(True) # Create a connection scene.create_connection(nodes[0]['output'][0], nodes[1]['input'][0]) scene.set_enabled(False) check(False) scene.set_enabled(True) check(True) assert nodes[0].graphics_object.isSelected() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1683790582.0 ufo-tofu-0.13.0/tofu/tests/test_flow_util.py0000664000175000017500000000366700000000000021402 0ustar00tomastomas00000000000000import pytest from PyQt5.QtWidgets import QInputDialog from tofu.flow.util import CompositeConnection, get_config_key, saved_kwargs from tofu.flow.main import get_filled_registry def test_get_config_key(): # Existing key assert 'z' in get_config_key('models', 'general-backproject', 'hidden-properties') # Non-existent key assert get_config_key('foobarbaz') is None assert get_config_key('foobarbaz', default=1) == 1 def test_saved_kwargs(qtbot, monkeypatch, scene): registry = get_filled_registry() name = 'retrieve_phase' # No num-inputs info monkeypatch.setattr(QInputDialog, 'getInt', lambda *args, **kwargs: (2, True)) state = {'name': name} model = registry.create(name) assert model.num_ports['input'] == 2 # num-inputs specified state = {'name': name, 'num-inputs': 3} with saved_kwargs(registry, state): model = registry.create(name) assert model.num_ports['input'] == 3 assert 'num_inputs' not in registry.registered_model_creators()[state['name']][1] class TestCompositeConnection: def test_init(self): # Identical source and tartet -> exception with pytest.raises(ValueError): CompositeConnection('a', 0, 'a', 0) # OK, must pass CompositeConnection('a', 0, 'b', 0) def test_contains(self): conn = CompositeConnection('a', 0, 'b', 0) assert conn.contains('a', 'output', 0) assert not conn.contains('a', 'output', 1) assert not conn.contains('a', 'input', 0) assert not conn.contains('a', 'input', 1) assert conn.contains('b', 'input', 0) assert not conn.contains('b', 'input', 1) assert not conn.contains('b', 'output', 0) assert not conn.contains('b', 'output', 1) assert not conn.contains('foo', 'input', 14) def test_save(self): conn = CompositeConnection('a', 0, 'b', 0) assert conn.save() == ['a', 0, 'b', 0] ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1683790582.0 ufo-tofu-0.13.0/tofu/tests/test_flow_viewer.py0000664000175000017500000002723400000000000021722 0ustar00tomastomas00000000000000import pytest import numpy as np from PyQt5.QtGui import QValidator from tofu.flow.viewer import ImageLabel, ImageViewingError, ScreenImage, ImageViewer @pytest.fixture(scope='function') def screen_image(): image = np.arange(256, dtype=np.float32).reshape(16, 16) return ScreenImage(image=image) @pytest.fixture(scope='function') def viewer(qtbot): viewer = ImageViewer() viewer.images = np.ones((10, 16, 16)) viewer.popup() qtbot.addWidget(viewer._pg_window) return viewer class TestScreenImage: def test_image_setter(self): screen_image = ScreenImage() assert screen_image.image is None screen_image.image = np.random.normal(size=(8, 8)) assert screen_image.minimum is not None assert screen_image.maximum is not None assert screen_image.black_point is not None assert screen_image.white_point is not None def test_black_point_setter(self, screen_image): screen_image.black_point = 100 assert screen_image.black_point == 100 screen_image.white_point = 150 # Black point cannot be greater than white point with pytest.raises(ImageViewingError): screen_image.black_point = 200 def test_white_point_setter(self, screen_image): screen_image.white_point = 100 assert screen_image.white_point == 100 screen_image.black_point = 50 # White point cannot be smaller than black point with pytest.raises(ImageViewingError): screen_image.white_point = 0 def test_reset(self, screen_image): screen_image.reset() # We can test with ==, the data types are the same so the extrema must be exactly the same assert screen_image.minimum == 0 assert screen_image.maximum == 255 assert screen_image.black_point == 0 assert screen_image.white_point == 255 # Going out of the original gray value range must not cause exception on reset screen_image.black_point = -100 screen_image.white_point = -50 screen_image.reset() def test_auto_levels(self, screen_image): screen_image.auto_levels() # nonsense values must pass as well screen_image.auto_levels(percentile=200.0) screen_image.auto_levels(percentile=-200.0) def test_set_black_point_normalized(self, screen_image): screen_image.set_black_point_normalized(100) assert screen_image.black_point == 100 screen_image.set_white_point_normalized(150) # Black point cannot be greater than white point with pytest.raises(ImageViewingError): screen_image.set_black_point_normalized(200) def test_set_white_point_normalized(self, screen_image): screen_image.set_white_point_normalized(100) assert screen_image.white_point == 100 screen_image.set_black_point_normalized(50) # White point cannot be smaller than black point with pytest.raises(ImageViewingError): screen_image.set_white_point_normalized(0) def test_convert_normalized_value_to_native(self, screen_image): assert screen_image.convert_normalized_value_to_native(128) == 128. with pytest.raises(ImageViewingError): screen_image.convert_normalized_value_to_native(-500) with pytest.raises(ImageViewingError): screen_image.convert_normalized_value_to_native(500) def test_convert_native_value_to_normalized(self, screen_image): assert screen_image.convert_native_value_to_normalized(128) == 128. with pytest.raises(ImageViewingError): screen_image.convert_native_value_to_normalized(-500) with pytest.raises(ImageViewingError): screen_image.convert_native_value_to_normalized(500) # One gray value must not cause division by zero erro screen_image.image = np.ones((4, 4)) screen_image.reset() screen_image.convert_native_value_to_normalized(1) def test_get_pixmap(self, qtbot, screen_image): # Empty image must raise an exception with pytest.raises(ImageViewingError): ScreenImage().get_pixmap() # Downsampling pixmap = screen_image.get_pixmap() assert (pixmap.height(), pixmap.width()) == screen_image.image.shape pixmap = screen_image.get_pixmap(downsampling=2) assert (pixmap.height(), pixmap.width()) == tuple(dim // 2 for dim in screen_image.image.shape) # One gray value must not cause division by zero erro screen_image.image = np.ones((4, 4)) screen_image.reset() screen_image.get_pixmap() class TestImageLabel: def test_updateImage(self, qtbot, screen_image): label = ImageLabel() # Empty image must pass label.updateImage() label.screen_image = screen_image label.updateImage() assert label.pixmap() is not None def test_resizeEvent(self, qtbot, screen_image): label = ImageLabel(screen_image) label.updateImage() old_size = label.pixmap().size() # ensure the label will get the resize event label.show() label.resize(8, 8) new_size = label.pixmap().size() assert new_size != old_size class TestImageViewer: def test_images_setter(self, qtbot): viewer = ImageViewer() viewer.images = np.zeros((16, 16)) assert viewer.images.ndim == 3 assert viewer.slider.isHidden() assert float(viewer.min_slider_edit.text()) == 0 assert float(viewer.max_slider_edit.text()) == 0 viewer.images = np.ones((5, 16, 16)) assert viewer.images.ndim == 3 assert not viewer.slider.isHidden() assert viewer.slider.minimum() == 0 assert viewer.slider.maximum() == viewer.images.shape[0] - 1 assert float(viewer.min_slider_edit.text()) == 1 assert float(viewer.max_slider_edit.text()) == 1 # Test viewer and popup window equality viewer.popup() qtbot.addWidget(viewer._pg_window) np.testing.assert_almost_equal(viewer.images, 1) np.testing.assert_almost_equal(viewer._pg_window.image, 1) # 3D viewer.images = np.ones((5, 16, 16)) * 5 np.testing.assert_almost_equal(viewer.images, 5) np.testing.assert_almost_equal(viewer._pg_window.image, 5) # 2D viewer.images = np.ones((16, 16)) * 3 np.testing.assert_almost_equal(viewer.images, 3) np.testing.assert_almost_equal(viewer._pg_window.image, 3) # validators viewer.images = 10 + np.arange(200 * 8 ** 2).reshape(200, 8, 8) validator = viewer.slider_edit.validator() assert validator.validate('199', 0)[0] == QValidator.Acceptable assert validator.validate('2000', 0)[0] == QValidator.Invalid assert viewer.min_slider_edit.validator().bottom() == viewer.images[0].min() assert viewer.min_slider_edit.validator().top() == viewer.images[0].max() viewer._pg_window.close() def test_append(self, qtbot): viewer = ImageViewer() # Append to empty viewer.append(np.zeros((4, 4))) assert viewer.images.ndim == 3 assert viewer.images.shape == (1, 4, 4) # Append 2D viewer.append(np.zeros((4, 4))) assert viewer.images.shape == (2, 4, 4) # Append 3D viewer.append(np.zeros((3, 4, 4))) assert viewer.images.shape == (5, 4, 4) # Append wrong shape with pytest.raises(ImageViewingError): viewer.append(np.zeros((3, 2, 2))) def test_set_enabled_adjustments(self, qtbot): viewer = ImageViewer() def assert_all(value): viewer.set_enabled_adjustments(value) assert viewer.slider.isEnabled() == value assert viewer.slider_edit.isEnabled() == value assert viewer.min_slider.isEnabled() == value assert viewer.min_slider_edit.isEnabled() == value assert viewer.max_slider.isEnabled() == value assert viewer.max_slider_edit.isEnabled() == value assert_all(True) assert_all(False) def test_reset_clim(self, viewer): image = np.arange(16 ** 2).reshape(16, 16) viewer.images = image viewer.append(image * 2) viewer.slider.setValue(1) viewer.reset_clim(auto=False) si = viewer.screen_image min_converted = si.convert_native_value_to_normalized(si.black_point) max_converted = si.convert_native_value_to_normalized(si.white_point) assert viewer.screen_image.maximum == pytest.approx(510) assert viewer.min_slider.value() == pytest.approx(min_converted) assert viewer.max_slider.value() == pytest.approx(max_converted) assert float(viewer.min_slider_edit.text()) == pytest.approx(si.black_point) assert float(viewer.max_slider_edit.text()) == pytest.approx(si.white_point) # Pop up window must be updated assert viewer._pg_window.getLevels() == pytest.approx((si.black_point, si.white_point)) viewer._pg_window.close() def test_on_slider_value_changed(self, viewer): viewer.slider.setValue(5) assert viewer._pg_window.currentIndex == 5 assert viewer.slider_edit.text() == '5' viewer._pg_window.close() def test_on_slider_edit_return_pressed(self, viewer): viewer.slider_edit.setText('5') viewer.slider_edit.returnPressed.emit() assert viewer.slider.value() == 5 assert viewer._pg_window.currentIndex == 5 viewer._pg_window.close() def test_on_min_slider_edit_return_pressed(self, viewer): viewer.images = np.arange(256).reshape(16, 16) viewer.min_slider_edit.setText('100') viewer.min_slider_edit.returnPressed.emit() assert viewer.screen_image.black_point == pytest.approx(100) assert viewer.min_slider.value() == pytest.approx(100) assert viewer._pg_window.getLevels()[0] == pytest.approx(100) viewer._pg_window.close() def test_on_max_slider_edit_return_pressed(self, viewer): viewer.images = np.arange(256).reshape(16, 16) viewer.max_slider_edit.setText('100') viewer.max_slider_edit.returnPressed.emit() assert viewer.screen_image.white_point == pytest.approx(100) assert viewer.max_slider.value() == pytest.approx(100) assert viewer._pg_window.getLevels()[1] == pytest.approx(100) viewer._pg_window.close() def test_on_min_slider_value_changed(self, viewer): viewer.images = np.arange(256).reshape(16, 16) viewer.min_slider.valueChanged.emit(100) assert viewer.screen_image.black_point == pytest.approx(100) assert float(viewer.min_slider_edit.text()) == pytest.approx(100) assert viewer._pg_window.getLevels()[0] == pytest.approx(100) viewer._pg_window.close() def test_on_max_slider_value_changed(self, viewer): viewer.images = np.arange(256).reshape(16, 16) viewer.max_slider.valueChanged.emit(100) assert viewer.screen_image.white_point == pytest.approx(100) assert float(viewer.max_slider_edit.text()) == pytest.approx(100) assert viewer._pg_window.getLevels()[1] == pytest.approx(100) viewer._pg_window.close() def test_popup(self, qtbot, viewer): # Close and another popup call must show the widget viewer._pg_window.close() viewer.popup() assert viewer._pg_window.isVisible() # 2D must work other = ImageViewer() other.images = np.ones((4, 4)) other.popup() qtbot.addWidget(other._pg_window) assert other._pg_window is not None viewer._pg_window.close() other._pg_window.close() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414563.0 ufo-tofu-0.13.0/tofu/util.py0000664000175000017500000003071400000000000016143 0ustar00tomastomas00000000000000"""Various utility functions.""" import argparse import gi import glob import logging import math import os from collections import OrderedDict gi.require_version('Ufo', '0.0') from gi.repository import Ufo LOG = logging.getLogger(__name__) RESOURCES = None def range_list(value): """ Split *value* separated by ':' into int triple, filling missing values with 1s. """ def check(region): if region[0] >= region[1]: raise argparse.ArgumentTypeError("{} must be less than {}".format(region[0], region[1])) lst = [int(x) for x in value.split(':')] if len(lst) == 1: frm = lst[0] return (frm, frm + 1, 1) if len(lst) == 2: check(lst) return (lst[0], lst[1], 1) if len(lst) == 3: check(lst) return (lst[0], lst[1], lst[2]) raise argparse.ArgumentTypeError("Cannot parse {}".format(value)) def make_subargs(args, subargs): """Return an argparse.Namespace consisting of arguments from *args* which are listed in the *subargs* list.""" namespace = argparse.Namespace() for subarg in subargs: setattr(namespace, subarg, getattr(args, subarg)) return namespace def set_node_props(node, args): """Set up *node*'s properties to *args* which is a dictionary of values.""" for name in dir(node.props): if not name.startswith('_') and hasattr(args, name): value = getattr(args, name) if value is not None: LOG.debug("Setting {}:{} to {}".format(node.get_plugin_name(), name, value)) node.set_property(name, getattr(args, name)) def get_filenames(path): """ Get all filenams from *path*, which could be a directory or a pattern for matching files in a directory. """ if not path: return [] return sorted(glob.glob(os.path.join(path, '*') if os.path.isdir(path) else path)) def setup_read_task(task, path, args): """Set up *task* and take care of handling file types correctly.""" task.props.path = path fnames = get_filenames(path) if fnames and fnames[0].endswith('.raw'): if not args.width or not args.height: raise RuntimeError("Raw files require --width, --height and --bitdepth arguments.") task.props.raw_width = args.width task.props.raw_height = args.height task.props.raw_bitdepth = args.bitdepth def restrict_value(limits, dtype=float): """Convert value to *dtype* and make sure it is within *limits* (included) specified as tuple (min, max). If one of the tuple values is None it is ignored.""" def check(value=None, clamp=False): if value is None: return limits result = dtype(value) if limits[0] is not None and result < limits[0]: if clamp: result = dtype(limits[0]) else: raise argparse.ArgumentTypeError('Value cannot be less than {}'.format(limits[0])) if limits[1] is not None and result > limits[1]: if clamp: result = dtype(limits[1]) else: raise argparse.ArgumentTypeError('Value cannot be greater than {}'.format(limits[1])) return result return check def convert_filesize(value): multiplier = 1 conv = OrderedDict((('k', 2 ** 10), ('m', 2 ** 20), ('g', 2 ** 30), ('t', 2 ** 40))) if not value[-1].isdigit(): if value[-1] not in list(conv.keys()): raise argparse.ArgumentTypeError('--output-bytes-per-file must either be a ' + 'number or end with {} '.format(list(conv.keys())) + 'to indicate kilo, mega, giga or terabytes') multiplier = conv[value[-1]] value = value[:-1] value = int(float(value) * multiplier) if value < 0: raise argparse.ArgumentTypeError('--output-bytes-per-file cannot be less than zero') return value def tupleize(num_items=None, conv=float, dtype=tuple): """Convert comma-separated string values to a *num-items*-tuple of values converted with *conv*. """ def split_values(value=None): """Convert comma-separated string *value* to a tuple of numbers.""" if not value: # empty value or string return dtype([]) if type(value) is float or type(value) is int: return dtype([value]) try: result = dtype([conv(x) for x in value.split(',')]) except: raise argparse.ArgumentTypeError('Expect comma-separated tuple') if num_items and len(result) != num_items: raise argparse.ArgumentTypeError('Expected {} items'.format(num_items)) return result return split_values def next_power_of_two(number): """Compute the next power of two of the *number*.""" return 2 ** int(math.ceil(math.log(number, 2))) def read_image(filename): """Read image from file *filename*.""" if filename.lower().endswith('.tif') or filename.lower().endswith('.tiff'): from tifffile import TiffFile with TiffFile(filename) as tif: return tif.asarray(out='memmap') elif '.edf' in filename.lower(): import fabio edf = fabio.edfimage.edfimage() edf.read(filename) return edf.data else: raise ValueError('Unsupported image format') def write_image(filename, image): import tifffile directory = os.path.dirname(filename) os.makedirs(directory, exist_ok=True) tifffile.imwrite(filename, image) def get_image_shape(filename): """Determine image shape (numpy order) from file *filename*.""" if filename.lower().endswith('.tif') or filename.lower().endswith('.tiff'): from tifffile import TiffFile with TiffFile(filename) as tif: page = tif.pages[0] shape = (page.imagelength, page.imagewidth) if len(tif.pages) > 1: shape = (len(tif.pages),) + shape else: # fabio doesn't seem to be able to read the shape without reading the data shape = read_image(filename).shape return shape def get_first_filename(path): """Returns the first valid image filename in *path*.""" if not path: raise RuntimeError("Path to sinograms or projections not set.") filenames = get_filenames(path) if not filenames: raise RuntimeError("No files found in `{}'".format(path)) return filenames[0] def determine_shape(args, path=None, store=False, do_raise=False): """Determine input shape from *args* which means either width and height are specified in args or try to read the *path* and determine the shape from it. The default path is args.projections, which is the typical place to find the input. If *store* is True, assign the determined values if they aren't already present in *args*. Return a tuple (width, height). If *do_raise* is True, raise an exception if shape cannot be determined. """ width = args.width height = args.height if not (width and height): filename = get_first_filename(path or args.projections) try: shape = get_image_shape(filename) # Now set the width and height if not specified width = width or shape[-1] height = height or shape[-2] except Exception as exc: LOG.info("Couldn't determine image dimensions from '{}'".format(filename)) if do_raise: raise exc if store: if not args.width: args.width = width if not args.height: args.height = height - args.y return (width, height) def get_filtering_padding(width): """Get the number of horizontal padded pixels in order to avoid convolution artifacts.""" return next_power_of_two(2 * width) - width def setup_padding(pad, width, height, mode, crop=None, pad_width=None, pad_height=0, centered=True): if pad_width is not None and pad_width < 0: raise ValueError("pad_width must be >= 0") if pad_height < 0: raise ValueError("pad_height must be >= 0") if pad_width is None: # Default is horizontal padding only pad_width = get_filtering_padding(width) pad.props.width = width + pad_width pad.props.height = height + pad_height pad.props.x = pad_width // 2 if centered else 0 pad.props.y = pad_height // 2 if centered else 0 pad.props.addressing_mode = mode LOG.debug( "Padding (x=0, y=0, w=%d, h=%d) -> (x=%d, y=%d, w=%d, h=%d) with mode `%s'", width, height, pad.props.x, pad.props.y, pad.props.width, pad.props.height, mode, ) if crop: # crop to original width after filtering crop.props.width = width crop.props.height = height crop.props.x = pad_width // 2 if centered else 0 crop.props.y = pad_height // 2 if centered else 0 return (pad_width, pad_height) def make_region(n, dtype=int): """Make region in such a way that in case of odd *n* it is centered around 0. Use *dtype* as data type. """ return (-dtype(n / 2), dtype(n / 2 + n % 2), dtype(1)) def get_reconstructed_cube_shape(x_region, y_region, z_region): """Get the shape of the reconstructed cube as (slice width, slice height, num slices).""" import numpy as np z_start, z_stop, z_step = z_region y_start, y_stop, y_step = y_region x_start, x_stop, x_step = x_region num_slices = len(np.arange(z_start, z_stop, z_step)) slice_height = len(np.arange(y_start, y_stop, y_step)) slice_width = len(np.arange(x_start, x_stop, x_step)) return slice_width, slice_height, num_slices def get_reconstruction_regions(params, store=False, dtype=int): """Compute reconstruction regions along all three axes, use *dtype* for as data type for x and y regions, z region is always float. """ width, height = determine_shape(params) if getattr(params, 'transpose_input', False): # In case down the pipeline there is a transpose task tmp = width width = height height = tmp if params.x_region[1] == -1: x_region = make_region(width, dtype=dtype) else: x_region = params.x_region if params.y_region[1] == -1: y_region = make_region(width, dtype=dtype) else: y_region = params.y_region if params.region[1] == -1: region = make_region(height, dtype=float) else: region = params.region LOG.info('X region: {}'.format(x_region)) LOG.info('Y region: {}'.format(y_region)) LOG.info('Parameter region: {}'.format(region)) if store: params.x_region = x_region params.y_region = y_region params.region = region return x_region, y_region, region def get_scarray_value(scarray, index): if len(scarray) == 1: return scarray[0] return scarray[index] def run_scheduler(scheduler, graph): from threading import Thread # Reuse resources until https://github.com/ufo-kit/ufo-core/issues/191 is solved. global RESOURCES if not RESOURCES: RESOURCES = Ufo.Resources() scheduler.set_resources(RESOURCES) thread = Thread(target=scheduler.run, args=(graph,)) thread.setDaemon(True) thread.start() try: thread.join() return True except KeyboardInterrupt: LOG.info('Processing interrupted') scheduler.abort() return False def fbp_filtering_in_phase_retrieval(args): if args.energy is None or args.propagation_distance is None: # No phase retrieval at all return False return ( args.projection_filter != 'none' and ( args.retrieval_method != 'tie' or args.tie_approximate_logarithm ) ) class Vector(object): """A vector based on axis-angle representation.""" def __init__(self, x_angle=0, y_angle=0, z_angle=0, position=None): import numpy as np self.position = np.array(position, dtype=float) if position is not None else None self.x_angle = x_angle self.y_angle = y_angle self.z_angle = z_angle def __repr__(self): return 'Vector(position={}, angles=({}, {}, {}))'.format(self.position, self.x_angle, self.y_angle, self.z_angle) def __str__(self): return repr(self) ././@PaxHeader0000000000000000000000000000003300000000000011451 xustar000000000000000027 mtime=1698416097.773776 ufo-tofu-0.13.0/tofu/vis/0000775000175000017500000000000000000000000015410 5ustar00tomastomas00000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1665760162.0 ufo-tofu-0.13.0/tofu/vis/__init__.py0000664000175000017500000000000000000000000017507 0ustar00tomastomas00000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698414563.0 ufo-tofu-0.13.0/tofu/vis/qt.py0000664000175000017500000001337200000000000016414 0ustar00tomastomas00000000000000import pyqtgraph as pg try: import pyqtgraph.opengl as gl except ImportError: pass import logging import numpy as np import tifffile from PyQt5 import QtCore, QtWidgets LOG = logging.getLogger(__name__) def read_tiff(filename): tiff = tifffile.TiffFile(filename) array = tiff.asarray() return array.T def remove_extrema(data): upper = np.percentile(data, 99) lower = np.percentile(data, 1) data[data > upper] = upper data[data < lower] = lower return data def create_volume(data): gradient = (data - np.roll(data, 1))**2 cmin = gradient.min() div = gradient.max() - cmin gradient = (gradient - cmin) / div * 255 volume = np.empty(data.shape + (4, ), dtype=np.ubyte) volume[..., 0] = data volume[..., 1] = data volume[..., 2] = data volume[..., 3] = gradient return volume class ImageViewer(QtWidgets.QWidget): """ Present a sequence of files that can be browsed with a slider. To get the currently selected position connect to the *slider* attribute's valueChanged signal. """ def __init__(self, filenames, parent=None): super(ImageViewer, self).__init__(parent) image_view = pg.ImageView() image_view.getView().setAspectLocked(True) self.image_item = image_view.getImageItem() self.slider = QtWidgets.QSlider(QtCore.Qt.Horizontal) self.slider.valueChanged.connect(self.update_image) self.main_layout = QtWidgets.QVBoxLayout(self) self.main_layout.addWidget(image_view) self.main_layout.addWidget(self.slider) self.setLayout(self.main_layout) self.load_files(filenames) def load_files(self, filenames): """Load *filenames* for display.""" self.filenames = filenames self.slider.setRange(0, len(self.filenames) - 1) self.slider.setSliderPosition(0) self.update_image() def update_image(self): """Update the currently display image.""" if self.filenames: pos = self.slider.value() image = read_tiff(self.filenames[pos]) self.image_item.setImage(image) class ImageWindow(object): """ Stand-alone window to display image sequences. """ global_app = None def __init__(self, filenames): self.global_app = QtWidgets.QApplication.instance() or QtWidgets.QApplication([]) self.viewer = ImageViewer(filenames) self.viewer.show() class OverlapViewer(QtWidgets.QWidget): """ Presents two images by subtracting the flipped second from the first. To get the current deviation connect to the *slider* attribute's valueChanged signal. """ def __init__(self, parent=None, remove_extrema=False): super(OverlapViewer, self).__init__() image_view = pg.ImageView() image_view.getView().setAspectLocked(True) self.image_item = image_view.getImageItem() self.slider = QtWidgets.QSlider(QtCore.Qt.Horizontal) self.slider.setRange(0, 0) self.slider.valueChanged.connect(self.update_image) self.main_layout = QtWidgets.QVBoxLayout() self.main_layout.addWidget(image_view) self.main_layout.addWidget(self.slider) self.setLayout(self.main_layout) self.first, self.second = (None, None) self.remove_extrema = remove_extrema self.subtract = True def set_images(self, first, second): """Set *first* and *second* image.""" self.first, self.second = first.T, np.flipud(second.T) if self.remove_extrema: self.first = remove_extrema(self.first) self.second = remove_extrema(self.second) if self.first.shape != self.second.shape: LOG.warn("Shape {} of {} is different to {} of {}". format(self.first.shape, self.first, self.second.shape, self.second)) width = self.first.shape[0] self.slider.setRange(-width / 2, int(1.5 * width)) self.slider.setSliderPosition(self.first.shape[0] / 2) self.update_image() def set_position(self, position): self.slider.setValue(int(position)) self.update_image() def update_image(self): """Update the current subtraction.""" if self.first is None or self.second is None: LOG.warn("No images set yet") else: pos = self.slider.value() moved = np.roll(self.second, self.second.shape[0] // 2 - pos, axis=0) if self.subtract: self.image_item.setImage(moved - self.first) else: self.image_item.setImage(moved + self.first) class VolumeViewer(QtWidgets.QWidget): def __init__(self, step=1, density=1, parent=None): super(VolumeViewer, self).__init__(parent) self.volume_view = gl.GLViewWidget() self.main_layout = QtWidgets.QVBoxLayout() self.main_layout.addWidget(self.volume_view) self.setLayout(self.main_layout) self.step = step self.density = density def load_files(self, filenames): """Load *filenames* for display.""" filenames = filenames[::self.step] num = len(filenames) first = read_tiff(filenames[0])[::self.step, ::self.step] width, height = first.shape data = np.empty((width, height, num), dtype=np.float32) data[:,:,0] = first for i, filename in enumerate(filenames[1:]): data[:, :, i + 1] = read_tiff(filename)[::self.step, ::self.step] volume = create_volume(data) dx, dy, dz, _ = volume.shape volume_item = gl.GLVolumeItem(volume, sliceDensity=self.density) volume_item.translate(-dx / 2, -dy / 2, -dz / 2) volume_item.scale(0.05, 0.05, 0.05, local=False) self.volume_view.addItem(volume_item) ././@PaxHeader0000000000000000000000000000003300000000000011451 xustar000000000000000027 mtime=1698416097.773776 ufo-tofu-0.13.0/ufo_tofu.egg-info/0000775000175000017500000000000000000000000017152 5ustar00tomastomas00000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698416097.0 ufo-tofu-0.13.0/ufo_tofu.egg-info/PKG-INFO0000664000175000017500000001123500000000000020251 0ustar00tomastomas00000000000000Metadata-Version: 2.1 Name: ufo-tofu Version: 0.13.0 Summary: A fast, versatile and user-friendly image processing toolkit for computed tomography Home-page: http://github.com/ufo-kit/tofu Author: Matthias Vogelgesang Author-email: matthias.vogelgesang@kit.edu License: LGPL Requires-Python: >=3 Description-Content-Type: text/markdown License-File: LICENSE ## About [![PyPI version](https://badge.fury.io/py/ufo-tofu.png)](http://badge.fury.io/py/ufo-tofu) [![Documentation status](https://readthedocs.org/projects/tofu/badge/?version=latest)](http://tofu.readthedocs.io/en/latest/?badge=latest) This repository contains Python data processing scripts to be used with the UFO framework. At the moment they are targeted at high-performance reconstruction of tomographic data sets. If you use this software for publishing your data, we kindly ask you to cite the article **Faragó, T., Gasilov, S., Emslie, I., Zuber, M., Helfen, L., Vogelgesang, M. & Baumbach, T. (2022). J. Synchrotron Rad. 29, https://doi.org/10.1107/S160057752200282X** If you want to stay updated, subscribe to our [newsletter](mailto:sympa@lists.kit.edu?subject=subscribe%20ufo%20YourFirstName%20YourLastName). Simply leave the body of the e-mail empty and in the subject change ``YourFirstName YourLastName`` accordingly. ## Installation First make sure you have [ufo-core](https://github.com/ufo-kit/ufo-core) and [ufo-filters](https://github.com/ufo-kit/ufo-filters) installed. For that, please follow the [installation instructions](https://ufo-core.readthedocs.io/en/latest/install/index.html). You can either install the prerequisites yourself on [Linux](https://ufo-core.readthedocs.io/en/latest/install/linux.html), or use one of our [Docker containers](https://ufo-core.readthedocs.io/en/latest/install/docker.html). Then, for the newest version run the following in *tofu*'s top directory: pip install . or to install via PyPI: pip install ufo-tofu in a prepared virtualenv or as root for system-wide installation. Note, that if you do plan to use the graphical user interface you need PyQt5, pyqtgraph and PyOpenGL. You are strongly advised to install PyQt through your system package manager, you can install pyqtgraph and PyOpenGL using the pip package manager though: pip install pyqtgraph PyOpenGL ## Usage ### Flow `tofu flow` is a visual flow programming tool. You can create a flow by using any task from [ufo-filters](https://github.com/ufo-kit/ufo-filters) and execute it. In includes visualization of 2D and 3D results, so you can quickly check the output of your flow, which is useful for finding algorithm parameters. ![flow](https://user-images.githubusercontent.com/2648829/150096902-fdbf1b7e-b34e-4368-98ac-c924cad8a6cd.jpg) ### Reconstruction To do a tomographic reconstruction you simply call $ tofu tomo --sinograms $PATH_TO_SINOGRAMS from the command line. To get get correct results, you may need to append options such as `--axis-pos/-a` and `--angle-step/-a` (which are given in radians!). Input paths are either directories or glob patterns. Output paths are either directories or a format that contains one `%i` [specifier](http://www.pixelbeat.org/programming/gcc/format_specs.html): $ tofu tomo --axis-pos=123.4 --angle-step=0.000123 \ --sinograms="/foo/bar/*.tif" --output="/output/slices-%05i.tif" You can get a help for all options by running $ tofu tomo --help and more verbose output by running with the `-v/--verbose` flag. You can also load reconstruction parameters from a configuration file called `reco.conf`. You may create a template with $ tofu init Note, that options passed via the command line always override configuration parameters! Besides scripted reconstructions, one can also run a standalone GUI for both reconstruction and quick assessment of the reconstructed data via $ tofu gui ![GUI](https://cloud.githubusercontent.com/assets/115270/6442540/db0b55fe-c0f0-11e4-9577-0048fddae8b7.png) ### Performance measurement If you are running at least ufo-core/filters 0.6, you can evaluate the performance of the filtered backprojection (without sinogram transposition!), with $ tofu perf You can customize parameter scans, pretty easily via $ tofu perf --width 256:8192:256 --height 512 which will reconstruct all combinations of width between 256 and 8192 with a step of 256 and a fixed height of 512 pixels. ### Estimating the center of rotation If you do not know the correct center of rotation from your experimental setup, you can estimate it with: $ tofu estimate -i $PATH_TO_SINOGRAMS Currently, a modified algorithm based on the work of [Donath et al.](http://dx.doi.org/10.1364/JOSAA.23.001048) is used to determine the center. ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698416097.0 ufo-tofu-0.13.0/ufo_tofu.egg-info/SOURCES.txt0000664000175000017500000000465200000000000021045 0ustar00tomastomas00000000000000LICENSE MANIFEST.in README.md setup.py bin/tofu tofu/__init__.py tofu/config.py tofu/find_large_spots.py tofu/genreco.py tofu/gui.py tofu/gui.ui tofu/inpaint.py tofu/lamino.py tofu/preprocess.py tofu/reco.py tofu/tasks.py tofu/util.py tofu/ez/RR_external.py tofu/ez/__init__.py tofu/ez/ctdir_walker.py tofu/ez/evaluate_sharpness.py tofu/ez/find_axis_cmd_gen.py tofu/ez/image_read_write.py tofu/ez/main.py tofu/ez/params.py tofu/ez/tofu_cmd_gen.py tofu/ez/ufo_cmd_gen.py tofu/ez/util.py tofu/ez/yaml_in_out.py tofu/ez/GUI/__init__.py tofu/ez/GUI/ezufo_launcher.py tofu/ez/GUI/image_viewer.py tofu/ez/GUI/login_dialog.py tofu/ez/GUI/message_dialog.py tofu/ez/GUI/Advanced/__init__.py tofu/ez/GUI/Advanced/advanced.py tofu/ez/GUI/Advanced/ffc.py tofu/ez/GUI/Advanced/nlmdn.py tofu/ez/GUI/Advanced/optimization.py tofu/ez/GUI/Main/__init__.py tofu/ez/GUI/Main/batch_process.py tofu/ez/GUI/Main/centre_of_rotation.py tofu/ez/GUI/Main/config.py tofu/ez/GUI/Main/filters.py tofu/ez/GUI/Main/phase_retrieval.py tofu/ez/GUI/Main/region_and_histogram.py tofu/ez/GUI/Stitch_tools_tab/__init__.py tofu/ez/GUI/Stitch_tools_tab/auto_horizontal_stitch_funcs.py tofu/ez/GUI/Stitch_tools_tab/auto_horizontal_stitch_gui.py tofu/ez/GUI/Stitch_tools_tab/ez_360_multi_stitch_qt.py tofu/ez/GUI/Stitch_tools_tab/ez_360_overlap_qt.py tofu/ez/GUI/Stitch_tools_tab/ezmview_qt.py tofu/ez/GUI/Stitch_tools_tab/ezstitch_qt.py tofu/ez/Helpers/__init__.py tofu/ez/Helpers/find_360_overlap.py tofu/ez/Helpers/halfacqmode-mpi-stitch.py tofu/ez/Helpers/mview_main.py tofu/ez/Helpers/stitch_funcs.py tofu/flow/__init__.py tofu/flow/config.json tofu/flow/execution.py tofu/flow/filedirdialog.py tofu/flow/main.py tofu/flow/models.py tofu/flow/propertylinksmodels.py tofu/flow/propertylinkswidget.py tofu/flow/runslider.py tofu/flow/scene.py tofu/flow/util.py tofu/flow/viewer.py tofu/flow/composites/ffc-links.cm tofu/flow/composites/pr.cm tofu/tests/__init__.py tofu/tests/conftest.py tofu/tests/flow_util.py tofu/tests/test_flow_execution.py tofu/tests/test_flow_main.py tofu/tests/test_flow_models.py tofu/tests/test_flow_propertylinksmodels.py tofu/tests/test_flow_propertylinkswidget.py tofu/tests/test_flow_runslider.py tofu/tests/test_flow_scene.py tofu/tests/test_flow_util.py tofu/tests/test_flow_viewer.py tofu/vis/__init__.py tofu/vis/qt.py ufo_tofu.egg-info/PKG-INFO ufo_tofu.egg-info/SOURCES.txt ufo_tofu.egg-info/dependency_links.txt ufo_tofu.egg-info/requires.txt ufo_tofu.egg-info/top_level.txt././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698416097.0 ufo-tofu-0.13.0/ufo_tofu.egg-info/dependency_links.txt0000664000175000017500000000000100000000000023220 0ustar00tomastomas00000000000000 ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698416097.0 ufo-tofu-0.13.0/ufo_tofu.egg-info/requires.txt0000664000175000017500000000011000000000000021542 0ustar00tomastomas00000000000000PyGObject imageio numpy networkx PyQt5 pyqtconsole pyxdg qtpynodeeditor ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1698416097.0 ufo-tofu-0.13.0/ufo_tofu.egg-info/top_level.txt0000664000175000017500000000000500000000000021677 0ustar00tomastomas00000000000000tofu