Saadi07 commited on
Commit
95e1523
·
1 Parent(s): d52fa20
Files changed (4) hide show
  1. README.md +23 -0
  2. app.py +35 -9
  3. colab.py +299 -0
  4. requirements.txt +5 -2
README.md CHANGED
@@ -9,4 +9,27 @@ app_file: app.py
9
  pinned: false
10
  ---
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
9
  pinned: false
10
  ---
11
 
12
+ # Fine-tuned BLIP2 Image Captioning
13
+
14
+ This Hugging Face Space hosts a BLIP2 model that has been fine-tuned on the Flickr8k dataset using Low-Rank Adaptation (LoRA).
15
+
16
+ ## Model Details
17
+
18
+ - Base model: `ybelkada/blip2-opt-2.7b-fp16-sharded`
19
+ - Fine-tuning technique: LoRA (Low-Rank Adaptation)
20
+ - Training dataset: Flickr8k
21
+ - LoRA configuration:
22
+ - Rank (r): 16
23
+ - Alpha: 32
24
+ - Dropout: 0.05
25
+ - Target modules: q_proj, k_proj
26
+
27
+ ## Usage
28
+
29
+ Upload an image to generate a caption. The model will process the image and return a descriptive caption based on its fine-tuned knowledge.
30
+
31
+ ## Notes
32
+
33
+ The model uses 8-bit quantization to reduce memory usage while maintaining performance.
34
+
35
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,19 +1,45 @@
1
  import gradio as gr
2
  from PIL import Image
3
  import torch
4
- from transformers import BlipProcessor, BlipForConditionalGeneration
 
5
 
6
- # Load model and processor
7
- processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
8
- model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
- model = model.to(device)
12
 
13
- # Define the function to generate caption
14
  def generate_caption(image):
15
- inputs = processor(images=image, return_tensors="pt").to(device)
 
 
 
 
 
 
16
  generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=25)
 
 
17
  caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
18
  return caption
19
 
@@ -22,8 +48,8 @@ iface = gr.Interface(
22
  fn=generate_caption,
23
  inputs=gr.Image(type="pil"),
24
  outputs="text",
25
- title="Image Caption Generator",
26
- description="Upload an image to generate a caption."
27
  )
28
 
29
  # Launch
 
1
  import gradio as gr
2
  from PIL import Image
3
  import torch
4
+ from transformers import AutoProcessor, Blip2ForConditionalGeneration
5
+ from peft import PeftModel, LoraConfig
6
 
7
+ # LoRA configuration used during training:
8
+ # config = LoraConfig(
9
+ # r=16,
10
+ # lora_alpha=32,
11
+ # lora_dropout=0.05,
12
+ # bias="none",
13
+ # target_modules=["q_proj", "k_proj"]
14
+ # )
15
+
16
+ # Load base model with the same configuration as in training
17
+ base_model = Blip2ForConditionalGeneration.from_pretrained(
18
+ "ybelkada/blip2-opt-2.7b-fp16-sharded",
19
+ device_map="auto",
20
+ load_in_8bit=True
21
+ )
22
+
23
+ # Load the fine-tuned LoRA weights
24
+ model = PeftModel.from_pretrained(base_model, "./model")
25
+
26
+ # Load processor - use the same one as training
27
+ processor = AutoProcessor.from_pretrained("./processor")
28
 
29
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
30
 
31
+ # Define the function to generate caption - exactly as in colab
32
  def generate_caption(image):
33
+ # Convert image to RGB if needed
34
+ image = image.convert('RGB') if image.mode != 'RGB' else image
35
+
36
+ # Process the image exactly as in colab.py
37
+ inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
38
+
39
+ # Generate caption with the same parameters
40
  generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=25)
41
+
42
+ # Decode the caption
43
  caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
44
  return caption
45
 
 
48
  fn=generate_caption,
49
  inputs=gr.Image(type="pil"),
50
  outputs="text",
51
+ title="Fine-tuned BLIP2 Image Caption Generator",
52
+ description="Upload an image to generate a caption using BLIP2 fine-tuned on Flickr8k with LoRA (r=16, alpha=32)."
53
  )
54
 
55
  # Launch
colab.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Create image # Install libraries
2
+ # !pip install -q git+https://github.com/huggingface/peft.git transformers bitsandbytes datasets
3
+
4
+ # # Fix fsspec version mismatch (if needed)
5
+ # !pip install fsspec==2025.3.0
6
+
7
+ # from google.colab import files
8
+ # files.upload() # Upload kaggle.json here
9
+ # # Move kaggle.json to the right location
10
+ # !mkdir -p ~/.kaggle
11
+ # !cp kaggle.json ~/.kaggle/
12
+ # !chmod 600 ~/.kaggle/kaggle.json
13
+
14
+
15
+ # # Download the dataset
16
+ # !kaggle datasets download -d adityajn105/flickr8k --force
17
+
18
+ # # Unzip it
19
+ # !unzip -q flickr8k.zip -d flickr8k
20
+
21
+ # DATASET_PATH = '/content/flickr8k'
22
+ # CAPTIONS_FILE = os.path.join(DATASET_PATH, 'captions.txt')
23
+ # IMAGES_PATH = os.path.join(DATASET_PATH, 'Images/')
24
+ # import os
25
+ # import pandas as pd
26
+ # from PIL import Image
27
+ # from torch.utils.data import Dataset, DataLoader
28
+ # import torch
29
+ # # Load and process captions
30
+ # df = pd.read_csv(CAPTIONS_FILE, sep=',', names=["image", "caption"])
31
+
32
+ # df["caption"] = df["caption"][1:]
33
+ # df["caption"]
34
+ # df = df.dropna().reset_index(drop=True)
35
+ # df
36
+ # df = df[:8000]
37
+
38
+ # from transformers import AutoProcessor
39
+ # from PIL import Image
40
+
41
+ # processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
42
+
43
+ # class Flickr8kDataset(Dataset):
44
+ # def __init__(self, dataframe, image_dir, processor):
45
+ # self.dataframe = dataframe
46
+ # self.image_dir = image_dir
47
+ # self.processor = processor
48
+
49
+ # def __len__(self):
50
+ # return len(self.dataframe)
51
+
52
+ # def __getitem__(self, idx):
53
+ # row = self.dataframe.iloc[idx]
54
+ # image_path = os.path.join(self.image_dir, row["image"])
55
+ # caption = row["caption"]
56
+
57
+ # # Load image
58
+ # image = Image.open(image_path).convert('RGB')
59
+
60
+ # # Process image
61
+ # encoding = self.processor(images=image, return_tensors="pt")
62
+ # encoding = {k: v.squeeze() for k, v in encoding.items()}
63
+ # encoding["text"] = caption
64
+
65
+ # return encoding
66
+
67
+ # def collate_fn(batch):
68
+ # processed_batch = {}
69
+ # for key in batch[0].keys():
70
+ # if key != "text":
71
+ # processed_batch[key] = torch.stack([example[key] for example in batch])
72
+ # else:
73
+ # text_inputs = processor.tokenizer(
74
+ # [example["text"] for example in batch], padding=True, return_tensors="pt"
75
+ # )
76
+ # processed_batch["input_ids"] = text_inputs["input_ids"]
77
+ # processed_batch["attention_mask"] = text_inputs["attention_mask"]
78
+ # return processed_batch
79
+
80
+ # from transformers import Blip2ForConditionalGeneration
81
+ # from peft import LoraConfig, get_peft_model
82
+
83
+ # model = Blip2ForConditionalGeneration.from_pretrained(
84
+ # "ybelkada/blip2-opt-2.7b-fp16-sharded",
85
+ # device_map="auto",
86
+ # load_in_8bit=True
87
+ # )
88
+
89
+ # # Apply LoRA
90
+ # config = LoraConfig(
91
+ # r=16,
92
+ # lora_alpha=32,
93
+ # lora_dropout=0.05,
94
+ # bias="none",
95
+ # target_modules=["q_proj", "k_proj"]
96
+ # )
97
+ # model = get_peft_model(model, config)
98
+ # model.print_trainable_parameters()
99
+
100
+ # # Load dataset
101
+ # train_dataset = Flickr8kDataset(df, IMAGES_PATH, processor)
102
+ # train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=3, collate_fn=collate_fn)
103
+
104
+ # # Set up optimizer
105
+ # optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
106
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
107
+ # model.train()
108
+
109
+ # # Training loop
110
+ # for epoch in range(1): # Use small epochs for testing, increase later
111
+ # print(f"Epoch: {epoch}")
112
+ # for idx, batch in enumerate(train_dataloader):
113
+ # input_ids = batch.pop("input_ids").to(device)
114
+ # pixel_values = batch.pop("pixel_values").to(device, torch.float16)
115
+
116
+ # outputs = model(input_ids=input_ids, pixel_values=pixel_values, labels=input_ids)
117
+ # loss = outputs.loss
118
+
119
+ # print(f"Batch {idx} Loss: {loss.item():.4f}")
120
+
121
+ # loss.backward()
122
+ # optimizer.step()
123
+ # optimizer.zero_grad()
124
+
125
+ # # Example prediction
126
+ # sample_image = Image.open(os.path.join(IMAGES_PATH, df.iloc[0]["image"])).convert('RGB')
127
+ # inputs = processor(images=sample_image, return_tensors="pt").to(device, torch.float16)
128
+ # generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=25)
129
+ # caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
130
+ # print("Generated caption:", caption)
131
+
132
+ # import matplotlib.pyplot as plt
133
+
134
+ # # Show the sample image with the generated caption
135
+ # plt.figure(figsize=(6,6))
136
+ # plt.imshow(sample_image)
137
+ # plt.axis("off")
138
+ # plt.title(f"Generated caption:\n{caption}", fontsize=12)
139
+ # plt.show()
140
+
141
+
142
+
143
+
144
+ # # Load a sample image
145
+ # sample_image = Image.open(os.path.join(IMAGES_PATH, df.iloc[15]["image"])).convert('RGB')
146
+
147
+ # # Prepare inputs
148
+ # inputs = processor(images=sample_image, return_tensors="pt").to(device, torch.float16)
149
+ # generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=25)
150
+
151
+ # # Decode caption
152
+ # caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
153
+ # print("Generated caption:", caption)
154
+
155
+ # # Show image with caption
156
+ # import matplotlib.pyplot as plt
157
+
158
+ # plt.figure(figsize=(6,6))
159
+ # plt.imshow(sample_image)
160
+ # plt.axis("off")
161
+ # plt.title(f"Generated caption:\n{caption}", fontsize=12)
162
+ # plt.show()
163
+
164
+ # from PIL import Image
165
+ # import matplotlib.pyplot as plt
166
+ # import torch
167
+ # import io
168
+ # from google.colab import files # Only for Colab
169
+
170
+ # # Upload image
171
+ # uploaded = files.upload()
172
+
173
+ # # Get the uploaded file
174
+ # for filename in uploaded.keys():
175
+ # image_path = filename
176
+
177
+ # # Load the image
178
+ # sample_image = Image.open(image_path).convert('RGB')
179
+
180
+ # # Prepare inputs
181
+ # inputs = processor(images=sample_image, return_tensors="pt").to(device, torch.float16)
182
+
183
+ # # Generate caption
184
+ # generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=25)
185
+
186
+ # # Decode caption
187
+ # caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
188
+ # print("Generated caption:", caption)
189
+
190
+ # # Show image with caption
191
+ # plt.figure(figsize=(6,6))
192
+ # plt.imshow(sample_image)
193
+ # plt.axis("off")
194
+ # plt.title(f"Generated caption:\n{caption}", fontsize=12)
195
+ # plt.show()
196
+
197
+ # !pip install evaluate pycocoevalcap --quiet
198
+
199
+ # import evaluate
200
+ # from tqdm import tqdm
201
+ # from PIL import Image
202
+ # import torch
203
+ # import os
204
+
205
+ # # Load metrics
206
+ # bleu = evaluate.load("bleu")
207
+
208
+
209
+
210
+ # df = pd.read_csv(CAPTIONS_FILE, sep=',', names=["image", "caption"])
211
+
212
+ # # Subset of data
213
+ # subset_df = df[8001:8092].reset_index(drop=True)
214
+
215
+ # # Prepare references and predictions
216
+ # references = {}
217
+ # predictions = []
218
+
219
+ # for idx in tqdm(range(len(subset_df))):
220
+ # image_name = subset_df.iloc[idx]['image']
221
+
222
+ # # Load image
223
+ # image_path = os.path.join(IMAGES_PATH, image_name)
224
+ # image = Image.open(image_path).convert('RGB')
225
+
226
+ # # Generate caption
227
+ # inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
228
+ # generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=25)
229
+ # predicted_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
230
+
231
+ # # Save prediction
232
+ # predictions.append(predicted_caption)
233
+
234
+ # # Prepare ground-truth references
235
+ # if image_name not in references:
236
+ # gt = df[df['image'] == image_name]['caption'].tolist()
237
+ # references[image_name] = gt
238
+
239
+ # # Build reference and prediction lists for scoring
240
+ # gt_list = [references[name] for name in subset_df['image']]
241
+ # pred_list = predictions
242
+
243
+ # import evaluate
244
+ # from tqdm import tqdm
245
+ # from PIL import Image
246
+ # from pycocoevalcap.cider.cider import Cider
247
+ # from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
248
+
249
+ # bleu = evaluate.load("bleu")
250
+ # df = pd.read_csv(CAPTIONS_FILE, sep=',', names=["image", "caption"])
251
+
252
+ # # Get subset
253
+ # subset_df = df[8001:8092].reset_index(drop=True)
254
+
255
+ # # Collect predictions and references
256
+ # predictions = []
257
+ # references = {}
258
+
259
+ # for idx in tqdm(range(len(subset_df))):
260
+ # row = subset_df.iloc[idx]
261
+ # image_name = row['image']
262
+
263
+ # image_path = os.path.join(IMAGES_PATH, image_name)
264
+ # image = Image.open(image_path).convert('RGB')
265
+
266
+ # inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
267
+ # generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=25)
268
+ # caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
269
+
270
+ # predictions.append(caption)
271
+
272
+ # if image_name not in references:
273
+ # refs = df[df['image'] == image_name]['caption'].tolist()
274
+ # references[image_name] = refs
275
+
276
+ # # Prepare for BLEU
277
+ # gt_list = [references[name] for name in subset_df["image"]]
278
+ # pred_list = predictions
279
+
280
+ # bleu_score = bleu.compute(predictions=pred_list, references=gt_list)
281
+ # print("BLEU:", bleu_score)
282
+
283
+ # # Prepare COCO-style input
284
+ # gts = {}
285
+ # res = {}
286
+
287
+ # for i, img in enumerate(subset_df["image"]):
288
+ # gts[str(i)] = [{"caption": cap} for cap in references[img]]
289
+ # res[str(i)] = [{"caption": predictions[i]}]
290
+
291
+ # # Tokenize
292
+ # tokenizer = PTBTokenizer()
293
+ # gts_tokenized = tokenizer.tokenize(gts)
294
+ # res_tokenized = tokenizer.tokenize(res)
295
+
296
+ # # Compute CIDEr
297
+ # cider_scorer = Cider()
298
+ # cider_score, _ = cider_scorer.compute_score(gts_tokenized, res_tokenized)
299
+ # print("CIDEr:", cider_score)
requirements.txt CHANGED
@@ -1,4 +1,7 @@
1
  torch
2
- transformers
3
  gradio
4
- Pillow
 
 
 
 
1
  torch
2
+ transformers>=4.30.0
3
  gradio
4
+ Pillow
5
+ peft
6
+ bitsandbytes
7
+ accelerate