README.md
CHANGED
@@ -9,4 +9,27 @@ app_file: app.py
|
|
9 |
pinned: false
|
10 |
---
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
+
# Fine-tuned BLIP2 Image Captioning
|
13 |
+
|
14 |
+
This Hugging Face Space hosts a BLIP2 model that has been fine-tuned on the Flickr8k dataset using Low-Rank Adaptation (LoRA).
|
15 |
+
|
16 |
+
## Model Details
|
17 |
+
|
18 |
+
- Base model: `ybelkada/blip2-opt-2.7b-fp16-sharded`
|
19 |
+
- Fine-tuning technique: LoRA (Low-Rank Adaptation)
|
20 |
+
- Training dataset: Flickr8k
|
21 |
+
- LoRA configuration:
|
22 |
+
- Rank (r): 16
|
23 |
+
- Alpha: 32
|
24 |
+
- Dropout: 0.05
|
25 |
+
- Target modules: q_proj, k_proj
|
26 |
+
|
27 |
+
## Usage
|
28 |
+
|
29 |
+
Upload an image to generate a caption. The model will process the image and return a descriptive caption based on its fine-tuned knowledge.
|
30 |
+
|
31 |
+
## Notes
|
32 |
+
|
33 |
+
The model uses 8-bit quantization to reduce memory usage while maintaining performance.
|
34 |
+
|
35 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
CHANGED
@@ -1,19 +1,45 @@
|
|
1 |
import gradio as gr
|
2 |
from PIL import Image
|
3 |
import torch
|
4 |
-
from transformers import
|
|
|
5 |
|
6 |
-
#
|
7 |
-
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
11 |
-
model = model.to(device)
|
12 |
|
13 |
-
# Define the function to generate caption
|
14 |
def generate_caption(image):
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=25)
|
|
|
|
|
17 |
caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
18 |
return caption
|
19 |
|
@@ -22,8 +48,8 @@ iface = gr.Interface(
|
|
22 |
fn=generate_caption,
|
23 |
inputs=gr.Image(type="pil"),
|
24 |
outputs="text",
|
25 |
-
title="Image Caption Generator",
|
26 |
-
description="Upload an image to generate a caption."
|
27 |
)
|
28 |
|
29 |
# Launch
|
|
|
1 |
import gradio as gr
|
2 |
from PIL import Image
|
3 |
import torch
|
4 |
+
from transformers import AutoProcessor, Blip2ForConditionalGeneration
|
5 |
+
from peft import PeftModel, LoraConfig
|
6 |
|
7 |
+
# LoRA configuration used during training:
|
8 |
+
# config = LoraConfig(
|
9 |
+
# r=16,
|
10 |
+
# lora_alpha=32,
|
11 |
+
# lora_dropout=0.05,
|
12 |
+
# bias="none",
|
13 |
+
# target_modules=["q_proj", "k_proj"]
|
14 |
+
# )
|
15 |
+
|
16 |
+
# Load base model with the same configuration as in training
|
17 |
+
base_model = Blip2ForConditionalGeneration.from_pretrained(
|
18 |
+
"ybelkada/blip2-opt-2.7b-fp16-sharded",
|
19 |
+
device_map="auto",
|
20 |
+
load_in_8bit=True
|
21 |
+
)
|
22 |
+
|
23 |
+
# Load the fine-tuned LoRA weights
|
24 |
+
model = PeftModel.from_pretrained(base_model, "./model")
|
25 |
+
|
26 |
+
# Load processor - use the same one as training
|
27 |
+
processor = AutoProcessor.from_pretrained("./processor")
|
28 |
|
29 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
30 |
|
31 |
+
# Define the function to generate caption - exactly as in colab
|
32 |
def generate_caption(image):
|
33 |
+
# Convert image to RGB if needed
|
34 |
+
image = image.convert('RGB') if image.mode != 'RGB' else image
|
35 |
+
|
36 |
+
# Process the image exactly as in colab.py
|
37 |
+
inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
|
38 |
+
|
39 |
+
# Generate caption with the same parameters
|
40 |
generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=25)
|
41 |
+
|
42 |
+
# Decode the caption
|
43 |
caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
44 |
return caption
|
45 |
|
|
|
48 |
fn=generate_caption,
|
49 |
inputs=gr.Image(type="pil"),
|
50 |
outputs="text",
|
51 |
+
title="Fine-tuned BLIP2 Image Caption Generator",
|
52 |
+
description="Upload an image to generate a caption using BLIP2 fine-tuned on Flickr8k with LoRA (r=16, alpha=32)."
|
53 |
)
|
54 |
|
55 |
# Launch
|
colab.py
ADDED
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Create image # Install libraries
|
2 |
+
# !pip install -q git+https://github.com/huggingface/peft.git transformers bitsandbytes datasets
|
3 |
+
|
4 |
+
# # Fix fsspec version mismatch (if needed)
|
5 |
+
# !pip install fsspec==2025.3.0
|
6 |
+
|
7 |
+
# from google.colab import files
|
8 |
+
# files.upload() # Upload kaggle.json here
|
9 |
+
# # Move kaggle.json to the right location
|
10 |
+
# !mkdir -p ~/.kaggle
|
11 |
+
# !cp kaggle.json ~/.kaggle/
|
12 |
+
# !chmod 600 ~/.kaggle/kaggle.json
|
13 |
+
|
14 |
+
|
15 |
+
# # Download the dataset
|
16 |
+
# !kaggle datasets download -d adityajn105/flickr8k --force
|
17 |
+
|
18 |
+
# # Unzip it
|
19 |
+
# !unzip -q flickr8k.zip -d flickr8k
|
20 |
+
|
21 |
+
# DATASET_PATH = '/content/flickr8k'
|
22 |
+
# CAPTIONS_FILE = os.path.join(DATASET_PATH, 'captions.txt')
|
23 |
+
# IMAGES_PATH = os.path.join(DATASET_PATH, 'Images/')
|
24 |
+
# import os
|
25 |
+
# import pandas as pd
|
26 |
+
# from PIL import Image
|
27 |
+
# from torch.utils.data import Dataset, DataLoader
|
28 |
+
# import torch
|
29 |
+
# # Load and process captions
|
30 |
+
# df = pd.read_csv(CAPTIONS_FILE, sep=',', names=["image", "caption"])
|
31 |
+
|
32 |
+
# df["caption"] = df["caption"][1:]
|
33 |
+
# df["caption"]
|
34 |
+
# df = df.dropna().reset_index(drop=True)
|
35 |
+
# df
|
36 |
+
# df = df[:8000]
|
37 |
+
|
38 |
+
# from transformers import AutoProcessor
|
39 |
+
# from PIL import Image
|
40 |
+
|
41 |
+
# processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
|
42 |
+
|
43 |
+
# class Flickr8kDataset(Dataset):
|
44 |
+
# def __init__(self, dataframe, image_dir, processor):
|
45 |
+
# self.dataframe = dataframe
|
46 |
+
# self.image_dir = image_dir
|
47 |
+
# self.processor = processor
|
48 |
+
|
49 |
+
# def __len__(self):
|
50 |
+
# return len(self.dataframe)
|
51 |
+
|
52 |
+
# def __getitem__(self, idx):
|
53 |
+
# row = self.dataframe.iloc[idx]
|
54 |
+
# image_path = os.path.join(self.image_dir, row["image"])
|
55 |
+
# caption = row["caption"]
|
56 |
+
|
57 |
+
# # Load image
|
58 |
+
# image = Image.open(image_path).convert('RGB')
|
59 |
+
|
60 |
+
# # Process image
|
61 |
+
# encoding = self.processor(images=image, return_tensors="pt")
|
62 |
+
# encoding = {k: v.squeeze() for k, v in encoding.items()}
|
63 |
+
# encoding["text"] = caption
|
64 |
+
|
65 |
+
# return encoding
|
66 |
+
|
67 |
+
# def collate_fn(batch):
|
68 |
+
# processed_batch = {}
|
69 |
+
# for key in batch[0].keys():
|
70 |
+
# if key != "text":
|
71 |
+
# processed_batch[key] = torch.stack([example[key] for example in batch])
|
72 |
+
# else:
|
73 |
+
# text_inputs = processor.tokenizer(
|
74 |
+
# [example["text"] for example in batch], padding=True, return_tensors="pt"
|
75 |
+
# )
|
76 |
+
# processed_batch["input_ids"] = text_inputs["input_ids"]
|
77 |
+
# processed_batch["attention_mask"] = text_inputs["attention_mask"]
|
78 |
+
# return processed_batch
|
79 |
+
|
80 |
+
# from transformers import Blip2ForConditionalGeneration
|
81 |
+
# from peft import LoraConfig, get_peft_model
|
82 |
+
|
83 |
+
# model = Blip2ForConditionalGeneration.from_pretrained(
|
84 |
+
# "ybelkada/blip2-opt-2.7b-fp16-sharded",
|
85 |
+
# device_map="auto",
|
86 |
+
# load_in_8bit=True
|
87 |
+
# )
|
88 |
+
|
89 |
+
# # Apply LoRA
|
90 |
+
# config = LoraConfig(
|
91 |
+
# r=16,
|
92 |
+
# lora_alpha=32,
|
93 |
+
# lora_dropout=0.05,
|
94 |
+
# bias="none",
|
95 |
+
# target_modules=["q_proj", "k_proj"]
|
96 |
+
# )
|
97 |
+
# model = get_peft_model(model, config)
|
98 |
+
# model.print_trainable_parameters()
|
99 |
+
|
100 |
+
# # Load dataset
|
101 |
+
# train_dataset = Flickr8kDataset(df, IMAGES_PATH, processor)
|
102 |
+
# train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=3, collate_fn=collate_fn)
|
103 |
+
|
104 |
+
# # Set up optimizer
|
105 |
+
# optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
|
106 |
+
# device = "cuda" if torch.cuda.is_available() else "cpu"
|
107 |
+
# model.train()
|
108 |
+
|
109 |
+
# # Training loop
|
110 |
+
# for epoch in range(1): # Use small epochs for testing, increase later
|
111 |
+
# print(f"Epoch: {epoch}")
|
112 |
+
# for idx, batch in enumerate(train_dataloader):
|
113 |
+
# input_ids = batch.pop("input_ids").to(device)
|
114 |
+
# pixel_values = batch.pop("pixel_values").to(device, torch.float16)
|
115 |
+
|
116 |
+
# outputs = model(input_ids=input_ids, pixel_values=pixel_values, labels=input_ids)
|
117 |
+
# loss = outputs.loss
|
118 |
+
|
119 |
+
# print(f"Batch {idx} Loss: {loss.item():.4f}")
|
120 |
+
|
121 |
+
# loss.backward()
|
122 |
+
# optimizer.step()
|
123 |
+
# optimizer.zero_grad()
|
124 |
+
|
125 |
+
# # Example prediction
|
126 |
+
# sample_image = Image.open(os.path.join(IMAGES_PATH, df.iloc[0]["image"])).convert('RGB')
|
127 |
+
# inputs = processor(images=sample_image, return_tensors="pt").to(device, torch.float16)
|
128 |
+
# generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=25)
|
129 |
+
# caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
130 |
+
# print("Generated caption:", caption)
|
131 |
+
|
132 |
+
# import matplotlib.pyplot as plt
|
133 |
+
|
134 |
+
# # Show the sample image with the generated caption
|
135 |
+
# plt.figure(figsize=(6,6))
|
136 |
+
# plt.imshow(sample_image)
|
137 |
+
# plt.axis("off")
|
138 |
+
# plt.title(f"Generated caption:\n{caption}", fontsize=12)
|
139 |
+
# plt.show()
|
140 |
+
|
141 |
+
|
142 |
+
|
143 |
+
|
144 |
+
# # Load a sample image
|
145 |
+
# sample_image = Image.open(os.path.join(IMAGES_PATH, df.iloc[15]["image"])).convert('RGB')
|
146 |
+
|
147 |
+
# # Prepare inputs
|
148 |
+
# inputs = processor(images=sample_image, return_tensors="pt").to(device, torch.float16)
|
149 |
+
# generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=25)
|
150 |
+
|
151 |
+
# # Decode caption
|
152 |
+
# caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
153 |
+
# print("Generated caption:", caption)
|
154 |
+
|
155 |
+
# # Show image with caption
|
156 |
+
# import matplotlib.pyplot as plt
|
157 |
+
|
158 |
+
# plt.figure(figsize=(6,6))
|
159 |
+
# plt.imshow(sample_image)
|
160 |
+
# plt.axis("off")
|
161 |
+
# plt.title(f"Generated caption:\n{caption}", fontsize=12)
|
162 |
+
# plt.show()
|
163 |
+
|
164 |
+
# from PIL import Image
|
165 |
+
# import matplotlib.pyplot as plt
|
166 |
+
# import torch
|
167 |
+
# import io
|
168 |
+
# from google.colab import files # Only for Colab
|
169 |
+
|
170 |
+
# # Upload image
|
171 |
+
# uploaded = files.upload()
|
172 |
+
|
173 |
+
# # Get the uploaded file
|
174 |
+
# for filename in uploaded.keys():
|
175 |
+
# image_path = filename
|
176 |
+
|
177 |
+
# # Load the image
|
178 |
+
# sample_image = Image.open(image_path).convert('RGB')
|
179 |
+
|
180 |
+
# # Prepare inputs
|
181 |
+
# inputs = processor(images=sample_image, return_tensors="pt").to(device, torch.float16)
|
182 |
+
|
183 |
+
# # Generate caption
|
184 |
+
# generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=25)
|
185 |
+
|
186 |
+
# # Decode caption
|
187 |
+
# caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
188 |
+
# print("Generated caption:", caption)
|
189 |
+
|
190 |
+
# # Show image with caption
|
191 |
+
# plt.figure(figsize=(6,6))
|
192 |
+
# plt.imshow(sample_image)
|
193 |
+
# plt.axis("off")
|
194 |
+
# plt.title(f"Generated caption:\n{caption}", fontsize=12)
|
195 |
+
# plt.show()
|
196 |
+
|
197 |
+
# !pip install evaluate pycocoevalcap --quiet
|
198 |
+
|
199 |
+
# import evaluate
|
200 |
+
# from tqdm import tqdm
|
201 |
+
# from PIL import Image
|
202 |
+
# import torch
|
203 |
+
# import os
|
204 |
+
|
205 |
+
# # Load metrics
|
206 |
+
# bleu = evaluate.load("bleu")
|
207 |
+
|
208 |
+
|
209 |
+
|
210 |
+
# df = pd.read_csv(CAPTIONS_FILE, sep=',', names=["image", "caption"])
|
211 |
+
|
212 |
+
# # Subset of data
|
213 |
+
# subset_df = df[8001:8092].reset_index(drop=True)
|
214 |
+
|
215 |
+
# # Prepare references and predictions
|
216 |
+
# references = {}
|
217 |
+
# predictions = []
|
218 |
+
|
219 |
+
# for idx in tqdm(range(len(subset_df))):
|
220 |
+
# image_name = subset_df.iloc[idx]['image']
|
221 |
+
|
222 |
+
# # Load image
|
223 |
+
# image_path = os.path.join(IMAGES_PATH, image_name)
|
224 |
+
# image = Image.open(image_path).convert('RGB')
|
225 |
+
|
226 |
+
# # Generate caption
|
227 |
+
# inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
|
228 |
+
# generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=25)
|
229 |
+
# predicted_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
230 |
+
|
231 |
+
# # Save prediction
|
232 |
+
# predictions.append(predicted_caption)
|
233 |
+
|
234 |
+
# # Prepare ground-truth references
|
235 |
+
# if image_name not in references:
|
236 |
+
# gt = df[df['image'] == image_name]['caption'].tolist()
|
237 |
+
# references[image_name] = gt
|
238 |
+
|
239 |
+
# # Build reference and prediction lists for scoring
|
240 |
+
# gt_list = [references[name] for name in subset_df['image']]
|
241 |
+
# pred_list = predictions
|
242 |
+
|
243 |
+
# import evaluate
|
244 |
+
# from tqdm import tqdm
|
245 |
+
# from PIL import Image
|
246 |
+
# from pycocoevalcap.cider.cider import Cider
|
247 |
+
# from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
|
248 |
+
|
249 |
+
# bleu = evaluate.load("bleu")
|
250 |
+
# df = pd.read_csv(CAPTIONS_FILE, sep=',', names=["image", "caption"])
|
251 |
+
|
252 |
+
# # Get subset
|
253 |
+
# subset_df = df[8001:8092].reset_index(drop=True)
|
254 |
+
|
255 |
+
# # Collect predictions and references
|
256 |
+
# predictions = []
|
257 |
+
# references = {}
|
258 |
+
|
259 |
+
# for idx in tqdm(range(len(subset_df))):
|
260 |
+
# row = subset_df.iloc[idx]
|
261 |
+
# image_name = row['image']
|
262 |
+
|
263 |
+
# image_path = os.path.join(IMAGES_PATH, image_name)
|
264 |
+
# image = Image.open(image_path).convert('RGB')
|
265 |
+
|
266 |
+
# inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
|
267 |
+
# generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=25)
|
268 |
+
# caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
269 |
+
|
270 |
+
# predictions.append(caption)
|
271 |
+
|
272 |
+
# if image_name not in references:
|
273 |
+
# refs = df[df['image'] == image_name]['caption'].tolist()
|
274 |
+
# references[image_name] = refs
|
275 |
+
|
276 |
+
# # Prepare for BLEU
|
277 |
+
# gt_list = [references[name] for name in subset_df["image"]]
|
278 |
+
# pred_list = predictions
|
279 |
+
|
280 |
+
# bleu_score = bleu.compute(predictions=pred_list, references=gt_list)
|
281 |
+
# print("BLEU:", bleu_score)
|
282 |
+
|
283 |
+
# # Prepare COCO-style input
|
284 |
+
# gts = {}
|
285 |
+
# res = {}
|
286 |
+
|
287 |
+
# for i, img in enumerate(subset_df["image"]):
|
288 |
+
# gts[str(i)] = [{"caption": cap} for cap in references[img]]
|
289 |
+
# res[str(i)] = [{"caption": predictions[i]}]
|
290 |
+
|
291 |
+
# # Tokenize
|
292 |
+
# tokenizer = PTBTokenizer()
|
293 |
+
# gts_tokenized = tokenizer.tokenize(gts)
|
294 |
+
# res_tokenized = tokenizer.tokenize(res)
|
295 |
+
|
296 |
+
# # Compute CIDEr
|
297 |
+
# cider_scorer = Cider()
|
298 |
+
# cider_score, _ = cider_scorer.compute_score(gts_tokenized, res_tokenized)
|
299 |
+
# print("CIDEr:", cider_score)
|
requirements.txt
CHANGED
@@ -1,4 +1,7 @@
|
|
1 |
torch
|
2 |
-
transformers
|
3 |
gradio
|
4 |
-
Pillow
|
|
|
|
|
|
|
|
1 |
torch
|
2 |
+
transformers>=4.30.0
|
3 |
gradio
|
4 |
+
Pillow
|
5 |
+
peft
|
6 |
+
bitsandbytes
|
7 |
+
accelerate
|