fouriax.optim#
Optimization-loop helpers for example scripts and notebooks.
Functions
|
Apply one Optax update step and return (params, opt_state). |
|
Yield (lo, hi) slices that cover n_items in minibatches. |
|
Maximize power concentration around a target focal spot. |
|
Yield aligned minibatches from one or more arrays. |
|
Return the number of minibatches produced for a dataset size. |
|
Optimize optical and decoder params over minibatches using a shared loss. |
|
Optimize an optical module over minibatches using a shared train/val loss. |
|
Optimize arbitrary params over minibatches with optional validation tracking. |
|
Run an Optax optimization loop and return best module + optimization metadata. |
|
Sample batch indices using sklearn's NumPy-style random_state handling. |
|
Return True on periodic log steps and the final step. |
|
Apply the same random permutation to multiple arrays. |
|
Split arrays into (train_arrays, val_arrays) with aligned indexing. |
Classes
|
Track the best scalar metric and a copied snapshot of a JAX pytree. |
|
Outputs from dataset optimization over minibatches. |
|
Hybrid dataset optimization outputs for optical module + decoder params. |
|
Dataset optimization outputs plus built optical modules. |
|
Outputs from optimize_optical_module. |
|
Validation metrics recorded at a training step/epoch boundary. |
- num_batches(n_items, batch_size, *, drop_last=False)#
Return the number of minibatches produced for a dataset size.
- Parameters:
n_items (int)
batch_size (int)
drop_last (bool)
- Return type:
int
- batch_slices(n_items, batch_size, *, drop_last=False)#
Yield (lo, hi) slices that cover n_items in minibatches.
- Parameters:
n_items (int)
batch_size (int)
drop_last (bool)
- Return type:
Iterator[tuple[int, int]]
- iter_minibatches(*arrays, batch_size, seed=None, shuffle=True, drop_last=False)#
Yield aligned minibatches from one or more arrays.
- Parameters:
arrays (ArrayT)
batch_size (int)
seed (int | None)
shuffle (bool)
drop_last (bool)
- Return type:
_MinibatchIterable
- random_batch_indices(n_items, batch_size, *, seed=None, replace=True)#
Sample batch indices using sklearn’s NumPy-style random_state handling.
- Parameters:
n_items (int)
batch_size (int)
seed (int | None)
replace (bool)
- Return type:
ndarray
- shuffled_arrays(*arrays, seed=None)#
Apply the same random permutation to multiple arrays.
- Parameters:
arrays (ArrayT)
seed (int | None)
- Return type:
tuple[ArrayT, …]
- train_val_split(*arrays, val_fraction, seed=None, shuffle=True)#
Split arrays into (train_arrays, val_arrays) with aligned indexing.
- Parameters:
arrays (ArrayT)
val_fraction (float)
seed (int | None)
shuffle (bool)
- Return type:
tuple[tuple[ArrayT, …], tuple[ArrayT, …]]
- focal_spot_loss(intensity, target_xy, window_px=2, eps=1e-12)#
Maximize power concentration around a target focal spot.
- Parameters:
intensity (Array)
target_xy (tuple[int, int])
window_px (int)
eps (float)
- Return type:
Array
- class BestValueTracker(mode='min')#
Bases:
Generic[ParamsT]Track the best scalar metric and a copied snapshot of a JAX pytree.
- Parameters:
mode (Literal['min', 'max'])
- mode: Literal['min', 'max'] = 'min'#
- best_value: float#
- best_state: ParamsT | None = None#
- best_step: int | None = None#
- update(value, state, *, step=None)#
- Parameters:
value (float)
state (ParamsT)
step (int | None)
- Return type:
bool
- class DatasetOptResult(best_params, final_params, train_loss_history, val_history, best_metric_name, best_metric_value, best_step, best_epoch, final_train_loss, final_val_metrics)#
Bases:
Generic[ParamsT]Outputs from dataset optimization over minibatches.
- Parameters:
best_params (ParamsT)
final_params (ParamsT)
train_loss_history (list[float])
val_history (list[ValidationRecord])
best_metric_name (str)
best_metric_value (float)
best_step (int | None)
best_epoch (int | None)
final_train_loss (float)
final_val_metrics (dict[str, float] | None)
- best_params: ParamsT#
- final_params: ParamsT#
- train_loss_history: list[float]#
- val_history: list[ValidationRecord]#
- best_metric_name: str#
- best_metric_value: float#
- best_step: int | None#
- best_epoch: int | None#
- final_train_loss: float#
- final_val_metrics: dict[str, float] | None#
- class HybridModuleDatasetOptResult(best_module, final_module, best_optical_params, final_optical_params, best_decoder_params, final_decoder_params, params_result)#
Bases:
Generic[ModuleT,ParamsT,DecoderParamsT]Hybrid dataset optimization outputs for optical module + decoder params.
- Parameters:
best_module (ModuleT)
final_module (ModuleT)
best_optical_params (ParamsT)
final_optical_params (ParamsT)
best_decoder_params (DecoderParamsT)
final_decoder_params (DecoderParamsT)
params_result (DatasetOptResult[dict[str, Any]])
- best_module: ModuleT#
- final_module: ModuleT#
- best_optical_params: ParamsT#
- final_optical_params: ParamsT#
- best_decoder_params: DecoderParamsT#
- final_decoder_params: DecoderParamsT#
- params_result: DatasetOptResult[dict[str, Any]]#
- class ModuleOptResult(best_module, best_params, history, best_loss, final_loss, best_step)#
Bases:
Generic[ModuleT,ParamsT]Outputs from optimize_optical_module.
- Parameters:
best_module (ModuleT)
best_params (ParamsT)
history (list[float])
best_loss (float)
final_loss (float)
best_step (int | None)
- best_module: ModuleT#
- best_params: ParamsT#
- history: list[float]#
- best_loss: float#
- final_loss: float#
- best_step: int | None#
- class ModuleDatasetOptResult(best_module, final_module, params_result)#
Bases:
Generic[ModuleT,ParamsT]Dataset optimization outputs plus built optical modules.
- Parameters:
best_module (ModuleT)
final_module (ModuleT)
params_result (DatasetOptResult[ParamsT])
- best_module: ModuleT#
- final_module: ModuleT#
- params_result: DatasetOptResult[ParamsT]#
- class ValidationRecord(step, epoch, metrics)#
Bases:
objectValidation metrics recorded at a training step/epoch boundary.
- Parameters:
step (int)
epoch (int)
metrics (dict[str, float])
- step: int#
- epoch: int#
- metrics: dict[str, float]#
- apply_optax_updates(optimizer, params, opt_state, grads)#
Apply one Optax update step and return (params, opt_state).
- Parameters:
optimizer (GradientTransformation)
params (ParamsT)
opt_state (OptStateT)
grads (ParamsT)
- Return type:
tuple[ParamsT, OptStateT]
- optimize_dataset_hybrid_module(*, init_optical_params, init_decoder_params, build_module, batch_loss_fn=None, sample_loss_fn=None, train_data, batch_size, epochs, val_data=None, optimizer=None, optical_optimizer=None, decoder_optimizer=None, val_every_epochs=1, val_every_steps=0, log_every_steps=0, report_batch_progress=False, jit=True, seed=0, drop_last_train=False)#
Optimize optical and decoder params over minibatches using a shared loss.
This wrapper mirrors
optimize_dataset_optical_module()but carries two parameter groups: optical parameters used to build the module and decoder parameters used by the hybrid loss.Exactly one of
batch_loss_fnorsample_loss_fnmust be provided. Whensample_loss_fnis used, it is vmapped over each minibatch and then averaged into a batch loss.Optimizer configuration can be supplied in one of two ways:
pass
optimizerdirectly for a single optimizer over the combined parameter dictpass both
optical_optimizeranddecoder_optimizerto build an internaloptax.multi_transform(...)optimizer with separate groups
- Parameters:
init_optical_params (ParamsT)
init_decoder_params (DecoderParamsT)
build_module (Callable[[ParamsT], ModuleT])
batch_loss_fn (Callable[[ParamsT, DecoderParamsT, Any], Array] | None)
sample_loss_fn (Callable[[ParamsT, DecoderParamsT, Any], Array] | None)
train_data (Any)
batch_size (int)
epochs (int)
val_data (Any | None)
optimizer (GradientTransformation | None)
optical_optimizer (GradientTransformation | None)
decoder_optimizer (GradientTransformation | None)
val_every_epochs (int)
val_every_steps (int)
log_every_steps (int)
report_batch_progress (bool)
jit (bool)
seed (int)
drop_last_train (bool)
- Return type:
HybridModuleDatasetOptResult[ModuleT, ParamsT, DecoderParamsT]
- optimize_dataset_optical_module(*, init_params, build_module, batch_loss_fn=None, sample_loss_fn=None, optimizer, train_data, batch_size, epochs, val_data=None, val_every_epochs=1, val_every_steps=0, log_every_steps=0, report_batch_progress=False, jit=True, seed=0, drop_last_train=False)#
Optimize an optical module over minibatches using a shared train/val loss.
This is a high-level convenience wrapper: - minibatching is derived internally from iter_minibatches(…) - exactly one of batch_loss_fn or sample_loss_fn must be provided - validation uses the same effective batch loss averaged over val_data - reporting uses a uniform built-in console formatter
- Parameters:
init_params (ParamsT)
build_module (Callable[[ParamsT], ModuleT])
batch_loss_fn (Callable[[ParamsT, Any], Array] | None)
sample_loss_fn (Callable[[ParamsT, Any], Array] | None)
optimizer (GradientTransformation)
train_data (Any)
batch_size (int)
epochs (int)
val_data (Any | None)
val_every_epochs (int)
val_every_steps (int)
log_every_steps (int)
report_batch_progress (bool)
jit (bool)
seed (int)
drop_last_train (bool)
- Return type:
ModuleDatasetOptResult[ModuleT, ParamsT]
- optimize_dataset_params(*, init_params, optimizer, train_data, batch_iter_fn, train_loss_fn, epochs, val_eval_fn=None, val_data=None, val_every_epochs=1, val_every_steps=0, log_every_steps=50, report_batch_progress=False, select_metric=None, select_mode='min', jit=True, reporter=None, rng=None)#
Optimize arbitrary params over minibatches with optional validation tracking.
batch_iter_fn should typically accept (train_data, epoch, rng) and return an iterable of batches for that epoch. Simpler callables that accept fewer positional arguments are also supported.
- Parameters:
init_params (ParamsT)
optimizer (GradientTransformation)
train_data (TrainDataT)
batch_iter_fn (Callable[[TrainDataT], Iterable[BatchT]] | Callable[[TrainDataT, int], Iterable[BatchT]] | Callable[[TrainDataT, int, Any], Iterable[BatchT]])
train_loss_fn (Callable[[ParamsT, BatchT], Array])
epochs (int)
val_eval_fn (Callable[[ParamsT, ValDataT], Mapping[str, Any]] | None)
val_data (ValDataT | None)
val_every_epochs (int)
val_every_steps (int)
log_every_steps (int)
report_batch_progress (bool)
select_metric (str | None)
select_mode (Literal['min', 'max'])
jit (bool)
reporter (Callable[[dict[str, Any]], None] | None)
rng (Any | None)
- Return type:
DatasetOptResult[ParamsT]
- optimize_optical_module(*, init_params, build_module, loss_fn, optimizer, steps, log_every=50, mode='min', jit=True, reporter=None)#
Run an Optax optimization loop and return best module + optimization metadata.
- Parameters:
init_params (ParamsT)
build_module (Callable[[ParamsT], ModuleT])
loss_fn (Callable[[ParamsT], Array])
optimizer (GradientTransformation)
steps (int)
log_every (int)
mode (Literal['min', 'max'])
jit (bool)
reporter (Callable[[int, float], None] | None)
- Return type:
ModuleOptResult[ModuleT, ParamsT]
- should_log_step(step, *, every, total_steps)#
Return True on periodic log steps and the final step.
- Parameters:
step (int)
every (int)
total_steps (int)
- Return type:
bool