Sense / app.py
Staticaliza's picture
Update app.py
18799c5 verified
# Imports
import gradio as gr
import spaces
import torch
import os
import math
import gc
import librosa
from PIL import Image, ImageSequence
from decord import VideoReader, cpu
from transformers import AutoModel, AutoTokenizer, AutoProcessor
# Variables
DEVICE = "auto"
if DEVICE == "auto":
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[SYSTEM] | Using {DEVICE} type compute device.")
DEFAULT_INPUT = "Describe in one short sentence."
MAX_FRAMES = 64
AUDIO_SR = 16000
model_name = "openbmb/MiniCPM-o-2_6"
repo = AutoModel.from_pretrained(model_name, trust_remote_code=True, attn_implementation="sdpa", torch_dtype=torch.bfloat16, init_vision=True, init_audio=True, init_tts=False).to(DEVICE)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
css = '''
.gradio-container{max-width: 560px !important}
h1{text-align:center}
footer {
visibility: hidden
}
'''
instruction = "You will analyze image, GIF, video, and audio input, then use as much keywords to describe the given content and take as much guesses of what it could be."
filetypes = {
"Image": {
"extensions": [".jpg",".jpeg",".png",".bmp"],
"instruction": "Analyze the 'β–ˆ' image.",
"function": "build_image"
},
"GIF":{
"extensions": [".gif"],
"instruction": "Analyze the 'β–ˆ' GIF.",
"function": "build_gif"
},
"Video": {
"extensions": [".mp4",".mov",".avi",".mkv"],
"instruction": "Analyze the 'β–ˆ' video including the audio associated with the video.",
"function": "build_video"
},
"Audio": {
"extensions": [".wav",".mp3",".flac",".aac"],
"instruction": "Analyze the 'β–ˆ' audio.",
"function": "build_audio"
},
}
# Functions
def uniform_sample(sequence, n): return sequence[::max(len(sequence) // n,1)][:n]
def build_image(path): return [Image.open(path).convert("RGB")]
def build_gif(path):
frames = [f.copy().convert("RGB") for f in ImageSequence.Iterator(Image.open(path))]
return uniform_sample(frames, MAX_FRAMES)
def build_video(path):
vr = VideoReader(path, ctx=cpu(0))
idx = uniform_sample(range(len(vr)), MAX_FRAMES)
frames = [Image.fromarray(f.astype("uint8")) for f in vr.get_batch(idx).asnumpy()]
audio = build_audio(path)[0]
units = []
for i, frame in enumerate(frames):
chunk = audio[i*AUDIO_SR:(i+1)*AUDIO_SR]
if not chunk.size: break
units.extend(["<unit>", frame, chunk])
return units
def build_audio(path):
audio, _ = librosa.load(path, sr=AUDIO_SR, mono=True)
return [audio]
@spaces.GPU(duration=30)
def generate(filepath, input=DEFAULT_INPUT, sampling=False, temperature=0.7, top_p=0.8, top_k=100, repetition_penalty=1.05, max_tokens=512):
if not input: return "No input provided."
extension = os.path.splitext(filepath)[1].lower()
filetype = next((k for k, v in filetypes.items() if extension in v["extensions"]), None)
if not filetype: return "Unsupported file type."
filetype_data = filetypes[filetype]
input_prefix = filetype_data["instruction"].replace("β–ˆ", os.path.basename(filepath))
content = globals()[filetype_data["function"]](filepath) + [f"{instruction}\n{input_prefix}\n{input}"]
messages = [{ "role": "user", "content": content }]
print(messages)
output = repo.chat(
msgs=messages,
tokenizer=tokenizer,
sampling=sampling,
temperature= temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
max_new_tokens=max_tokens,
omni_input=True,
use_image_id=False,
max_slice_nums=9
)
torch.cuda.empty_cache()
gc.collect()
return output
def cloud():
print("[CLOUD] | Space maintained.")
# Initialize
with gr.Blocks(css=css) as main:
with gr.Column():
file = gr.File(label="File", file_types=["image", "video", "audio"], type="filepath")
input = gr.Textbox(lines=1, value=DEFAULT_INPUT, label="Input")
sampling = gr.Checkbox(value=True, label="Sampling")
temperature = gr.Slider(minimum=0, maximum=2, step=0.01, value=0.7, label="Temperature")
top_p = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.8, label="Top P")
top_k = gr.Slider(minimum=0, maximum=1000, step=1, value=50, label="Top K")
repetition_penalty = gr.Slider(minimum=0, maximum=2, step=0.01, value=1.05, label="Repetition Penalty")
max_tokens = gr.Slider(minimum=1, maximum=4096, step=1, value=512, label="Max Tokens")
submit = gr.Button("β–Ά")
maintain = gr.Button("☁️")
with gr.Column():
output = gr.Textbox(lines=1, value="", label="Output")
submit.click(fn=generate, inputs=[file, input, sampling, temperature, top_p, top_k, repetition_penalty, max_tokens], outputs=[output], queue=False)
maintain.click(cloud, inputs=[], outputs=[], queue=False)
main.launch(show_api=True)