fouriax.optim#

Optimization-loop helpers for example scripts and notebooks.

Functions

apply_optax_updates(optimizer, params, ...)

Apply one Optax update step and return (params, opt_state).

batch_slices(n_items, batch_size, *[, drop_last])

Yield (lo, hi) slices that cover n_items in minibatches.

focal_spot_loss(intensity, target_xy[, ...])

Maximize power concentration around a target focal spot.

iter_minibatches(*arrays, batch_size[, ...])

Yield aligned minibatches from one or more arrays.

num_batches(n_items, batch_size, *[, drop_last])

Return the number of minibatches produced for a dataset size.

optimize_dataset_hybrid_module(*, ...[, ...])

Optimize optical and decoder params over minibatches using a shared loss.

optimize_dataset_optical_module(*, ...[, ...])

Optimize an optical module over minibatches using a shared train/val loss.

optimize_dataset_params(*, init_params, ...)

Optimize arbitrary params over minibatches with optional validation tracking.

optimize_optical_module(*, init_params, ...)

Run an Optax optimization loop and return best module + optimization metadata.

random_batch_indices(n_items, batch_size, *)

Sample batch indices using sklearn's NumPy-style random_state handling.

should_log_step(step, *, every, total_steps)

Return True on periodic log steps and the final step.

shuffled_arrays(*arrays[, seed])

Apply the same random permutation to multiple arrays.

train_val_split(*arrays, val_fraction[, ...])

Split arrays into (train_arrays, val_arrays) with aligned indexing.

Classes

BestValueTracker([mode])

Track the best scalar metric and a copied snapshot of a JAX pytree.

DatasetOptResult(best_params, final_params, ...)

Outputs from dataset optimization over minibatches.

HybridModuleDatasetOptResult(best_module, ...)

Hybrid dataset optimization outputs for optical module + decoder params.

ModuleDatasetOptResult(best_module, ...)

Dataset optimization outputs plus built optical modules.

ModuleOptResult(best_module, best_params, ...)

Outputs from optimize_optical_module.

ValidationRecord(step, epoch, metrics)

Validation metrics recorded at a training step/epoch boundary.

num_batches(n_items, batch_size, *, drop_last=False)#

Return the number of minibatches produced for a dataset size.

Parameters:
  • n_items (int)

  • batch_size (int)

  • drop_last (bool)

Return type:

int

batch_slices(n_items, batch_size, *, drop_last=False)#

Yield (lo, hi) slices that cover n_items in minibatches.

Parameters:
  • n_items (int)

  • batch_size (int)

  • drop_last (bool)

Return type:

Iterator[tuple[int, int]]

iter_minibatches(*arrays, batch_size, seed=None, shuffle=True, drop_last=False)#

Yield aligned minibatches from one or more arrays.

Parameters:
  • arrays (ArrayT)

  • batch_size (int)

  • seed (int | None)

  • shuffle (bool)

  • drop_last (bool)

Return type:

_MinibatchIterable

random_batch_indices(n_items, batch_size, *, seed=None, replace=True)#

Sample batch indices using sklearn’s NumPy-style random_state handling.

Parameters:
  • n_items (int)

  • batch_size (int)

  • seed (int | None)

  • replace (bool)

Return type:

ndarray

shuffled_arrays(*arrays, seed=None)#

Apply the same random permutation to multiple arrays.

Parameters:
  • arrays (ArrayT)

  • seed (int | None)

Return type:

tuple[ArrayT, …]

train_val_split(*arrays, val_fraction, seed=None, shuffle=True)#

Split arrays into (train_arrays, val_arrays) with aligned indexing.

Parameters:
  • arrays (ArrayT)

  • val_fraction (float)

  • seed (int | None)

  • shuffle (bool)

Return type:

tuple[tuple[ArrayT, …], tuple[ArrayT, …]]

focal_spot_loss(intensity, target_xy, window_px=2, eps=1e-12)#

Maximize power concentration around a target focal spot.

Parameters:
  • intensity (Array)

  • target_xy (tuple[int, int])

  • window_px (int)

  • eps (float)

Return type:

Array

class BestValueTracker(mode='min')#

Bases: Generic[ParamsT]

Track the best scalar metric and a copied snapshot of a JAX pytree.

Parameters:

mode (Literal['min', 'max'])

mode: Literal['min', 'max'] = 'min'#
best_value: float#
best_state: ParamsT | None = None#
best_step: int | None = None#
update(value, state, *, step=None)#
Parameters:
  • value (float)

  • state (ParamsT)

  • step (int | None)

Return type:

bool

class DatasetOptResult(best_params, final_params, train_loss_history, val_history, best_metric_name, best_metric_value, best_step, best_epoch, final_train_loss, final_val_metrics)#

Bases: Generic[ParamsT]

Outputs from dataset optimization over minibatches.

Parameters:
  • best_params (ParamsT)

  • final_params (ParamsT)

  • train_loss_history (list[float])

  • val_history (list[ValidationRecord])

  • best_metric_name (str)

  • best_metric_value (float)

  • best_step (int | None)

  • best_epoch (int | None)

  • final_train_loss (float)

  • final_val_metrics (dict[str, float] | None)

best_params: ParamsT#
final_params: ParamsT#
train_loss_history: list[float]#
val_history: list[ValidationRecord]#
best_metric_name: str#
best_metric_value: float#
best_step: int | None#
best_epoch: int | None#
final_train_loss: float#
final_val_metrics: dict[str, float] | None#
class HybridModuleDatasetOptResult(best_module, final_module, best_optical_params, final_optical_params, best_decoder_params, final_decoder_params, params_result)#

Bases: Generic[ModuleT, ParamsT, DecoderParamsT]

Hybrid dataset optimization outputs for optical module + decoder params.

Parameters:
  • best_module (ModuleT)

  • final_module (ModuleT)

  • best_optical_params (ParamsT)

  • final_optical_params (ParamsT)

  • best_decoder_params (DecoderParamsT)

  • final_decoder_params (DecoderParamsT)

  • params_result (DatasetOptResult[dict[str, Any]])

best_module: ModuleT#
final_module: ModuleT#
best_optical_params: ParamsT#
final_optical_params: ParamsT#
best_decoder_params: DecoderParamsT#
final_decoder_params: DecoderParamsT#
params_result: DatasetOptResult[dict[str, Any]]#
class ModuleOptResult(best_module, best_params, history, best_loss, final_loss, best_step)#

Bases: Generic[ModuleT, ParamsT]

Outputs from optimize_optical_module.

Parameters:
  • best_module (ModuleT)

  • best_params (ParamsT)

  • history (list[float])

  • best_loss (float)

  • final_loss (float)

  • best_step (int | None)

best_module: ModuleT#
best_params: ParamsT#
history: list[float]#
best_loss: float#
final_loss: float#
best_step: int | None#
class ModuleDatasetOptResult(best_module, final_module, params_result)#

Bases: Generic[ModuleT, ParamsT]

Dataset optimization outputs plus built optical modules.

Parameters:
  • best_module (ModuleT)

  • final_module (ModuleT)

  • params_result (DatasetOptResult[ParamsT])

best_module: ModuleT#
final_module: ModuleT#
params_result: DatasetOptResult[ParamsT]#
class ValidationRecord(step, epoch, metrics)#

Bases: object

Validation metrics recorded at a training step/epoch boundary.

Parameters:
  • step (int)

  • epoch (int)

  • metrics (dict[str, float])

step: int#
epoch: int#
metrics: dict[str, float]#
apply_optax_updates(optimizer, params, opt_state, grads)#

Apply one Optax update step and return (params, opt_state).

Parameters:
  • optimizer (GradientTransformation)

  • params (ParamsT)

  • opt_state (OptStateT)

  • grads (ParamsT)

Return type:

tuple[ParamsT, OptStateT]

optimize_dataset_hybrid_module(*, init_optical_params, init_decoder_params, build_module, batch_loss_fn=None, sample_loss_fn=None, train_data, batch_size, epochs, val_data=None, optimizer=None, optical_optimizer=None, decoder_optimizer=None, val_every_epochs=1, val_every_steps=0, log_every_steps=0, report_batch_progress=False, jit=True, seed=0, drop_last_train=False)#

Optimize optical and decoder params over minibatches using a shared loss.

This wrapper mirrors optimize_dataset_optical_module() but carries two parameter groups: optical parameters used to build the module and decoder parameters used by the hybrid loss.

Exactly one of batch_loss_fn or sample_loss_fn must be provided. When sample_loss_fn is used, it is vmapped over each minibatch and then averaged into a batch loss.

Optimizer configuration can be supplied in one of two ways:

  • pass optimizer directly for a single optimizer over the combined parameter dict

  • pass both optical_optimizer and decoder_optimizer to build an internal optax.multi_transform(...) optimizer with separate groups

Parameters:
  • init_optical_params (ParamsT)

  • init_decoder_params (DecoderParamsT)

  • build_module (Callable[[ParamsT], ModuleT])

  • batch_loss_fn (Callable[[ParamsT, DecoderParamsT, Any], Array] | None)

  • sample_loss_fn (Callable[[ParamsT, DecoderParamsT, Any], Array] | None)

  • train_data (Any)

  • batch_size (int)

  • epochs (int)

  • val_data (Any | None)

  • optimizer (GradientTransformation | None)

  • optical_optimizer (GradientTransformation | None)

  • decoder_optimizer (GradientTransformation | None)

  • val_every_epochs (int)

  • val_every_steps (int)

  • log_every_steps (int)

  • report_batch_progress (bool)

  • jit (bool)

  • seed (int)

  • drop_last_train (bool)

Return type:

HybridModuleDatasetOptResult[ModuleT, ParamsT, DecoderParamsT]

optimize_dataset_optical_module(*, init_params, build_module, batch_loss_fn=None, sample_loss_fn=None, optimizer, train_data, batch_size, epochs, val_data=None, val_every_epochs=1, val_every_steps=0, log_every_steps=0, report_batch_progress=False, jit=True, seed=0, drop_last_train=False)#

Optimize an optical module over minibatches using a shared train/val loss.

This is a high-level convenience wrapper: - minibatching is derived internally from iter_minibatches(…) - exactly one of batch_loss_fn or sample_loss_fn must be provided - validation uses the same effective batch loss averaged over val_data - reporting uses a uniform built-in console formatter

Parameters:
  • init_params (ParamsT)

  • build_module (Callable[[ParamsT], ModuleT])

  • batch_loss_fn (Callable[[ParamsT, Any], Array] | None)

  • sample_loss_fn (Callable[[ParamsT, Any], Array] | None)

  • optimizer (GradientTransformation)

  • train_data (Any)

  • batch_size (int)

  • epochs (int)

  • val_data (Any | None)

  • val_every_epochs (int)

  • val_every_steps (int)

  • log_every_steps (int)

  • report_batch_progress (bool)

  • jit (bool)

  • seed (int)

  • drop_last_train (bool)

Return type:

ModuleDatasetOptResult[ModuleT, ParamsT]

optimize_dataset_params(*, init_params, optimizer, train_data, batch_iter_fn, train_loss_fn, epochs, val_eval_fn=None, val_data=None, val_every_epochs=1, val_every_steps=0, log_every_steps=50, report_batch_progress=False, select_metric=None, select_mode='min', jit=True, reporter=None, rng=None)#

Optimize arbitrary params over minibatches with optional validation tracking.

batch_iter_fn should typically accept (train_data, epoch, rng) and return an iterable of batches for that epoch. Simpler callables that accept fewer positional arguments are also supported.

Parameters:
  • init_params (ParamsT)

  • optimizer (GradientTransformation)

  • train_data (TrainDataT)

  • batch_iter_fn (Callable[[TrainDataT], Iterable[BatchT]] | Callable[[TrainDataT, int], Iterable[BatchT]] | Callable[[TrainDataT, int, Any], Iterable[BatchT]])

  • train_loss_fn (Callable[[ParamsT, BatchT], Array])

  • epochs (int)

  • val_eval_fn (Callable[[ParamsT, ValDataT], Mapping[str, Any]] | None)

  • val_data (ValDataT | None)

  • val_every_epochs (int)

  • val_every_steps (int)

  • log_every_steps (int)

  • report_batch_progress (bool)

  • select_metric (str | None)

  • select_mode (Literal['min', 'max'])

  • jit (bool)

  • reporter (Callable[[dict[str, Any]], None] | None)

  • rng (Any | None)

Return type:

DatasetOptResult[ParamsT]

optimize_optical_module(*, init_params, build_module, loss_fn, optimizer, steps, log_every=50, mode='min', jit=True, reporter=None)#

Run an Optax optimization loop and return best module + optimization metadata.

Parameters:
  • init_params (ParamsT)

  • build_module (Callable[[ParamsT], ModuleT])

  • loss_fn (Callable[[ParamsT], Array])

  • optimizer (GradientTransformation)

  • steps (int)

  • log_every (int)

  • mode (Literal['min', 'max'])

  • jit (bool)

  • reporter (Callable[[int, float], None] | None)

Return type:

ModuleOptResult[ModuleT, ParamsT]

should_log_step(step, *, every, total_steps)#

Return True on periodic log steps and the final step.

Parameters:
  • step (int)

  • every (int)

  • total_steps (int)

Return type:

bool