afnio.trainer.trainer#

Functions

get_batch_size(batch)

Returns the number of samples in a batch, supporting all DataLoader output formats:

Classes

MinutesPerStepColumn([table_column])

Show average minutes per step as Xm/step, styled like TimeElapsedColumn, only for training.

Trainer(*[, max_epochs, ...])

class afnio.trainer.trainer.MinutesPerStepColumn(table_column=None)[source]#

Bases: ProgressColumn

Show average minutes per step as Xm/step, styled like TimeElapsedColumn, only for training.

get_table_column()#

Get a table column, used to build tasks table.

max_refresh: Optional[float] = None#
render(task)[source]#

Should return a renderable object.

class afnio.trainer.trainer.Trainer(*, max_epochs=None, enable_checkpointing=True, enable_progress_bar=True, enable_agent_summary=True, default_root_dir=None)[source]#

Bases: object

fit(agent, train_dataloader=None, val_dataloader=None, ckpt_path=None, llm_clients=[])[source]#

Runs the full optimization routine.

Parameters:
  • agent (Module) – AI agent (or flow) to fit.

  • train_dataloader (Union[Iterable[Any], DataLoader, None]) – An iterable or DataLoader specifying training samples.

  • val_dataloader (Union[Iterable[Any], DataLoader, None]) – An iterable or or DataLoader specifying validation samples.

  • ckpt_path (Union[str, Path, None]) – Path of the checkpoint from which training is resumed. Otherwise, if there is no checkpoint file at the path, an exception is raised.

  • llm_clients (Optional[List[BaseModel]]) – Optional list of LLM clients used during training. If provided this list is used to calculate the total cost of training (in USD).

Raises:

TypeError – If agent is not Module.

predict()[source]#
test(agent, test_dataloader=None, llm_clients=[])[source]#
validate(agent, val_dataloader=None, llm_clients=[])[source]#
afnio.trainer.trainer.get_batch_size(batch)[source]#

Returns the number of samples in a batch, supporting all DataLoader output formats:

  • If batch is a dict: returns the length of the first value if it’s a list/tuple/Variable, else 1.

  • If batch is a tuple or list: recursively checks the first element.

  • If batch is a Variable: returns the length of its .data attribute if possible, else 1.

  • Otherwise: returns 1 (single sample).

Raises:

ValueError – If the batch is empty (size 0).