Exploring Transfer Learning with PyTorch: A Deep Learning Journey
I recently embarked on a project to implement a model using PyTorch. This led me to delve deep into the PyTorch framework, and to make use of transfer learning (as I did in my last blog post). In this post, I share both what I learned from this experience, and how you can create a similar PyTorch model of you own. This post is inspired greatly by a post by Sasank Chilamkurthy, whose excellent tutorial you can find here. This post also makes use of Rohit Rawat’s Crocodylia dataset.
Transfer Learning: A Brief Overview
In the field of deep learning, few people start by training an entire convolutional neural network (ConvNet) from scratch, primarily due to the scarcity of large datasets. Instead, a common practice is to employ transfer learning. This approach entails utilizing a model that has been pre-trained on an extensive dataset like ImageNet, which contains 1.2 million images across 1000 categories. To do this, we freeze the weights of the pre-trained neural network layers except a few of the last fully connected layers. We can then replace the last layer with a new one (to making the model capable of being used as, for example, a binary classifier), and then train this new layer to make the model’s predictions more accurate.
Let’s see how all of this can be done using the PyTorch framework!
Loading the Data
To kickstart our exploration, we need a dataset. In this tutorial, we’re working with a small dataset containing images of alligators and crocodiles, with ~350 training images per class and ~70 validation images per class. You can download the dataset here.
We’ll leverage PyTorch’s torchvision
and torch.utils.data
packages for data loading. To get started, we define data transformations for both the training and validation datasets.
import os
import torch
from torchvision import datasets, transforms
# Define data transformations for both training and validation datasets
= {
data_transforms 'train': transforms.Compose([
224),
transforms.RandomResizedCrop(
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
transforms.Normalize([
]),'val': transforms.Compose([
256),
transforms.Resize(224),
transforms.CenterCrop(
transforms.ToTensor(),0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
transforms.Normalize([
]), }
Data transformations are defined using PyTorch’s transforms
module. These transformations are applied to both the training and validation datasets to preprocess the images. The transformations include resizing, random cropping, horizontal flipping, converting to tensors, and normalization.
Next, we mount Google Drive to access the dataset. (I used Google Colab to execute my code cells. This step may be slightly different for you depending on the platform you use to execute your own code cells.) You should update data_dir
with your dataset location.
from google.colab import drive
'/content/drive')
drive.mount(= '/content/drive/MyDrive/hymenoptera_data' # Update this path to your dataset location data_dir
With the data transformations in place, we can create datasets and data loaders for both training and validation sets.
# Create datasets for both training and validation
= {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])
image_datasets for x in ['train', 'val']}
# Create data loaders for both training and validation datasets
= {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
dataloaders =True, num_workers=4)
shufflefor x in ['train', 'val']}
# Store the sizes of the training and validation datasets
= {x: len(image_datasets[x]) for x in ['train', 'val']}
dataset_sizes
# Get the class names (e.g., 'alligators' and 'crocodiles') from the training dataset
= image_datasets['train'].classes
class_names
# Check if a CUDA-enabled GPU is available and set the device accordingly
= torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device
Datasets and data loaders are created for both the training and validation sets. We use the ImageFolder
class to load images from the specified directories and apply the previously defined data transformations. Data loaders are used to load batches of data during training and validation. The device is also determined based on the availability of a GPU.
Visualizing Data
Let’s start by visualizing a few training images to get a better understanding of the applied data augmentation techniques.
import numpy as np
import matplotlib.pyplot as plt
import torchvision
def imshow(inp, title=None):
= inp.numpy().transpose((1, 2, 0))
inp = np.array([0.485, 0.456, 0.406])
mean = np.array([0.229, 0.224, 0.225])
std = std * inp + mean
inp = np.clip(inp, 0, 1)
inp
plt.imshow(inp)if title is not None:
plt.title(title)0.001)
plt.pause(
# Get a batch of training data
= next(iter(dataloaders['train']))
inputs, classes
# Make a grid from the batch of images
= torchvision.utils.make_grid(inputs)
out
# Display the grid of images with their corresponding class names as titles
=[class_names[x] for x in classes]) imshow(out, title
Here, we define a function (imshow) to display images and then uses it to visualize a batch of training data. Specifically, this function shows a grid of images with their corresponding class names as titles.
Nice, our function appears to be working well.
Training the Model
Now, let’s write a general function to train a model. This function encompasses several key aspects of deep learning training, including learning rate scheduling and model checkpointing.
import time
import os
import torch
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
= time.time()
since
with TemporaryDirectory() as tempdir:
= os.path.join(tempdir, 'best_model_params.pt')
best_model_params_path
torch.save(model.state_dict(), best_model_params_path)= 0.0
best_acc
for epoch in range(num_epochs):
print(f'Epoch {epoch}/{num_epochs - 1}')
print('-' * 10)
for phase in ['train', 'val']:
if phase == 'train':
model.train()else:
eval()
model.
= 0.0
running_loss = 0
running_corrects
for inputs, labels in dataloaders[phase]:
= inputs.to(device)
inputs = labels.to(device)
labels
optimizer.zero_grad()
with torch.set_grad_enabled(phase == 'train'):
= model(inputs)
outputs = torch.max(outputs, 1)
_, preds = criterion(outputs, labels)
loss
if phase == 'train':
loss.backward()
optimizer.step()
+= loss.item() * inputs.size(0)
running_loss += torch.sum(preds == labels.data)
running_corrects
if phase == 'train':
scheduler.step()
= running_loss / dataset_sizes[phase]
epoch_loss = running_corrects.double() / dataset_sizes[phase]
epoch_acc
print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
if phase == 'val' and epoch_acc > best_acc:
= epoch_acc
best_acc
torch.save(model.state_dict
(), best_model_params_path)
print()
= time.time() - since
time_elapsed print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
print(f'Best val Acc: {best_acc:.4f}')
model.load_state_dict(torch.load(best_model_params_path))return model
We define a function (train_model) to train a deep learning model. It includes training and validation loops, tracks loss and accuracy, and saves the best model parameters based on validation accuracy. Learning rate scheduling is also applied.
Visualizing Model Predictions
Let’s create a function to visualize model predictions. This will help us understand how well the model is performing on specific images.
import matplotlib.pyplot as plt
import numpy as np
def visualize_model(model, num_images=6):
= model.training
was_training eval()
model.= 0
images_so_far = plt.figure()
fig
with torch.no_grad():
for i, (inputs, labels) in enumerate(dataloaders['val']):
= inputs.to(device)
inputs = labels.to(device)
labels
= model(inputs)
outputs = torch.max(outputs, 1)
_, preds
for j in range(inputs.size()[0]):
+= 1
images_so_far = plt.subplot(num_images // 2, 2, images_so_far)
ax 'off')
ax.axis(f'Predicted: {class_names[preds[j]]}')
ax.set_title(
imshow(inputs.cpu().data[j])
if images_so_far == num_images:
=was_training)
model.train(modereturn
=was_training) model.train(mode
We define another function (visualize_model) to visualize model predictions. It takes a trained model, makes predictions on a batch of validation data, and displays the input images along with their respective predicted class labels.
Finetuning the Model
Now comes the exciting part—finetuning a pretrained model for our specific task. In this example, we’ll use the popular ResNet18
architecture pretrained on ImageNet and adapt it for our binary classification task (alligators vs. crocodiles).
from torchvision import models
import torch.nn as nn
import torch.optim as optim
# Load the pretrained ResNet-18 model
= models.resnet18(pretrained=True)
model_ft
# Modify the final fully connected layer for binary classification
= model_ft.fc.in_features
num_ftrs = nn.Linear(num_ftrs, 2)
model_ft.fc
# Define loss function and optimizer
= nn.CrossEntropyLoss()
criterion
# Observe that all parameters are being optimized
= optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
optimizer_ft
# Decay the learning rate by a factor of 0.1 every 7 epochs
= lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
exp_lr_scheduler
# Move the model to the GPU if available
= model_ft.to(device) model_ft
In this code block, a pretrained ResNet-18 model is loaded and modified for binary classification. The final fully connected layer is replaced with a new one with 2 output units (for alligators and crocodiles classification). The loss function, optimizer, and learning rate scheduler are defined. The model is moved to the GPU if available.
Now that everything is set up, we can proceed to train the model using the train_model
function defined earlier.
= train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=25) model_ft
We train the fine-tuned model using the train_model function defined earlier. It prints the training and validation loss and accuracy for each epoch and records the best model parameters based on validation accuracy.
Epoch 0/24
----------
train Loss: 0.7276 Acc: 0.6117
val Loss: 0.4810 Acc: 0.7767
Epoch 1/24
----------
train Loss: 0.7097 Acc: 0.6749
val Loss: 0.5373 Acc: 0.7573
Epoch 2/24
----------
train Loss: 0.9122 Acc: 0.6185
val Loss: 0.5871 Acc: 0.7379
-
…
Epoch 23/24
----------
train Loss: 0.3144 Acc: 0.8555
val Loss: 0.4539 Acc: 0.7864
Epoch 24/24
----------
train Loss: 0.3832 Acc: 0.8330
val Loss: 0.4158 Acc: 0.7961
Training complete in 9m 19s
Best val Acc: 0.8350
After training, our best validation accuracy is recorded as 0.8350. But can we do better?
Visualizing Model Predictions
To see how our finetuned model performs, we can use the visualize_model
function.
visualize_model(model_ft)
This code block visualizes the predictions of the fine-tuned model on a batch of validation data, allowing us to see how well the model is performing on specific images.
ConvNet as a Fixed Feature Extractor
In the previous example, we finetuned the entire model. However, sometimes we might want to use a pretrained ConvNet as a fixed feature extractor. In this case, we’ll freeze all layers except the final fully connected layer. This can be useful when you have a small dataset and you don’t want to risk overfitting by finetuning the entire model.
Let’s create a function to set the requires_grad attribute of the model’s layers for feature extraction.
def set_parameter_requires_grad(model, feature_extracting):
if feature_extracting:
for param in model.parameters():
= False param.requires_grad
We define a function (set_parameter_requires_grad) to freeze the weights of layers in a model. It sets requires_grad to False for all parameters if feature_extracting is True.
Next, we’ll (again) load a pretrained ResNet18
model and replace the final fully connected layer for our binary classification task; however, this time will freeze all layers except the final fully connected layer.
# Load the pretrained ResNet-18 model
= models.resnet18(pretrained=True)
model_conv
# Freeze all layers except the final fully connected layer
=True)
set_parameter_requires_grad(model_conv, feature_extracting
# Modify the final fully connected layer for binary classification
= model_conv.fc.in_features
num_ftrs = nn.Linear(num_ftrs, 2)
model_conv.fc
# Define loss function and optimizer
= nn.CrossEntropyLoss()
criterion
# Observe that only the parameters of the final layer are being optimized
= optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)
optimizer_conv
# Decay the learning rate by a factor of 0.1 every 7 epochs
= lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)
exp_lr_scheduler
# Move the model to the GPU if available
= model_conv.to(device) model_conv
This freezing technique is frequently used when you want to use a pretrained model as a fixed feature extractor. The rest of the code is similar to the fine-tuning scenario, defining loss, optimizer, and learning rate scheduler and moving the model to the GPU.
Now, we can train the model using the same train_model
function as before.
= train_model(model_conv, criterion, optimizer_conv, exp_lr_scheduler, num_epochs=25) model_conv
Again, we print the training and validation loss and accuracy for each epoch and record the best model parameters based on validation accuracy.
Epoch 0/24
----------
train Loss: 0.7446 Acc: 0.5350
val Loss: 0.4117 Acc: 0.7864
Epoch 1/24
----------
train Loss: 0.6309 Acc: 0.6682
val Loss: 0.3449 Acc: 0.8447
Epoch 2/24
----------
train Loss: 0.6819 Acc: 0.6614
val Loss: 0.6134 Acc: 0.6990
-
…
Epoch 23/24
----------
train Loss: 0.5224 Acc: 0.7381
val Loss: 0.3565 Acc: 0.8738
Epoch 24/24
----------
train Loss: 0.5292 Acc: 0.7336
val Loss: 0.3602 Acc: 0.8738
Training complete in 8m 42s
Best val Acc: 0.8835
After freezing the layers, our best validation accuracy is now recorded as 0.8835 – great!
Let’s now visualize our model!
visualize_model(model_conv)
plt.ioff() plt.show()
Inference on custom images
While it is interesting to see how well our model performs on a pre-made dataset, it will doubtless be more exciting if we can test our model on any image pulled from Google. Allegheny College (of which I am now an alumnus!) has a statue of an alligator on its campus. For fun, then, let’s see how well our model performs when fed an image of this very statue. Here’s the image:
Let’s create a function that takes a model and an image as input, processes the image, makes predictions using the model, and displays the image along with its respective predicted class label.
def visualize_model_predictions(model, img_path):
# Store the current training state of the model
= model.training
was_training # Set the model to evaluation mode
eval()
model.
# Open and preprocess the image from the specified path
= Image.open(img_path)
img = data_transforms['val'](img) # Apply the same transformations as for validation data
img = img.unsqueeze(0) # Add a batch dimension
img = img.to(device) # Move the image to the specified device (GPU or CPU)
img
# Disable gradient computation during inference
with torch.no_grad():
# Forward pass to obtain predictions
= model(img)
outputs = torch.max(outputs, 1)
_, preds
# Display the image and its predicted class
= plt.subplot(2, 2, 1) # Create a subplot
ax 'off') # Turn off axis labels
ax.axis(f'Predicted: {class_names[preds[0]]}') # Set the title with the predicted class name
ax.set_title(0]) # Display the image
imshow(img.cpu().data[
# Restore the model's original training state
=was_training) model.train(mode
We can now predict the label of our image.
# Visualize predictions for a single image using the specified model
visualize_model_predictions(
model_conv,='/content/drive/MyDrive/crocodylia_data/alligator.jpeg' # Specify the path to the image
img_path
)
# Turn off interactive mode for matplotlib to display the plot
plt.ioff()
# Display the plot with the image and its predictions
plt.show()
Indeed, the statue is in fact one of an alligator! Wonderful!
Conclusion
Using transfer learning with PyTorch, we’ve explored how we can finetune a pretrained ConvNet and employ it to a novel sitation (in this case, a small dataset of alligator and crocodile images). By harnessing the knowledge encoded in pretrained models, we can accelerate our model development process and achieve great results with relatively small datasets!