---
title: "Training Callbacks"
output: rmarkdown::html_vignette
vignette: >
%\VignetteIndexEntry{Training Callbacks}
%\VignetteEngine{knitr::rmarkdown}
%\VignetteEncoding{UTF-8}
type: docs
repo: https://github.com/rstudio/keras
menu:
main:
name: "Training Callbacks"
identifier: "keras-training-callbacks"
parent: "keras-advanced-top"
weight: 30
aliases:
- /keras/articles/training_callbacks.html
---
```{r setup, include = FALSE}
library(keras)
knitr::opts_chunk$set(comment = NA, eval = FALSE)
```
## Overview
A callback is a set of functions to be applied at given stages of the training procedure. You can use callbacks to get a view on internal states and statistics of the model during training. You can pass a list of callbacks (as the keyword argument `callbacks`) to the `fit()` function. The relevant methods of the callbacks will then be called at each stage of the training.
For example:
```{r}
library(keras)
# generate dummy training data
data <- matrix(rexp(1000*784), nrow = 1000, ncol = 784)
labels <- matrix(round(runif(1000*10, min = 0, max = 9)), nrow = 1000, ncol = 10)
# create model
model <- keras_model_sequential()
# add layers and compile
model %>%
layer_dense(32, input_shape = c(784)) %>%
layer_activation('relu') %>%
layer_dense(10) %>%
layer_activation('softmax') %>%
compile(
loss='binary_crossentropy',
optimizer = optimizer_sgd(),
metrics='accuracy'
)
# fit with callbacks
model %>% fit(data, labels, callbacks = list(
callback_model_checkpoint("checkpoints.h5"),
callback_reduce_lr_on_plateau(monitor = "val_loss", factor = 0.1)
))
```
## Built in Callbacks
The following built-in callbacks are available as part of Keras:
`callback_progbar_logger()`
|
Callback that prints metrics to stdout. |
`callback_model_checkpoint()`
|
Save the model after every epoch. |
`callback_early_stopping()`
|
Stop training when a monitored quantity has stopped improving. |
`callback_remote_monitor()`
|
Callback used to stream events to a server. |
`callback_learning_rate_scheduler()`
|
Learning rate scheduler. |
`callback_tensorboard()`
|
TensorBoard basic visualizations |
`callback_reduce_lr_on_plateau()`
|
Reduce learning rate when a metric has stopped improving. |
`callback_csv_logger()`
|
Callback that streams epoch results to a csv file |
`callback_lambda()`
|
Create a custom callback |
## Custom Callbacks
You can create a custom callback by creating a new [R6 class](https://CRAN.R-project.org/package=R6) that inherits from the `KerasCallback` class.
Here's a simple example saving a list of losses over each batch during training:
```{r}
library(keras)
# define custom callback class
LossHistory <- R6::R6Class("LossHistory",
inherit = KerasCallback,
public = list(
losses = NULL,
on_batch_end = function(batch, logs = list()) {
self$losses <- c(self$losses, logs[["loss"]])
}
))
# define model
model <- keras_model_sequential()
# add layers and compile
model %>%
layer_dense(units = 10, input_shape = c(784)) %>%
layer_activation(activation = 'softmax') %>%
compile(
loss = 'categorical_crossentropy',
optimizer = 'rmsprop'
)
# create history callback object and use it during training
history <- LossHistory$new()
model %>% fit(
X_train, Y_train,
batch_size=128, epochs=20, verbose=0,
callbacks= list(history)
)
# print the accumulated losses
history$losses
```
```
[1] 0.6604760 0.3547246 0.2595316 0.2590170 ...
```
### Fields
Custom callback objects have access to the current model and it's training parameters via the following fields:
`self$params`
: Named list with training parameters (eg. verbosity, batch size, number of epochs...).
`self$model`
: Reference to the Keras model being trained.
### Methods
Custom callback objects can implement one or more of the following methods:
`on_epoch_begin(epoch, logs)`
: Called at the beginning of each epoch.
`on_epoch_end(epoch, logs)`
: Called at the end of each epoch.
`on_batch_begin(batch, logs)`
: Called at the beginning of each batch.
`on_batch_end(batch, logs)`
: Called at the end of each batch.
`on_train_begin(logs)`
: Called at the beginning of training.
`on_train_end(logs)`
: Called at the end of training.
`on_train_batch_begin`
: Called at the beginning of every batch.
`on_train_batch_end`
: Called at the end of every batch.`
`on_predict_batch_begin`
: Called at the beginning of a batch in predict methods.
`on_predict_batch_end`
: Called at the end of a batch in predict methods.
`on_predict_begin`
: Called at the beginning of prediction.
`on_predict_end`
: Called at the end of prediction.
`on_test_batch_begin`
: Called at the beginning of a batch in evaluate methods. Also called at the beginning of a validation batch in the fit methods, if validation data is provided.
`on_test_batch_end`
: Called at the end of a batch in evaluate methods. Also called at the end of a validation batch in the fit methods, if validation data is provided.
`on_test_begin`
: Called at the beginning of evaluation or validation.
`on_test_end`
: Called at the end of evaluation or validation.