Skip to content

Commit a315d06

Browse files
Non-integer focus support (mehta-lab#473)
* refactor `compute_midband_power` * test compute_midband_power * Allow float values for z_focus_offset in settings - Change z_focus_offset type from Union[int, Literal["auto"]] to Union[float, Literal["auto"]] - This enables sub-pixel precision for focus offset values in 2D phase reconstruction - Addresses issue mehta-lab#470 for coarsely sampled slice improvements 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com> * Add sub-pixel precision to focus_from_transverse_band - Add enable_subpixel_precision parameter (default False for backward compatibility) - Use polynomial derivative analysis to find continuous extrema when enabled - Return float focus indices when sub-pixel precision is enabled - Update plotting function to handle float indices via interpolation - Enhance docstring with new parameter and return type information This enables more accurate focus detection for coarsely sampled data by finding focus positions between discrete slice indices. Addresses issue mehta-lab#470. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com> * Add comprehensive tests for non-integer focus support - test_subpixel_precision: Validates float focus detection with synthetic data - test_subpixel_precision_backward_compatibility: Ensures default behavior unchanged - test_subpixel_precision_with_plotting: Tests plotting with float indices - test_z_focus_offset_float_type: Validates settings accept float z_focus_offset - test_position_list_with_float_offset: Tests position calculation pipeline All tests verify both functionality and backward compatibility. Ensures robust implementation of issue mehta-lab#470 requirements. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com> --------- Co-authored-by: Claude <noreply@anthropic.com>
1 parent ad4dd98 commit a315d06

3 files changed

Lines changed: 210 additions & 7 deletions

File tree

tests/test_focus_estimator.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,3 +156,156 @@ def test_compute_midband_power_consistency():
156156

157157
expected_focus_slice = np.argmax(manual_powers)
158158
assert focus_slice == expected_focus_slice
159+
160+
161+
def test_subpixel_precision():
162+
"""Test that sub-pixel precision returns float values when enabled."""
163+
# Test parameters
164+
ps = 6.5 / 100
165+
lambda_ill = 0.532
166+
NA_det = 1.4
167+
168+
# Create synthetic test data with a clear peak between slices
169+
z_size, y_size, x_size = 11, 64, 64
170+
x = np.linspace(-1, 1, x_size)
171+
y = np.linspace(-1, 1, y_size)
172+
z = np.linspace(-5, 5, z_size)
173+
174+
# Create a 3D Gaussian that peaks between slice indices
175+
test_data = np.zeros((z_size, y_size, x_size))
176+
true_peak_z = 5.3 # Peak between slices 5 and 6
177+
178+
for i, z_val in enumerate(z):
179+
# Create Gaussian centered at true_peak_z position in physical space
180+
gaussian_2d = np.exp(
181+
-(
182+
(x[None, :] ** 2 + y[:, None] ** 2)
183+
+ (z_val - (true_peak_z - 5)) ** 2
184+
)
185+
)
186+
test_data[i] = gaussian_2d
187+
188+
# Test without sub-pixel precision (should return integer)
189+
focus_slice_int = focus.focus_from_transverse_band(
190+
test_data,
191+
NA_det,
192+
lambda_ill,
193+
ps,
194+
polynomial_fit_order=4,
195+
enable_subpixel_precision=False,
196+
)
197+
assert isinstance(focus_slice_int, (int, np.integer))
198+
199+
# Test with sub-pixel precision (should return float)
200+
focus_slice_float = focus.focus_from_transverse_band(
201+
test_data,
202+
NA_det,
203+
lambda_ill,
204+
ps,
205+
polynomial_fit_order=4,
206+
enable_subpixel_precision=True,
207+
)
208+
209+
# Should return a float
210+
assert isinstance(focus_slice_float, float)
211+
212+
# Should be close to the true peak position
213+
assert abs(focus_slice_float - true_peak_z) < 1.0 # Within 1 slice
214+
215+
# Sub-pixel result should be different from integer result
216+
assert focus_slice_float != focus_slice_int
217+
218+
219+
def test_subpixel_precision_backward_compatibility():
220+
"""Test that default behavior (integer results) is preserved."""
221+
ps = 6.5 / 100
222+
lambda_ill = 0.532
223+
NA_det = 1.4
224+
225+
# Create simple test data
226+
test_data = np.random.random((5, 32, 32)).astype(np.float32)
227+
228+
# Test default behavior (should return integer)
229+
focus_slice = focus.focus_from_transverse_band(
230+
test_data,
231+
NA_det,
232+
lambda_ill,
233+
ps,
234+
polynomial_fit_order=4,
235+
)
236+
237+
assert isinstance(focus_slice, (int, np.integer))
238+
239+
240+
def test_subpixel_precision_with_plotting(tmp_path):
241+
"""Test that sub-pixel precision works with plotting."""
242+
ps = 6.5 / 100
243+
lambda_ill = 0.532
244+
NA_det = 1.4
245+
246+
# Create test data
247+
test_data = np.random.random((7, 32, 32)).astype(np.float32)
248+
plot_path = tmp_path / "subpixel_test.pdf"
249+
250+
# Should work without errors
251+
focus_slice = focus.focus_from_transverse_band(
252+
test_data,
253+
NA_det,
254+
lambda_ill,
255+
ps,
256+
polynomial_fit_order=4,
257+
enable_subpixel_precision=True,
258+
plot_path=str(plot_path),
259+
)
260+
261+
assert isinstance(focus_slice, float)
262+
assert plot_path.exists()
263+
264+
265+
def test_z_focus_offset_float_type():
266+
"""Test that z_focus_offset can accept float values in settings."""
267+
from waveorder.cli.settings import FourierTransferFunctionSettings
268+
269+
# Test that float values are accepted
270+
settings = FourierTransferFunctionSettings(z_focus_offset=1.5)
271+
assert settings.z_focus_offset == 1.5
272+
assert isinstance(settings.z_focus_offset, float)
273+
274+
# Test that "auto" still works
275+
settings_auto = FourierTransferFunctionSettings(z_focus_offset="auto")
276+
assert settings_auto.z_focus_offset == "auto"
277+
278+
# Test that integers are converted to float
279+
settings_int = FourierTransferFunctionSettings(z_focus_offset=2)
280+
assert settings_int.z_focus_offset == 2
281+
assert isinstance(settings_int.z_focus_offset, (int, float))
282+
283+
284+
def test_position_list_with_float_offset():
285+
"""Test that _position_list_from_shape_scale_offset works correctly with float offsets."""
286+
from waveorder.cli.compute_transfer_function import (
287+
_position_list_from_shape_scale_offset,
288+
)
289+
290+
# Test integer offset
291+
pos_int = _position_list_from_shape_scale_offset(5, 1.0, 0)
292+
expected_int = [2.0, 1.0, 0.0, -1.0, -2.0]
293+
assert pos_int == expected_int
294+
295+
# Test float offset
296+
pos_float = _position_list_from_shape_scale_offset(5, 1.0, 0.5)
297+
expected_float = [2.5, 1.5, 0.5, -0.5, -1.5]
298+
assert pos_float == expected_float
299+
300+
# Verify the difference is exactly the offset
301+
import numpy as np
302+
303+
diff = np.array(pos_float) - np.array(pos_int)
304+
assert np.allclose(diff, 0.5)
305+
306+
# Test with different scale and offset
307+
pos_scaled = _position_list_from_shape_scale_offset(4, 2.0, 0.3)
308+
# shape=4, shape//2=2, so indices are [0,1,2,3],
309+
# positions are [(-0+2+0.3)*2, (-1+2+0.3)*2, (-2+2+0.3)*2, (-3+2+0.3)*2] = [4.6, 2.6, 0.6, -1.4]
310+
expected_scaled = [4.6, 2.6, 0.6, -1.4]
311+
assert np.allclose(pos_scaled, expected_scaled)

waveorder/cli/settings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class FourierTransferFunctionSettings(MyBaseModel):
6666
yx_pixel_size: PositiveFloat = 6.5 / 20
6767
z_pixel_size: PositiveFloat = 2.0
6868
z_padding: NonNegativeInt = 0
69-
z_focus_offset: Union[int, Literal["auto"]] = 0
69+
z_focus_offset: Union[float, Literal["auto"]] = 0
7070
index_of_refraction_media: PositiveFloat = 1.3
7171
numerical_aperture_detection: PositiveFloat = 1.2
7272

waveorder/focus.py

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def focus_from_transverse_band(
6060
polynomial_fit_order: Optional[int] = None,
6161
plot_path: Optional[str] = None,
6262
threshold_FWHM: float = 0,
63+
enable_subpixel_precision: bool = False,
6364
):
6465
"""Estimates the in-focus slice from a 3D stack by optimizing a transverse spatial frequency band.
6566
@@ -91,12 +92,16 @@ def focus_from_transverse_band(
9192
The default value, 0, applies no threshold, and the maximum midband power is always considered in focus.
9293
For values > 0, the peak's FWHM must be greater than the threshold for the slice to be considered in focus.
9394
If the peak does not meet this threshold, the function returns None.
95+
enable_subpixel_precision: bool, optional
96+
If True and polynomial_fit_order is provided, enables sub-pixel precision focus detection
97+
by finding the continuous extremum of the polynomial fit. Default is False for backward compatibility.
9498
9599
Returns
96-
------
97-
slice : int or None
100+
-------
101+
slice : int, float, or None
98102
If peak's FWHM > peak_width_threshold:
99-
return the index of the in-focus slice
103+
return the index of the in-focus slice (int if enable_subpixel_precision=False,
104+
float if enable_subpixel_precision=True and polynomial_fit_order is not None)
100105
else:
101106
return None
102107
@@ -140,9 +145,44 @@ def focus_from_transverse_band(
140145
else:
141146
x = np.arange(len(midband_sum))
142147
coeffs = np.polyfit(x, midband_sum, polynomial_fit_order)
143-
peak_index = minmaxfunc(np.poly1d(coeffs)(x))
148+
poly_func = np.poly1d(coeffs)
149+
150+
if enable_subpixel_precision:
151+
# Find the continuous extremum using derivative
152+
poly_deriv = np.polyder(coeffs)
153+
# Find roots of the derivative (critical points)
154+
critical_points = np.roots(poly_deriv)
155+
156+
# Filter for real roots within the data range
157+
real_critical_points = []
158+
for cp in critical_points:
159+
if np.isreal(cp) and 0 <= cp.real < len(midband_sum):
160+
real_critical_points.append(cp.real)
161+
162+
if real_critical_points:
163+
# Evaluate the polynomial at critical points to find extremum
164+
critical_values = [
165+
poly_func(cp) for cp in real_critical_points
166+
]
167+
if mode == "max":
168+
best_idx = np.argmax(critical_values)
169+
else: # mode == "min"
170+
best_idx = np.argmin(critical_values)
171+
peak_index = real_critical_points[best_idx]
172+
else:
173+
# Fall back to discrete maximum if no valid critical points
174+
peak_index = float(minmaxfunc(poly_func(x)))
175+
else:
176+
peak_index = minmaxfunc(poly_func(x))
144177

145-
peak_results = peak_widths(midband_sum, [peak_index])
178+
# For peak width calculation, use integer peak index
179+
if enable_subpixel_precision and polynomial_fit_order is not None:
180+
# Use the closest integer index for peak width calculation
181+
integer_peak_index = int(np.round(peak_index))
182+
else:
183+
integer_peak_index = int(peak_index)
184+
185+
peak_results = peak_widths(midband_sum, [integer_peak_index])
146186
peak_FWHM = peak_results[0][0]
147187

148188
if peak_FWHM >= threshold_FWHM:
@@ -215,9 +255,19 @@ def _plot_focus_metric(
215255
):
216256
_, ax = plt.subplots(1, 1, figsize=(4, 4))
217257
ax.plot(midband_sum, "-k")
258+
259+
# Handle floating-point peak_index for plotting
260+
if isinstance(peak_index, float) and not peak_index.is_integer():
261+
# Use interpolation to get the y-value at the floating-point x-position
262+
peak_y_value = np.interp(
263+
peak_index, np.arange(len(midband_sum)), midband_sum
264+
)
265+
else:
266+
peak_y_value = midband_sum[int(peak_index)]
267+
218268
ax.plot(
219269
peak_index,
220-
midband_sum[peak_index],
270+
peak_y_value,
221271
"go" if in_focus_index is not None else "ro",
222272
)
223273
ax.hlines(*peak_results[1:], color="k", linestyles="dashed")

0 commit comments

Comments
 (0)