TFPredictor
TFPredictor
TFPredictor takes a list of TensorFlow tensors as the model outputs and feed all the elements in TFDatasets to produce those outputs and returns a Spark RDD with each of its elements representing the model prediction for the corresponding input elements.
Remarks:
- You need to install tensorflow==1.15.0 on your driver node.
- Your operating system (OS) is required to be one of the following 64-bit systems: Ubuntu 16.04 or later and macOS 10.12.6 or later.
- To run on other systems, you need to manually compile the TensorFlow source code. Instructions can be found here.
Python
logist = ...
predictor = TFPredictor.from_outputs(sess, [logits])
predictions_rdd = predictor.predict()
For Keras model:
model = Model(inputs=..., outputs=...)
model.load_weights("/tmp/mnist_keras.h5")
predictor = TFPredictor.from_keras(model, dataset)
predictions_rdd = predictor.predict()