LSTMForecaster and MTNetForecaster


In this guide, we will show you how to use the built-in LSTMForecaster and MTNetForecaster for time series forecasting.

The built-in LSTMForecaster and MTNetForecaster are both derived from tfpark.KerasModels.

Refer to network traffic notebook for demonstration of forecasting network traffic data with Chronos built-in LSTMForecaster and MTNetForecaster.

Refer to LSTMForecaster API and MTNetForecaster API detailed explanation of all arguments for each forecast model.


Step 0: Prepare environment

We recommend you to use Anaconda to prepare the enviroments, especially if you want to run automated training on a yarn cluster (yarn-client mode only).

conda create -n zoo python=3.7 #zoo is conda enviroment name, you can set another name you like.
conda activate zoo
pip install analytics-zoo[automl]==0.9.0.dev0 # or above

Step 1: Create forecast model

To start, you need to create a forecast model first. Specify target_dim and feature_dim in constructor.

Below are some example code to create forecast models.

#import forecast models
from zoo.chronos.forecaster.lstm_forecaster import LSTMForecaster
from zoo.chronos.forecaster.mtnet_forecaster import MTNetForecaster

#build a lstm forecast model
lstm_forecaster = LSTMForecaster(target_dim=1, 
                      feature_dim=4)

#build a mtnet forecast model
mtnet_forecaster = MTNetForecaster(target_dim=1,
                        feature_dim=4,
                        long_series_num=1,
                        series_length=3,
                        ar_window_size=2,
                        cnn_height=2)

Step 2: Use forecast model

Use forecaster.fit/evalute/predict in the same way as tfpark.KerasModel

For univariant forecasting (i.e. to predict one series at a time), you can use either LSTMForecaster or MTNetForecaster. The input data shape for fit/evaluation/predict should match the arguments you used to create the forecaster. Specifically:

For multivariant forecasting (i.e. to predict several series at the same time), you have to use MTNetForecaster. The input data shape should meet below criteria.