File size: 4,557 Bytes
309f212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
import traceback
import tempfile
import torch
# import mimetypes
from PIL import Image
import av
import numpy as np
import os

from transformers import LlavaNextVideoProcessor, LlavaNextVideoForConditionalGeneration
from my_lib.preproces_video import read_video_pyav

app = FastAPI()

# Load model and processor
MODEL_ID = "llava-hf/LLaVA-NeXT-Video-7B-hf"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("Loading model and processor...")
processor = LlavaNextVideoProcessor.from_pretrained(MODEL_ID)

# Optional: Pre-cache model on HF Spaces to avoid redownloading
# from huggingface_hub import snapshot_download
# snapshot_download(MODEL_ID)

if device.type == "cuda":
    try:
        model = LlavaNextVideoForConditionalGeneration.from_pretrained(
            MODEL_ID,
            torch_dtype=torch.float16,
            low_cpu_mem_usage=True,
            load_in_4bit=True  # Requires bitsandbytes and GPU
        ).to(device)
        print("Loaded model in 4-bit quantized mode.")
    except Exception as e:
        print("Failed to load in 4-bit mode:", e)
        print("Falling back to full precision FP16.")
        model = LlavaNextVideoForConditionalGeneration.from_pretrained(
            MODEL_ID,
            torch_dtype=torch.float16,
            low_cpu_mem_usage=True,
        ).to(device)
else:
    model = LlavaNextVideoForConditionalGeneration.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.float32
    ).to(device)

print(f"Model and processor loaded on {device}.")

@app.get("/")
async def root():
    return {"message": "Welcome to the Summarization API. Use /summarize to summarize media files."}

@app.get("/health")
async def health():
    return {"status": "ok", "device": device.type}

@app.post("/summarize")
async def summarize_media(file: UploadFile = File(...)):
    try:
        with tempfile.NamedTemporaryFile(delete=False, suffix=file.filename) as tmp:
            tmp.write(await file.read())
            tmp_path = tmp.name

        content_type = file.content_type
        is_video = content_type.startswith("video/")
        is_image = content_type.startswith("image/")

        if not (is_video or is_image):
            os.unlink(tmp_path)
            return JSONResponse(status_code=400, content={"error": f"Unsupported file type: {content_type}"})

        if is_video:
            container = av.open(tmp_path)
            total_frames = container.streams.video[0].frames or sum(1 for _ in container.decode(video=0))
            container = av.open(tmp_path)  # reopen to reset position

            if total_frames == 0:
                raise ValueError("Could not extract frames: total frame count is zero.")

            num_frames = min(8, total_frames)
            indices = np.linspace(0, total_frames - 1, num_frames).astype(int)
            clip = read_video_pyav(container, indices)

            conversation = [
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": "Summarize this video and explain the key highlights."},
                        {"type": "video"},
                    ],
                },
            ]
            prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
            inputs = processor(text=prompt, videos=clip, return_tensors="pt").to(device)

        elif is_image:
            image = Image.open(tmp_path).convert("RGB")
            conversation = [
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": "Describe the image and summarize its content."},
                        {"type": "image"},
                    ],
                },
            ]
            prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
            inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)

        output_ids = model.generate(**inputs, max_new_tokens=512)
        response_text = processor.batch_decode(output_ids, skip_special_tokens=True)[0]

        return JSONResponse(content={"summary": response_text})

    except Exception as e:
        print("Unhandled error:", e)
        print(traceback.format_exc())
        return JSONResponse(status_code=500, content={"error": str(e)})

    finally:
        if 'tmp_path' in locals() and os.path.exists(tmp_path):
            os.unlink(tmp_path)