Train, evaluate or predict a model


This page shows how to train, evaluate or predict a model using the Keras-Style API.

You may refer to the User Guide page to see how to define a model in Python or Scala correspondingly.

You may refer to Layers section to find all the available layers.

After defining a model with the Keras-Style API, you can call the following methods on the model:


Compile

Configure the learning process. Must be called before fit or evaluate.

Scala:

compile(optimizer, loss, metrics = null)

Parameters:

Alternatively, one can pass in the corresponding Keras-Style string representations when calling compile. For example: optimizer = "sgd", loss = "mse", metrics = List("accuracy")

Python

compile(optimizer, loss, metrics=None)

Parameters:


Fit

Train a model for a fixed number of epochs on a DataSet.

Scala:

fit(x, batchSize = 32,nbEpoch = 10, validationData = null)

Parameters:

Python

fit(x, y=None, batch_size=32, nb_epoch=10, validation_data=None, distributed=True)

Parameters:


Evaluate

Evaluate a model on a given dataset in distributed mode.

Scala:

evaluate(x, batchSize = 32)

Parameters:

Python

evaluate(x, y=None, batch_size=32)

Parameters:


Predict

Use a model to do prediction.

Scala:

predict(x, batchPerThread = 4)

Parameters:

Python

predict(x, batch_per_thread=4, distributed=True)

Parameters:

Use a model to predict class labels.

Scala:

predictClasses(x, batchPerThread = 4, zeroBasedLabel = true)

Parameters:

Python

predict_classes(x, batch_per_thread=4, zero_based_label=True)

Parameters:

Visualization

We use tensorbroad-compatible tevent file to store the training and validation metrics. Then you could use tensorboard to visualize the training, or use analytics-zoo build-in API to read the metrics.

Enable training metrics

The training metrics will be saved to logDir/appName/training, and validation metrics will be saved to logDir/appName/validation

scala

setTensorBoard(logDir, appName)

Parameters:

python

set_tensorboard(log_dir, app_name)

Parameters:

Validation with tensorboard

TODO: add link

Reading metrics with build-in API

To get scalar metrics with build-in API, you can use following API.

scala

getTrainSummary(tag)

Get training metrics by tag. Parameters:

scala

getValidationSummary(tag)

Get validation metrics by tag. Parameters:

python

get_train_summary(tag)

Get training metrics by tag. Parameters:

python

get_validation_summary(tag)

Get validation metrics by tag. Parameters: