Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, UploadFile, File | |
from fastapi.responses import JSONResponse | |
import traceback | |
import tempfile | |
import torch | |
# import mimetypes | |
from PIL import Image | |
import av | |
import numpy as np | |
import os | |
from transformers import LlavaNextVideoProcessor, LlavaNextVideoForConditionalGeneration | |
from my_lib.preproces_video import read_video_pyav | |
app = FastAPI() | |
# Load model and processor | |
MODEL_ID = "llava-hf/LLaVA-NeXT-Video-7B-hf" | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print("Loading model and processor...") | |
processor = LlavaNextVideoProcessor.from_pretrained(MODEL_ID) | |
# Optional: Pre-cache model on HF Spaces to avoid redownloading | |
# from huggingface_hub import snapshot_download | |
# snapshot_download(MODEL_ID) | |
if device.type == "cuda": | |
try: | |
model = LlavaNextVideoForConditionalGeneration.from_pretrained( | |
MODEL_ID, | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=True, | |
load_in_4bit=True # Requires bitsandbytes and GPU | |
).to(device) | |
print("Loaded model in 4-bit quantized mode.") | |
except Exception as e: | |
print("Failed to load in 4-bit mode:", e) | |
print("Falling back to full precision FP16.") | |
model = LlavaNextVideoForConditionalGeneration.from_pretrained( | |
MODEL_ID, | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=True, | |
).to(device) | |
else: | |
model = LlavaNextVideoForConditionalGeneration.from_pretrained( | |
MODEL_ID, | |
torch_dtype=torch.float32 | |
).to(device) | |
print(f"Model and processor loaded on {device}.") | |
async def root(): | |
return {"message": "Welcome to the Summarization API. Use /summarize to summarize media files."} | |
async def health(): | |
return {"status": "ok", "device": device.type} | |
async def summarize_media(file: UploadFile = File(...)): | |
try: | |
with tempfile.NamedTemporaryFile(delete=False, suffix=file.filename) as tmp: | |
tmp.write(await file.read()) | |
tmp_path = tmp.name | |
content_type = file.content_type | |
is_video = content_type.startswith("video/") | |
is_image = content_type.startswith("image/") | |
if not (is_video or is_image): | |
os.unlink(tmp_path) | |
return JSONResponse(status_code=400, content={"error": f"Unsupported file type: {content_type}"}) | |
if is_video: | |
container = av.open(tmp_path) | |
total_frames = container.streams.video[0].frames or sum(1 for _ in container.decode(video=0)) | |
container = av.open(tmp_path) # reopen to reset position | |
if total_frames == 0: | |
raise ValueError("Could not extract frames: total frame count is zero.") | |
num_frames = min(8, total_frames) | |
indices = np.linspace(0, total_frames - 1, num_frames).astype(int) | |
clip = read_video_pyav(container, indices) | |
conversation = [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": "Summarize this video and explain the key highlights."}, | |
{"type": "video"}, | |
], | |
}, | |
] | |
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) | |
inputs = processor(text=prompt, videos=clip, return_tensors="pt").to(device) | |
elif is_image: | |
image = Image.open(tmp_path).convert("RGB") | |
conversation = [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": "Describe the image and summarize its content."}, | |
{"type": "image"}, | |
], | |
}, | |
] | |
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) | |
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device) | |
output_ids = model.generate(**inputs, max_new_tokens=512) | |
response_text = processor.batch_decode(output_ids, skip_special_tokens=True)[0] | |
return JSONResponse(content={"summary": response_text}) | |
except Exception as e: | |
print("Unhandled error:", e) | |
print(traceback.format_exc()) | |
return JSONResponse(status_code=500, content={"error": str(e)}) | |
finally: | |
if 'tmp_path' in locals() and os.path.exists(tmp_path): | |
os.unlink(tmp_path) | |