Optical Neural Network (ONN) on MNIST — Phase-Mask Classifier#

This is a standalone advanced notebook. It trains a small diffractive optical network to classify MNIST digits using alternating phase masks and free-space propagation layers, with the final intensity sampled by a 2x5 detector grid whose 10 outputs serve as class logits.

Modeling assumptions#

  • grayscale MNIST images are injected as field amplitudes on the input plane,

  • each trainable layer is phase-only, with propagation between layers fixed by the chosen spacing, and

  • the final detector array is treated as a task-specific readout rather than a physically calibrated camera.

This example is best read as a compact differentiable-optics classifier demo, not as a benchmark for state-of-the-art digit recognition.

0 Imports#

We use JAX and Optax for training, Matplotlib for visual diagnostics, and fouriax optics modules for the phase-mask stack, intermediate intensity monitors, and detector-array readout.

from __future__ import annotations

import urllib.request
from pathlib import Path

import jax
import jax.image
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
from matplotlib.patches import Rectangle

import fouriax as fx
%matplotlib inline

EXAMPLES_ROOT = Path.cwd() / "examples"
EXAMPLES_DATA_DIR = EXAMPLES_ROOT / "data"
EXAMPLES_ARTIFACTS_DIR = EXAMPLES_ROOT / "artifacts"

1 Paths and Parameters#

The parameters control both the learning problem and the optical architecture: device selection, dataset size, number of phase layers, phase-mask downsampling, propagation distance, and optimization hyperparameters.

# Keep the MNIST cache under examples/data by default.
MNIST_URL = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz"
DATA_PATH = Path(str(EXAMPLES_DATA_DIR / 'mnist.npz'))
ARTIFACTS_DIR = Path(str(EXAMPLES_ARTIFACTS_DIR))
PLOT_PATH = ARTIFACTS_DIR / "onn_mnist_field_evolution.png"
SUMMARY_PATH = ARTIFACTS_DIR / "onn_mnist_summary.json"

DEVICE = 'cpu'
SEED = 0
EPOCHS = 10
BATCH_SIZE = 64
LEARNING_RATE = 0.05
NUM_PHASE_LAYERS = 4
PHASE_MASK_DOWNSAMPLE = 4
NYQUIST_FACTOR = 1.0
DISTANCE_UM = 50.0
TRAIN_SAMPLES = 1000
TEST_SAMPLES = 100
PLOT = True

2 Helper Functions#

The helpers load and cache MNIST, then resize the digit images onto the optical working grid. This keeps the rest of the notebook focused on the optical model and training loop rather than on dataset plumbing.

def load_mnist(cache_path: Path) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    cache_path.parent.mkdir(parents=True, exist_ok=True)
    if not cache_path.exists():
        print(f"Downloading MNIST to {cache_path} ...")
        urllib.request.urlretrieve(MNIST_URL, cache_path)

    with np.load(cache_path) as data:
        x_train = data["x_train"].astype(np.float32) / 255.0
        y_train = data["y_train"].astype(np.int32)
        x_test = data["x_test"].astype(np.float32) / 255.0
        y_test = data["y_test"].astype(np.int32)
    return x_train, y_train, x_test, y_test


def resize_images_to_grid(images: np.ndarray, grid: fx.Grid) -> np.ndarray:
    arr = jnp.asarray(images, dtype=jnp.float32)[..., None]
    resized = jax.image.resize(
        arr,
        shape=(arr.shape[0], grid.ny, grid.nx, 1),
        method="linear",
    )[..., 0]
    return np.asarray(resized, dtype=np.float32)

3 Setup#

We choose a working grid, build a repeated block of

\[ \text{phase mask} \;\rightarrow\; \text{propagation}, \]

and attach IntensityMonitor layers so the field evolution can be visualized after training. The final DetectorArray has 10 cells arranged as 2×5, one for each digit class.

The trainable parameters live on a lower-resolution latent grid and are bilinearly upsampled before being converted into physical phase masks.

jax.config.update("jax_platform_name", DEVICE)
selected_device = jax.devices()[0]
jax.config.update("jax_default_device", selected_device)
print(
    "device="
    f"{selected_device.platform} kind={getattr(selected_device, 'device_kind', 'unknown')}"
)

input_grid = fx.Grid.from_extent(nx=28, ny=28, dx_um=1.0, dy_um=1.0)
spectrum = fx.Spectrum.from_scalar(1.55)
propagator = fx.plan_propagation(
    mode="auto",
    grid=input_grid,
    spectrum=spectrum,
    distance_um=DISTANCE_UM,
    nyquist_factor=NYQUIST_FACTOR,
    min_padding_factor=2.0,
)
work_grid = propagator.precomputed_grid or input_grid
mask_nx = work_grid.nx // PHASE_MASK_DOWNSAMPLE
mask_ny = work_grid.ny // PHASE_MASK_DOWNSAMPLE
mask_grid = fx.Grid.from_extent(
    nx=mask_nx,
    ny=mask_ny,
    dx_um=(work_grid.nx * work_grid.dx_um) / mask_nx,
    dy_um=(work_grid.ny * work_grid.dy_um) / mask_ny,
)
detector_grid = fx.Grid.from_extent(
    nx=5,
    ny=2,
    dx_um=(work_grid.nx * work_grid.dx_um) / 5.0,
    dy_um=(work_grid.ny * work_grid.dy_um) / 2.0,
)
detector_array = fx.DetectorArray(
    detector_grid=detector_grid,
)

def build_module(raw_params: jnp.ndarray) -> fx.OpticalModule:
    layers = [fx.IntensityMonitor(sum_wavelengths=True, output_domain="spatial")]
    for i in range(raw_params.shape[0]):
        upsampled_latent = jax.image.resize(
            raw_params[i],
            shape=(work_grid.ny, work_grid.nx),
            method="linear",
        )
        bounded_phase = 2.0 * jnp.pi * jax.nn.sigmoid(upsampled_latent)
        layers.append(fx.PhaseMask(phase_map_rad=bounded_phase))
        layers.append(propagator)
        layers.append(fx.IntensityMonitor(sum_wavelengths=True, output_domain="spatial"))
    return fx.OpticalModule(layers=tuple(layers), sensor=detector_array)

x_train, y_train, x_test, y_test = load_mnist(DATA_PATH)
x_train = x_train[:TRAIN_SAMPLES]
y_train = y_train[:TRAIN_SAMPLES]
x_test = x_test[:TEST_SAMPLES]
y_test = y_test[:TEST_SAMPLES]
x_train = resize_images_to_grid(x_train, work_grid)
x_test = resize_images_to_grid(x_test, work_grid)

key = jax.random.PRNGKey(SEED)
phase_params = 0.05 * jax.random.normal(
    key,
    (NUM_PHASE_LAYERS, mask_grid.ny, mask_grid.nx),
    dtype=jnp.float32,
)
device=cpu kind=cpu
W0407 21:11:04.828098 1088825 cpp_gen_intrinsics.cc:74] Empty bitcode string provided for eigen. Optimizations relying on this IR will be disabled.

4 Loss Function and Optimization#

The forward pass maps each input image to detector intensities, reshapes them into 10 logits, and optimizes the negative log-likelihood of the correct class. If \(z_{n,c}\) is the detector logit for sample \(n\) and class \(c\), then the softmax probability is

\[ p_{n,c} = \frac{e^{z_{n,c}}}{\sum_{k=1}^{10} e^{z_{n,k}}}, \]

and the minibatch classification loss is

\[ \mathcal{L}_{\mathrm{cls}} = -\frac{1}{B} \sum_{n=1}^{B} \log p_{n,y_n}, \]

where \(y_n\) is the ground-truth digit label. Accuracy is evaluated with the same optical model, without adding a separate electronic classifier head.

This makes the task interpretation very direct: the diffractive stack itself is learning to route optical energy toward the detector region associated with the correct class.

def logits_batch(raw_params: jnp.ndarray, images_3d: jnp.ndarray) -> jnp.ndarray:
    module = build_module(raw_params)
    field = fx.Field(
        data=images_3d[:, None, :, :].astype(jnp.complex64),
        grid=work_grid,
        spectrum=spectrum,
    )
    return module.measure(field).reshape((images_3d.shape[0], -1))

def batch_loss_fn(
    params: jnp.ndarray,
    batch: tuple[np.ndarray, np.ndarray] | tuple[jnp.ndarray, jnp.ndarray],
) -> jnp.ndarray:
    image_raw, label_raw = batch
    images = jnp.asarray(image_raw, dtype=jnp.float32)
    labels = jnp.asarray(label_raw, dtype=jnp.int32)
    logits = logits_batch(params, images)
    log_probs = logits - jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True)
    return -jnp.mean(log_probs[jnp.arange(labels.shape[0]), labels])

def batch_accuracy(params: jnp.ndarray, images: np.ndarray, labels: np.ndarray) -> float:
    logits = np.asarray(logits_batch(params, jnp.asarray(images)))
    pred = np.argmax(logits, axis=1)
    return float(np.mean(pred == labels))

optimizer = optax.adam(learning_rate=LEARNING_RATE)
train_data = (x_train, y_train)
val_data = (x_test, y_test)

result = fx.optim.optimize_dataset_optical_module(
    init_params=phase_params,
    build_module=build_module,
    batch_loss_fn=batch_loss_fn,
    optimizer=optimizer,
    train_data=train_data,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    val_data=val_data,
    seed=SEED + 1,
)
A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.4.1 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/opt/anaconda3/lib/python3.12/site-packages/ipykernel_launcher.py", line 17, in <module>
    app.launch_new_instance()
  File "/opt/anaconda3/lib/python3.12/site-packages/traitlets/config/application.py", line 1075, in launch_instance
    app.start()
  File "/opt/anaconda3/lib/python3.12/site-packages/ipykernel/kernelapp.py", line 701, in start
    self.io_loop.start()
  File "/opt/anaconda3/lib/python3.12/site-packages/tornado/platform/asyncio.py", line 205, in start
    self.asyncio_loop.run_forever()
  File "/opt/anaconda3/lib/python3.12/asyncio/base_events.py", line 641, in run_forever
    self._run_once()
  File "/opt/anaconda3/lib/python3.12/asyncio/base_events.py", line 1986, in _run_once
    handle._run()
  File "/opt/anaconda3/lib/python3.12/asyncio/events.py", line 88, in _run
    self._context.run(self._callback, *self._args)
  File "/opt/anaconda3/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 534, in dispatch_queue
    await self.process_one()
  File "/opt/anaconda3/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 523, in process_one
    await dispatch(*args)
  File "/opt/anaconda3/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 429, in dispatch_shell
    await result
  File "/opt/anaconda3/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 767, in execute_request
    reply_content = await reply_content
  File "/opt/anaconda3/lib/python3.12/site-packages/ipykernel/ipkernel.py", line 429, in do_execute
    res = shell.run_cell(
  File "/opt/anaconda3/lib/python3.12/site-packages/ipykernel/zmqshell.py", line 549, in run_cell
    return super().run_cell(*args, **kwargs)
  File "/opt/anaconda3/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3075, in run_cell
    result = self._run_cell(
  File "/opt/anaconda3/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3130, in _run_cell
    result = runner(coro)
  File "/opt/anaconda3/lib/python3.12/site-packages/IPython/core/async_helpers.py", line 128, in _pseudo_sync_runner
    coro.send(None)
  File "/opt/anaconda3/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3334, in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
  File "/opt/anaconda3/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3517, in run_ast_nodes
    if await self.run_code(code, result, async_=asy):
  File "/opt/anaconda3/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/var/folders/h9/2zpp40gn24l3ddb3fnkq7fzh0000gn/T/ipykernel_51485/2366447797.py", line 30, in <module>
    result = fx.optim.optimize_dataset_optical_module(
  File "/Users/liam/metasurface/fouriax/src/fouriax/optim.py", line 752, in optimize_dataset_optical_module
    params_result = optimize_dataset_params(
  File "/Users/liam/metasurface/fouriax/src/fouriax/optim.py", line 584, in optimize_dataset_params
    epoch_batches = iter_batches(epoch)
  File "/Users/liam/metasurface/fouriax/src/fouriax/optim.py", line 530, in iter_batches
    return batch_iter_any(train_data, epoch)
  File "/Users/liam/metasurface/fouriax/src/fouriax/optim.py", line 727, in batch_iter_fn
    return iter_minibatches(
  File "/Users/liam/metasurface/fouriax/src/fouriax/optim.py", line 200, in iter_minibatches
    from sklearn.utils import shuffle as sklearn_shuffle
  File "/opt/anaconda3/lib/python3.12/site-packages/sklearn/__init__.py", line 84, in <module>
    from .base import clone
  File "/opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py", line 19, in <module>
    from .utils._estimator_html_repr import _HTMLDocumentationLinkMixin, estimator_html_repr
  File "/opt/anaconda3/lib/python3.12/site-packages/sklearn/utils/__init__.py", line 11, in <module>
    from ._chunking import gen_batches, gen_even_slices
  File "/opt/anaconda3/lib/python3.12/site-packages/sklearn/utils/_chunking.py", line 8, in <module>
    from ._param_validation import Interval, validate_params
  File "/opt/anaconda3/lib/python3.12/site-packages/sklearn/utils/_param_validation.py", line 14, in <module>
    from .validation import _is_arraylike_not_scalar
  File "/opt/anaconda3/lib/python3.12/site-packages/sklearn/utils/validation.py", line 26, in <module>
    from ..utils._array_api import _asarray_with_order, _is_numpy_namespace, get_namespace
  File "/opt/anaconda3/lib/python3.12/site-packages/sklearn/utils/_array_api.py", line 11, in <module>
    from .fixes import parse_version
  File "/opt/anaconda3/lib/python3.12/site-packages/sklearn/utils/fixes.py", line 24, in <module>
    import pandas as pd
  File "/opt/anaconda3/lib/python3.12/site-packages/pandas/__init__.py", line 26, in <module>
    from pandas.compat import (
  File "/opt/anaconda3/lib/python3.12/site-packages/pandas/compat/__init__.py", line 27, in <module>
    from pandas.compat.pyarrow import (
  File "/opt/anaconda3/lib/python3.12/site-packages/pandas/compat/pyarrow.py", line 8, in <module>
    import pyarrow as pa
  File "/opt/anaconda3/lib/python3.12/site-packages/pyarrow/__init__.py", line 65, in <module>
    import pyarrow.lib as _lib
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
File ~/metasurface/fouriax/src/fouriax/optim.py:527, in optimize_dataset_params.<locals>.iter_batches(epoch)
    526 try:
--> 527     return batch_iter_any(train_data, epoch, rng)
    528 except TypeError:

TypeError: optimize_dataset_optical_module.<locals>.batch_iter_fn() takes 2 positional arguments but 3 were given

During handling of the above exception, another exception occurred:

ImportError                               Traceback (most recent call last)
File /opt/anaconda3/lib/python3.12/site-packages/numpy/core/_multiarray_umath.py:46, in __getattr__(attr_name)
     41     # Also print the message (with traceback).  This is because old versions
     42     # of NumPy unfortunately set up the import to replace (and hide) the
     43     # error.  The traceback shouldn't be needed, but e.g. pytest plugins
     44     # seem to swallow it and we should be failing anyway...
     45     sys.stderr.write(msg + tb_msg)
---> 46     raise ImportError(msg)
     48 ret = getattr(_multiarray_umath, attr_name, None)
     49 if ret is None:

ImportError: 
A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.4.1 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.
A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.4.1 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/opt/anaconda3/lib/python3.12/site-packages/ipykernel_launcher.py", line 17, in <module>
    app.launch_new_instance()
  File "/opt/anaconda3/lib/python3.12/site-packages/traitlets/config/application.py", line 1075, in launch_instance
    app.start()
  File "/opt/anaconda3/lib/python3.12/site-packages/ipykernel/kernelapp.py", line 701, in start
    self.io_loop.start()
  File "/opt/anaconda3/lib/python3.12/site-packages/tornado/platform/asyncio.py", line 205, in start
    self.asyncio_loop.run_forever()
  File "/opt/anaconda3/lib/python3.12/asyncio/base_events.py", line 641, in run_forever
    self._run_once()
  File "/opt/anaconda3/lib/python3.12/asyncio/base_events.py", line 1986, in _run_once
    handle._run()
  File "/opt/anaconda3/lib/python3.12/asyncio/events.py", line 88, in _run
    self._context.run(self._callback, *self._args)
  File "/opt/anaconda3/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 534, in dispatch_queue
    await self.process_one()
  File "/opt/anaconda3/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 523, in process_one
    await dispatch(*args)
  File "/opt/anaconda3/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 429, in dispatch_shell
    await result
  File "/opt/anaconda3/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 767, in execute_request
    reply_content = await reply_content
  File "/opt/anaconda3/lib/python3.12/site-packages/ipykernel/ipkernel.py", line 429, in do_execute
    res = shell.run_cell(
  File "/opt/anaconda3/lib/python3.12/site-packages/ipykernel/zmqshell.py", line 549, in run_cell
    return super().run_cell(*args, **kwargs)
  File "/opt/anaconda3/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3075, in run_cell
    result = self._run_cell(
  File "/opt/anaconda3/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3130, in _run_cell
    result = runner(coro)
  File "/opt/anaconda3/lib/python3.12/site-packages/IPython/core/async_helpers.py", line 128, in _pseudo_sync_runner
    coro.send(None)
  File "/opt/anaconda3/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3334, in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
  File "/opt/anaconda3/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3517, in run_ast_nodes
    if await self.run_code(code, result, async_=asy):
  File "/opt/anaconda3/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/var/folders/h9/2zpp40gn24l3ddb3fnkq7fzh0000gn/T/ipykernel_51485/2366447797.py", line 30, in <module>
    result = fx.optim.optimize_dataset_optical_module(
  File "/Users/liam/metasurface/fouriax/src/fouriax/optim.py", line 752, in optimize_dataset_optical_module
    params_result = optimize_dataset_params(
  File "/Users/liam/metasurface/fouriax/src/fouriax/optim.py", line 584, in optimize_dataset_params
    epoch_batches = iter_batches(epoch)
  File "/Users/liam/metasurface/fouriax/src/fouriax/optim.py", line 530, in iter_batches
    return batch_iter_any(train_data, epoch)
  File "/Users/liam/metasurface/fouriax/src/fouriax/optim.py", line 727, in batch_iter_fn
    return iter_minibatches(
  File "/Users/liam/metasurface/fouriax/src/fouriax/optim.py", line 200, in iter_minibatches
    from sklearn.utils import shuffle as sklearn_shuffle
  File "/opt/anaconda3/lib/python3.12/site-packages/sklearn/__init__.py", line 84, in <module>
    from .base import clone
  File "/opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py", line 19, in <module>
    from .utils._estimator_html_repr import _HTMLDocumentationLinkMixin, estimator_html_repr
  File "/opt/anaconda3/lib/python3.12/site-packages/sklearn/utils/__init__.py", line 11, in <module>
    from ._chunking import gen_batches, gen_even_slices
  File "/opt/anaconda3/lib/python3.12/site-packages/sklearn/utils/_chunking.py", line 8, in <module>
    from ._param_validation import Interval, validate_params
  File "/opt/anaconda3/lib/python3.12/site-packages/sklearn/utils/_param_validation.py", line 14, in <module>
    from .validation import _is_arraylike_not_scalar
  File "/opt/anaconda3/lib/python3.12/site-packages/sklearn/utils/validation.py", line 26, in <module>
    from ..utils._array_api import _asarray_with_order, _is_numpy_namespace, get_namespace
  File "/opt/anaconda3/lib/python3.12/site-packages/sklearn/utils/_array_api.py", line 11, in <module>
    from .fixes import parse_version
  File "/opt/anaconda3/lib/python3.12/site-packages/sklearn/utils/fixes.py", line 24, in <module>
    import pandas as pd
  File "/opt/anaconda3/lib/python3.12/site-packages/pandas/__init__.py", line 49, in <module>
    from pandas.core.api import (
  File "/opt/anaconda3/lib/python3.12/site-packages/pandas/core/api.py", line 1, in <module>
    from pandas._libs import (
  File "/opt/anaconda3/lib/python3.12/site-packages/pandas/_libs/__init__.py", line 17, in <module>
    import pandas._libs.pandas_datetime  # noqa: F401 # isort: skip # type: ignore[reportUnusedImport]
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
File ~/metasurface/fouriax/src/fouriax/optim.py:527, in optimize_dataset_params.<locals>.iter_batches(epoch)
    526 try:
--> 527     return batch_iter_any(train_data, epoch, rng)
    528 except TypeError:

TypeError: optimize_dataset_optical_module.<locals>.batch_iter_fn() takes 2 positional arguments but 3 were given

During handling of the above exception, another exception occurred:

ImportError                               Traceback (most recent call last)
File /opt/anaconda3/lib/python3.12/site-packages/numpy/core/_multiarray_umath.py:46, in __getattr__(attr_name)
     41     # Also print the message (with traceback).  This is because old versions
     42     # of NumPy unfortunately set up the import to replace (and hide) the
     43     # error.  The traceback shouldn't be needed, but e.g. pytest plugins
     44     # seem to swallow it and we should be failing anyway...
     45     sys.stderr.write(msg + tb_msg)
---> 46     raise ImportError(msg)
     48 ret = getattr(_multiarray_umath, attr_name, None)
     49 if ret is None:

ImportError: 
A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.4.1 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.
batch=001/010 val_loss=36.778109 train_loss=25.567694 epoch_time=6.16s elapsed=6.16s [best]
batch=002/010 val_loss=10.723537 train_loss=15.951373 epoch_time=3.90s elapsed=10.06s [best]
batch=003/010 val_loss=6.842509 train_loss=5.668540 epoch_time=3.85s elapsed=13.91s [best]
batch=004/010 val_loss=4.999387 train_loss=5.187243 epoch_time=3.91s elapsed=17.81s [best]
batch=005/010 val_loss=4.098100 train_loss=1.708079 epoch_time=3.90s elapsed=21.72s [best]
batch=006/010 val_loss=3.864942 train_loss=2.445751 epoch_time=3.92s elapsed=25.63s [best]
batch=007/010 val_loss=3.319789 train_loss=0.800965 epoch_time=3.91s elapsed=29.54s [best]
batch=008/010 val_loss=2.935117 train_loss=0.241678 epoch_time=3.93s elapsed=33.47s [best]
batch=009/010 val_loss=2.680075 train_loss=0.869893 epoch_time=3.92s elapsed=37.38s [best]
batch=010/010 val_loss=2.751285 train_loss=0.203986 epoch_time=3.92s elapsed=41.30s

5 Evaluation#

After training, we report train/test accuracy and inspect one sample digit by observing the monitored intensity checkpoints through the best optical module. This ties the classification result back to a physical picture of how the field evolves across the diffractive layers.

train_acc = batch_accuracy(result.params_result.best_params, x_train, y_train)
test_acc = batch_accuracy(result.params_result.best_params, x_test, y_test)
final_val_loss = (
    float(result.params_result.final_val_metrics["val_loss"])
    if result.params_result.final_val_metrics
    else float("nan")
)
print(
    f"final_train_acc={train_acc:.4f} final_test_acc={test_acc:.4f} "
    f"final_val_loss={final_val_loss:.4f}"
)

module = result.best_module
sample_idx = 0
test_image = jnp.asarray(x_test[sample_idx], dtype=jnp.float32)
sample_field = fx.Field(
    data=test_image[None, :, :].astype(jnp.complex64),
    grid=work_grid,
    spectrum=spectrum,
)
final_train_acc=0.9230 final_test_acc=0.8300 final_val_loss=2.7513

6 Plot Results#

The top row visualizes the propagated intensity after each stage, while the bottom row shows the learned phase masks. The final panel overlays the detector regions used as class readouts.

A useful read of this figure is architectural rather than aesthetic: it shows whether the network has learned progressively structured energy routing that culminates in a separable detector response.

if PLOT:
    _, intensity_steps = module.observe(sample_field)
    phase_masks = [
        np.asarray(stage.phase_map_rad)
        for stage in module.layers
        if isinstance(stage, fx.PhaseMask)
    ]
    titles = ["Input"] + [f"After Propagation {i + 1}" for i in range(len(intensity_steps) - 1)]
    n_cols = max(len(intensity_steps), len(phase_masks))
    fig_field, axes = plt.subplots(
        2,
        n_cols,
        figsize=(max(6.0, 2.8 * n_cols), 6.8),
        squeeze=False,
    )
    for col, ax in enumerate(axes[0]):
        if col >= len(intensity_steps):
            ax.axis("off")
            continue
        title = titles[col]
        image = intensity_steps[col]
        im = ax.imshow(np.asarray(image), cmap="inferno")
        ax.set_title(title, fontsize=9)
        ax.set_xticks([])
        ax.set_yticks([])
        fig_field.colorbar(im, ax=ax, fraction=0.046, pad=0.03)
        if col == len(intensity_steps) - 1:
            cell_w = work_grid.nx / detector_grid.nx
            cell_h = work_grid.ny / detector_grid.ny
            digit = 0
            for row in range(detector_grid.ny):
                for det_col in range(detector_grid.nx):
                    ax.add_patch(
                        Rectangle(
                            (det_col * cell_w - 0.5, row * cell_h - 0.5),
                            cell_w,
                            cell_h,
                            fill=False,
                            edgecolor="red",
                            linewidth=1.2,
                        )
                    )
                    ax.text(
                        det_col * cell_w + 0.5 * cell_w - 0.5,
                        row * cell_h + 0.5 * cell_h - 0.5,
                        str(digit),
                        color="red",
                        fontsize=10,
                        fontweight="bold",
                        ha="center",
                        va="center",
                    )
                    digit += 1

    for col, ax in enumerate(axes[1]):
        if col >= len(phase_masks):
            ax.axis("off")
            continue
        phase = phase_masks[col]
        im = ax.imshow(phase, cmap="twilight", vmin=0.0, vmax=2.0 * np.pi)
        ax.set_title(f"Phase Mask {col + 1}", fontsize=9)
        ax.set_xticks([])
        ax.set_yticks([])
        fig_field.colorbar(im, ax=ax, fraction=0.046, pad=0.03)

    fig_field.suptitle("ONN Intensity Checkpoints and Learned Phase Masks", y=0.98)
    fig_field.tight_layout(rect=(0.0, 0.0, 1.0, 0.93))
    fig_field.savefig(PLOT_PATH, dpi=150)
    plt.show()
    print(f"saved: {PLOT_PATH}")
../../_images/2023f37289e53c90c4798f8302907325c705d5a34304c96e0a4db02791642b5f.png
saved: /Users/liam/metasurface/fouriax/examples/artifacts/onn_mnist_field_evolution.png