PyTorch Estimator


Introduction

Analytics Zoo Orca PyTorch Estimator provides a set APIs for running PyTorch model on Spark in a distributed fashion.

Remarks:


Orca PyTorch Estimator

Orca PyTorch Estimator is an estimator to do PyTorch training/evaluation/prediction on Spark in a distributed fashion.

It can support various data types, like XShards, PyTorch DataLoader, PyTorch DataLoader creator, etc.

It supports horovod backend and BigDL backend in the unified APIs.

Create Estimator from pyTorch Model

You can create Orca PyTorch Estimator with native PyTorch model.

from zoo.orca.learn.pytorch import Estimator
Estimator.from_torch(*,
                   model,
                   optimizer,
                   loss=None,
                   metrics=None,
                   scheduler_creator=None,
                   training_operator_cls=TrainingOperator,
                   initialization_hook=None,
                   config=None,
                   scheduler_step_freq="batch",
                   use_tqdm=False,
                   workers_per_node=1,
                   model_dir=None,
                   backend="bigdl"):

Use horovod Estimator

Train model

After an Estimator is created, you can call estimator API to train PyTorch model:

fit(self, data, epochs=1, profile=False, reduce_results=True, info=None)

Evaluate model

After Training, you can call estimator API to evaluate PyTorch model:

evaluate(self, data, num_steps=None, profile=False, info=None)

Get model

You can get the trained model using get_model(self)

Save model

You can save model using save(self, model_path) * model_path: (str) Path to save the model.

Load model

You can load an exsiting model saved by save(self, model_path) using load(self, model_path) * model_path: (str) Path to the existing model.

Shutdown workers

You can shut down workers and releases resources using shutdown(self, force=False)

Use BigDL Estimator

Train model

After an Estimator is created, you can call estimator API to train PyTorch model:

fit(self, data, epochs=1, batch_size=32, feature_cols=None, label_cols=None, validation_data=None, checkpoint_trigger=None)

Evaluate model

After Training, you can call estimator API to evaluate PyTorch model:

evaluate(self, data, batch_size=32, feature_cols=None, label_cols=None)

Inference

After training or loading trained model, you can call estimator API to inference:

predict(self, data, batch_size=4, feature_cols=None)

Get model

You can get model using get_model(self)

Save model

You can save model using save(self, model_path) * model_path: (str) Path to save the model.

Load model

You can load an exsiting model saved by save(self, model_path) using load(self, model_path) * model_path: (str) Path to the existing model.

Load orca checkpoint

You can load saved orca checkpoint using load_orca_checkpoint(self, path, version, prefix). To load a specific checkpoint, please provide both version and perfix. If version is None, then the latest checkpoint will be loaded. * path: Path to the existing checkpoint (or directory containing Orca checkpoint files if version is None). * version: checkpoint version, which is the suffix of model.* file, i.e., for modle.4 file, the version is 4. If it is None, then load the latest checkpoint. * prefix: optimMethod prefix, for example 'optimMethod-TorchModelf53bddcc'. If loading the latest checkpoint, just leave it as None.

Clear gradient clipping

You can clear gradient clipping parameters using clear_gradient_clipping(self). In this case, gradient clipping will not be applied. Note: In order to take effect, it needs to be called before fit.

Set constant gradient clipping

You can Set constant gradient clipping during the training process using set_constant_gradient_clipping(self, min, max). * min: The minimum value to clip by. * max: The maximum value to clip by. Note: In order to take effect, it needs to be called before fit.

Set clip gradient to a maximum L2-Norm

You can set clip gradient to a maximum L2-Norm during the training process using set_l2_norm_gradient_clipping(self, clip_norm). * clip_norm: Gradient L2-Norm threshold. Note: In order to take effect, it needs to be called before fit.