MYousafRana's picture
Update main.py
309f212 verified
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
import traceback
import tempfile
import torch
# import mimetypes
from PIL import Image
import av
import numpy as np
import os
from transformers import LlavaNextVideoProcessor, LlavaNextVideoForConditionalGeneration
from my_lib.preproces_video import read_video_pyav
app = FastAPI()
# Load model and processor
MODEL_ID = "llava-hf/LLaVA-NeXT-Video-7B-hf"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Loading model and processor...")
processor = LlavaNextVideoProcessor.from_pretrained(MODEL_ID)
# Optional: Pre-cache model on HF Spaces to avoid redownloading
# from huggingface_hub import snapshot_download
# snapshot_download(MODEL_ID)
if device.type == "cuda":
try:
model = LlavaNextVideoForConditionalGeneration.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
load_in_4bit=True # Requires bitsandbytes and GPU
).to(device)
print("Loaded model in 4-bit quantized mode.")
except Exception as e:
print("Failed to load in 4-bit mode:", e)
print("Falling back to full precision FP16.")
model = LlavaNextVideoForConditionalGeneration.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
).to(device)
else:
model = LlavaNextVideoForConditionalGeneration.from_pretrained(
MODEL_ID,
torch_dtype=torch.float32
).to(device)
print(f"Model and processor loaded on {device}.")
@app.get("/")
async def root():
return {"message": "Welcome to the Summarization API. Use /summarize to summarize media files."}
@app.get("/health")
async def health():
return {"status": "ok", "device": device.type}
@app.post("/summarize")
async def summarize_media(file: UploadFile = File(...)):
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=file.filename) as tmp:
tmp.write(await file.read())
tmp_path = tmp.name
content_type = file.content_type
is_video = content_type.startswith("video/")
is_image = content_type.startswith("image/")
if not (is_video or is_image):
os.unlink(tmp_path)
return JSONResponse(status_code=400, content={"error": f"Unsupported file type: {content_type}"})
if is_video:
container = av.open(tmp_path)
total_frames = container.streams.video[0].frames or sum(1 for _ in container.decode(video=0))
container = av.open(tmp_path) # reopen to reset position
if total_frames == 0:
raise ValueError("Could not extract frames: total frame count is zero.")
num_frames = min(8, total_frames)
indices = np.linspace(0, total_frames - 1, num_frames).astype(int)
clip = read_video_pyav(container, indices)
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": "Summarize this video and explain the key highlights."},
{"type": "video"},
],
},
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
inputs = processor(text=prompt, videos=clip, return_tensors="pt").to(device)
elif is_image:
image = Image.open(tmp_path).convert("RGB")
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": "Describe the image and summarize its content."},
{"type": "image"},
],
},
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
output_ids = model.generate(**inputs, max_new_tokens=512)
response_text = processor.batch_decode(output_ids, skip_special_tokens=True)[0]
return JSONResponse(content={"summary": response_text})
except Exception as e:
print("Unhandled error:", e)
print(traceback.format_exc())
return JSONResponse(status_code=500, content={"error": str(e)})
finally:
if 'tmp_path' in locals() and os.path.exists(tmp_path):
os.unlink(tmp_path)