Spaces:
Sleeping
Sleeping
File size: 5,518 Bytes
fef0a8d 99eb93c fef0a8d 4cab0f7 38e087a 8fa5734 294c109 38e087a 7820541 45099c6 542f90d fef0a8d 294c109 fef0a8d f8a64f8 45099c6 24f9533 5268082 e52a62d 0629ecb 4bd5128 294c109 4bd5128 96f2f76 4036c77 c81c545 4cab0f7 ab966e8 4cab0f7 957abbb 4036c77 bcbc1e7 ab966e8 4036c77 d7a2675 ab966e8 d7a2675 ef14932 3e7bef2 d7a2675 ef14932 957abbb 3e7bef2 d7a2675 3e7bef2 ef14932 bcbc1e7 3e7bef2 d7a2675 ab966e8 d7a2675 ef14932 ab966e8 e8b05fb c74b254 bcbc1e7 ab966e8 d7a2675 bcbc1e7 c81c545 ab966e8 d7a2675 ab966e8 d7a2675 ab966e8 c81c545 d7a2675 bcbc1e7 c81c545 957abbb c81c545 5268082 294c109 5268082 46010b5 ab966e8 4036c77 d7a2675 5268082 5a25e75 5268082 ab966e8 5268082 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
# Imports
import gradio as gr
import spaces
import torch
import os
import math
import gc
import librosa
import tempfile
from PIL import Image, ImageSequence
from decord import VideoReader, cpu
from transformers import AutoModel, AutoTokenizer, AutoProcessor
# Variables
DEVICE = "auto"
if DEVICE == "auto":
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[SYSTEM] | Using {DEVICE} type compute device.")
DEFAULT_INPUT = "Describe in one short sentence."
MAX_FRAMES = 64
AUDIO_SR = 16000
model_name = "openbmb/MiniCPM-o-2_6"
repo = AutoModel.from_pretrained(model_name, trust_remote_code=True, attn_implementation="sdpa", torch_dtype=torch.bfloat16).to(DEVICE)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
css = '''
.gradio-container{max-width: 560px !important}
h1{text-align:center}
footer {
visibility: hidden
}
'''
global_instruction = "You will analyze image, GIF, video, and audio input, then use as much keywords to describe the given content and take as much guesses of what it could be."
filetypes = {
"Image": {
"extensions": [".jpg",".jpeg",".png",".bmp"],
"instruction": "Analyze the 'β' image.",
"function": "build_image"
},
"GIF":{
"extensions": [".gif"],
"instruction": "Analyze the 'β' GIF.",
"function": "build_gif"
},
"Video": {
"extensions": [".mp4",".mov",".avi",".mkv"],
"instruction": "Analyze the 'β' video including the audio associated with the video.",
"function": "build_video"
},
"Audio": {
"extensions": [".wav",".mp3",".flac",".aac"],
"instruction": "Analyze the 'β' audio.",
"function": "build_audio"
},
}
# Functions
uniform_sample=lambda seq, n: seq[::max(len(seq) // n,1)][:n]
def build_video(filepath):
vr = VideoReader(filepath, ctx = cpu(0))
i = uniform_sample(range(len(vr)), MAX_FRAMES)
batch = vr.get_batch(i).asnumpy()
frames = [Image.fromarray(frame.astype("uint8")) for frame in batch]
audio = build_audio(filepath)
audio_length = math.ceil(len(audio) / AUDIO_SR)
total_length = max(1, min(len(frames), audio_length))
contents = []
for i in range(total_length):
frame = frames[i] if i < len(frames) else frames[-1]
start = i * AUDIO_SR
end = min((i + 1) * AUDIO_SR, len(audio))
chunk = audio[start:end]
if chunk.size == 0: break
contents.extend([frame, chunk])
return contents
def build_image(filepath):
image = Image.open(filepath).convert("RGB")
return image
def build_gif(filepath):
image = Image.open(filepath)
frames = [f.copy().convert("RGB") for f in ImageSequence.Iterator(image)]
frames = uniform_sample(frames, MAX_FRAMES)
return frames
def build_audio(filepath):
audio, _ = librosa.load(filepath, sr=AUDIO_SR, mono=True)
return audio
@spaces.GPU(duration=30)
def generate(filepath, input=DEFAULT_INPUT, sampling=False, temperature=0.7, top_p=0.8, top_k=100, repetition_penalty=1.05, max_tokens=512):
if not input: return "No input provided."
extension = os.path.splitext(filepath)[1].lower()
filetype = next((k for k, v in filetypes.items() if extension in v["extensions"]), None)
if not filetype: return "Unsupported file type."
filetype_data = filetypes[filetype]
input_prefix = filetype_data["instruction"].replace("β", os.path.basename(filepath))
file_content = globals()[filetype_data["function"]](filepath)
full_instruction=f"{global_instruction}\n{input_prefix}\n{instruction}"
content = (file_content if isinstance(file_content, list) else [file_content]) + [full_instruction]
msgs = [{ "role": "user", "content": content }]
print(msgs)
output = repo.chat(
msgs=msgs,
tokenizer=tokenizer,
sampling=sampling,
temperature= temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
max_new_tokens=max_tokens,
omni_input=True,
use_image_id=False,
max_slice_nums=9
)
torch.cuda.empty_cache()
gc.collect()
return output
def cloud():
print("[CLOUD] | Space maintained.")
# Initialize
with gr.Blocks(css=css) as main:
with gr.Column():
file = gr.File(label="File", file_types=["image", "video", "audio"], type="filepath")
input = gr.Textbox(lines=1, value=DEFAULT_INPUT, label="Input")
sampling = gr.Checkbox(value=False, label="Sampling")
temperature = gr.Slider(minimum=0, maximum=2, step=0.01, value=1, label="Temperature")
top_p = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.95, label="Top P")
top_k = gr.Slider(minimum=0, maximum=1000, step=1, value=50, label="Top K")
repetition_penalty = gr.Slider(minimum=0, maximum=2, step=0.01, value=1.05, label="Repetition Penalty")
max_tokens = gr.Slider(minimum=1, maximum=4096, step=1, value=512, label="Max Tokens")
submit = gr.Button("βΆ")
maintain = gr.Button("βοΈ")
with gr.Column():
output = gr.Textbox(lines=1, value="", label="Output")
submit.click(fn=generate, inputs=[file, input, sampling, temperature, top_p, top_k, repetition_penalty, max_tokens], outputs=[output], queue=False)
maintain.click(cloud, inputs=[], outputs=[], queue=False)
main.launch(show_api=True) |