Quantcast
Channel: PyImageSearch
Viewing all articles
Browse latest Browse all 55

PyTorch image classification with pre-trained networks

$
0
0

In this tutorial, you will learn how to perform image classification with pre-trained networks using PyTorch. Utilizing these networks, you can accurately classify 1,000 common object categories in only a few lines of code.

Today’s tutorial is part four in our five part series on PyTorch fundamentals:

  1. What is PyTorch?
  2. Intro to PyTorch: Training your first neural network using PyTorch
  3. PyTorch: Training your first Convolutional Neural Network
  4. PyTorch image classification with pre-trained networks (today’s tutorial)
  5. August 2nd: PyTorch object detection with pre-trained networks (next week’s tutorial)

Throughout the rest of this tutorial, you’ll gain experience using PyTorch to classify input images using seminal, state-of-the-art image classification networks, including VGG, Inception, DenseNet, and ResNet.

To learn how to perform image classification with pre-trained PyTorch networks, just keep reading.

Looking for the source code to this post?

Jump Right To The Downloads Section

PyTorch image classification with pre-trained networks

In the first part of this tutorial, we’ll discuss what pre-trained image classification networks are, including those that are built into the PyTorch library.

From there, we’ll configure our development environment and review our project directory structure.

I’ll then show you how to implement a Python script that can accurately classify input images using pre-trained PyTorch networks.

We’ll wrap up this tutorial with a discussion of our results.

What are pre-trained image classification networks?

Figure 1: Most popular, state-of-the-art neural networks come with weights pre-trained on the ImageNet dataset. The PyTorch library includes many of these popular image classification networks.

When it comes to image classification, there is no dataset/challenge more famous than ImageNet. The goal of ImageNet is to accurately classify input images into a set of 1,000 common object categories that computer vision systems will “see” in everyday life.

Most popular deep learning frameworks, including PyTorch, Keras, TensorFlow, fast.ai, and others, include pre-trained networks. These are highly accurate, state-of-the-art models that computer vision researchers trained on the ImageNet dataset.

After training on ImageNet was complete, researchers saved their models to disk and then published them freely for other researchers, students, and developers to learn from and use in their own projects.

This tutorial will show how to use PyTorch to classify input images using the following state-of-the-art classification networks:

  • VGG16
  • VGG19
  • Inception
  • DenseNet
  • ResNet

Let’s get started!

Configuring your development environment

To follow this guide, you need to have both PyTorch and OpenCV installed on your system.

Luckily, both PyTorch and OpenCV are extremely easy to install using pip:

$ pip install torch torchvision
$ pip install opencv-contrib-python

If you need help configuring your development environment for PyTorch, I highly recommend that you read the PyTorch documentation — PyTorch’s documentation is comprehensive and will have you up and running quickly.

And if you need help installing OpenCV, be sure to refer to my pip install OpenCV tutorial.

Having problems configuring your development environment?

Figure 2: Having trouble configuring your dev environment? Want access to pre-configured Jupyter Notebooks running on Google Colab? Be sure to join PyImageSearch University — you’ll be up and running with this tutorial in a matter of minutes.

All that said, are you:

  • Short on time?
  • Learning on your employer’s administratively locked system?
  • Wanting to skip the hassle of fighting with the command line, package managers, and virtual environments?
  • Ready to run the code right now on your Windows, macOS, or Linux system?

Then join PyImageSearch University today!

Gain access to Jupyter Notebooks for this tutorial and other PyImageSearch guides that are pre-configured to run on Google Colab’s ecosystem right in your web browser! No installation required.

And best of all, these Jupyter Notebooks will run on Windows, macOS, and Linux!

Project structure

Before we implement image classification with PyTorch, let’s first review our project directory structure.

Start by accessing the “Downloads” section of this guide to retrieve the source code and example images. You’ll then be presented with the following directory structure.

$ tree . --dirsfirst
.
├── images
│   ├── bmw.png
│   ├── boat.png
│   ├── clint_eastwood.jpg
│   ├── jemma.png
│   ├── office.png
│   ├── scotch.png
│   ├── soccer_ball.jpg
│   └── tv.png
├── pyimagesearch
│   └── config.py
├── classify_image.py
└── ilsvrc2012_wordnet_lemmas.txt

Inside the pyimagesearch module we have a single file, config.py. This file stores important configurations, such as:

  • Our input image dimensions
  • Mean and standard deviation for mean subtraction and scaling
  • Whether or not we are using a GPU for training
  • Path to the human-readable ImageNet class labels (i.e., ilsvrc2012_wordnet_lemmas.txt)

Our classify_image.py script will load our config and then classify an input image using either VGG16, VGG19, Inception, DenseNet, or ResNet (depending on which model architecture we supply as our command line argument).

The images directory contains a number of sample images where we’ll apply these image classification networks.

Creating our configuration file

Before we implement our image classification driver script, let’s first create a configuration file to store important configurations.

Open the config.py file in the pyimagesearch module and insert the following code:

# import the necessary packages
import torch

# specify image dimension
IMAGE_SIZE = 224

# specify ImageNet mean and standard deviation
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]

# determine the device we will be using for inference
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# specify path to the ImageNet labels
IN_LABELS = "ilsvrc2012_wordnet_lemmas.txt"

Line 5 defines our input image spatial dimensions, meaning that each image will be resized to 224×224 pixels before being passed through our pre-trained PyTorch network for classification.

Note: Most networks trained on the ImageNet dataset accept images that are 224×224 or 227×227. Some networks, particularly fully convolutional networks, may accept larger image dimensions.

From there, we define the mean and standard deviation of RGB pixel intensities across our training set (Lines 8 and 9). Prior to passing an input image through our network for classification, we first scale the image pixel intensities by subtracting the mean and then dividing by the standard deviation — this preprocessing is typical for CNNs trained on large, diverse image datasets such as ImageNet.

From there, Line 12 specifies whether we are using our CPU or GPU for training, while Line 15 defines the path to our input text file of ImageNet class labels.

If you were to open this file in your favorite text editor of choice, you would see the following contents:

tench, Tinca_tinca
goldfish, Carassius_auratus
...
bolete
ear, spike, capitulum
toilet_tissue, toilet_paper, bathroom_tissue

Each row in this text file maps to the name of a class label our pre-trained PyTorch networks were trained to recognize and classify.

Implementing our image classification script

With our configuration file taken care of, let’s move on to implementing our main driver script used to classify input images using our pre-trained PyTorch networks.

Open the classify_image.py file in your project directory structure, and let’s get to work:

# import the necessary packages
from pyimagesearch import config
from torchvision import models
import numpy as np
import argparse
import torch
import cv2

We start on Lines 2-7 importing our Python packages, including:

  • config: The configuration file we implemented from the previous section
  • models: Contains PyTorch’s pre-trained neural networks
  • numpy: Numerical array processing
  • torch: Accesses the PyTorch API
  • cv2: Our OpenCV bindings

With our imports taken care of, let’s define a function to accept an input image and preprocess it:

def preprocess_image(image):
	# swap the color channels from BGR to RGB, resize it, and scale
	# the pixel values to [0, 1] range
	image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
	image = cv2.resize(image, (config.IMAGE_SIZE, config.IMAGE_SIZE))
	image = image.astype("float32") / 255.0

	# subtract ImageNet mean, divide by ImageNet standard deviation,
	# set "channels first" ordering, and add a batch dimension
	image -= config.MEAN
	image /= config.STD
	image = np.transpose(image, (2, 0, 1))
	image = np.expand_dims(image, 0)

	# return the preprocessed image
	return image

Our preprocess_image function takes a single argument, image, which is the image we’ll be preprocessing for classification.

We start the preprocessing operations by:

  1. Swapping from BGR to RGB channel ordering (the pre-trained networks we’re using here utilized RGB channel ordering whereas OpenCV uses BGR ordering by default)
  2. Resizing our image to fixed dimensions (i.e., 224×224), ignoring aspect ratio
  3. Converting our image to a floating point data type and then scaling the pixel intensities to the range [0, 1]

From there, we perform a second set of preprocessing operations:

  1. Subtracting the mean (Line 18) and dividing by the standard deviation (Line 19)
  2. Moving the channels dimension to the front of the array (Line 20), which is called channels-first ordering and is the default channel ordering method that PyTorch expects
  3. Adding a batch dimension to the array (Line 21)

The preprocessed image is then returned to the calling function.

Next, let’s parse our command line arguments:

# construct the argument parser and parse the arguments
ap = argparse.ArgumentParser()
ap.add_argument("-i", "--image", required=True,
	help="path to the input image")
ap.add_argument("-m", "--model", type=str, default="vgg16",
	choices=["vgg16", "vgg19", "inception", "densenet", "resnet"],
	help="name of pre-trained network to use")
args = vars(ap.parse_args())

We have two command line arguments to parse:

  1. --image: The path to the input image that we wish to classify
  2. --model: The pre-trained CNN model we’ll be using to classify the image

Let’s now define a MODELS dictionary which maps the name of the --model command line argument to its corresponding PyTorch function:

# define a dictionary that maps model names to their classes
# inside torchvision
MODELS = {
	"vgg16": models.vgg16(pretrained=True),
	"vgg19": models.vgg19(pretrained=True),
	"inception": models.inception_v3(pretrained=True),
	"densenet": models.densenet121(pretrained=True),
	"resnet": models.resnet50(pretrained=True)
}

# load our the network weights from disk, flash it to the current
# device, and set it to evaluation mode
print("[INFO] loading {}...".format(args["model"]))
model = MODELS[args["model"]].to(config.DEVICE)
model.eval()

Lines 37-43 create our MODELS dictionary:

  • The key to the dictionary is the human-readable name of the model, passed in via the --model command line argument.
  • The value to the dictionary is the corresponding PyTorch function used to load the model with the weights pre-trained on ImageNet

You’ll be able to use the following pre-trained models to classify an input image with PyTorch:

  1. VGG16
  2. VGG19
  3. Inception
  4. DenseNet
  5. ResNet

Specifying the pretrained=True flag instructs PyTorch to not only load the model architecture definition, but also download the pre-trained ImageNet weights for the model.

Line 48 then loads the model and pre-trained weights (if you’ve never downloaded the model weights before they will be automatically downloaded and cached for you) and then sets the model to run either on your CPU or GPU, depending on your DEVICE from the configuration file.

Line 49 puts our model into evaluation mode, instructing PyTorch to handle special layers, such as dropout and batch normalization, different from how it would otherwise handle them during training. Putting your model into evaluation mode before making predictions is critical, so don’t forget to do it!

Now that our model is loaded, we need an input image — let’s take care of that now:

# load the image from disk, clone it (so we can draw on it later),
# and preprocess it
print("[INFO] loading image...")
image = cv2.imread(args["image"])
orig = image.copy()
image = preprocess_image(image)

# convert the preprocessed image to a torch tensor and flash it to
# the current device
image = torch.from_numpy(image)
image = image.to(config.DEVICE)

# load the preprocessed the ImageNet labels
print("[INFO] loading ImageNet labels...")
imagenetLabels = dict(enumerate(open(config.IN_LABELS)))

Line 54 loads our input image from disk. We make a copy of it on Line 55 so that we can draw on it and visualize the top prediction of our network. We also make use of our preprocess_image function on Line 56 to perform resizing and scaling.

Line 60 converts our image from a NumPy array to a PyTorch tensor, while Line 61 moves the image to our device (either CPU or GPU).

FInally, Line 65 loads our input ImageNet class labels from disk.

We are now ready to make predictions on input image using our model:

# classify the image and extract the predictions
print("[INFO] classifying image with '{}'...".format(args["model"]))
logits = model(image)
probabilities = torch.nn.Softmax(dim=-1)(logits)
sortedProba = torch.argsort(probabilities, dim=-1, descending=True)

# loop over the predictions and display the rank-5 predictions and
# corresponding probabilities to our terminal
for (i, idx) in enumerate(sortedProba[0, :5]):
	print("{}. {}: {:.2f}%".format
		(i, imagenetLabels[idx.item()].strip(),
		probabilities[0, idx.item()] * 100))

Line 69 performs a forward-pass of our network, resulting in the outputs of the network.

We pass these through the Softmax function on Line 70 to obtain the predicted probabilities for each of the possible 1,000 class labels the model was trained on.

Line 71 then sorts the probabilities in descending order with higher probabilities at the front of the list.

We then display the top-5 predicted class labels and corresponding probabilities to our terminal on Lines 75-78 by:

  • Looping over the top-5 predictions
  • Looking up the name of the class label using our imagenetLabels dictionary
  • Displaying the predicted probability

Our final code block draws the top-1 (i.e., top predicted label) on our output image:

# draw the top prediction on the image and display the image to
# our screen
(label, prob) = (imagenetLabels[probabilities.argmax().item()],
	probabilities.max().item())
cv2.putText(orig, "Label: {}, {:.2f}%".format(label.strip(), prob * 100),
	(10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
cv2.imshow("Classification", orig)
cv2.waitKey(0)

The result is then displayed to our screen.

Image classification with PyTorch results

We are now ready to apply image classification with PyTorch!

Be sure to access the “Downloads” section of this tutorial to retrieve the source code and example images.

From there, try classifying an input image using the following command:

$ python classify_image.py --image images/boat.png
[INFO] loading vgg16...
[INFO] loading image...
[INFO] loading ImageNet labels...
[INFO] classifying image with 'vgg16'...
0. wreck: 99.99%
1. seashore, coast, seacoast, sea-coast: 0.01%
2. pirate, pirate_ship: 0.00%
3. breakwater, groin, groyne, mole, bulwark, seawall, jetty: 0.00%
4. sea_lion: 0.00%
Figure 3: Using PyTorch and VGG16 to classify an input image.

It appears that Captain Jack Sparrow is stranded on the beach! And sure enough, the VGG16 network is able to correctly classify the input image as a “wreck” (i.e., shipwreck) with 99.99% probability.

It’s also interesting to see that “seashore” is the second top prediction from the model — this prediction is also accurate, due to the boat being on the beach.

Let’s try a different image, this time using the DenseNet model:

$ python classify_image.py --image images/bmw.png --model densenet
[INFO] loading densenet...
[INFO] loading image...
[INFO] loading ImageNet labels...
[INFO] classifying image with 'densenet'...
0. convertible: 96.61%
1. sports_car, sport_car: 2.25%
2. car_wheel: 0.45%
3. beach_wagon, station_wagon, wagon, estate_car, beach_waggon, station_waggon, waggon: 0.22%
4. racer, race_car, racing_car: 0.13%
Figure 4: Applying DenseNet and PyTorch to classify an image.

The top prediction from DenseNet is “convertible” with 96.61% accuracy. The second top prediction, “sports car” is also accurate.

This image contains Jemma, my family’s beagle:

$ python classify_image.py --image images/jemma.png --model resnet
[INFO] loading resnet...
[INFO] loading image...
[INFO] loading ImageNet labels...
[INFO] classifying image with 'resnet'...
0. beagle: 95.98%
1. bluetick: 1.46%
2. Walker_hound, Walker_foxhound: 1.11%
3. English_foxhound: 0.45%
4. maraca: 0.25%
Figure 5: Utilizing ResNet and PyTorch to correctly classify an input image.

Here we are using the ResNet architecture to classify our input image. Jemma is a “beagle” (a type of dog), which ResNet accurately predicts with 95.98% probability.

Interestingly, a “bluetick,” “walker hound,” and “English foxhound” are all types of dogs belonging to the “hound” family — all of these would be reasonable predictions from the model.

Let’s take a look at one final example:

$ python classify_image.py --image images/soccer_ball.jpg --model inception
[INFO] loading inception...
[INFO] loading image...
[INFO] loading ImageNet labels...
[INFO] classifying image with 'inception'...
0. soccer_ball: 100.00%
1. volleyball: 0.00%
2. sea_urchin: 0.00%
3. rugby_ball: 0.00%
4. silky_terrier, Sydney_silky: 0.00%
Figure 6: Using Inception and PyTorch to make predictions on an input image.

Our Inception model correctly classifies the input image as “soccer ball” with 100% probability.

Image classification allows us to assign one or more labels to an input image; however, it tells us nothing about where in the image the object resides.

To determine where in an input image a given object is, we need to apply object detection:

Figure 7: Object detection can not only tell us what is in an image but also where the object is.

Just like we have pre-trained networks for image classification, we also have pre-trained networks for object detection as well. Next week you’ll learn how to use PyTorch to detect objects in images using specialized object detection networks.

What's next? I recommend PyImageSearch University.

Course information:
25 total classes • 37h 19m video • Last updated: 7/2021
★★★★★ 4.84 (128 Ratings) • 10,597 Students Enrolled

I strongly believe that if you had the right teacher you could master computer vision and deep learning.

Do you think learning computer vision and deep learning has to be time-consuming, overwhelming, and complicated? Or has to involve complex mathematics and equations? Or requires a degree in computer science?

That’s not the case.

All you need to master computer vision and deep learning is for someone to explain things to you in simple, intuitive terms. And that’s exactly what I do. My mission is to change education and how complex Artificial Intelligence topics are taught.

If you're serious about learning computer vision, your next stop should be PyImageSearch University, the most comprehensive computer vision, deep learning, and OpenCV course online today. Here you’ll learn how to successfully and confidently apply computer vision to your work, research, and projects. Join me in computer vision mastery.

Inside PyImageSearch University you'll find:

  • ✓ 25 courses on essential computer vision, deep learning, and OpenCV topics
  • ✓ 25 Certificates of Completion
  • ✓ 37h 19m on-demand video
  • ✓ Brand new courses released every month, ensuring you can keep up with state-of-the-art techniques
  • ✓ Pre-configured Jupyter Notebooks in Google Colab
  • ✓ Run all code examples in your web browser — works on Windows, macOS, and Linux (no dev environment configuration required!)
  • ✓ Access to centralized code repos for all 400+ tutorials on PyImageSearch
  • ✓ Easy one-click downloads for code, datasets, pre-trained models, etc.
  • ✓ Access on mobile, laptop, desktop, etc.

Click here to join PyImageSearch University

Summary

In this tutorial, you learned how to perform image classification using PyTorch. Specifically, we utilized popular pre-trained network architectures, including:

  • VGG16
  • VGG19
  • Inception
  • DenseNet
  • ResNet

These models were trained by the researchers responsible for inventing and proposing the novel architectures listed above. After training was complete, these researchers saved the model weights to disk and then published them for other researchers, students, and developers to learn from and use in their own projects.

While the models are free to use, make sure you check any terms/conditions associated with them, as some models are not free to use in commercial applications (typically entrepreneurs in the AI space get around this restriction by training the models themselves rather than using the pre-trained weights provided by the original authors).

Stay tuned for next week’s blog post, where you’ll learn how to perform object detection using PyTorch.

To download the source code to this post (and be notified when future tutorials are published here on PyImageSearch), simply enter your email address in the form below!

Download the Source Code and FREE 17-page Resource Guide

Enter your email address below to get a .zip of the code and a FREE 17-page Resource Guide on Computer Vision, OpenCV, and Deep Learning. Inside you'll find my hand-picked tutorials, books, courses, and libraries to help you master CV and DL!

The post PyTorch image classification with pre-trained networks appeared first on PyImageSearch.


Viewing all articles
Browse latest Browse all 55

Latest Images

Trending Articles





Latest Images