KerasModel


KerasModel

KerasModel enables user to use tf.keras API to define TensorFlow models and perform training or evaluation on top of Spark and BigDL in a distributed fashion.

Remarks:

from zoo.tfpark import KerasModel, TFDataset
import tensorflow as tf

model = tf.keras.Sequential(
    [tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
     tf.keras.layers.Dense(64, activation='relu'),
     tf.keras.layers.Dense(10, activation='softmax'),
     ]
)

model.compile(optimizer=tf.keras.optimizers.RMSprop(),
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
keras_model = KerasModel(model)

Methods

__init__

KerasModel(model)

Arguments

fit

fit(x=None, y = None, batch_size=None, epochs=1, validation_data=None, distributed=False)

Arguments

evaluate

evaluate(x=None, y=None, bath_per_thread=None, distributed=False)

Arguments

predict

predict(x, batch_per_thread=None, distributed=False):

Arguments