BERT Classifier


Analytics Zoo provides a built-in BERTClassifier in TFPark for Natural Language Processing (NLP) classification tasks based on TFEstimator and BERT.

Bidirectional Encoder Representations from Transformers (BERT) is Google's state-of-the-art pre-trained NLP model. You may refer to here for more details.

BERTClassifier is a pre-built TFEstimator that takes the hidden state of the first token to do classification.

In this page, we show the general steps how to train and evaluate an BERTClassifier in a distributed fashion and use this estimator for distributed inference.

Remarks:


BERTClassifier Construction

You can easily construct an estimator for classification based on BERT using the following API.

from zoo.tfpark.text.estimator import BERTClassifier

estimator = BERTClassifier(num_classes, bert_config_file, init_checkpoint, optimizer=tf.train.AdamOptimizer(learning_rate), model_dir="/tmp/bert")

Data Preparation

BERT has three inputs of the same sequence length: input_ids, input_mask and token_type_ids.

The preprocessing steps should follow BERT's conventions. You may refer to BERT TensorFlow run_classifier example for more details.

To construct the input function for BERTClassifier, you can use the following API:

from zoo.tfpark.text.estimator import bert_input_fn

input_fn = bert_input_fn(rdd, max_seq_length, batch_size)

Estimator Training

You can easily call train to train the BERTClassifier in a distributed fashion.

estimator.train(train_input_fn, steps)

You can find the trained checkpoints saved under model_dir, which is specified when you initiate BERTClassifier.


Estimator Evaluation

You can easily call evaluate to evaluate the BERTClassifier in a distributed fashion using top1 accuracy.

result = estimator.evaluate(eval_input_fn, eval_methods=["acc"])

Estimator Inference

You can easily call predict to use the trained BERTClassifier for distributed inference.

predictions_rdd = estimator.predict(test_input_fn)