How to plot Keras models
1 min readDec 2, 2017
Some notes that summarize how to plot Keras models. Code is here: https://github.com/yang-zhang/deep-learning/blob/master/keras_plot_model.ipynb
Model example
import keras
import IPythonmodel = keras.models.Sequential()
model.add(keras.layers.Dense(3, activation='relu', input_dim=3))
model.add(keras.layers.Dense(5, activation='relu'))
model.add(keras.layers.Dense(2, activation='softmax'))
Option-1
keras.utils.plot_model(model, to_file=’test_keras_plot_model.png’, show_shapes=True)
IPython.display.Image(“test_keras_plot_model.png”)
Option-2
IPython.display.SVG(keras.utils.vis_utils.model_to_dot(model).create(prog=’dot’, format=’svg’))
Make into a function:
From IPython.display import SVG
From keras.utils.vis_utils import model_to_dotdef plot_keras_model(model, show_shapes=True, show_layer_names=True):
return SVG(model_to_dot(model, show_shapes=show_shapes, show_layer_names=show_layer_names).create(prog='dot',format='svg'))plot_keras_model(model, show_shapes=True, show_layer_names=False)