Getting Started: High-level API¶
In this tutorial we will explore the basic features of Elegy. If you are a Keras user you should feel at home, if you are currently learning JAX things will appear much more streamlined. To get started you will first need to install the following dependencies:
! pip install --upgrade pip
! pip install elegy datasets matplotlib
# For GPU install proper version of your CUDA, following will work in colab:
# ! pip install --upgrade jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_releases.html
Note: that Elegy depends on the jax CPU version hosted on Pypi, if you want to run jax on GPU you will need to install it separately. If you are running this example on colab, JAX is already preinstalled but you can uncomment the last line of the previous cell if you want to update it.
Loading the Data¶
In this tutorial we will train a Neural Network on the MNIST dataset, for this we will first need to download and load the data into memory. Here we will use the datasets
library to load the dataset.
from datasets.load import load_dataset
import numpy as np
dataset = load_dataset("mnist")
dataset.set_format("np")
X_train = np.stack(dataset["train"]["image"])[..., None]
y_train = dataset["train"]["label"]
X_test = np.stack(dataset["test"]["image"])[..., None]
y_test = dataset["test"]["label"]
print("X_train:", X_train.shape, X_train.dtype)
print("y_train:", y_train.shape, y_train.dtype)
print("X_test:", X_test.shape, X_test.dtype)
print("y_test:", y_test.shape, y_test.dtype)
Reusing dataset mnist (/home/cris/.cache/huggingface/datasets/mnist/mnist/1.0.0/fda16c03c4ecfb13f165ba7e29cf38129ce035011519968cdaf74894ce91c9d4)
0%| | 0/2 [00:00<?, ?it/s]
X_train: (60000, 28, 28, 1) uint8 y_train: (60000,) int64 X_test: (10000, 28, 28, 1) uint8 y_test: (10000,) int64
Defining the Architecture¶
The first thing we need to do is define our model's architecture inside a Module
, to do this we just create a class that inherites from Module
and implement a __call__
method. In this example we will create a simple 2 layer MLP:
import jax.numpy as jnp
import jax
import elegy as eg
class MLP(eg.Module):
def __init__(self, n1: int, n2: int):
self.n1 = n1
self.n2 = n2
@eg.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
x = x.astype(jnp.float32) / 255.0
x = eg.nn.Flatten()(x)
# first layers
x = eg.nn.Linear(self.n1)(x)
x = jax.nn.relu(x)
# first layers
x = eg.nn.Linear(self.n2)(x)
x = jax.nn.relu(x)
# output layer
x = eg.nn.Linear(10)(x)
return x
This code should feel familiar to most Keras / PyTorch users, the main difference is that we are using the @compact
decorator to define submodules inline (e.g. Linear) inline, this tends to produce much shorter and readable code.
Creating the Model¶
Now that we have this module we can create an Elegy Model
which is Elegy's central API:
import optax
model = eg.Model(
module=MLP(n1=300, n2=100),
loss=[
eg.losses.Crossentropy(),
eg.regularizers.L2(l=1e-4),
],
metrics=eg.metrics.Accuracy(),
optimizer=optax.adam(1e-3),
)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
If you are a Keras user this code should look familiar, main differences are:
- You need to pass a
module
with the architecture. loss
andmetrics
are a bit more flexible in that they do not need to match the label's structure.- There is no
compile
step, all its done in the constructor. - For the
optimizer
you can use anyoptax
optimizer.
As in Keras, you can get a rich description of the model by calling Model.summary
with a sample input:
model.summary(X_train[:64])
┏━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓ ┃ path ┃ module ┃ params ┃ inputs ┃ outputs ┃ Parameter ┃ ┡━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩ │ * │ MLP() │ │ (64, 28, 28, │ (64, 10) │ │ │ │ │ │ 1) uint8 │ float32 │ │ ├──────────┼───────────┼────────────────┼────────────────┼────────────────┼─────────────────┤ │ .flatten │ Flatten() │ │ (64, 28, 28, │ (64, 784) │ │ │ │ │ │ 1) float32 │ float32 │ │ ├──────────┼───────────┼────────────────┼────────────────┼────────────────┼─────────────────┤ │ .linear │ Linear() │ bias: │ (64, 784) │ (64, 300) │ 235,500 │ │ │ │ Parameter(300) │ float32 │ float32 │ 942.0KB │ │ │ │ float32 │ │ │ │ │ │ │ kernel: │ │ │ │ │ │ │ Parameter(784, │ │ │ │ │ │ │ 300) float32 │ │ │ │ ├──────────┼───────────┼────────────────┼────────────────┼────────────────┼─────────────────┤ │ .linear2 │ Linear() │ bias: │ (64, 300) │ (64, 100) │ 30,100 │ │ │ │ Parameter(100) │ float32 │ float32 │ 120.4KB │ │ │ │ float32 │ │ │ │ │ │ │ kernel: │ │ │ │ │ │ │ Parameter(300, │ │ │ │ │ │ │ 100) float32 │ │ │ │ ├──────────┼───────────┼────────────────┼────────────────┼────────────────┼─────────────────┤ │ .linear3 │ Linear() │ bias: │ (64, 100) │ (64, 10) │ 1,010 4.0KB │ │ │ │ Parameter(10) │ float32 │ float32 │ │ │ │ │ float32 │ │ │ │ │ │ │ kernel: │ │ │ │ │ │ │ Parameter(100, │ │ │ │ │ │ │ 10) float32 │ │ │ │ ├──────────┼───────────┼────────────────┼────────────────┼────────────────┼─────────────────┤ │ │ │ │ │ Total: │ 266,610 1.1MB │ └──────────┴───────────┴────────────────┴────────────────┴────────────────┴─────────────────┘ Total Parameters: 266,610 1.1MB
Training the Model¶
We are now ready to pass our model some data to start training, like in Keras this is done via the fit
method which contains more or less the same signature. Elegy support a variety of input data sources like Tensorflow Dataset, Pytorch DataLoader, Elegy DataLoader, and Python Generators, check out the guide on Data Sources for more information.
The following code will train our model for 100
epochs while limiting each epoch to 200
steps and using a batch size of 64
:
history = model.fit(
inputs=X_train,
labels=y_train,
epochs=100,
steps_per_epoch=200,
batch_size=64,
validation_data=(X_test, y_test),
shuffle=True,
callbacks=[eg.callbacks.ModelCheckpoint("models/high-level", save_best_only=True)],
)
...
199/200 [============================>.] - 1s 7ms/step - accuracy: 0.9932 - l2_loss: 0.0455 - loss: 0.0682 - sparse_categorical_crossentropy_loss: 0.0227 - val_accuracy: 0.9824 - val_l2_loss: 0.0453 - val_loss: 0.1015 - val_sparse_categorical_crossentropy_loss: 0.0562
Epoch 100/100
200/200 [==============================] - 1s 7ms/step - accuracy: 0.9955 - l2_loss: 0.0451 - loss: 0.0596 - sparse_categorical_crossentropy_loss: 0.0146 - val_accuracy: 0.9804 - val_l2_loss: 0.0448 - val_loss: 0.1068 - val_sparse_categorical_crossentropy_loss: 0.0619
The ModelCheckpoint
callback will periodically save the model in a folder called "models/high-level"
, we will use it later.fit
returns a History
object which of the losses and metrics during training which we can visualize:
import matplotlib.pyplot as plt
def plot_history(history):
n_plots = len(history.history.keys()) // 2
plt.figure(figsize=(14, 24))
for i, key in enumerate(list(history.history.keys())[:n_plots]):
metric = history.history[key]
val_metric = history.history[f"val_{key}"]
plt.subplot(n_plots, 1, i + 1)
plt.plot(metric, label=f"Training {key}")
plt.plot(val_metric, label=f"Validation {key}")
plt.legend(loc="lower right")
plt.ylabel(key)
plt.title(f"Training and Validation {key}")
plt.show()
plot_history(history)
Generating Predictions¶
Having our trained model we can now get some samples from the test set and generate some predictions. Lets select 9
random images and call .predict
:
import numpy as np
idxs = np.random.randint(0, len(X_test), size=(9,))
x_sample = X_test[idxs]
y_pred = model.predict(x=x_sample)
Easy right? Finally lets plot the results to see if they are accurate.
plt.figure(figsize=(12, 12))
for i in range(3):
for j in range(3):
k = 3 * i + j
plt.subplot(3, 3, k + 1)
plt.title(f"{np.argmax(y_pred[k])}")
plt.imshow(x_sample[k], cmap="gray")
Perfect!
Serialization¶
To serialize the Model
you can use the model.save(...)
, this will create a folder with some files that contain the model's code plus all parameters and states, however since we had previously used the ModelCheckpoint
callback we can load it using elegy.load
. Lets get a new model reference containing the same weights and call its evaluate
method to verify it loaded correctly:
# You can use can use `save` but `ModelCheckpoint already serialized the model
# model.save("model")
# current model reference
print("current model id:", id(model))
# load model from disk
model = eg.load("models/high-level")
# new model reference
print("new model id: ", id(model))
# check that it works!
model.evaluate(x=X_test, y=y_test)
current model id: 140139818231600 new model id: 140134587163168 313/313 [==============================] - 2s 7ms/step - accuracy: 0.9790 - crossentropy_loss: 0.0661 - l2_loss: 0.0460 - loss: 0.1121
{'accuracy': DeviceArray(0.9831, dtype=float32), 'crossentropy_loss': DeviceArray(0.05327696, dtype=float32), 'l2_loss': DeviceArray(0.04596329, dtype=float32), 'loss': DeviceArray(0.09924024, dtype=float32), 'size': 32}
Saved Models¶
You can also serialize your Elegy Model as a TensorFlow SavedModel which is portable to many platforms and services by using the saved_model
method with a sample input and the output path:
model.saved_model(x_sample, "saved-models/high-level")
INFO:tensorflow:Assets written to: saved-models/high-level/assets
INFO:tensorflow:Assets written to: saved-models/high-level/assets
We can test our saved model by loading it with TensorFlow and generating a couple of predictions as we did previously:
import tensorflow as tf
saved_model = tf.saved_model.load("saved-models/high-level")
y_pred_tf = saved_model(x_sample.astype(np.uint8))
plt.figure(figsize=(12, 12))
for i in range(3):
for j in range(3):
k = 3 * i + j
plt.subplot(3, 3, k + 1)
plt.title(f"{np.argmax(y_pred_tf[k])}")
plt.imshow(x_sample[k], cmap="gray")
Excellent! We hope you've enjoyed this tutorial.