Skip to content

Commit 6eaf90d

Browse files
Parameter optimization for phase and fluorescence reconstructions (mehta-lab#525)
* Add model-level reconstruct() convenience functions Each physics model (phase_thick_3d, isotropic_thin_3d, isotropic_fluorescent_thick_3d, isotropic_fluorescent_thin_3d, inplane_oriented_thick_pol3d) gets a reconstruct() function that chains generate_test_phantom + calculate_transfer_function + apply_transfer_function + apply_inverse_transfer_function. * Add per-model API layer (waveorder.api) Each physics model (birefringence, phase, fluorescence, birefringence_and_phase) gets its own module with Settings, simulate(), compute_transfer_function(), apply_inverse_transfer_function(), and reconstruct(). All API functions accept/return xr.DataArray/Dataset. Shared utilities in _settings.py and _utils.py. * Refactor CLI to delegate to per-model API modules CLI settings.py now re-exports from waveorder.api modules. compute_transfer_function.py and apply_inverse_transfer_function.py dispatch to birefringence, phase, fluorescence, or birefringence_and_phase API functions. Removes apply_inverse_models.py. * Add CLI simulate, view, and download commands New commands: wo sim, wo view, wo download-examples. Adds AliasGroup for short aliases (rec, sim, dle, gui). Updates main.py help formatting. Adds wo entry point. * Add API and CLI examples 6 API examples (birefringence, birefringence_and_phase, phase_2d, phase_3d, fluorescence_2d, fluorescence_3d) and matching CLI shell scripts with YAML config files. * Update tests for new API architecture Rewrite API tests to use per-model imports. Update channel names (BF -> Transmittance, Pol -> Depolarization). Add parametrized example runner tests for API and CLI examples. * Update model examples and configs README Align model example parameters (4-State scheme, yx_pixel_size=0.1, z_position_list sign convention). Update configs README. * Update internal callers to use API transfer function settings Replace BirefringenceTransferFunctionSettings, PhaseTransferFunctionSettings, FluorescenceTransferFunctionSettings with direct imports from waveorder.api modules in plugin, calibration, and test code. * Simplify API examples to only use the API Remove model-level imports (torch, waveorder.models, waveorder.api._utils) and comparison/verification code. Each example now shows only the API: settings, simulate, compute_transfer_function, apply_inverse_transfer_function, reconstruct, and print output. * Simplify configs README Replace outdated instructions, zenodo links, and TODO with a short note that configs are auto-generated by test_generate_example_settings. * Delete download-examples command Remove download.py, its import/registration/alias in main.py, and its reference in the view command docstring. * Move view to end of help, add [v] alias * Simplify test_examples.py with glob parametrization Replace explicit file lists with glob("*.py") and glob("*.sh") so new examples are picked up automatically. * Consolidate duplicated _make_czyx test helpers into shared fixture Replace 5 near-identical _make_czyx / _make_czyx_data / _make_czyx_xarray helpers across test files with a single make_czyx factory fixture in tests/conftest.py. Also remove duplicate test_position_list_from_shape_scale_offset from cli tests (superset exists in api tests). * Documented configs * Move configs README and delete old configs directory * Remove expensive contrast limit computation from CLI * smaller arrays to relive test bottleneck * Migrate to uv, ruff, and hatch (mehta-lab#523) * Migrate build system from setuptools to hatchling + uv-dynamic-versioning * Replace black/isort/autoflake with ruff config * Apply ruff formatting to codebase * Migrate CI workflows to uv * Migrate ReadTheDocs to uv * Update CONTRIBUTING.md for uv workflow * Generate and commit uv.lock * Fix trailing whitespace and missing newlines in non-Python files * Add formatting commits to git-blame-ignore-revs * Update install instructions across docs for uv and removed [all] extra * Update install instructions and keep [all] extra for backward compatibility [all] is now an alias for [visual] since dev/docs/test deps moved to [dependency-groups] (PEP 735). Extras ([project.optional-dependencies]) are published to PyPI for end users; dependency groups are local-only for contributors via uv sync --group <name>. * Update uv.lock * Use --only-group for minimal CI/RTD installs, add pre-commit to dev group * Add waveorder/optim/ module with OptimizableFloat type, losses, logging, and optimizer loop * Make optics differentiable: smooth pupil, tilted pupil on Ewald sphere, gradient-safe kernels * Make models differentiable: remove in-place ops, accept tensor params, add pseudo-SVD and tilt * Add OptimizableFloat and tilt fields to settings, split base vs optimizable settings classes * Add optimize() to phase and fluorescence API modules * Add CLI optimization support: OptimizationSettings, optimizable param detection, birefringence guard * Add tensorboard dependency and update example configs with tilt fields * Add optimization examples and tests * add comment for using your own data * Handle numpy input in generate_pupil for backwards compatibility * Clean up steepness defaults: use 1e4 consistently, fix docstring * Make tilt integration test check loss improvement instead of convergence tolerance * iohub depend on pypi * Reduce 3D midband_power_loss to scalar for optimizer * Use optimized config path for reconstruction after optimization * feat: Added to_device to be target cuda Signed-off-by: Sricharan Reddy Varra <sricharan.varra@biohub.org> * style: formatting Signed-off-by: Sricharan Reddy Varra <sricharan.varra@biohub.org> * Remove z_focus_offset="auto" option and auto-focus CLI code * Fix 3D optimize loss reduction and parameter name unpacking in phase.py Reduce per-slice (Z,) loss to scalar before backward pass, matching fluorescence.py. Use split(".")[-1] for dotted parameter names. * Revert device/CUDA changes and fix phase.py bugs * Add numpy-style docstrings to public API functions * Add 3D optimization integration tests for phase and fluorescence * Apply missing api-refactor changes: rename validators, remove vector_transfer_function * Fix generate_pupil callers to pass tensors instead of numpy arrays * Remove '(or auto)' from z_focus_offset config comments * bump iohub * cleaner docs on pupil * cleaner documentation of z_position_list and na_det * enable tilt_angle_{zenith,azimuth} for phase_thick_3d * remove unecessary branching logic * remove parallel SVD paths...prefer one-shot SVD when available * synchronize fluorescence and phase optimization loops * simplify loss handling * remove redundant _get helper * clean resolution of optimizable floats * clean and test optimization cli * [o] -> (optimizable) * document pupil steepness choice --------- Signed-off-by: Sricharan Reddy Varra <sricharan.varra@biohub.org> Co-authored-by: Sricharan Reddy Varra <sricharan.varra@biohub.org>
1 parent 185d5dd commit 6eaf90d

39 files changed

Lines changed: 2167 additions & 231 deletions

docs/examples/cli/configs/birefringence-and-phase_3d.yml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,12 @@ phase:
2020
yx_pixel_size: 0.1 # lateral pixel size in micrometers
2121
z_pixel_size: 0.25 # axial pixel size in micrometers
2222
z_padding: 0 # z slices to pad for axial boundary effects
23-
z_focus_offset: 0 # offset from center slice in slice units
23+
z_focus_offset: 0 # (optimizable) offset from center slice in slice units
2424
index_of_refraction_media: 1.3 # refractive index of imaging media
25-
numerical_aperture_detection: 1.2 # detection objective numerical aperture
26-
numerical_aperture_illumination: 0.9 # condenser numerical aperture
25+
numerical_aperture_detection: 1.2 # (optimizable) detection objective numerical aperture
26+
tilt_angle_zenith: 0.0 # (optimizable) illumination tilt zenith angle in radians
27+
tilt_angle_azimuth: 0.0 # (optimizable) illumination tilt azimuth angle in radians
28+
numerical_aperture_illumination: 0.9 # (optimizable) condenser numerical aperture
2729
invert_phase_contrast: false # invert contrast for positive/negative phase
2830
apply_inverse:
2931
reconstruction_algorithm: Tikhonov # 'Tikhonov' or 'TV' regularization

docs/examples/cli/configs/fluorescence_2d.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@ fluorescence:
77
yx_pixel_size: 0.1 # lateral pixel size in micrometers
88
z_pixel_size: 0.25 # axial pixel size in micrometers
99
z_padding: 0 # z slices to pad for axial boundary effects
10-
z_focus_offset: 0 # offset from center slice in slice units
10+
z_focus_offset: 0 # (optimizable) offset from center slice in slice units
1111
index_of_refraction_media: 1.3 # refractive index of imaging media
12-
numerical_aperture_detection: 1.2 # detection objective numerical aperture
12+
numerical_aperture_detection: 1.2 # (optimizable) detection objective numerical aperture
13+
tilt_angle_zenith: 0.0 # (optimizable) illumination tilt zenith angle in radians
14+
tilt_angle_azimuth: 0.0 # (optimizable) illumination tilt azimuth angle in radians
1315
wavelength_emission: 0.532 # emission wavelength in micrometers
1416
confocal_pinhole_diameter: null # confocal pinhole diameter (null = widefield)
1517
apply_inverse:

docs/examples/cli/configs/fluorescence_3d.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@ fluorescence:
77
yx_pixel_size: 0.1 # lateral pixel size in micrometers
88
z_pixel_size: 0.25 # axial pixel size in micrometers
99
z_padding: 0 # z slices to pad for axial boundary effects
10-
z_focus_offset: 0 # offset from center slice in slice units
10+
z_focus_offset: 0 # (optimizable) offset from center slice in slice units
1111
index_of_refraction_media: 1.3 # refractive index of imaging media
12-
numerical_aperture_detection: 1.2 # detection objective numerical aperture
12+
numerical_aperture_detection: 1.2 # (optimizable) detection objective numerical aperture
13+
tilt_angle_zenith: 0.0 # (optimizable) illumination tilt zenith angle in radians
14+
tilt_angle_azimuth: 0.0 # (optimizable) illumination tilt azimuth angle in radians
1315
wavelength_emission: 0.532 # emission wavelength in micrometers
1416
confocal_pinhole_diameter: null # confocal pinhole diameter (null = widefield)
1517
apply_inverse:

docs/examples/cli/configs/phase_2d.yml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@ phase:
88
yx_pixel_size: 0.1 # lateral pixel size in micrometers
99
z_pixel_size: 0.25 # axial pixel size in micrometers
1010
z_padding: 0 # z slices to pad for axial boundary effects
11-
z_focus_offset: 0 # offset from center slice in slice units
11+
z_focus_offset: 0 # (optimizable) offset from center slice in slice units
1212
index_of_refraction_media: 1.3 # refractive index of imaging media
13-
numerical_aperture_detection: 1.2 # detection objective numerical aperture
14-
numerical_aperture_illumination: 0.9 # condenser numerical aperture
13+
numerical_aperture_detection: 1.2 # (optimizable) detection objective numerical aperture
14+
tilt_angle_zenith: 0.0 # (optimizable) illumination tilt zenith angle in radians
15+
tilt_angle_azimuth: 0.0 # (optimizable) illumination tilt azimuth angle in radians
16+
numerical_aperture_illumination: 0.9 # (optimizable) condenser numerical aperture
1517
invert_phase_contrast: false # invert contrast for positive/negative phase
1618
apply_inverse:
1719
reconstruction_algorithm: Tikhonov # 'Tikhonov' or 'TV' regularization

docs/examples/cli/configs/phase_3d.yml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@ phase:
88
yx_pixel_size: 0.1 # lateral pixel size in micrometers
99
z_pixel_size: 0.25 # axial pixel size in micrometers
1010
z_padding: 0 # z slices to pad for axial boundary effects
11-
z_focus_offset: 0 # offset from center slice in slice units
11+
z_focus_offset: 0 # (optimizable) offset from center slice in slice units
1212
index_of_refraction_media: 1.3 # refractive index of imaging media
13-
numerical_aperture_detection: 1.2 # detection objective numerical aperture
14-
numerical_aperture_illumination: 0.9 # condenser numerical aperture
13+
numerical_aperture_detection: 1.2 # (optimizable) detection objective numerical aperture
14+
tilt_angle_zenith: 0.0 # (optimizable) illumination tilt zenith angle in radians
15+
tilt_angle_azimuth: 0.0 # (optimizable) illumination tilt azimuth angle in radians
16+
numerical_aperture_illumination: 0.9 # (optimizable) condenser numerical aperture
1517
invert_phase_contrast: false # invert contrast for positive/negative phase
1618
apply_inverse:
1719
reconstruction_algorithm: Tikhonov # 'Tikhonov' or 'TV' regularization

docs/examples/maintenance/PTI_simulation/PTI_Simulation_Forward_2D3D.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import matplotlib.pyplot as plt
2121
import numpy as np
22+
import torch
2223
from numpy.fft import fftshift
2324
from platformdirs import user_data_dir
2425

@@ -337,7 +338,7 @@
337338
# DPC + BF illumination + PolState (sector illumination)
338339

339340
xx, yy, fxx, fyy = util.gen_coordinate((N, M), ps)
340-
radial_frequencies = np.sqrt(fxx**2 + fyy**2)
341+
radial_frequencies = torch.sqrt(torch.as_tensor(fxx**2 + fyy**2, dtype=torch.float32))
341342

342343
Pupil_obj = optics.generate_pupil(radial_frequencies, NA_obj / n_media, lambda_illu / n_media).numpy()
343344
Source_support = optics.generate_pupil(radial_frequencies, NA_illu / n_media, lambda_illu / n_media).numpy()

docs/examples/maintenance/QLIPP_simulation/2D_QLIPP_forward.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import matplotlib.pyplot as plt
2121
import numpy as np
22+
import torch
2223
from numpy.fft import fftshift
2324
from platformdirs import user_data_dir
2425

@@ -70,7 +71,7 @@
7071
# Subsample source pattern for speed
7172

7273
xx, yy, fxx, fyy = util.gen_coordinate((N, M), ps)
73-
radial_frequencies = np.sqrt(fxx**2 + fyy**2)
74+
radial_frequencies = torch.sqrt(torch.as_tensor(fxx**2 + fyy**2, dtype=torch.float32))
7475
Source_cont = optics.generate_pupil(radial_frequencies, NA_illu, lambda_illu).numpy()
7576
Source_discrete = optics.Source_subsample(Source_cont, lambda_illu * fxx, lambda_illu * fyy, subsampled_NA=0.1)
7677
plt.figure(figsize=(10, 10))
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
"""Example: optimize fluorescence reconstruction parameters.
2+
3+
Simulates fluorescence data with a known z_focus_offset,
4+
then optimizes from a wrong initial guess to recover it.
5+
"""
6+
7+
import datetime
8+
9+
from waveorder.api import fluorescence
10+
from waveorder.optim import OptimizableFloat
11+
12+
# Ground truth parameters
13+
gt_z_offset = 0.6
14+
15+
# Simulate with ground truth
16+
gt_settings = fluorescence.Settings(
17+
transfer_function=fluorescence.TransferFunctionSettings(
18+
z_focus_offset=gt_z_offset,
19+
)
20+
)
21+
phantom, data = fluorescence.simulate(gt_settings, recon_dim=2, zyx_shape=(11, 256, 256))
22+
23+
# Optimize from wrong initial guess
24+
opt_settings = fluorescence.Settings(
25+
transfer_function=fluorescence.TransferFunctionSettings(
26+
z_focus_offset=OptimizableFloat(init=0, lr=0.1),
27+
)
28+
)
29+
30+
log_dir = f"./runs/{datetime.datetime.now():%Y%m%d_%H%M%S}"
31+
32+
optimized_settings, recon = fluorescence.optimize(
33+
data,
34+
settings=opt_settings,
35+
num_iterations=50,
36+
midband_fractions=(0.01, 0.5),
37+
log_dir=log_dir,
38+
log_images=True,
39+
)
40+
41+
s = optimized_settings.transfer_function
42+
print(f"\n{'Parameter':<20} {'Ground truth':>12} {'Optimized':>12}")
43+
print(f"{'z_focus_offset':<20} {gt_z_offset:>12.3f} {s.z_focus_offset:>12.3f}")
44+
45+
print("\nTo view optimization logs, run:")
46+
print(" tensorboard --logdir ./runs")
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
"""Example: optimize phase reconstruction parameters.
2+
3+
Simulates brightfield phase data with known z_focus_offset and
4+
tilt illumination, then optimizes from (0, 0, 0) to recover
5+
the ground truth parameters.
6+
"""
7+
8+
import datetime
9+
10+
import numpy as np
11+
12+
from waveorder.api import phase
13+
from waveorder.optim import OptimizableFloat
14+
15+
# To use your own data instead of simulated data, create a CZYX xr.DataArray:
16+
#
17+
# import xarray as xr
18+
#
19+
# zyx_array = np.load("my_data.npy") # shape (Z, Y, X)
20+
# data = xr.DataArray(
21+
# zyx_array[None], # add C dimension -> (1, Z, Y, X)
22+
# dims=("c", "z", "y", "x"),
23+
# )
24+
#
25+
# Then pass `data` to `phase.optimize(data, ...)` below.
26+
27+
# Ground truth parameters
28+
gt_z_offset = 0.6
29+
gt_tilt_zenith = 0.5
30+
gt_tilt_azimuth = np.pi / 4
31+
32+
# Simulate with ground truth
33+
gt_settings = phase.Settings(
34+
transfer_function=phase.TransferFunctionSettings(
35+
z_focus_offset=gt_z_offset,
36+
tilt_angle_zenith=gt_tilt_zenith,
37+
tilt_angle_azimuth=gt_tilt_azimuth,
38+
)
39+
)
40+
phantom, data = phase.simulate(gt_settings, recon_dim=2, zyx_shape=(11, 256, 256))
41+
42+
# Optimize from (0, 0, 0) initial guess
43+
opt_settings = phase.Settings(
44+
transfer_function=phase.TransferFunctionSettings(
45+
z_focus_offset=OptimizableFloat(init=0, lr=0.1),
46+
tilt_angle_zenith=OptimizableFloat(init=0, lr=0.1),
47+
tilt_angle_azimuth=OptimizableFloat(init=0, lr=0.1),
48+
)
49+
)
50+
51+
log_dir = f"./runs/{datetime.datetime.now():%Y%m%d_%H%M%S}"
52+
53+
optimized_settings, recon = phase.optimize(
54+
data,
55+
settings=opt_settings,
56+
num_iterations=50,
57+
midband_fractions=(0.1, 0.5),
58+
log_dir=log_dir,
59+
log_images=True,
60+
)
61+
62+
s = optimized_settings.transfer_function
63+
print(f"\n{'Parameter':<20} {'Ground truth':>12} {'Optimized':>12}")
64+
print(f"{'z_focus_offset':<20} {gt_z_offset:>12.3f} {s.z_focus_offset:>12.3f}")
65+
print(f"{'tilt_angle_zenith':<20} {gt_tilt_zenith:>12.3f} {s.tilt_angle_zenith:>12.3f}")
66+
print(f"{'tilt_angle_azimuth':<20} {gt_tilt_azimuth:>12.3f} {s.tilt_angle_azimuth:>12.3f}")
67+
68+
print("\nTo view optimization logs and images, run TensorBoard in your terminal:")
69+
print(" tensorboard --logdir ./runs")
70+
print("Then open the displayed URL in your browser.")
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
input_channel_names: [Brightfield]
2+
time_indices: all
3+
reconstruction_dimension: 2
4+
phase:
5+
transfer_function:
6+
wavelength_illumination: 0.532 # illumination wavelength in micrometers
7+
yx_pixel_size: 0.1 # lateral pixel size in micrometers
8+
z_pixel_size: 0.25 # axial pixel size in micrometers
9+
z_focus_offset: # (optimizable) offset from center slice in slice units
10+
init: 0
11+
lr: 0.01
12+
index_of_refraction_media: 1.3 # refractive index of imaging media
13+
numerical_aperture_detection: 1.2 # (optimizable) detection objective numerical aperture
14+
numerical_aperture_illumination: 0.9 # (optimizable) condenser numerical aperture
15+
tilt_angle_zenith: # (optimizable) illumination tilt zenith angle in radians
16+
init: 0
17+
lr: 0.005
18+
tilt_angle_azimuth: # (optimizable) illumination tilt azimuth angle in radians
19+
init: 0
20+
lr: 0.001
21+
invert_phase_contrast: false
22+
apply_inverse:
23+
reconstruction_algorithm: Tikhonov
24+
regularization_strength: 0.01
25+
optimization:
26+
num_iterations: 10
27+
loss:
28+
type: midband_power
29+
midband_fractions: [0.125, 0.25]
30+
log_dir: ./runs/phase_optim

0 commit comments

Comments
 (0)