BERT Classifier
Analytics Zoo provides a built-in BERTClassifier in TFPark for Natural Language Processing (NLP) classification tasks based on TFEstimator and BERT.
Bidirectional Encoder Representations from Transformers (BERT) is Google's state-of-the-art pre-trained NLP model. You may refer to here for more details.
BERTClassifier is a pre-built TFEstimator that takes the hidden state of the first token to do classification.
In this page, we show the general steps how to train and evaluate an BERTClassifier in a distributed fashion and use this estimator for distributed inference.
Remarks:
- You need to install tensorflow==1.15.0 on your driver node.
- Your operating system (OS) is required to be one of the following 64-bit systems: Ubuntu 16.04 or later, macOS 10.12.6 or later and Windows 7 or later.
- To run on other systems, you need to manually compile the TensorFlow source code. Instructions can be found here.
BERTClassifier Construction
You can easily construct an estimator for classification based on BERT using the following API.
from zoo.tfpark.text.estimator import BERTClassifier
estimator = BERTClassifier(num_classes, bert_config_file, init_checkpoint, optimizer=tf.train.AdamOptimizer(learning_rate), model_dir="/tmp/bert")
Data Preparation
BERT has three inputs of the same sequence length: input_ids, input_mask and token_type_ids.
The preprocessing steps should follow BERT's conventions. You may refer to BERT TensorFlow run_classifier example for more details.
To construct the input function for BERTClassifier, you can use the following API:
from zoo.tfpark.text.estimator import bert_input_fn
input_fn = bert_input_fn(rdd, max_seq_length, batch_size)
- For training and evaluation, each element in rdd should be a tuple: (feature dict, label). Label is supposed to be an integer.
- For prediction, each element in rdd should be a feature dict.
- The keys of feature dict should be
input_ids
,input_mask
andtoken_type_ids
and the values should be the corresponding preprocessed results of max_seq_length for a record.
Estimator Training
You can easily call train to train the BERTClassifier in a distributed fashion.
estimator.train(train_input_fn, steps)
You can find the trained checkpoints saved under model_dir
, which is specified when you initiate BERTClassifier.
Estimator Evaluation
You can easily call evaluate to evaluate the BERTClassifier in a distributed fashion using top1 accuracy.
result = estimator.evaluate(eval_input_fn, eval_methods=["acc"])
Estimator Inference
You can easily call predict to use the trained BERTClassifier for distributed inference.
predictions_rdd = estimator.predict(test_input_fn)