Sense / app.py
Staticaliza's picture
Update app.py
ab966e8 verified
raw
history blame
5.52 kB
# Imports
import gradio as gr
import spaces
import torch
import os
import math
import gc
import librosa
import tempfile
from PIL import Image, ImageSequence
from decord import VideoReader, cpu
from transformers import AutoModel, AutoTokenizer, AutoProcessor
# Variables
DEVICE = "auto"
if DEVICE == "auto":
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[SYSTEM] | Using {DEVICE} type compute device.")
DEFAULT_INPUT = "Describe in one short sentence."
MAX_FRAMES = 64
AUDIO_SR = 16000
model_name = "openbmb/MiniCPM-o-2_6"
repo = AutoModel.from_pretrained(model_name, trust_remote_code=True, attn_implementation="sdpa", torch_dtype=torch.bfloat16).to(DEVICE)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
css = '''
.gradio-container{max-width: 560px !important}
h1{text-align:center}
footer {
visibility: hidden
}
'''
global_instruction = "You will analyze image, GIF, video, and audio input, then use as much keywords to describe the given content and take as much guesses of what it could be."
filetypes = {
"Image": {
"extensions": [".jpg",".jpeg",".png",".bmp"],
"instruction": "Analyze the 'β–ˆ' image.",
"function": "build_image"
},
"GIF":{
"extensions": [".gif"],
"instruction": "Analyze the 'β–ˆ' GIF.",
"function": "build_gif"
},
"Video": {
"extensions": [".mp4",".mov",".avi",".mkv"],
"instruction": "Analyze the 'β–ˆ' video including the audio associated with the video.",
"function": "build_video"
},
"Audio": {
"extensions": [".wav",".mp3",".flac",".aac"],
"instruction": "Analyze the 'β–ˆ' audio.",
"function": "build_audio"
},
}
# Functions
uniform_sample=lambda seq, n: seq[::max(len(seq) // n,1)][:n]
def build_video(filepath):
vr = VideoReader(filepath, ctx = cpu(0))
i = uniform_sample(range(len(vr)), MAX_FRAMES)
batch = vr.get_batch(i).asnumpy()
frames = [Image.fromarray(frame.astype("uint8")) for frame in batch]
audio = build_audio(filepath)
audio_length = math.ceil(len(audio) / AUDIO_SR)
total_length = max(1, min(len(frames), audio_length))
contents = []
for i in range(total_length):
frame = frames[i] if i < len(frames) else frames[-1]
start = i * AUDIO_SR
end = min((i + 1) * AUDIO_SR, len(audio))
chunk = audio[start:end]
if chunk.size == 0: break
contents.extend([frame, chunk])
return contents
def build_image(filepath):
image = Image.open(filepath).convert("RGB")
return image
def build_gif(filepath):
image = Image.open(filepath)
frames = [f.copy().convert("RGB") for f in ImageSequence.Iterator(image)]
frames = uniform_sample(frames, MAX_FRAMES)
return frames
def build_audio(filepath):
audio, _ = librosa.load(filepath, sr=AUDIO_SR, mono=True)
return audio
@spaces.GPU(duration=30)
def generate(filepath, input=DEFAULT_INPUT, sampling=False, temperature=0.7, top_p=0.8, top_k=100, repetition_penalty=1.05, max_tokens=512):
if not input: return "No input provided."
extension = os.path.splitext(filepath)[1].lower()
filetype = next((k for k, v in filetypes.items() if extension in v["extensions"]), None)
if not filetype: return "Unsupported file type."
filetype_data = filetypes[filetype]
input_prefix = filetype_data["instruction"].replace("β–ˆ", os.path.basename(filepath))
file_content = globals()[filetype_data["function"]](filepath)
full_instruction=f"{global_instruction}\n{input_prefix}\n{instruction}"
content = (file_content if isinstance(file_content, list) else [file_content]) + [full_instruction]
msgs = [{ "role": "user", "content": content }]
print(msgs)
output = repo.chat(
msgs=msgs,
tokenizer=tokenizer,
sampling=sampling,
temperature= temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
max_new_tokens=max_tokens,
omni_input=True,
use_image_id=False,
max_slice_nums=9
)
torch.cuda.empty_cache()
gc.collect()
return output
def cloud():
print("[CLOUD] | Space maintained.")
# Initialize
with gr.Blocks(css=css) as main:
with gr.Column():
file = gr.File(label="File", file_types=["image", "video", "audio"], type="filepath")
input = gr.Textbox(lines=1, value=DEFAULT_INPUT, label="Input")
sampling = gr.Checkbox(value=False, label="Sampling")
temperature = gr.Slider(minimum=0, maximum=2, step=0.01, value=1, label="Temperature")
top_p = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.95, label="Top P")
top_k = gr.Slider(minimum=0, maximum=1000, step=1, value=50, label="Top K")
repetition_penalty = gr.Slider(minimum=0, maximum=2, step=0.01, value=1.05, label="Repetition Penalty")
max_tokens = gr.Slider(minimum=1, maximum=4096, step=1, value=512, label="Max Tokens")
submit = gr.Button("β–Ά")
maintain = gr.Button("☁️")
with gr.Column():
output = gr.Textbox(lines=1, value="", label="Output")
submit.click(fn=generate, inputs=[file, input, sampling, temperature, top_p, top_k, repetition_penalty, max_tokens], outputs=[output], queue=False)
maintain.click(cloud, inputs=[], outputs=[], queue=False)
main.launch(show_api=True)