TFPredictor


TFPredictor

TFPredictor takes a list of TensorFlow tensors as the model outputs and feed all the elements in TFDatasets to produce those outputs and returns a Spark RDD with each of its elements representing the model prediction for the corresponding input elements.

Remarks:

Python

logist = ...
predictor = TFPredictor.from_outputs(sess, [logits])
predictions_rdd = predictor.predict()

For Keras model:

model = Model(inputs=..., outputs=...)
model.load_weights("/tmp/mnist_keras.h5")
predictor = TFPredictor.from_keras(model, dataset)
predictions_rdd = predictor.predict()