keras.fit() and keras.fit_generator() in Python are two separate deep learning libraries which can be used to train our machine learning and deep learning models. Both these functions can do the same task, but when to use which function is the main question.
fit(object, x = NULL, y = NULL, batch_size = NULL, epochs = 10, verbose = getOption("keras.fit_verbose", default = 1), callbacks = NULL, view_metrics = getOption("keras.view_metrics", default = "auto"), validation_split = 0, validation_data = NULL, shuffle = TRUE, class_weight = NULL, sample_weight = NULL, initial_epoch = 0, steps_per_epoch = NULL, validation_steps = NULL, ...)
Understanding few important arguments:
-> object : the model to train. -> X : our training data. Can be Vector, array or matrix -> Y : our training labels. Can be Vector, array or matrix -> Batch_size : it can take any integer value or NULL and by default, it will be set to 32. It specifies no. of samples per gradient. -> Epochs : an integer and number of epochs we want to train our model for. -> Verbose : specifies verbosity mode(0 = silent, 1= progress bar, 2 = one line per epoch). -> Shuffle : whether we want to shuffle our training data before each epoch. -> steps_per_epoch : it specifies the total number of steps taken before one epoch has finished and started the next epoch. By default it values is set to NULL.
How to use Keras fit:
model.fit(Xtrain, Ytrain, batch_size = 32, epochs = 100)
Here we are first feeding the training data(Xtrain) and training labels(Ytrain). We then use Keras to allow our model to train for 100 epochs on a batch_size of 32.
When we call the .fit() function it makes assumptions:
- The entire training set can fit into the Random Access Memory (RAM) of the computer.
- Calling the model. fit method for a second time is not going to reinitialize our already trained weights, which means we can actually make consecutive calls to fit if we want to and then manage it properly.
- There is no need for using the Keras generators(i.e no data argumentation)
- Raw data is itself used for training our network and our raw data will only fit into the memory.
fit_generator(object, generator, steps_per_epoch, epochs = 1, verbose = getOption("keras.fit_verbose", default = 1), callbacks = NULL, view_metrics = getOption("keras.view_metrics", default = "auto"), validation_data = NULL, validation_steps = NULL, class_weight = NULL, max_queue_size = 10, workers = 1, initial_epoch = 0)
Understanding few important arguments:
-> object : the Keras Object model. -> generator : a generator whose output must be a list of the form: - (inputs, targets) - (input, targets, sample_weights) a single output of the generator makes a single batch and hence all arrays in the list must be having the length equal to the size of the batch. The generator is expected to loop over its data infinite no. of times, it should never return or exit. -> steps_per_epoch : it specifies the total number of steps taken from the generator as soon as one epoch is finished and next epoch has started. We can calculate the value of steps_per_epoch as the total number of samples in your dataset divided by the batch size. -> Epochs : an integer and number of epochs we want to train our model for. -> Verbose : specifies verbosity mode(0 = silent, 1= progress bar, 2 = one line per epoch). -> callbacks : a list of callback functions applied during the training of our model. -> validation_data can be either: - an inputs and targets list - a generator - an inputs, targets, and sample_weights list which can be used to evaluate the loss and metrics for any model after any epoch has ended. -> validation_steps :only if the validation_data is a generator then only this argument can be used. It specifies the total number of steps taken from the generator before it is stopped at every epoch and its value is calculated as the total number of validation data points in your dataset divided by the validation batch size.
How to use Keras fit_generator:
# performing data argumentation by training image generator dataAugmentaion = ImageDataGenerator(rotation_range = 30, zoom_range = 0.20, fill_mode = "nearest", shear_range = 0.20, horizontal_flip = True, width_shift_range = 0.1, height_shift_range = 0.1) # training the model model.fit_generator(dataAugmentaion.flow(trainX, trainY, batch_size = 32), validation_data = (testX, testY), steps_per_epoch = len(trainX) // 32, epochs = 10)
Here we are training our network for 10 epochs along with the default batch size of 32.
For small and less complex datasets it is recommended to use keras.fit function whereas while dealing with real-world datasets it is not that simple because real-world datasets are huge in size and are much harder to fit into the computer memory.
It is more challenging to deal with those datasets and an important step to deal with those datasets is to perform data augmentation to avoid the overfitting of a model and also to increase the ability of our model to generalize.
Data Augmentation is a method of artificially creating a new dataset for training from the existing training dataset to improve the performance of deep learning neural networks with the amount of data available. It is a form of regularization which makes our model generalize better than before.
Here we have used a Keras ImageDataGenerator object to apply data augmentation for randomly translating, resizing, rotating, etc the images. Each new batch of our data is randomly adjusting according to the parameters supplied to ImageDataGenerator.
When we call the .fit_generator() function it makes assumptions:
- Keras is first calling the generator function(dataAugmentaion)
- Generator function(dataAugmentaion) provides a batch_size of 32 to our .fit_generator() function.
- our .fit_generator() function first accepts a batch of the dataset, then performs backpropagation on it, and then updates the weights in our model.
- For the number of epochs specified(10 in our case) the process is repeated.
So, we have learned the difference between Keras.fit and Keras.fit_generator functions used to train a deep learning neural network
.fit is used when the entire training dataset can fit into the memory and no data augmentation is applied.
.fit_generator is used when either we have a huge dataset to fit into our memory or when data augmentation needs to be applied.
- Loops and Control Statements (continue, break and pass) in Python
- Python counter and dictionary intersection example (Make a string using deletion and rearrangement)
- Python | Using variable outside and inside the class and method
- Releasing GIL and mixing threads from C and Python
- Interquartile Range and Quartile Deviation using NumPy and SciPy
- Django Request and Response cycle - HttpRequest and HttpResponse Objects
- Python | Boolean List AND and OR operations
- Text Detection and Extraction using OpenCV and OCR
- Difference between 'and' and '&' in Python
- OpenCV - Facial Landmarks and Face Detection using dlib and OpenCV
- Create a DataFrame from a Numpy array and specify the index column and column headers
- Replace the column contains the values 'yes' and 'no' with True and False In Python-Pandas
- Login and Registration Project Using Flask and MySQL
- Ceil and floor of the dataframe in Pandas Python – Round up and Truncate
- Display the Pandas DataFrame in table style and border around the table and not around the rows
- Login Application and Validating info using Kivy GUI and Pandas in Python
- Create a Pandas DataFrame from a Numpy array and specify the index column and column headers
- Pandas - GroupBy One Column and Get Mean, Min, and Max values
- Calculate inner, outer, and cross products of matrices and vectors using NumPy
- How to load and save 3D Numpy array to file using savetxt() and loadtxt() functions?
If you like GeeksforGeeks and would like to contribute, you can also write an article using contribute.geeksforgeeks.org or mail your article to email@example.com. See your article appearing on the GeeksforGeeks main page and help other Geeks.
Please Improve this article if you find anything incorrect by clicking on the "Improve Article" button below.