plavu_MA / app.py
retromarz's picture
Update app.py
b031819 verified
import gradio as gr
import torch
from PIL import Image
from transformers import LlavaForConditionalGeneration, AutoProcessor
import logging
import json
import os
from datetime import datetime
import uuid
from huggingface_hub import snapshot_download
import shutil
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Define output JSON file path
OUTPUT_JSON_PATH = "captions.json"
# Clear Hugging Face cache and download model
MODEL_PATH = "fancyfeast/llama-joycaption-beta-one-hf-llava"
try:
# Clear cache to avoid corrupted files
cache_dir = os.path.expanduser("~/.cache/huggingface/hub")
model_cache = os.path.join(cache_dir, f"models--{MODEL_PATH.replace('/', '--')}")
if os.path.exists(model_cache):
shutil.rmtree(model_cache)
logger.info(f"Cleared cache for {MODEL_PATH}")
# Pre-download model to ensure integrity
snapshot_download(repo_id=MODEL_PATH)
logger.info(f"Downloaded model {MODEL_PATH}")
# Load processor and model
processor = AutoProcessor.from_pretrained(MODEL_PATH)
model = LlavaForConditionalGeneration.from_pretrained(
MODEL_PATH,
torch_dtype=torch.float32, # CPU-compatible dtype
low_cpu_mem_usage=True, # Minimize memory usage
use_safetensors=True # Force safetensors
).to("cpu")
model.eval()
logger.info("Model and processor loaded successfully.")
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
raise
# Function to append results to JSON file
def save_to_json(image_name, caption, caption_type, caption_length, error=None):
result = {
"image_name": image_name,
"caption": caption,
"caption_type": caption_type,
"caption_length": caption_length,
"timestamp": datetime.now().isoformat(),
"error": error
}
try:
if os.path.exists(OUTPUT_JSON_PATH):
with open(OUTPUT_JSON_PATH, "r") as f:
data = json.load(f)
else:
data = []
except Exception as e:
logger.error(f"Error reading JSON file: {str(e)}")
data = []
data.append(result)
try:
with open(OUTPUT_JSON_PATH, "w") as f:
json.dump(data, f, indent=4)
logger.info(f"Saved result to {OUTPUT_JSON_PATH}")
except Exception as e:
logger.error(f"Error writing to JSON file: {str(e)}")
# Define the captioning function
def generate_caption(input_image: Image.Image, caption_type: str = "descriptive", caption_length: str = "medium") -> str:
if input_image is None:
error_msg = "Please upload an image."
save_to_json("unknown", error_msg, caption_type, caption_length, error=error_msg)
return error_msg
# Generate a unique image name
image_name = f"image_{uuid.uuid4().hex}.jpg"
try:
# Resize image to reduce memory usage
input_image = input_image.resize((256, 256))
# Prepare the prompt
prompt = f"Write a {caption_length} {caption_type} caption for this image."
convo = [
{
"role": "system",
"content": "You are a helpful assistant that generates accurate and relevant image captions."
},
{
"role": "user",
"content": prompt.strip()
}
]
# Process the image and prompt
inputs = processor(images=input_image, text=convo[1]["content"], return_tensors="pt").to("cpu")
# Generate the caption with reduced max tokens
with torch.no_grad():
output = model.generate(**inputs, max_new_tokens=50, temperature=0.7, top_p=0.9)
# Decode the output
caption = processor.decode(output[0], skip_special_tokens=True).strip()
# Save to JSON
save_to_json(image_name, caption, caption_type, caption_length, error=None)
return caption
except Exception as e:
error_msg = f"Error generating caption: {str(e)}"
logger.error(error_msg)
save_to_json(image_name, "", caption_type, caption_length, error=error_msg)
return error_msg
# Create the Gradio interface
interface = gr.Interface(
fn=generate_caption,
inputs=[
gr.Image(label="Upload Image", type="pil"),
gr.Dropdown(choices=["descriptive", "casual", "social media"], label="Caption Type", value="descriptive"),
gr.Dropdown(choices=["short", "medium", "long"], label="Caption Length", value="medium")
],
outputs=gr.Textbox(label="Generated Caption"),
title="Image Captioning with JoyCaption",
description="Upload an image to generate a caption using the fancyfeast/llama-joycaption-beta-one-hf-llava model. Results are saved to captions.json."
)
if __name__ == "__main__":
interface.launch()