!pip install -Uqq accelerate
!pip install -Uqq transformers
Why and How to Fine-tune CLIP for remote sensing images
Traditional Image Classification models sometimes struggle with generalization in real-life situations. Moreover, labeling is a significant challenge when training these models, making it difficult for them to generalize across all cases. Enter CLIP (Text-Image pairing model), which benefits from recent developments in NLP (using Transformers) and the billions of image captions available on the Internet.
CLIP demonstrates less accuracy degradation in real-world scenarios compared to previous methods and introduces various exciting applications, such as searching for images using text, zero-shot learning classification, and more.
However, like many preceding deep learning models, even when CLIP is trained on an enormous dataset, it can encounter difficulties if there’s a mismatch between the domain of the inference data and the training data.
In this blog post, you’ll discover that CLIP, by default, doesn’t perform very well on remote-sensing datasets (images captured by satellites) and how we fine-tune the CLIP model using this new dataset.
This blog-post inspired a lot from the example of contrastive-image-text from HuggingFace and this blog-post finetuning with Remote Sensing Images
Dataset used: rsicd
import transformers
transformers.__version__
'4.33.2'
How CLIP works
In a nutshell, the CLIP model leverages two pretrained models for text and image. It fine-tunes them in such a way that their embedding outputs for similar concepts become as close as possible
Imports
import os
import datasets
from dataclasses import dataclass, field
from typing import Optional
import matplotlib.pyplot as plt
import requests
import random
import numpy as np
import torch
from datasets import load_dataset
from PIL import Image
from torchvision.io import ImageReadMode, read_image
from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
from torchvision.transforms.functional import InterpolationMode
from pdb import set_trace
import transformers
from transformers import (
VisionTextDualEncoderProcessor,
VisionTextDualEncoderModel,
AutoImageProcessor,
AutoModel,
AutoTokenizer,
HfArgumentParser,
Trainer,
TrainingArguments,
set_seed,
)from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
import warnings
'ignore', category=UserWarning, module='torchvision')
warnings.filterwarnings(
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
"4.31.0.dev0")
check_min_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt") require_version(
Base Model
As mentioned above, the Image Encoder for our CLIP is clip-vit-base-patch32 and Text Encoder is roberta-base
= VisionTextDualEncoderModel.from_vision_text_pretrained(
model "openai/clip-vit-base-patch32", "roberta-base"
)
= AutoTokenizer.from_pretrained("roberta-base")
tokenizer = AutoImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
image_processor = VisionTextDualEncoderProcessor(image_processor, tokenizer)
processor
"clip-roberta")
model.save_pretrained("clip-roberta") processor.save_pretrained(
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
The projection layer and logit scale weights `['visual_projection.weight', 'text_projection.weight', 'logit_scale']` are newly initialized. You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.
Arguments
We define two argument classes: ModelArguments
and HfArgumentParser
. This allows us to utilize Hugging Face’s HfArgumentParser
in conjunction with the default TrainingArguments
from Hugging Face.
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
"""
str = field(
model_name_or_path: ={"help": "Path to pretrained model or model identifier from huggingface.co/models"},
metadata
)str] = field(
config_name: Optional[=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
default
)str] = field(
tokenizer_name: Optional[=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
default
)str = field(default=None, metadata={"help": "Name or path of preprocessor config."})
image_processor_name: str] = field(
cache_dir: Optional[=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
default
)str = field(
model_revision: ="main",
default={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
metadata
)bool = field(
use_fast_tokenizer: =True,
default={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
metadata
)bool = field(
use_auth_token: =False,
default={
metadata"help": (
"Will use the token generated when running `huggingface-cli login` (necessary to use this script "
"with private models)."
)
}, )
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
str] = field(
dataset_name: Optional[=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
default
)str] = field(default=None, metadata={"help": "The data directory containing input files."})
data_dir: Optional[str] = field(
image_column: Optional[="image_path",
default={"help": "The name of the column in the datasets containing the full image file paths."},
metadata
)str] = field(
caption_column: Optional[="caption",
default={"help": "The name of the column in the datasets containing the image captions."},
metadata
)int] = field(
max_seq_length: Optional[=128,
default={
metadata"help": (
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
)
},
)bool = field(
overwrite_cache: =False, metadata={"help": "Overwrite the cached training and evaluation sets"}
default
)int] = field(
preprocessing_num_workers: Optional[=None,
default={"help": "The number of processes to use for the preprocessing."},
metadata )
= {'output_dir': './clip-roberta-finetuned',
args_dict 'model_name_or_path': './clip-roberta',
'data_dir': './data',
'dataset_name': 'arampacha/rsicd',
'image_column': 'image',
'caption_column': 'captions',
'remove_unused_columns': False,
'per_device_train_batch_size': 64,
'per_device_eval_batch_size': 64,
'learning_rate': 5e-05,
'warmup_steps': 0,
'weight_decay': 0.1,
'overwrite_output_dir': True,
'push_to_hub': False}
= HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
parser = parser.parse_dict(args_dict) model_args, data_args, training_args
model_args, data_args
(ModelArguments(model_name_or_path='./clip-roberta', config_name=None, tokenizer_name=None, image_processor_name=None, cache_dir=None, model_revision='main', use_fast_tokenizer=True, use_auth_token=False),
DataTrainingArguments(dataset_name='arampacha/rsicd', data_dir='./data', image_column='image', caption_column='captions', max_seq_length=128, overwrite_cache=False, preprocessing_num_workers=None))
Dataset Preparation
class Transform(torch.nn.Module):
def __init__(self, image_size, mean, std):
super().__init__()
self.transforms = torch.nn.Sequential(
=InterpolationMode.BICUBIC, antialias=True),
Resize([image_size], interpolation
CenterCrop(image_size),float),
ConvertImageDtype(torch.
Normalize(mean, std),
)def forward(self, x) -> torch.Tensor:
"""`x` should be an instance of `PIL.Image.Image`"""
with torch.no_grad():
= self.transforms(x)
x return x
def collate_fn(examples):
= torch.stack([example["pixel_values"] for example in examples])
pixel_values = torch.tensor([example["input_ids"] for example in examples], dtype=torch.long)
input_ids = torch.tensor([example["attention_mask"] for example in examples], dtype=torch.long)
attention_mask return {
"pixel_values": pixel_values,
"input_ids": input_ids,
"attention_mask": attention_mask,
"return_loss": True,
}
Below is the remote sensing dataset that we use in this blog post
= datasets.load_dataset("arampacha/rsicd") dataset
Found cached dataset parquet (/home/.cache/huggingface/datasets/arampacha___parquet/arampacha--rsicd-56e24d6cc63cb9d9/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
dataset
DatasetDict({
train: Dataset({
features: ['filename', 'captions', 'image'],
num_rows: 8734
})
test: Dataset({
features: ['filename', 'captions', 'image'],
num_rows: 1093
})
valid: Dataset({
features: ['filename', 'captions', 'image'],
num_rows: 1094
})
})
Let’s see examples of this dataset
def show_images(dset, num_images=8, without_caption=True,num_columns=2,img_size=(4, 4)):
= -(-num_images // num_columns) # Ceiling division
num_rows = plt.figure(figsize=(img_size[0] * num_columns, img_size[1] * num_rows))
fig
= list(range(len(dset)))
_list for i in range(num_images):
= _list[i]
index = fig.add_subplot(num_rows, num_columns, i+1)
ax = dset[index]['image']
image
plt.imshow(image)
# Set title as the first caption
if without_caption:
= dset[index]['captions'][0]
caption =10)
ax.set_title(caption, fontsize
# Remove axis
'off')
plt.axis(
plt.tight_layout()=0.5, hspace=0.01) # Adjust these values as needed
plt.subplots_adjust(wspace
plt.show()
'train'], num_images=8, without_caption=True) show_images(dataset[
Model Preparation
= AutoTokenizer.from_pretrained(
tokenizer =model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
model_args.model_name_or_path, cache_dir )
= AutoImageProcessor.from_pretrained(
image_processor or model_args.model_name_or_path,
model_args.image_processor_name =model_args.cache_dir,
cache_dir=model_args.model_revision,
revision=True if model_args.use_auth_token else None,
use_auth_token
)
= AutoModel.from_pretrained(
model
model_args.model_name_or_path,=model_args.cache_dir,
cache_dir=model_args.model_revision,
revision=True if model_args.use_auth_token else None,
use_auth_token
)= model.config config
To ensure reproducible output, we should set the seed.
set_seed(training_args.seed)
= Transform(
image_transformations
config.vision_config.image_size, image_processor.image_mean, image_processor.image_std
)= torch.jit.script(image_transformations) image_transformations
def tokenize_captions(examples):
= [example[0] for example in examples[data_args.caption_column]]
captions = tokenizer(captions, max_length=data_args.max_seq_length, padding="max_length", truncation=True)
text_inputs "input_ids"] = text_inputs.input_ids
examples["attention_mask"] = text_inputs.attention_mask
examples[return examples
def transform_images(examples):
= [torch.tensor(np.array(image)).permute(2, 0, 1) for image in examples[data_args.image_column]]
images "pixel_values"] = [image_transformations(image) for image in images]
examples[return examples
def filter_corrupt_images(examples):
"""remove problematic images"""
= []
valid_images for image_file in examples[data_args.image_column]:
try:
open(image_file)
Image.True)
valid_images.append(except Exception:
False)
valid_images.append(return valid_images
= dataset["train"]
train_dataset = train_dataset.map(
train_dataset =tokenize_captions,
function=True,
batched=data_args.preprocessing_num_workers,
num_proc=not data_args.overwrite_cache,
load_from_cache_file="Running tokenizer on train dataset",
desc
) train_dataset.set_transform(transform_images)
Loading cached processed dataset at /home/.cache/huggingface/datasets/arampacha___parquet/arampacha--rsicd-56e24d6cc63cb9d9/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-01af25c2d3c15faa.arrow
Parameter 'transform'=<function transform_images> of the transform datasets.arrow_dataset.Dataset.set_format couldn't be hashed properly, a random hash was used instead. Make sure your transforms and parameters are serializable with pickle or dill for the dataset fingerprinting and caching to work. If you reuse this transform, the caching mechanism will consider it to be different from the previous calls and recompute everything. This warning is only showed once. Subsequent hashing failures won't be showed.
train_dataset
Dataset({
features: ['filename', 'captions', 'image', 'input_ids', 'attention_mask'],
num_rows: 8734
})
= dataset["valid"]
eval_dataset = eval_dataset.map(
eval_dataset =tokenize_captions,
function=True,
batched=data_args.preprocessing_num_workers,
num_proc=not data_args.overwrite_cache,
load_from_cache_file="Running tokenizer on validation dataset",
desc
) eval_dataset.set_transform(transform_images)
train_dataset, eval_dataset
(Dataset({
features: ['filename', 'captions', 'image', 'input_ids', 'attention_mask'],
num_rows: 8734
}),
Dataset({
features: ['filename', 'captions', 'image', 'input_ids', 'attention_mask'],
num_rows: 1094
}))
= VisionTextDualEncoderProcessor(image_processor, tokenizer) processor
We have a straightforward example below (sourced from Hugging Face) to quickly demonstrate how the CLIP model works by default. There are two images: the first one is of a cat and the second is of a dog. We will use the text “a photo of a cat” and determine which picture has the highest probability.
= [
urls "http://images.cocodataset.org/val2017/000000039769.jpg",
"https://farm3.staticflickr.com/2674/5850229113_4fe05d5265_z.jpg",
]= [Image.open(requests.get(url, stream=True).raw) for url in urls]
images = processor(
inputs =["a photo of a cat"], images=images, return_tensors="pt", padding=True
text
)'input_ids'] = inputs['input_ids'].cuda()
inputs['attention_mask'] = inputs['attention_mask'].cuda()
inputs['pixel_values'] = inputs['pixel_values'].cuda()
inputs[= model.cuda()
model = model(**inputs)
outputs = outputs.logits_per_image logits_per_image
logits_per_image
tensor([[-0.9982],
[-0.5772]], device='cuda:0', grad_fn=<PermuteBackward0>)
As you can see, the first picture is more likely to be of a cat (and that’s correct).
0] images[
1] images[
Finetuning CLIP
Take a look at how the default model performs with these remote-sensing images, which aren’t predominant in the training set. From a randomly selected set of 8 images, identify the first 3 images that correspond to the prompt “green trees.”
0)
np.random.seed(= np.random.choice(len(dataset['valid']), 8, replace=False)
indices = dataset['valid'].select(indices.tolist()) patches
8, without_caption=False, num_columns=4,img_size=(3, 3)) show_images(patches,
def show_result(model, patches, text, top_n = 3):
= [patch['image'] for patch in patches]
images = processor(text=[text], images=images, return_tensors="pt", padding=True)
inputs 'input_ids'] = inputs['input_ids'].cuda()
inputs['attention_mask'] = inputs['attention_mask'].cuda()
inputs['pixel_values'] = inputs['pixel_values'].cuda()
inputs[
= model.cuda()
model = model(**inputs)
outputs = outputs.logits_per_image
logits_per_image = (torch.sort(logits_per_image, dim=0, descending=True)[1][:,0]).tolist()
sorted_idx = sorted_idx[:top_n]
sorted_idx = patches.select(sorted_idx)
patches_sorted =len(patches_sorted), without_caption=False, num_columns=1, img_size=(3,3)) show_images(patches_sorted, num_images
'green trees') show_result(model, patches,
Without fine-tuning, the performance isn’t optimal. As you can see, the first 3 images don’t showcase many trees.
# 8. Initalize our trainer
= Trainer(
trainer =model,
model=training_args,
args=train_dataset,
train_dataset=eval_dataset,
eval_dataset=collate_fn,
data_collator
)
# 9. Training
= trainer.train()
train_result "train", train_result.metrics)
trainer.log_metrics(= trainer.evaluate()
metrics "eval", metrics) trainer.log_metrics(
/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501: UserWarning: operator() profile_node %380 : int = prim::profile_ivalue(%out_dtype.1)
does not have profile information (Triggered internally at /opt/conda/conda-bld/pytorch_1678411187366/work/third_party/nvfuser/csrc/graph_fuser.cpp:104.)
return forward_call(*args, **kwargs)
Step | Training Loss |
---|
***** train metrics *****
epoch = 3.0
total_flos = 3258157GF
train_loss = 1.7008
train_runtime = 0:06:44.15
train_samples_per_second = 64.832
train_steps_per_second = 1.017
***** eval metrics *****
epoch = 3.0
eval_loss = 3.8048
eval_runtime = 0:00:07.26
eval_samples_per_second = 150.574
eval_steps_per_second = 2.477
'green trees') show_result(model, patches,
After finetuning the result is much better!! There are trees in all 3 images