[Computer Vision] CLIP combat: Zero-Shot Prediction (including source code)

1. Code combat

The code below performs zero-shot prediction using CLIP. This example takes an image from the CIFAR-100 dataset and predicts the most likely label out of the 100 text labels in the dataset.

import os
import clip
import torch
from torchvision.datasets import CIFAR100

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)

# Download the dataset
cifar100 = CIFAR100(root=os.path.expanduser("./data/"), download=True, train=False)

# Prepare the inputs
image, class_id = cifar100[3637]
image_input = preprocess(image).unsqueeze(0).to(device)
text_inputs = torch.cat([clip.tokenize(f"a photo of a {<!-- -->c}") for c in cifar100.classes]).to(device)

# Calculate features
with torch.no_grad():
    image_features = model.encode_image(image_input)
    text_features = model.encode_text(text_inputs)

# Pick the top 5 most similar labels for the image
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(5)

# Print the result
print("\\
Top predictions:\\
")
for value, index in zip(values, indices):
    print(f"{<!-- -->cifar100.classes[index]:>16s}: {<!-- -->100 * value.item():.2f}%")

The final output is:


Let’s visualize this picture:

import os
import pickle
from PIL import Image
import matplotlib.pyplot as plt

# Define the path to the CIFAR-100 dataset
dataset_path = os.path.expanduser('./data/cifar-100-python')

# Load the image
with open(os.path.join(dataset_path, 'test'), 'rb') as f:
    cifar100 = pickle.load(f, encoding='latin1')

# Select an image index to visualize
image_index = 3637

# Extract the image and its label
image = cifar100['data'][image_index]
label = cifar100['fine_labels'][image_index]

# Reshape and transpose the image to the correct format
image = image.reshape((3, 32, 32)).transpose((1, 2, 0))

# Create a PIL image from the numpy array
pil_image = Image.fromarray(image)

# Display the image
plt.imshow(pil_image, interpolation='bilinear')
plt. title('Label: ' + str(label))
plt. axis('off')
plt. show()


It can be seen that the picture is very blurry, which may be because the CIFAR-100 data set itself has a low image resolution, which cannot be changed.

2. Interpretation of code line by line

2.1 Forecast

import os
import clip
import torch
from torchvision.datasets import CIFAR100

First import the required libraries and modules, including os, clip, torch and CIFAR100.

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)

Determine the device type (using GPU or CPU), and load the pre-trained CLIP model (Vision Transformer – B/32). The clip.load() function returns the loaded model and data preprocessing functions.

# Download the dataset
cifar100 = CIFAR100(root=os.path.expanduser("./data/"), download=True, train=False)

Download the CIFAR-100 dataset and save it in the specified root directory (“./data/”). The CIFAR100 class is imported from the torchvision.datasets module and is used to load the CIFAR-100 dataset.

# Prepare the inputs
image, class_id = cifar100[3637]
image_input = preprocess(image).unsqueeze(0).to(device)
text_inputs = torch.cat([clip.tokenize(f"a photo of a {<!-- -->c}") for c in cifar100.classes]).to(device)

Prepare to enter data. First, get the image and class ID for the specified index (3637) from the CIFAR-100 dataset. Then, the image is preprocessed, including normalization and conversion to the tensor format required by the model, and moved to the device (GPU or CPU). Next, text input is generated, which includes text descriptions of all categories in the CIFAR-100 dataset, also converted to the tensor format required by the model, and moved to the device.

# Calculate features
with torch.no_grad():
    image_features = model.encode_image(image_input)
    text_features = model.encode_text(text_inputs)

Compute feature vectors for images and text. Convert the input image and text into feature vectors by calling the model’s encode_image() and encode_text() methods. Use the torch.no_grad() context manager to disable gradient calculations since they are not required.

# Pick the top 5 most similar labels for the image
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(5)

Select the top 5 tags with the most similar images. First, the image feature vectors and text feature vectors are normalized. Then, the similarity between image feature vectors and all text feature vectors is calculated. The similarity of each text description to the image is obtained by performing matrix multiplication and softmax operations. Finally, the highest top 5 values and corresponding indices are selected from the similarity.

# Print the result
print("\\
Top predictions:\\
")
for value, index in zip(values, indices):
    print(f"{<!-- -->cifar100.classes[index]:>16s}: {<!-- -->100 * value.item():.2f}%")

Print the result. Print the top 5 most similar tags and their corresponding similarities in the format of category name and similarity in percentage.

This code encodes an image with text using the CLIP model and finds the text label most similar to the image. This can be used for tasks like image classification or image retrieval.

2.2 Visualization

import os
import pickle
from PIL import Image
import matplotlib.pyplot as plt

First import the required libraries and modules, including os, pickle, Image and matplotlib.pyplot.

# Define the path to the CIFAR-100 dataset
dataset_path = os.path.expanduser('./data/cifar-100-python')

Defines the path to the CIFAR-100 dataset. The os.path.expanduser() function is used to expand the path in the user directory.

# Load the image
with open(os.path.join(dataset_path, 'test'), 'rb') as f:
    cifar100 = pickle.load(f, encoding='latin1')

Load image data. Open the image file (‘test’) in the CIFAR-100 dataset using the open() function and load the image data into the cifar100 variable using the pickle.load() function. ‘latin1’ is the encoding parameter, which is used to specify the encoding format of the loaded data.

# Select an image index to visualize
image_index = 3637

Select an index of an image to use for visualizing that image. Here, the image with index 3637 is selected for visualization.

# Extract the image and its label
image = cifar100['data'][image_index]
label = cifar100['fine_labels'][image_index]

Extract the selected images and their labels. Extract the image data at the specified index from the ‘data’ key in the cifar100 dictionary, and extract the corresponding labels from the ‘fine_labels’ key.

# Reshape and transpose the image to the correct format
image = image.reshape((3, 32, 32)).transpose((1, 2, 0))

Adjust the shape and order of the images to match the correct format. The reshape() function reshape the image from a flattened one-dimensional array to a three-dimensional array of (3, 32, 32), representing the number of channels, height, and width. Then, the transpose() function rearranges the dimensions to move the channel dimension to the end, resulting in an image format of (32, 32, 3).

# Create a PIL image from the numpy array
pil_image = Image.fromarray(image)

Convert a NumPy array to a PIL image object. Convert the NumPy array image to a PIL image object pil_image using the Image.fromarray() function.

# Display the image
plt.imshow(pil_image, interpolation='bilinear')
plt. title('Label: ' + str(label))
plt. axis('off')
plt. show()

Display the image. Use the plt.imshow() function to display the image, and perform bilinear interpolation by setting the interpolation parameter to ‘bilinear’ to improve the display effect of the image. The plt.title() function is used to set the image title, which contains the label of the image. plt.axis(‘off’) is used to turn off the display of the coordinate axis. Finally, the image is displayed using the plt.show() function.

This code loads image data from the CIFAR-100 dataset and visualizes the images at the specified index along with their labels. Note that image clarity and quality can be improved by using image display options such as bilinear interpolation.