flare.train

+default-train-opts+

iter!

(iter! model optimizer get-loss-node data-gen opts)

run-batch!

(run-batch! factory get-loss-node batch)

run forward/backward passes for each of the loss graphs created by loss-module on the batch.

Each entry in batch as treated as arguments to module/graph to build an expression for example loss

train!

(train! model get-loss-node data-batch-fn train-opts)(train! model get-loss-node data-batch-fn)

Trains a model using the following parameters:

  • model a flare.model/PModel instance
  • get-loss-node a function which takes each data input and returns a node representing the loss in the input
  • data-batch-fn assumed that calling this function yields a lazy sequence of data batches, looping over each batch and calls get-loss-node
  • train-opts, see spec above