Saadi07 commited on
Commit
c76cc1b
·
1 Parent(s): 95e1523
Files changed (3) hide show
  1. README.md +5 -3
  2. app.py +49 -18
  3. requirements.txt +6 -6
README.md CHANGED
@@ -9,13 +9,13 @@ app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- # Fine-tuned BLIP2 Image Captioning
13
 
14
  This Hugging Face Space hosts a BLIP2 model that has been fine-tuned on the Flickr8k dataset using Low-Rank Adaptation (LoRA).
15
 
16
  ## Model Details
17
 
18
- - Base model: `ybelkada/blip2-opt-2.7b-fp16-sharded`
19
  - Fine-tuning technique: LoRA (Low-Rank Adaptation)
20
  - Training dataset: Flickr8k
21
  - LoRA configuration:
@@ -30,6 +30,8 @@ Upload an image to generate a caption. The model will process the image and retu
30
 
31
  ## Notes
32
 
33
- The model uses 8-bit quantization to reduce memory usage while maintaining performance.
 
 
34
 
35
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
9
  pinned: false
10
  ---
11
 
12
+ # Fine-tuned BLIP2 Image Caption Generator
13
 
14
  This Hugging Face Space hosts a BLIP2 model that has been fine-tuned on the Flickr8k dataset using Low-Rank Adaptation (LoRA).
15
 
16
  ## Model Details
17
 
18
+ - Base model: `Salesforce/blip2-opt-2.7b` (with fallback to `Salesforce/blip2-opt-560m` for CPU environments)
19
  - Fine-tuning technique: LoRA (Low-Rank Adaptation)
20
  - Training dataset: Flickr8k
21
  - LoRA configuration:
 
30
 
31
  ## Notes
32
 
33
+ - The app will automatically detect if CUDA is available
34
+ - If running on CPU, it will use a smaller model version to maintain performance
35
+ - The app includes fallback mechanisms to ensure it works in various environments
36
 
37
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -3,6 +3,7 @@ from PIL import Image
3
  import torch
4
  from transformers import AutoProcessor, Blip2ForConditionalGeneration
5
  from peft import PeftModel, LoraConfig
 
6
 
7
  # LoRA configuration used during training:
8
  # config = LoraConfig(
@@ -13,20 +14,42 @@ from peft import PeftModel, LoraConfig
13
  # target_modules=["q_proj", "k_proj"]
14
  # )
15
 
16
- # Load base model with the same configuration as in training
17
- base_model = Blip2ForConditionalGeneration.from_pretrained(
18
- "ybelkada/blip2-opt-2.7b-fp16-sharded",
19
- device_map="auto",
20
- load_in_8bit=True
21
- )
22
-
23
- # Load the fine-tuned LoRA weights
24
- model = PeftModel.from_pretrained(base_model, "./model")
25
 
26
- # Load processor - use the same one as training
27
  processor = AutoProcessor.from_pretrained("./processor")
28
 
29
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  # Define the function to generate caption - exactly as in colab
32
  def generate_caption(image):
@@ -34,14 +57,21 @@ def generate_caption(image):
34
  image = image.convert('RGB') if image.mode != 'RGB' else image
35
 
36
  # Process the image exactly as in colab.py
37
- inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
38
 
39
- # Generate caption with the same parameters
40
- generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=25)
 
41
 
42
- # Decode the caption
43
- caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
44
- return caption
 
 
 
 
 
 
45
 
46
  # Create Gradio interface
47
  iface = gr.Interface(
@@ -49,7 +79,8 @@ iface = gr.Interface(
49
  inputs=gr.Image(type="pil"),
50
  outputs="text",
51
  title="Fine-tuned BLIP2 Image Caption Generator",
52
- description="Upload an image to generate a caption using BLIP2 fine-tuned on Flickr8k with LoRA (r=16, alpha=32)."
 
53
  )
54
 
55
  # Launch
 
3
  import torch
4
  from transformers import AutoProcessor, Blip2ForConditionalGeneration
5
  from peft import PeftModel, LoraConfig
6
+ import os
7
 
8
  # LoRA configuration used during training:
9
  # config = LoraConfig(
 
14
  # target_modules=["q_proj", "k_proj"]
15
  # )
16
 
17
+ # Check if we're running on CPU or GPU
18
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
+ print(f"Using device: {device}")
 
 
 
 
 
 
20
 
21
+ # Load processor first
22
  processor = AutoProcessor.from_pretrained("./processor")
23
 
24
+ # Load base model without 8-bit quantization for CPU compatibility
25
+ try:
26
+ # Try loading with device_map for better memory usage if available
27
+ base_model = Blip2ForConditionalGeneration.from_pretrained(
28
+ "Salesforce/blip2-opt-2.7b",
29
+ device_map="auto" if torch.cuda.is_available() else None,
30
+ load_in_8bit=torch.cuda.is_available() # Only use 8-bit if CUDA is available
31
+ )
32
+ except Exception as e:
33
+ print(f"Error loading full model: {e}")
34
+ print("Falling back to smaller model...")
35
+ # Fall back to a smaller model if the large one fails
36
+ base_model = Blip2ForConditionalGeneration.from_pretrained(
37
+ "Salesforce/blip2-opt-560m",
38
+ device_map=None
39
+ )
40
+
41
+ # Load the fine-tuned LoRA weights
42
+ try:
43
+ model = PeftModel.from_pretrained(base_model, "./model")
44
+ print("Successfully loaded fine-tuned LoRA weights")
45
+ except Exception as e:
46
+ print(f"Error loading LoRA weights: {e}")
47
+ print("Continuing with base model only")
48
+ model = base_model
49
+
50
+ # Move model to device if not using device_map
51
+ if not hasattr(model, "hf_device_map"):
52
+ model = model.to(device)
53
 
54
  # Define the function to generate caption - exactly as in colab
55
  def generate_caption(image):
 
57
  image = image.convert('RGB') if image.mode != 'RGB' else image
58
 
59
  # Process the image exactly as in colab.py
60
+ inputs = processor(images=image, return_tensors="pt").to(device)
61
 
62
+ # Use fp32 instead of fp16 for CPU compatibility
63
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
64
+ pixel_values = inputs.pixel_values.to(dtype)
65
 
66
+ try:
67
+ # Generate caption with the same parameters
68
+ generated_ids = model.generate(pixel_values=pixel_values, max_length=25)
69
+
70
+ # Decode the caption
71
+ caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
72
+ return caption
73
+ except Exception as e:
74
+ return f"Error generating caption: {str(e)}"
75
 
76
  # Create Gradio interface
77
  iface = gr.Interface(
 
79
  inputs=gr.Image(type="pil"),
80
  outputs="text",
81
  title="Fine-tuned BLIP2 Image Caption Generator",
82
+ description="Upload an image to generate a caption using BLIP2 fine-tuned on Flickr8k with LoRA (r=16, alpha=32).",
83
+ examples=["https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png"]
84
  )
85
 
86
  # Launch
requirements.txt CHANGED
@@ -1,7 +1,7 @@
1
- torch
2
- transformers>=4.30.0
3
- gradio
4
  Pillow
5
- peft
6
- bitsandbytes
7
- accelerate
 
1
+ torch>=2.0.0
2
+ transformers>=4.31.0
3
+ gradio>=3.40.0
4
  Pillow
5
+ peft>=0.5.0
6
+ safetensors
7
+ accelerate>=0.25.0