Fouriax User Guide#
Welcome to fouriax, a differentiable free-space optics library for JAX. This guide will walk you through building a system, choosing propagation, running an optimization, using polarization (Jones mode), using meta-atoms, and reading sensor output.
At the source level, the core optics runtime lives under
src/fouriax/optics/, while optimization and analysis helpers are exposed as
the top-level modules fx.optim and fx.analysis.
Developer-oriented implementation and workflow notes live under
docs/development/.
1. First Optical System#
An optical stack in fouriax consists of:
a
Grid, which defines the sampled spatial coordinatesa
Spectrum, which defines the wavelength channelsa
Field, which stores the complex optical wave on that gridan
Intensity, for incoherent image-plane data and detector inputsone or more
OpticalLayers, composed into anOpticalModuleoptional readout through a
DetectororDetectorArray
This is the simplest non-optimization starting point and the best place to learn the basic object model before moving on to Optax-based examples.
import jax
import jax.numpy as jnp
import fouriax as fx
# 1. Define grid (micrometers) and spectrum (532 nm)
grid = fx.Grid.from_extent(nx=64, ny=64, dx_um=1.0, dy_um=1.0)
spectrum = fx.Spectrum.from_scalar(0.532)
# 2. Setup initial plane wave field
field = fx.Field.plane_wave(grid=grid, spectrum=spectrum)
# 3. Create components: phase mask and propagator
rng = jax.random.PRNGKey(0)
initial_phase = jnp.zeros((64, 64))
phase_layer = fx.PhaseMask(phase_map_rad=initial_phase)
# plan_propagation(...) is the recommended high-level entry point
propagator = fx.plan_propagation(
mode="auto", grid=grid, spectrum=spectrum, distance_um=1000.0,
)
# 4. Compose the module and sensor
module = fx.OpticalModule(layers=(phase_layer, propagator))
detector = fx.Detector()
# 5. Forward pass to get intensity
field_out = module.forward(field)
intensity = detector.measure(field_out)
For incoherent imaging, convert the coherent field boundary explicitly:
object_intensity = field.to_intensity()
sensor_grid = fx.Grid.from_extent(nx=64, ny=64, dx_um=1.0, dy_um=1.0)
input_grid = fx.Grid.from_extent(nx=64, ny=64, dx_um=2.0, dy_um=2.0)
imager = fx.IncoherentImager.for_finite_distance(
optical_layer=lens,
propagator=propagator,
input_distance_um=2.0 * lens.focal_length_um,
output_distance_um=2.0 * lens.focal_length_um,
)
object_intensity = fx.Intensity(
data=object_intensity.data,
grid=input_grid,
spectrum=object_intensity.spectrum,
)
image_intensity = imager.forward(object_intensity)
detector = fx.DetectorArray(detector_grid=sensor_grid)
measurement = detector.measure(image_intensity)
Why plan_propagation(...) is recommended:
it chooses a suitable propagator when
mode="auto"it plans the working propagation grid when sampled propagation needs padding or finer spacing
for ASM and k-space propagation, it precomputes the diagonal transfer stack automatically, so repeated fixed-grid, fixed-distance calls reuse that data by default
In other words, plan_propagation(...) is not just a convenience wrapper. It is
the intended high-level API for efficient repeated propagation, especially in
optimization examples where the same grid, spectrum, and distance_um are
reused for many forward passes.
If you already know you want ASM, you should still usually write:
propagator = fx.plan_propagation(
mode="asm",
grid=grid,
spectrum=spectrum,
distance_um=1000.0,
)
Use ASMPropagator(...), RSPropagator(...), or KSpacePropagator(...)
directly only when you need manual low-level control over the propagation
implementation.
For NA-limited planned propagation, plan_propagation(...) only accepts
na_limit on k-space paths. Spatial ASM/RS planning remains purely about
method selection and sampling.
2. Gradient-Based Optimization#
Since every layer is fully traceable in JAX, you can define a loss function and
compute gradients with respect to the system parameters. Fouriax provides
lightweight wrappers in fx.optim for Optax training loops.
The usual pattern is to create the propagator once with plan_propagation(...)
outside the loss, then reuse that same planned propagator through all optimizer
steps.
import optax
def build_module(params):
# params is a dictionary mapping parameter names to JAX arrays
return fx.OpticalModule(layers=(
fx.PhaseMask(phase_map_rad=params["phase"]),
propagator
))
def loss_fn(params):
mod = build_module(params)
out = mod.forward(field)
return jnp.mean((out.intensity() - target_intensity)**2)
# Initial parameters
init_params = {"phase": jnp.zeros((64, 64))}
optimizer = optax.adam(learning_rate=0.1)
# Optimize for 100 steps
result = fx.optim.optimize_optical_module(
init_params=init_params,
build_module=build_module,
loss_fn=loss_fn,
optimizer=optimizer,
steps=100
)
print(f"Final loss: {result.final_loss}")
optimized_phase = result.best_params["phase"]
3. Polarization (Jones Mode)#
To track arbitrary polarization states, initialize a Jones field. Modulation layers will correctly broadcast or mix the Ex and Ey channels.
# Initialize a linearly polarized input (45 degrees)
jones_field = fx.Field.plane_wave_jones(
grid=grid, spectrum=spectrum, ex=1.0, ey=1.0
)
# Apply a spatial-domain Jones matrix (e.g., quarter waveplate or polarizer)
jones_matrix = jnp.array([
[1.0, 0.0],
[0.0, 1j]
])
polarizer_layer = fx.JonesMatrixLayer(jones_matrix=jones_matrix)
# Trace through the layer
jones_field_out = polarizer_layer.forward(jones_field)
4. Meta-Atom Look-up Tables#
Fouriax supports parameterized MetaAtomLibrary models, mapping physical meta-atom dimensions to complex transmission coefficients across a wavelength spectrum.
# Load a precomputed LUT
library = fx.MetaAtomLibrary.from_file("my_meta_atoms.npz")
# Use it in a layer by providing spatial geometry maps
# (e.g., width of a pillar at each grid point)
meta_layer = fx.MetaAtomInterpolationLayer(
library=library,
geometry_maps={"width": jnp.ones((64, 64)) * 0.25}
)
field_out = meta_layer.forward(field)
5. Noise and Sensors#
You can model physical camera properties using realistic noise distributions and masks via a DetectorArray.
sensor = fx.DetectorArray(
detector_grid=fx.Grid.from_extent(nx=32, ny=32, dx_um=2.0, dy_um=2.0),
noise_model=fx.PoissonGaussianNoise(
count_scale=100.0, # scales expected intensity to expected photons
read_noise_std=2.5 # additive Gaussian read noise standard deviation
)
)
# Returns noisy samples over the specified detector_grid
noisy_measurement = sensor.measure(field_out, key=rng)