pred_step
The pred_step
method computes the predictions of the main model, by overriding this method you can directly influence what happens during predict
.
Inputs
Any of following input arguments are available for pred_step
:
name | type | |
---|---|---|
x |
Any |
Input data |
states |
States |
Current state of the model |
initializing |
bool |
Whether the model is initializing or not |
training |
bool |
Whether the model is training or not |
You must request the arguments you want by name.
Outputs
pred_step
must output a tuple with the following values:
name | type | |
---|---|---|
y_pred |
Any |
The predictions of the model |
states |
States |
The new state of the model |
Callers
method | when |
---|---|
predict |
always |
test_step |
default implementation |
summary_step |
default implementation |
Examples
If for some reason you wish to create a pure jax / Module-less model, you can define your own Model that implements pred_step
like this:
class LinearClassifier(elegy.Model):
def pred_step(self, x, y_true, states, initializing):
x = jnp.reshape(x, (x.shape[0], -1)) / 255
# initialize or use existing parameters
if initializing:
w = jax.random.uniform(
jax.random.PRNGKey(42), shape=[np.prod(x.shape[1:]), 10]
)
b = jax.random.uniform(jax.random.PRNGKey(69), shape=[1])
else:
w, b = states.net_params
# model
y_pred = jnp.dot(x, w) + b
return y_pred, states.update(net_params=(w, b))
model = LinearClassifier(
optimizer=optax.adam(1e-3),
loss=elegy.losses.Crossentropy(),
metrics=elegy.metrics.SparseCategoricalAccuracy(),
)
model.fit(
x=X_train,
y=y_train,
epochs=100,
batch_size=64,
)
Here we implement the same LinearClassifier
from the basics section but we extracted the definition of the model to pred_step
and we let the basic implementation of test_step
take care of the loss
and metrics
which we provide to the LinearClassifier
's constructor.
Default Implementation
The default implementation of pred_step
does the following:
- Calls
api_module.init
orapi_module.apply
depending on state ofinitializing
.api_module
of typeGeneralizedModule
is a wrapper over themodule
object passed by the user to theModel
s constructor.