--- title: "Training Callbacks" output: rmarkdown::html_vignette vignette: > %\VignetteIndexEntry{Training Callbacks} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} type: docs repo: https://github.com/rstudio/keras menu: main: name: "Training Callbacks" identifier: "keras-training-callbacks" parent: "keras-advanced-top" weight: 30 aliases: - /keras/articles/training_callbacks.html --- ```{r setup, include = FALSE} library(keras) knitr::opts_chunk$set(comment = NA, eval = FALSE) ``` ## Overview A callback is a set of functions to be applied at given stages of the training procedure. You can use callbacks to get a view on internal states and statistics of the model during training. You can pass a list of callbacks (as the keyword argument `callbacks`) to the `fit()` function. The relevant methods of the callbacks will then be called at each stage of the training. For example: ```{r} library(keras) # generate dummy training data data <- matrix(rexp(1000*784), nrow = 1000, ncol = 784) labels <- matrix(round(runif(1000*10, min = 0, max = 9)), nrow = 1000, ncol = 10) # create model model <- keras_model_sequential() # add layers and compile model %>% layer_dense(32, input_shape = c(784)) %>% layer_activation('relu') %>% layer_dense(10) %>% layer_activation('softmax') %>% compile( loss='binary_crossentropy', optimizer = optimizer_sgd(), metrics='accuracy' ) # fit with callbacks model %>% fit(data, labels, callbacks = list( callback_model_checkpoint("checkpoints.h5"), callback_reduce_lr_on_plateau(monitor = "val_loss", factor = 0.1) )) ``` ## Built in Callbacks The following built-in callbacks are available as part of Keras:
`callback_progbar_logger()`

Callback that prints metrics to stdout.

`callback_model_checkpoint()`

Save the model after every epoch.

`callback_early_stopping()`

Stop training when a monitored quantity has stopped improving.

`callback_remote_monitor()`

Callback used to stream events to a server.

`callback_learning_rate_scheduler()`

Learning rate scheduler.

`callback_tensorboard()`

TensorBoard basic visualizations

`callback_reduce_lr_on_plateau()`

Reduce learning rate when a metric has stopped improving.

`callback_csv_logger()`

Callback that streams epoch results to a csv file

`callback_lambda()`

Create a custom callback

## Custom Callbacks You can create a custom callback by creating a new [R6 class](https://CRAN.R-project.org/package=R6) that inherits from the `KerasCallback` class. Here's a simple example saving a list of losses over each batch during training: ```{r} library(keras) # define custom callback class LossHistory <- R6::R6Class("LossHistory", inherit = KerasCallback, public = list( losses = NULL, on_batch_end = function(batch, logs = list()) { self$losses <- c(self$losses, logs[["loss"]]) } )) # define model model <- keras_model_sequential() # add layers and compile model %>% layer_dense(units = 10, input_shape = c(784)) %>% layer_activation(activation = 'softmax') %>% compile( loss = 'categorical_crossentropy', optimizer = 'rmsprop' ) # create history callback object and use it during training history <- LossHistory$new() model %>% fit( X_train, Y_train, batch_size=128, epochs=20, verbose=0, callbacks= list(history) ) # print the accumulated losses history$losses ``` ``` [1] 0.6604760 0.3547246 0.2595316 0.2590170 ... ``` ### Fields Custom callback objects have access to the current model and it's training parameters via the following fields: `self$params` : Named list with training parameters (eg. verbosity, batch size, number of epochs...). `self$model` : Reference to the Keras model being trained. ### Methods Custom callback objects can implement one or more of the following methods: `on_epoch_begin(epoch, logs)` : Called at the beginning of each epoch. `on_epoch_end(epoch, logs)` : Called at the end of each epoch. `on_batch_begin(batch, logs)` : Called at the beginning of each batch. `on_batch_end(batch, logs)` : Called at the end of each batch. `on_train_begin(logs)` : Called at the beginning of training. `on_train_end(logs)` : Called at the end of training. `on_train_batch_begin` : Called at the beginning of every batch. `on_train_batch_end` : Called at the end of every batch.` `on_predict_batch_begin` : Called at the beginning of a batch in predict methods. `on_predict_batch_end` : Called at the end of a batch in predict methods. `on_predict_begin` : Called at the beginning of prediction. `on_predict_end` : Called at the end of prediction. `on_test_batch_begin` : Called at the beginning of a batch in evaluate methods. Also called at the beginning of a validation batch in the fit methods, if validation data is provided. `on_test_batch_end` : Called at the end of a batch in evaluate methods. Also called at the end of a validation batch in the fit methods, if validation data is provided. `on_test_begin` : Called at the beginning of evaluation or validation. `on_test_end` : Called at the end of evaluation or validation.