BERT Classifier

Analytics Zoo provides a built-in BERTClassifier in TFPark for Natural Language Processing (NLP) classification tasks based on TFEstimator and BERT.

Bidirectional Encoder Representations from Transformers (BERT) is Google's state-of-the-art pre-trained NLP model. You may refer to here for more details.

BERTClassifier is a pre-built TFEstimator that takes the hidden state of the first token to do classification.

In this page, we show the general steps how to train and evaluate an BERTClassifier in a distributed fashion and use this estimator for distributed inference.


BERTClassifier Construction

You can easily construct an estimator for classification based on BERT using the following API.

from zoo.tfpark.text.estimator import BERTClassifier

estimator = BERTClassifier(num_classes, bert_config_file, init_checkpoint, optimizer=tf.train.AdamOptimizer(learning_rate), model_dir="/tmp/bert")

Data Preparation

BERT has three inputs of the same sequence length: input_ids, input_mask and token_type_ids.

The preprocessing steps should follow BERT's conventions. You may refer to BERT TensorFlow run_classifier example for more details.

To construct the input function for BERTClassifier, you can use the following API:

from zoo.tfpark.text.estimator import bert_input_fn

input_fn = bert_input_fn(rdd, max_seq_length, batch_size)

Estimator Training

You can easily call train to train the BERTClassifier in a distributed fashion.

estimator.train(train_input_fn, steps)

You can find the trained checkpoints saved under model_dir, which is specified when you initiate BERTClassifier.

Estimator Evaluation

You can easily call evaluate to evaluate the BERTClassifier in a distributed fashion using top1 accuracy.

result = estimator.evaluate(eval_input_fn, eval_methods=["acc"])

Estimator Inference

You can easily call predict to use the trained BERTClassifier for distributed inference.

predictions_rdd = estimator.predict(test_input_fn)