Fouriax User Guide#

Welcome to fouriax, a differentiable free-space optics library for JAX. This guide will walk you through building a system, choosing propagation, running an optimization, using polarization (Jones mode), using meta-atoms, and reading sensor output.

At the source level, the core optics runtime lives under src/fouriax/optics/, while optimization and analysis helpers are exposed as the top-level modules fx.optim and fx.analysis.

Developer-oriented implementation and workflow notes live under docs/development/.

1. First Optical System#

An optical stack in fouriax consists of:

  • a Grid, which defines the sampled spatial coordinates

  • a Spectrum, which defines the wavelength channels

  • a Field, which stores the complex optical wave on that grid

  • an Intensity, for incoherent image-plane data and detector inputs

  • one or more OpticalLayers, composed into an OpticalModule

  • optional readout through a Detector or DetectorArray

This is the simplest non-optimization starting point and the best place to learn the basic object model before moving on to Optax-based examples.

import jax
import jax.numpy as jnp
import fouriax as fx

# 1. Define grid (micrometers) and spectrum (532 nm)
grid = fx.Grid.from_extent(nx=64, ny=64, dx_um=1.0, dy_um=1.0)
spectrum = fx.Spectrum.from_scalar(0.532)

# 2. Setup initial plane wave field
field = fx.Field.plane_wave(grid=grid, spectrum=spectrum)

# 3. Create components: phase mask and propagator
rng = jax.random.PRNGKey(0)
initial_phase = jnp.zeros((64, 64))
phase_layer = fx.PhaseMask(phase_map_rad=initial_phase)

# plan_propagation(...) is the recommended high-level entry point
propagator = fx.plan_propagation(
    mode="auto", grid=grid, spectrum=spectrum, distance_um=1000.0,
)

# 4. Compose the module and sensor
module = fx.OpticalModule(layers=(phase_layer, propagator))
detector = fx.Detector()

# 5. Forward pass to get intensity
field_out = module.forward(field)
intensity = detector.measure(field_out)

For incoherent imaging, convert the coherent field boundary explicitly:

object_intensity = field.to_intensity()
sensor_grid = fx.Grid.from_extent(nx=64, ny=64, dx_um=1.0, dy_um=1.0)
input_grid = fx.Grid.from_extent(nx=64, ny=64, dx_um=2.0, dy_um=2.0)
imager = fx.IncoherentImager.for_finite_distance(
    optical_layer=lens,
    propagator=propagator,
    input_distance_um=2.0 * lens.focal_length_um,
    output_distance_um=2.0 * lens.focal_length_um,
)
object_intensity = fx.Intensity(
    data=object_intensity.data,
    grid=input_grid,
    spectrum=object_intensity.spectrum,
)
image_intensity = imager.forward(object_intensity)
detector = fx.DetectorArray(detector_grid=sensor_grid)
measurement = detector.measure(image_intensity)

Why plan_propagation(...) is recommended:

  • it chooses a suitable propagator when mode="auto"

  • it plans the working propagation grid when sampled propagation needs padding or finer spacing

  • for ASM and k-space propagation, it precomputes the diagonal transfer stack automatically, so repeated fixed-grid, fixed-distance calls reuse that data by default

In other words, plan_propagation(...) is not just a convenience wrapper. It is the intended high-level API for efficient repeated propagation, especially in optimization examples where the same grid, spectrum, and distance_um are reused for many forward passes.

If you already know you want ASM, you should still usually write:

propagator = fx.plan_propagation(
    mode="asm",
    grid=grid,
    spectrum=spectrum,
    distance_um=1000.0,
)

Use ASMPropagator(...), RSPropagator(...), or KSpacePropagator(...) directly only when you need manual low-level control over the propagation implementation.

For NA-limited planned propagation, plan_propagation(...) only accepts na_limit on k-space paths. Spatial ASM/RS planning remains purely about method selection and sampling.

2. Gradient-Based Optimization#

Since every layer is fully traceable in JAX, you can define a loss function and compute gradients with respect to the system parameters. Fouriax provides lightweight wrappers in fx.optim for Optax training loops.

The usual pattern is to create the propagator once with plan_propagation(...) outside the loss, then reuse that same planned propagator through all optimizer steps.

import optax

def build_module(params):
    # params is a dictionary mapping parameter names to JAX arrays
    return fx.OpticalModule(layers=(
        fx.PhaseMask(phase_map_rad=params["phase"]),
        propagator
    ))

def loss_fn(params):
    mod = build_module(params)
    out = mod.forward(field)
    return jnp.mean((out.intensity() - target_intensity)**2)

# Initial parameters
init_params = {"phase": jnp.zeros((64, 64))}
optimizer = optax.adam(learning_rate=0.1)

# Optimize for 100 steps
result = fx.optim.optimize_optical_module(
    init_params=init_params,
    build_module=build_module,
    loss_fn=loss_fn,
    optimizer=optimizer,
    steps=100
)

print(f"Final loss: {result.final_loss}")
optimized_phase = result.best_params["phase"]

3. Polarization (Jones Mode)#

To track arbitrary polarization states, initialize a Jones field. Modulation layers will correctly broadcast or mix the Ex and Ey channels.

# Initialize a linearly polarized input (45 degrees)
jones_field = fx.Field.plane_wave_jones(
    grid=grid, spectrum=spectrum, ex=1.0, ey=1.0
)

# Apply a spatial-domain Jones matrix (e.g., quarter waveplate or polarizer)
jones_matrix = jnp.array([
    [1.0, 0.0],
    [0.0, 1j]
])
polarizer_layer = fx.JonesMatrixLayer(jones_matrix=jones_matrix)

# Trace through the layer
jones_field_out = polarizer_layer.forward(jones_field)

4. Meta-Atom Look-up Tables#

Fouriax supports parameterized MetaAtomLibrary models, mapping physical meta-atom dimensions to complex transmission coefficients across a wavelength spectrum.

# Load a precomputed LUT 
library = fx.MetaAtomLibrary.from_file("my_meta_atoms.npz")

# Use it in a layer by providing spatial geometry maps 
# (e.g., width of a pillar at each grid point)
meta_layer = fx.MetaAtomInterpolationLayer(
    library=library,
    geometry_maps={"width": jnp.ones((64, 64)) * 0.25}
)

field_out = meta_layer.forward(field)

5. Noise and Sensors#

You can model physical camera properties using realistic noise distributions and masks via a DetectorArray.

sensor = fx.DetectorArray(
    detector_grid=fx.Grid.from_extent(nx=32, ny=32, dx_um=2.0, dy_um=2.0),
    noise_model=fx.PoissonGaussianNoise(
        count_scale=100.0, # scales expected intensity to expected photons
        read_noise_std=2.5 # additive Gaussian read noise standard deviation
    )
)

# Returns noisy samples over the specified detector_grid
noisy_measurement = sensor.measure(field_out, key=rng)