Saadi07 commited on
Commit
d52fa20
·
1 Parent(s): eb89eec

Model Deployed

Browse files
Files changed (1) hide show
  1. app.py +15 -16
app.py CHANGED
@@ -1,31 +1,30 @@
1
  import gradio as gr
2
  from PIL import Image
3
  import torch
4
- from transformers import Blip2ForConditionalGeneration, AutoProcessor
5
 
6
- # Load your fine-tuned model and processor from local directories
7
- processor = AutoProcessor.from_pretrained("./processor")
8
- model = Blip2ForConditionalGeneration.from_pretrained("./model")
9
 
10
- # Inference function
11
- def generate_caption(image: Image.Image) -> str:
12
- # Convert image to RGB and process
13
- image = image.convert("RGB")
14
- inputs = processor(images=image, return_tensors="pt").to(model.device, torch.float16)
15
-
16
- # Generate caption
17
  generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=25)
18
  caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
19
  return caption
20
 
21
- # Gradio UI
22
  iface = gr.Interface(
23
  fn=generate_caption,
24
  inputs=gr.Image(type="pil"),
25
  outputs="text",
26
- title="🖼️ Image Captioning with Fine-Tuned BLIP2",
27
- description="Upload an image to generate a caption using your custom fine-tuned BLIP2 model.",
28
  )
29
 
30
- if __name__ == "__main__":
31
- iface.launch()
 
1
  import gradio as gr
2
  from PIL import Image
3
  import torch
4
+ from transformers import BlipProcessor, BlipForConditionalGeneration
5
 
6
+ # Load model and processor
7
+ processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
8
+ model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
9
 
10
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
+ model = model.to(device)
12
+
13
+ # Define the function to generate caption
14
+ def generate_caption(image):
15
+ inputs = processor(images=image, return_tensors="pt").to(device)
 
16
  generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=25)
17
  caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
18
  return caption
19
 
20
+ # Create Gradio interface
21
  iface = gr.Interface(
22
  fn=generate_caption,
23
  inputs=gr.Image(type="pil"),
24
  outputs="text",
25
+ title="Image Caption Generator",
26
+ description="Upload an image to generate a caption."
27
  )
28
 
29
+ # Launch
30
+ iface.launch()