Transfer Learning
Overview
Analytics Zoo provides some useful utilities for transfer learning.
Loading a pre-trained model
We can use the Net
api to load a pre-trained model, including models saved by Analytics Zoo,
BigDL, Torch, Caffe and Tensorflow. Please refer to Net API Guide
Remove the last a few layers
When a model is loaded using Net
, we can use the newGraph(output)
api to define a Model with
the output specified by the parameter.
For example,
In scala:
val inception = Net.loadBigDL[Float](inception_path)
.newGraph(output = "pool5/drop_7x7_s1")
In python:
full_model = Net.load_bigdl(model_path)
# create a new model by remove layers after pool5/drop_7x7_s1
model = full_model.new_graph(["pool5/drop_7x7_s1"])
The returning model's output layer is "pool5/drop_7x7_s1".
Freeze some layers
In transfer learning, we often want to freeze some layers to prevent overfitting. In Analytics Zoo,
we can use the freezeUpTo(endPoint)
api to do that.
For example,
In scala:
inception.freezeUpTo("pool4/3x3_s2") // freeze layer pool4/3x3_s2 and the layers before it
In python:
# freeze layers from input to pool4/3x3_s2 inclusive
model.freeze_up_to(["pool4/3x3_s2"])
This will freeze all the layers from the input layer to "pool4/3x3_s2"
Example
For a complete example, refer to the scala transfer learning example and python transfer learning example