TrOCR model fine-tuning [Transformer-based optical character recognition]

The TrOCR (Transformer-based Optical Character Recognition) model is one of the best performing OCR models. In our previous article we analyzed their performance on single lines of printed and handwritten text. However, like any other deep learning model, they have their limitations. TrOCR doesn’t perform well with curved text out of the box. This article will take the TrOCR family one step further by fine-tuning the TrOCR model on a curved text dataset.

Recommended online tools: Three.js AI texture development kit – YOLO synthetic data generator – GLTF/GLB online editing – 3D model format online conversion – Programmable 3D scene editor

From previous articles we know that TrOCR cannot recognize text on curved and vertical images. These images are part of the SCUT-CTW1500 dataset. We will train the TrOCR model on this dataset and run inference again to analyze the results. This will give us a comprehensive understanding of how far the boundaries of the TrOCR model can be pushed for different use cases.

We will use the Hugging Face Trainer API to train the model. To complete the entire process, the following steps must be followed:

  • Prepare and analyze a dataset of curved text images.
  • Load the TrOCR Small Printed model from Hugging Face.
  • Initialize the HF Seq2Seq trainer API.
  • Define evaluation metrics
  • Train the model and run inference.

1. Curved text data set

The SCUT-CTW1500 dataset (hereafter referred to as CTW1500) contains thousands of images of curved text and text in the wild.

The original dataset is available in the official GitHub repository. This includes training and test sets. Only the training set contains labels in XML format. Therefore, we split the training set into different training and validation subsets.

The final dataset contains 6052 training samples and 1651 validation samples. The tags for each image are present in a text file and are separated by newlines.

Let’s examine some images in the dataset and their text labels.

Figure 2. Baseline image with labels from CTW1500 dataset

A few things can be seen from the picture above. In addition to curved text images, the dataset also contains blurred and hazy images. This real-world image variation poses challenges to deep learning models. Understanding the characteristics of images and text in such diverse datasets is critical for state-of-the-art performance of any OCR model. This poses an interesting challenge to the TrOCR model, which naturally, after training, performs significantly better on such images.

Let’s get into the technical aspects of this article. From here on, we will discuss the code of the TrOCR training process in detail. All code is available in Jupyter Notebook via download link.

2. Development environment installation

The first step is to install all required libraries.

!pip install -q transformers
!pip install -q sentencepiece
!pip install -q jiwer
!pip install -q datasets
!pip install -q evaluate
!pip install -q -U accelerate
 
 
!pip install -q matplotlib
!pip install -q protobuf==3.20.1
!pip install -q tensorboard

Among them, some important ones are:

  • Transformers: This is the Hugging Face Transformers library, which gives us access to hundreds of Transformer-based models, including TrOCR models.
  • Sentencepiece: This is a sentence token generator library needed to convert words into tokens and numbers. This is also part of the Hugging Face collection.
  • jiwer: The jiwer library gives us access to a variety of speech recognition and language recognition metrics. These include WER (Word Error Rate) and CER (Character Error Rate). We will use the CER metric to evaluate the model during training.

Next, we import all required libraries and packages.

import os
import os
import torch
import evaluate
import numpy as np
import pandas as pd
import glob as glob
import torch.optim as optim
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
 
 
from PIL import Image
from zipfile import ZipFile
from tqdm.notebook import tqdm
from dataclasses import dataclass
from torch.utils.data import Dataset
from urllib.request import urlretrieve
from transformers import (
    VisionEncoderDecoderModel,
    TrocProcessor,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    default_data_collator
)

Some important import statements in the above code block are:

  • VisionEncoderDecoderModel: We need this class to define different TrOCR models.
  • TrOCRProcessor: TrOCR expects the dataset to follow a specific normalization process. This class will ensure that the image is normalized and processed correctly.
  • Seq2SeqTrainer: This is required to initialize the trainer API.
  • Seq2SeqTrainingArguments: When training, the trainer API requires multiple parameters.
  • The Seq2SeqTrainingArguments class initializes all required parameters before passing them to the API.
  • transforms: The Torchvision transform module is required to apply data augmentation to images.

Now, set the seed for repeatability across different runs and define the computing device.

def seed_everything(seed_value):
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
 
seed_everything(42)
 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

3. Download and extract the data set

The next block of code contains a helper function that downloads the CTW1500 data and extracts it.

def download_and_unzip(url, save_path):
    print(f"Downloading and extracting assets....", end="")
 
 
    # Downloading zip file using urllib package.
    urlretrieve(url, save_path)
 
 
    try:
        # Extracting zip file using the zipfile package.
        with ZipFile(save_path) as z:
            # Extract ZIP file contents in the same directory.
            z.extractall(os.path.split(save_path)[0])
 
 
        print("Done")
 
 
    except Exception as e:
        print("\
Invalid file.", e)
 
 
URL = r"https://www.dropbox.com/scl/fi/vyvr7jbdvu8o174mbqgde/scut_data.zip?rlkey=fs8axkpxunwu6if9a2su71kxs &dl=1"
asset_zip_path = os.path.join(os.getcwd(), "scut_data.zip")
 
# Download if asset ZIP does not exist.
if not os.path.exists(asset_zip_path):
    download_and_unzip(URL, asset_zip_path)

After extracting the model, the data set structure will look like this:

scut_data/
├── scut_train
├── scut_test
├── scut_train.txt
└── scut_test.txt

The data is extracted into the scut_data directory. It contains scut_train and scut_test subdirectories that hold training and validation images.

These two text files contain annotations in the following format:

006052.jpg ty Starts with Education
006053.jpg Cardi's
006054.jpg YOU THE BUSINESS SIDE OF GREEN
006055.jpg hat is
...

Each line contains an image file name, with the text in the image separated by spaces. The number of lines in the text file is the same as the number of samples in the images folder. The text in the image and the image file name are separated by the first space. Image filenames cannot contain any spaces, otherwise they will be considered part of the text.

4. Define model configuration

Before starting the training part, we define the training, dataset, and model configuration.

@dataclass(frozen=True)
class TrainingConfig:
    BATCH_SIZE: int = 48
    EPOCHS: int = 35
    LEARNING_RATE: float = 0.00005
 
@dataclass(frozen=True)
class DatasetConfig:
    DATA_ROOT: str = 'scut_data'
 
@dataclass(frozen=True)
class ModelConfig:
    MODEL_NAME: str = 'microsoft/trocr-small-printed'

The model will be trained for 35 epochs using a batch size of 48. The learning rate of the optimizer is set to 0.00005. A higher learning rate will destabilize the training process, resulting in higher losses from the beginning.

Additionally, we define the root dataset directory and the model we will use. The TrOCR Small Printed model will be fine-tuned as it demonstrated the best performance based on experiments on this dataset.

Detailed explanations of all models can be found in the TrOCR Inference blog post.

5. Visualize some samples

Let us visualize some images from the downloaded dataset and their filenames.

def visualize(dataset_path):
    plt.figure(figsize=(15, 3))
    for i in range(15):
        plt.subplot(3, 5, i + 1)
        all_images = os.listdir(f"{dataset_path}/scut_train")
        image = plt.imread(f"{dataset_path}/scut_train/{all_images[i]}")
        plt.imshow(image)
        plt.axis('off')
        plt.title(all_images[i].split('.')[0])
    plt.show()
 
 
visualize(DatasetConfig.DATA_ROOT)

6. Prepare data set

Tags exist in text file format. To prepare the data loaders more smoothly, they need to be modified into a simpler format. Let’s convert the training and test text files into Pandas DataFrame.

train_df = pd.read_fwf(
    os.path.join(DatasetConfig.DATA_ROOT, 'scut_train.txt'), header=None
)
train_df.rename(columns={0: 'file_name', 1: 'text'}, inplace=True)
test_df = pd.read_fwf(
    os.path.join(DatasetConfig.DATA_ROOT, 'scut_test.txt'), header=None
)
test_df.rename(columns={0: 'file_name', 1: 'text'}, inplace=True)

Now, the file_name column contains all the file names corresponding to the image, and the text column contains the text in the image.

Figure 4. CTW1500 dataset DataFrame with file names and labels

The next step is to define enhancements.

# Augmentations.
train_transforms = transforms.Compose([
    transforms.ColorJitter(brightness=.5, hue=.3),
    transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)),
])

We apply ColorJitter and GaussianBlur to the image. There is no need to apply any flip rotation to the image as there is already enough variability in the original dataset.

The best way to prepare a dataset is to write a custom dataset class. This gives us more control over the input. The following code block defines a CustomOCRDataset class to prepare the dataset.

class CustomOCRDataset(Dataset):
    def __init__(self, root_dir, df, processor, max_target_length=128):
        self.root_dir = root_dir
        self.df = df
        self.processor = processor
        self.max_target_length = max_target_length
 
 
    def __len__(self):
        returnlen(self.df)
 
 
    def __getitem__(self, idx):
        # The image file name.
        file_name = self.df['file_name'][idx]
        # The text (label).
        text = self.df['text'][idx]
        # Read the image, apply augmentations, and get the transformed pixels.
        image = Image.open(self.root_dir + file_name).convert('RGB')
        image = train_transforms(image)
        pixel_values = self.processor(image, return_tensors='pt').pixel_values
        # Pass the text through the tokenizer and get the labels,
        # i.e. tokenized labels.
        labels = self.processor.tokenizer(
            text,
            padding='max_length',
            max_length=self.max_target_length
        ).input_ids
        # We are using -100 as the padding token.
        labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]
        encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
        return encoding

The __init()__ method accepts the root directory path, DataFrame, TrOCR processor, and maximum label length as parameters.

The __getitem()__ method first reads the label and image from disk. It then passes the image through a transform to apply enhancement. TrOCRProcessor Returns normalized pixel values in PyTorch tensor format. Next, the text tags are passed through the tokenizer. If the label is shorter than 128 characters, it is padded with -100 to a length of 128. If longer than 128 characters, characters are truncated. Finally, it returns the pixel value and label in the form of a dictionary.

Before creating the training and validation sets, the TrOCRProcessor needs to be initialized so that it can be passed to the dataset class.

processor = TrOCRProcessor.from_pretrained(ModelConfig.MODEL_NAME)
train_dataset = CustomOCRDataset(
    root_dir=os.path.join(DatasetConfig.DATA_ROOT, 'scut_train/'),
    df=train_df,
    processor=processor
)
valid_dataset = CustomOCRDataset(
    root_dir=os.path.join(DatasetConfig.DATA_ROOT, 'scut_test/'),
    df=test_df,
    processor=processor
)

This concludes the dataset preparation process for fine-tuning the TrOCR model.

7. Prepare Trocr model

The VisionEncoderDecoderModel class gives us access to all TrOCR models. The from_pretrained() method accepts a repository name to load the pretrained model.

model = VisionEncoderDecoderModel.from_pretrained(ModelConfig.MODEL_NAME)
model.to(device)
print(model)
# Total parameters and trainable parameters.
total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,} total parameters.")
total_trainable_params = sum(
    p.numel() for p in model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} training parameters.")

The model contains 61.5 million parameters. All parameters will be fine-tuned in order to train them.

One of the most important aspects of model preparation is model configuration. These configurations are discussed below.

# Set special tokens used for creating the decoder_input_ids from the labels.
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# Set Correct vocab size.
model.config.vocab_size = model.config.decoder.vocab_size
model.config.eos_token_id = processor.tokenizer.sep_token_id
 
 
model.config.max_length = 64
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0

Pretrained TrOCR models come with their own set of predefined configurations. However, in order to fine-tune the model, we will cover some of these, including tag IDs, vocabulary sizes, and end-of-sequence tags.

Additionally, Early Stop is set to True. This ensures that if model metrics do not improve for several consecutive epochs, training will stop.

8. Optimizer and evaluation indicators

To optimize the model weights, we choose the AdamW optimizer with a weight decay of 0.0005.

optimizer = optim.AdamW(
    model.parameters(), lr=TrainingConfig.LEARNING_RATE, weight_decay=0.0005
)

The evaluation metric will be CER (Character Error Rate).

cer_metric = evaluate.load('cer')
 
 
def compute_cer(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions
 
 
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)
 
 
    cer = cer_metric.compute(predictions=pred_str, references=label_str)
 
 
    return {"cer": cer}

Without going into further detail, CER is basically the number of characters that the model did not predict correctly. The lower the CER, the better the model’s performance.

Note that we skip padding markers in the CER calculation because we do not want padding markers to affect the performance of the model.

9. Training and verification of TrOCR

Training parameters must be initialized before training begins.

training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy='epoch',
    per_device_train_batch_size=TrainingConfig.BATCH_SIZE,
    per_device_eval_batch_size=TrainingConfig.BATCH_SIZE,
    fp16=True,
    output_dir='seq2seq_model_printed/',
    logging_strategy='epoch',
    save_strategy='epoch',
    save_total_limit=5,
    report_to='tensorboard',
    num_train_epochs=TrainingConfig.EPOCHS
)

FP16 training is being used as it uses less GPU memory and also allows us to use higher batch sizes. Additionally, logging and model preservation strategies are epoch-based. All reports will be logged to tensorboard.

These training parameters will be passed to the trainer API along with other required parameters.

# Initialize trainer.
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=processor.feature_extractor,
    args=training_args,
    compute_metrics=compute_cer,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    data_collator=default_data_collator
)

The training process can be started by calling the train() method of the trainer object.

res = trainer.train()

The output is as follows:

Epoch Training Loss Validation Loss Cer
1 3.822000 2.677871 0.687739
2 2.497100 2.474666 0.690800
3 2.180700 2.336284 0.627641
.
.
.
33 0.146800 2.130022 0.504209
34 0.145800 2.167060 0.511095
35 0.138300 2.120335 0.494496

At the end of training, the model achieved a CER of 49%, which is a very good result considering the small TrOCR model used.

Below is the CER plot from the Tensorboard log.

Figure 5. CER after training the TrOCR model

Until the end of training, the curve shows a downward trend. Although longer training may yield better results, we will continue to use the existing model.

10. Use fine-tuned TrOCR model for inference

After training the TrOCR model, it is time to perform inference on the validation images.

The first step is to load the trained model from the last saved checkpoint.

processor = TrOCRProcessor.from_pretrained(ModelConfig.MODEL_NAME)
trained_model = VisionEncoderDecoderModel.from_pretrained('seq2seq_model_printed/checkpoint-' + str(res.global_step)).to(device)

The res object contains a global_step attribute that holds the total number of steps in model training. The code block above uses this property to load the final epoch weights.

Next are some helper functions. The first one is to read the image.

def read_and_show(image_path):
    """
    :param image_path: String, path to the input image.
 
 
    Returns:
        image: PIL Image.
    """
    image = Image.open(image_path).convert('RGB')
    return image

The next helper function performs the forward pass of the image through the model.

def ocr(image, processor, model):
    """
    :param image: PIL Image.
    :param processor: Huggingface OCR processor.
    :param model: Huggingface OCR model.
 
 
    Returns:
        generated_text: the OCR'd text string.
    """
    # We can directly perform OCR on cropped images.
    pixel_values = processor(image, return_tensors='pt').pixel_values.to(device)
    generated_ids = model.generate(pixel_values)
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return generated_text

The final helper function loops through all images in the directory and continues calling the ocr() function for inference.

def eval_new_data(
    data_path=os.path.join(DatasetConfig.DATA_ROOT, 'scut_test', '*'),
    num_samples=50
):
    image_paths = glob.glob(data_path)
    for i, image_path in tqdm(enumerate(image_paths), total=len(image_paths)):
        if i == num_samples:
            break
        image = read_and_show(image_path)
        text = ocr(image, processor, trained_model)
        plt.figure(figsize=(7, 4))
        plt.imshow(image)
        plt.title(text)
        plt.axis('off')
        plt.show()
 
eval_new_data(
    data_path=os.path.join(DatasetConfig.DATA_ROOT, 'scut_test', '*'),
    num_samples=100
)

We are doing inference on 100 samples (num_samples=100).

Below are two results of model OCR errors before training.

Figure 7 TrOCR is able to predict curved text in images

The results are impressive. After fine-tuning the TrOCR model, it is able to correctly predict text in curved and vertical images.

Here are more results where the model performed well.

Figure 8. Inference results for stretched text

In this case, even though the end-most texts are stretched, the model still predicts them correctly.

Figure 9. TrOCR inference results for fuzzy text

In the above three cases, the model can predict the text correctly even if the text is blurry.

11. Conclusion

In this paper, we fine-tune the TrOCR model on the curved text recognition dataset. We start with a discussion of the dataset. Next comes the dataset preparation and training of the TrOCR model. After training, we conducted inference experiments and analyzed the results. Our results show that fine-tuning the TrOCR model can lead to better performance even on blurry or curved text images.

OCR is not just about recognizing text in a scene, it is also about building applications using OCR, such as a captcha recognizer or combining a TrOCR recognizer with a license plate detection pipeline.

Original link: TrOCR model fine-tuning – BimAnt

The knowledge points of the article match the official knowledge files, and you can further learn relevant knowledge. Python entry skill treeHomepageOverview 389131 people are learning the system