Optical Neural Network (ONN) on MNIST — Phase-Mask Classifier#
This is a standalone advanced notebook. It trains a small diffractive optical network to classify
MNIST digits using alternating phase masks and free-space propagation layers, with the final intensity
sampled by a 2x5 detector grid whose 10 outputs serve as class logits.
Modeling assumptions#
grayscale MNIST images are injected as field amplitudes on the input plane,
each trainable layer is phase-only, with propagation between layers fixed by the chosen spacing, and
the final detector array is treated as a task-specific readout rather than a physically calibrated camera.
This example is best read as a compact differentiable-optics classifier demo, not as a benchmark for state-of-the-art digit recognition.
0 Imports#
We use JAX and Optax for training, Matplotlib for visual diagnostics, and fouriax optics modules for
the phase-mask stack, intermediate intensity monitors, and detector-array readout.
from __future__ import annotations
import urllib.request
from pathlib import Path
import jax
import jax.image
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
from matplotlib.patches import Rectangle
import fouriax as fx
%matplotlib inline
EXAMPLES_ROOT = Path.cwd() / "examples"
EXAMPLES_DATA_DIR = EXAMPLES_ROOT / "data"
EXAMPLES_ARTIFACTS_DIR = EXAMPLES_ROOT / "artifacts"
1 Paths and Parameters#
The parameters control both the learning problem and the optical architecture: device selection, dataset size, number of phase layers, phase-mask downsampling, propagation distance, and optimization hyperparameters.
# Keep the MNIST cache under examples/data by default.
MNIST_URL = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz"
DATA_PATH = Path(str(EXAMPLES_DATA_DIR / 'mnist.npz'))
ARTIFACTS_DIR = Path(str(EXAMPLES_ARTIFACTS_DIR))
PLOT_PATH = ARTIFACTS_DIR / "onn_mnist_field_evolution.png"
SUMMARY_PATH = ARTIFACTS_DIR / "onn_mnist_summary.json"
DEVICE = 'cpu'
SEED = 0
EPOCHS = 10
BATCH_SIZE = 64
LEARNING_RATE = 0.05
NUM_PHASE_LAYERS = 4
PHASE_MASK_DOWNSAMPLE = 4
NYQUIST_FACTOR = 1.0
DISTANCE_UM = 50.0
TRAIN_SAMPLES = 1000
TEST_SAMPLES = 100
PLOT = True
2 Helper Functions#
The helpers load and cache MNIST, then resize the digit images onto the optical working grid. This keeps the rest of the notebook focused on the optical model and training loop rather than on dataset plumbing.
def load_mnist(cache_path: Path) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
cache_path.parent.mkdir(parents=True, exist_ok=True)
if not cache_path.exists():
print(f"Downloading MNIST to {cache_path} ...")
urllib.request.urlretrieve(MNIST_URL, cache_path)
with np.load(cache_path) as data:
x_train = data["x_train"].astype(np.float32) / 255.0
y_train = data["y_train"].astype(np.int32)
x_test = data["x_test"].astype(np.float32) / 255.0
y_test = data["y_test"].astype(np.int32)
return x_train, y_train, x_test, y_test
def resize_images_to_grid(images: np.ndarray, grid: fx.Grid) -> np.ndarray:
arr = jnp.asarray(images, dtype=jnp.float32)[..., None]
resized = jax.image.resize(
arr,
shape=(arr.shape[0], grid.ny, grid.nx, 1),
method="linear",
)[..., 0]
return np.asarray(resized, dtype=np.float32)
3 Setup#
We choose a working grid, build a repeated block of
and attach IntensityMonitor layers so the field evolution can be visualized after training. The
final DetectorArray has 10 cells arranged as 2×5, one for each digit class.
The trainable parameters live on a lower-resolution latent grid and are bilinearly upsampled before being converted into physical phase masks.
jax.config.update("jax_platform_name", DEVICE)
selected_device = jax.devices()[0]
jax.config.update("jax_default_device", selected_device)
print(
"device="
f"{selected_device.platform} kind={getattr(selected_device, 'device_kind', 'unknown')}"
)
input_grid = fx.Grid.from_extent(nx=28, ny=28, dx_um=1.0, dy_um=1.0)
spectrum = fx.Spectrum.from_scalar(1.55)
propagator = fx.plan_propagation(
mode="auto",
grid=input_grid,
spectrum=spectrum,
distance_um=DISTANCE_UM,
nyquist_factor=NYQUIST_FACTOR,
min_padding_factor=2.0,
)
work_grid = propagator.precomputed_grid or input_grid
mask_nx = work_grid.nx // PHASE_MASK_DOWNSAMPLE
mask_ny = work_grid.ny // PHASE_MASK_DOWNSAMPLE
mask_grid = fx.Grid.from_extent(
nx=mask_nx,
ny=mask_ny,
dx_um=(work_grid.nx * work_grid.dx_um) / mask_nx,
dy_um=(work_grid.ny * work_grid.dy_um) / mask_ny,
)
detector_grid = fx.Grid.from_extent(
nx=5,
ny=2,
dx_um=(work_grid.nx * work_grid.dx_um) / 5.0,
dy_um=(work_grid.ny * work_grid.dy_um) / 2.0,
)
detector_array = fx.DetectorArray(
detector_grid=detector_grid,
)
def build_module(raw_params: jnp.ndarray) -> fx.OpticalModule:
layers = [fx.IntensityMonitor(sum_wavelengths=True, output_domain="spatial")]
for i in range(raw_params.shape[0]):
upsampled_latent = jax.image.resize(
raw_params[i],
shape=(work_grid.ny, work_grid.nx),
method="linear",
)
bounded_phase = 2.0 * jnp.pi * jax.nn.sigmoid(upsampled_latent)
layers.append(fx.PhaseMask(phase_map_rad=bounded_phase))
layers.append(propagator)
layers.append(fx.IntensityMonitor(sum_wavelengths=True, output_domain="spatial"))
return fx.OpticalModule(layers=tuple(layers), sensor=detector_array)
x_train, y_train, x_test, y_test = load_mnist(DATA_PATH)
x_train = x_train[:TRAIN_SAMPLES]
y_train = y_train[:TRAIN_SAMPLES]
x_test = x_test[:TEST_SAMPLES]
y_test = y_test[:TEST_SAMPLES]
x_train = resize_images_to_grid(x_train, work_grid)
x_test = resize_images_to_grid(x_test, work_grid)
key = jax.random.PRNGKey(SEED)
phase_params = 0.05 * jax.random.normal(
key,
(NUM_PHASE_LAYERS, mask_grid.ny, mask_grid.nx),
dtype=jnp.float32,
)
device=cpu kind=cpu
W0407 21:11:04.828098 1088825 cpp_gen_intrinsics.cc:74] Empty bitcode string provided for eigen. Optimizations relying on this IR will be disabled.
4 Loss Function and Optimization#
The forward pass maps each input image to detector intensities, reshapes them into 10 logits, and optimizes the negative log-likelihood of the correct class. If \(z_{n,c}\) is the detector logit for sample \(n\) and class \(c\), then the softmax probability is
and the minibatch classification loss is
where \(y_n\) is the ground-truth digit label. Accuracy is evaluated with the same optical model, without adding a separate electronic classifier head.
This makes the task interpretation very direct: the diffractive stack itself is learning to route optical energy toward the detector region associated with the correct class.
def logits_batch(raw_params: jnp.ndarray, images_3d: jnp.ndarray) -> jnp.ndarray:
module = build_module(raw_params)
field = fx.Field(
data=images_3d[:, None, :, :].astype(jnp.complex64),
grid=work_grid,
spectrum=spectrum,
)
return module.measure(field).reshape((images_3d.shape[0], -1))
def batch_loss_fn(
params: jnp.ndarray,
batch: tuple[np.ndarray, np.ndarray] | tuple[jnp.ndarray, jnp.ndarray],
) -> jnp.ndarray:
image_raw, label_raw = batch
images = jnp.asarray(image_raw, dtype=jnp.float32)
labels = jnp.asarray(label_raw, dtype=jnp.int32)
logits = logits_batch(params, images)
log_probs = logits - jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True)
return -jnp.mean(log_probs[jnp.arange(labels.shape[0]), labels])
def batch_accuracy(params: jnp.ndarray, images: np.ndarray, labels: np.ndarray) -> float:
logits = np.asarray(logits_batch(params, jnp.asarray(images)))
pred = np.argmax(logits, axis=1)
return float(np.mean(pred == labels))
optimizer = optax.adam(learning_rate=LEARNING_RATE)
train_data = (x_train, y_train)
val_data = (x_test, y_test)
result = fx.optim.optimize_dataset_optical_module(
init_params=phase_params,
build_module=build_module,
batch_loss_fn=batch_loss_fn,
optimizer=optimizer,
train_data=train_data,
batch_size=BATCH_SIZE,
epochs=EPOCHS,
val_data=val_data,
seed=SEED + 1,
)
A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.4.1 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.
If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.
Traceback (most recent call last): File "<frozen runpy>", line 198, in _run_module_as_main
File "<frozen runpy>", line 88, in _run_code
File "/opt/anaconda3/lib/python3.12/site-packages/ipykernel_launcher.py", line 17, in <module>
app.launch_new_instance()
File "/opt/anaconda3/lib/python3.12/site-packages/traitlets/config/application.py", line 1075, in launch_instance
app.start()
File "/opt/anaconda3/lib/python3.12/site-packages/ipykernel/kernelapp.py", line 701, in start
self.io_loop.start()
File "/opt/anaconda3/lib/python3.12/site-packages/tornado/platform/asyncio.py", line 205, in start
self.asyncio_loop.run_forever()
File "/opt/anaconda3/lib/python3.12/asyncio/base_events.py", line 641, in run_forever
self._run_once()
File "/opt/anaconda3/lib/python3.12/asyncio/base_events.py", line 1986, in _run_once
handle._run()
File "/opt/anaconda3/lib/python3.12/asyncio/events.py", line 88, in _run
self._context.run(self._callback, *self._args)
File "/opt/anaconda3/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 534, in dispatch_queue
await self.process_one()
File "/opt/anaconda3/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 523, in process_one
await dispatch(*args)
File "/opt/anaconda3/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 429, in dispatch_shell
await result
File "/opt/anaconda3/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 767, in execute_request
reply_content = await reply_content
File "/opt/anaconda3/lib/python3.12/site-packages/ipykernel/ipkernel.py", line 429, in do_execute
res = shell.run_cell(
File "/opt/anaconda3/lib/python3.12/site-packages/ipykernel/zmqshell.py", line 549, in run_cell
return super().run_cell(*args, **kwargs)
File "/opt/anaconda3/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3075, in run_cell
result = self._run_cell(
File "/opt/anaconda3/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3130, in _run_cell
result = runner(coro)
File "/opt/anaconda3/lib/python3.12/site-packages/IPython/core/async_helpers.py", line 128, in _pseudo_sync_runner
coro.send(None)
File "/opt/anaconda3/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3334, in run_cell_async
has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
File "/opt/anaconda3/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3517, in run_ast_nodes
if await self.run_code(code, result, async_=asy):
File "/opt/anaconda3/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "/var/folders/h9/2zpp40gn24l3ddb3fnkq7fzh0000gn/T/ipykernel_51485/2366447797.py", line 30, in <module>
result = fx.optim.optimize_dataset_optical_module(
File "/Users/liam/metasurface/fouriax/src/fouriax/optim.py", line 752, in optimize_dataset_optical_module
params_result = optimize_dataset_params(
File "/Users/liam/metasurface/fouriax/src/fouriax/optim.py", line 584, in optimize_dataset_params
epoch_batches = iter_batches(epoch)
File "/Users/liam/metasurface/fouriax/src/fouriax/optim.py", line 530, in iter_batches
return batch_iter_any(train_data, epoch)
File "/Users/liam/metasurface/fouriax/src/fouriax/optim.py", line 727, in batch_iter_fn
return iter_minibatches(
File "/Users/liam/metasurface/fouriax/src/fouriax/optim.py", line 200, in iter_minibatches
from sklearn.utils import shuffle as sklearn_shuffle
File "/opt/anaconda3/lib/python3.12/site-packages/sklearn/__init__.py", line 84, in <module>
from .base import clone
File "/opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py", line 19, in <module>
from .utils._estimator_html_repr import _HTMLDocumentationLinkMixin, estimator_html_repr
File "/opt/anaconda3/lib/python3.12/site-packages/sklearn/utils/__init__.py", line 11, in <module>
from ._chunking import gen_batches, gen_even_slices
File "/opt/anaconda3/lib/python3.12/site-packages/sklearn/utils/_chunking.py", line 8, in <module>
from ._param_validation import Interval, validate_params
File "/opt/anaconda3/lib/python3.12/site-packages/sklearn/utils/_param_validation.py", line 14, in <module>
from .validation import _is_arraylike_not_scalar
File "/opt/anaconda3/lib/python3.12/site-packages/sklearn/utils/validation.py", line 26, in <module>
from ..utils._array_api import _asarray_with_order, _is_numpy_namespace, get_namespace
File "/opt/anaconda3/lib/python3.12/site-packages/sklearn/utils/_array_api.py", line 11, in <module>
from .fixes import parse_version
File "/opt/anaconda3/lib/python3.12/site-packages/sklearn/utils/fixes.py", line 24, in <module>
import pandas as pd
File "/opt/anaconda3/lib/python3.12/site-packages/pandas/__init__.py", line 26, in <module>
from pandas.compat import (
File "/opt/anaconda3/lib/python3.12/site-packages/pandas/compat/__init__.py", line 27, in <module>
from pandas.compat.pyarrow import (
File "/opt/anaconda3/lib/python3.12/site-packages/pandas/compat/pyarrow.py", line 8, in <module>
import pyarrow as pa
File "/opt/anaconda3/lib/python3.12/site-packages/pyarrow/__init__.py", line 65, in <module>
import pyarrow.lib as _lib
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
File ~/metasurface/fouriax/src/fouriax/optim.py:527, in optimize_dataset_params.<locals>.iter_batches(epoch)
526 try:
--> 527 return batch_iter_any(train_data, epoch, rng)
528 except TypeError:
TypeError: optimize_dataset_optical_module.<locals>.batch_iter_fn() takes 2 positional arguments but 3 were given
During handling of the above exception, another exception occurred:
ImportError Traceback (most recent call last)
File /opt/anaconda3/lib/python3.12/site-packages/numpy/core/_multiarray_umath.py:46, in __getattr__(attr_name)
41 # Also print the message (with traceback). This is because old versions
42 # of NumPy unfortunately set up the import to replace (and hide) the
43 # error. The traceback shouldn't be needed, but e.g. pytest plugins
44 # seem to swallow it and we should be failing anyway...
45 sys.stderr.write(msg + tb_msg)
---> 46 raise ImportError(msg)
48 ret = getattr(_multiarray_umath, attr_name, None)
49 if ret is None:
ImportError:
A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.4.1 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.
If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.
A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.4.1 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.
If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.
Traceback (most recent call last): File "<frozen runpy>", line 198, in _run_module_as_main
File "<frozen runpy>", line 88, in _run_code
File "/opt/anaconda3/lib/python3.12/site-packages/ipykernel_launcher.py", line 17, in <module>
app.launch_new_instance()
File "/opt/anaconda3/lib/python3.12/site-packages/traitlets/config/application.py", line 1075, in launch_instance
app.start()
File "/opt/anaconda3/lib/python3.12/site-packages/ipykernel/kernelapp.py", line 701, in start
self.io_loop.start()
File "/opt/anaconda3/lib/python3.12/site-packages/tornado/platform/asyncio.py", line 205, in start
self.asyncio_loop.run_forever()
File "/opt/anaconda3/lib/python3.12/asyncio/base_events.py", line 641, in run_forever
self._run_once()
File "/opt/anaconda3/lib/python3.12/asyncio/base_events.py", line 1986, in _run_once
handle._run()
File "/opt/anaconda3/lib/python3.12/asyncio/events.py", line 88, in _run
self._context.run(self._callback, *self._args)
File "/opt/anaconda3/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 534, in dispatch_queue
await self.process_one()
File "/opt/anaconda3/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 523, in process_one
await dispatch(*args)
File "/opt/anaconda3/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 429, in dispatch_shell
await result
File "/opt/anaconda3/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 767, in execute_request
reply_content = await reply_content
File "/opt/anaconda3/lib/python3.12/site-packages/ipykernel/ipkernel.py", line 429, in do_execute
res = shell.run_cell(
File "/opt/anaconda3/lib/python3.12/site-packages/ipykernel/zmqshell.py", line 549, in run_cell
return super().run_cell(*args, **kwargs)
File "/opt/anaconda3/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3075, in run_cell
result = self._run_cell(
File "/opt/anaconda3/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3130, in _run_cell
result = runner(coro)
File "/opt/anaconda3/lib/python3.12/site-packages/IPython/core/async_helpers.py", line 128, in _pseudo_sync_runner
coro.send(None)
File "/opt/anaconda3/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3334, in run_cell_async
has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
File "/opt/anaconda3/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3517, in run_ast_nodes
if await self.run_code(code, result, async_=asy):
File "/opt/anaconda3/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "/var/folders/h9/2zpp40gn24l3ddb3fnkq7fzh0000gn/T/ipykernel_51485/2366447797.py", line 30, in <module>
result = fx.optim.optimize_dataset_optical_module(
File "/Users/liam/metasurface/fouriax/src/fouriax/optim.py", line 752, in optimize_dataset_optical_module
params_result = optimize_dataset_params(
File "/Users/liam/metasurface/fouriax/src/fouriax/optim.py", line 584, in optimize_dataset_params
epoch_batches = iter_batches(epoch)
File "/Users/liam/metasurface/fouriax/src/fouriax/optim.py", line 530, in iter_batches
return batch_iter_any(train_data, epoch)
File "/Users/liam/metasurface/fouriax/src/fouriax/optim.py", line 727, in batch_iter_fn
return iter_minibatches(
File "/Users/liam/metasurface/fouriax/src/fouriax/optim.py", line 200, in iter_minibatches
from sklearn.utils import shuffle as sklearn_shuffle
File "/opt/anaconda3/lib/python3.12/site-packages/sklearn/__init__.py", line 84, in <module>
from .base import clone
File "/opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py", line 19, in <module>
from .utils._estimator_html_repr import _HTMLDocumentationLinkMixin, estimator_html_repr
File "/opt/anaconda3/lib/python3.12/site-packages/sklearn/utils/__init__.py", line 11, in <module>
from ._chunking import gen_batches, gen_even_slices
File "/opt/anaconda3/lib/python3.12/site-packages/sklearn/utils/_chunking.py", line 8, in <module>
from ._param_validation import Interval, validate_params
File "/opt/anaconda3/lib/python3.12/site-packages/sklearn/utils/_param_validation.py", line 14, in <module>
from .validation import _is_arraylike_not_scalar
File "/opt/anaconda3/lib/python3.12/site-packages/sklearn/utils/validation.py", line 26, in <module>
from ..utils._array_api import _asarray_with_order, _is_numpy_namespace, get_namespace
File "/opt/anaconda3/lib/python3.12/site-packages/sklearn/utils/_array_api.py", line 11, in <module>
from .fixes import parse_version
File "/opt/anaconda3/lib/python3.12/site-packages/sklearn/utils/fixes.py", line 24, in <module>
import pandas as pd
File "/opt/anaconda3/lib/python3.12/site-packages/pandas/__init__.py", line 49, in <module>
from pandas.core.api import (
File "/opt/anaconda3/lib/python3.12/site-packages/pandas/core/api.py", line 1, in <module>
from pandas._libs import (
File "/opt/anaconda3/lib/python3.12/site-packages/pandas/_libs/__init__.py", line 17, in <module>
import pandas._libs.pandas_datetime # noqa: F401 # isort: skip # type: ignore[reportUnusedImport]
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
File ~/metasurface/fouriax/src/fouriax/optim.py:527, in optimize_dataset_params.<locals>.iter_batches(epoch)
526 try:
--> 527 return batch_iter_any(train_data, epoch, rng)
528 except TypeError:
TypeError: optimize_dataset_optical_module.<locals>.batch_iter_fn() takes 2 positional arguments but 3 were given
During handling of the above exception, another exception occurred:
ImportError Traceback (most recent call last)
File /opt/anaconda3/lib/python3.12/site-packages/numpy/core/_multiarray_umath.py:46, in __getattr__(attr_name)
41 # Also print the message (with traceback). This is because old versions
42 # of NumPy unfortunately set up the import to replace (and hide) the
43 # error. The traceback shouldn't be needed, but e.g. pytest plugins
44 # seem to swallow it and we should be failing anyway...
45 sys.stderr.write(msg + tb_msg)
---> 46 raise ImportError(msg)
48 ret = getattr(_multiarray_umath, attr_name, None)
49 if ret is None:
ImportError:
A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.4.1 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.
If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.
batch=001/010 val_loss=36.778109 train_loss=25.567694 epoch_time=6.16s elapsed=6.16s [best]
batch=002/010 val_loss=10.723537 train_loss=15.951373 epoch_time=3.90s elapsed=10.06s [best]
batch=003/010 val_loss=6.842509 train_loss=5.668540 epoch_time=3.85s elapsed=13.91s [best]
batch=004/010 val_loss=4.999387 train_loss=5.187243 epoch_time=3.91s elapsed=17.81s [best]
batch=005/010 val_loss=4.098100 train_loss=1.708079 epoch_time=3.90s elapsed=21.72s [best]
batch=006/010 val_loss=3.864942 train_loss=2.445751 epoch_time=3.92s elapsed=25.63s [best]
batch=007/010 val_loss=3.319789 train_loss=0.800965 epoch_time=3.91s elapsed=29.54s [best]
batch=008/010 val_loss=2.935117 train_loss=0.241678 epoch_time=3.93s elapsed=33.47s [best]
batch=009/010 val_loss=2.680075 train_loss=0.869893 epoch_time=3.92s elapsed=37.38s [best]
batch=010/010 val_loss=2.751285 train_loss=0.203986 epoch_time=3.92s elapsed=41.30s
5 Evaluation#
After training, we report train/test accuracy and inspect one sample digit by observing the monitored intensity checkpoints through the best optical module. This ties the classification result back to a physical picture of how the field evolves across the diffractive layers.
train_acc = batch_accuracy(result.params_result.best_params, x_train, y_train)
test_acc = batch_accuracy(result.params_result.best_params, x_test, y_test)
final_val_loss = (
float(result.params_result.final_val_metrics["val_loss"])
if result.params_result.final_val_metrics
else float("nan")
)
print(
f"final_train_acc={train_acc:.4f} final_test_acc={test_acc:.4f} "
f"final_val_loss={final_val_loss:.4f}"
)
module = result.best_module
sample_idx = 0
test_image = jnp.asarray(x_test[sample_idx], dtype=jnp.float32)
sample_field = fx.Field(
data=test_image[None, :, :].astype(jnp.complex64),
grid=work_grid,
spectrum=spectrum,
)
final_train_acc=0.9230 final_test_acc=0.8300 final_val_loss=2.7513
6 Plot Results#
The top row visualizes the propagated intensity after each stage, while the bottom row shows the learned phase masks. The final panel overlays the detector regions used as class readouts.
A useful read of this figure is architectural rather than aesthetic: it shows whether the network has learned progressively structured energy routing that culminates in a separable detector response.
if PLOT:
_, intensity_steps = module.observe(sample_field)
phase_masks = [
np.asarray(stage.phase_map_rad)
for stage in module.layers
if isinstance(stage, fx.PhaseMask)
]
titles = ["Input"] + [f"After Propagation {i + 1}" for i in range(len(intensity_steps) - 1)]
n_cols = max(len(intensity_steps), len(phase_masks))
fig_field, axes = plt.subplots(
2,
n_cols,
figsize=(max(6.0, 2.8 * n_cols), 6.8),
squeeze=False,
)
for col, ax in enumerate(axes[0]):
if col >= len(intensity_steps):
ax.axis("off")
continue
title = titles[col]
image = intensity_steps[col]
im = ax.imshow(np.asarray(image), cmap="inferno")
ax.set_title(title, fontsize=9)
ax.set_xticks([])
ax.set_yticks([])
fig_field.colorbar(im, ax=ax, fraction=0.046, pad=0.03)
if col == len(intensity_steps) - 1:
cell_w = work_grid.nx / detector_grid.nx
cell_h = work_grid.ny / detector_grid.ny
digit = 0
for row in range(detector_grid.ny):
for det_col in range(detector_grid.nx):
ax.add_patch(
Rectangle(
(det_col * cell_w - 0.5, row * cell_h - 0.5),
cell_w,
cell_h,
fill=False,
edgecolor="red",
linewidth=1.2,
)
)
ax.text(
det_col * cell_w + 0.5 * cell_w - 0.5,
row * cell_h + 0.5 * cell_h - 0.5,
str(digit),
color="red",
fontsize=10,
fontweight="bold",
ha="center",
va="center",
)
digit += 1
for col, ax in enumerate(axes[1]):
if col >= len(phase_masks):
ax.axis("off")
continue
phase = phase_masks[col]
im = ax.imshow(phase, cmap="twilight", vmin=0.0, vmax=2.0 * np.pi)
ax.set_title(f"Phase Mask {col + 1}", fontsize=9)
ax.set_xticks([])
ax.set_yticks([])
fig_field.colorbar(im, ax=ax, fraction=0.046, pad=0.03)
fig_field.suptitle("ONN Intensity Checkpoints and Learned Phase Masks", y=0.98)
fig_field.tight_layout(rect=(0.0, 0.0, 1.0, 0.93))
fig_field.savefig(PLOT_PATH, dpi=150)
plt.show()
print(f"saved: {PLOT_PATH}")
saved: /Users/liam/metasurface/fouriax/examples/artifacts/onn_mnist_field_evolution.png