ankandrew commited on
Commit
fa60b30
·
verified ·
1 Parent(s): d5e6127

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +277 -155
  2. infer.py +40 -32
app.py CHANGED
@@ -1,167 +1,289 @@
 
 
1
  import gradio as gr
2
- from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer
3
- from transformers.image_utils import load_image
4
- from threading import Thread
5
- import time
6
- import torch
7
  import spaces
8
- import cv2
9
- import numpy as np
10
- from PIL import Image
11
 
12
- def progress_bar_html(label: str) -> str:
13
- """
14
- Returns an HTML snippet for a thin progress bar with a label.
15
- The progress bar is styled as a dark animated bar.
16
- """
17
- return f'''
18
- <div style="display: flex; align-items: center;">
19
- <span style="margin-right: 10px; font-size: 14px;">{label}</span>
20
- <div style="width: 110px; height: 5px; background-color: #9370DB; border-radius: 2px; overflow: hidden;">
21
- <div style="width: 100%; height: 100%; background-color: #4B0082; animation: loading 1.5s linear infinite;"></div>
22
- </div>
23
- </div>
24
- <style>
25
- @keyframes loading {{
26
- 0% {{ transform: translateX(-100%); }}
27
- 100% {{ transform: translateX(100%); }}
28
- }}
29
- </style>
30
- '''
31
 
32
- def downsample_video(video_path):
33
- """
34
- Downsamples the video to 10 evenly spaced frames.
35
- Each frame is converted to a PIL Image along with its timestamp.
36
- """
37
- vidcap = cv2.VideoCapture(video_path)
38
- total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
39
- fps = vidcap.get(cv2.CAP_PROP_FPS)
40
- frames = []
41
- if total_frames <= 0 or fps <= 0:
42
- vidcap.release()
43
- return frames
44
- # Sample 10 evenly spaced frames.
45
- frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
46
- for i in frame_indices:
47
- vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
48
- success, image = vidcap.read()
49
- if success:
50
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
51
- pil_image = Image.fromarray(image)
52
- timestamp = round(i / fps, 2)
53
- frames.append((pil_image, timestamp))
54
- vidcap.release()
55
- return frames
 
 
 
 
 
 
56
 
57
- # MODEL_ID = "XiaomiMiMo/MiMo-VL-7B-RL"
58
- MODEL_ID = "XiaomiMiMo/MiMo-VL-7B-RL-2508"
59
- processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
60
- model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
61
- MODEL_ID,
62
- trust_remote_code=True,
63
- torch_dtype=torch.bfloat16
64
- ).to("cuda").eval()
65
 
66
- @spaces.GPU
67
- def model_inference(input_dict, history):
68
- text = input_dict["text"]
69
- files = input_dict["files"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
- if text.strip().lower().startswith("@video-infer"):
72
- # Remove the tag from the query.
73
- text = text[len("@video-infer"):].strip()
74
- if not files:
75
- yield "⚠️ Please upload a video file along with your `@video-infer` query."
76
- return
77
- # Assume the first file is a video.
78
- video_path = files[0]
79
- frames = downsample_video(video_path)
80
- if not frames:
81
- yield "⚠️ Could not process the video (no frames were read)."
82
- return
83
- # Build messages: start with the text prompt.
84
- messages = [
85
- {
86
- "role": "user",
87
- "content": [{"type": "text", "text": text}]
88
- }
89
- ]
90
- # Append each frame with a timestamp label.
91
- for image, timestamp in frames:
92
- messages[0]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
93
- messages[0]["content"].append({"type": "image", "image": image})
94
- # Collect only the images from the frames.
95
- video_images = [image for image, _ in frames]
96
- # Prepare the prompt.
97
- prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
98
- inputs = processor(
99
- text=[prompt],
100
- images=video_images,
101
- return_tensors="pt",
102
- padding=True,
103
- ).to("cuda")
104
- # Set up streaming generation.
105
- streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
106
- generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=2048)
107
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
108
- thread.start()
109
- buffer = ""
110
- yield progress_bar_html("Processing video with MiMo-VL-7B-RL Model")
111
- for new_text in streamer:
112
- buffer += new_text
113
- time.sleep(0.01)
114
- yield buffer
115
- return
116
 
117
- if len(files) > 1:
118
- images = [load_image(image) for image in files]
119
- elif len(files) == 1:
120
- images = [load_image(files[0])]
121
- else:
122
- images = []
 
 
 
 
 
 
 
 
 
 
123
 
124
- if text == "" and not images:
125
- yield "⚠️ Please enter a question and/or upload image(s)."
126
- return
127
- if text == "" and images:
128
- yield "⚠️ Please enter a text prompt along with the image(s)."
129
- return
130
 
131
- messages = [
132
- {
133
- "role": "user",
134
- "content": [
135
- *[{"type": "image", "image": image} for image in images],
136
- {"type": "text", "text": text},
137
- ],
138
- }
139
- ]
140
- prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
141
- inputs = processor(
142
- text=[prompt],
143
- images=images if images else None,
144
- return_tensors="pt",
145
- padding=True,
146
- ).to("cuda")
147
- streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
148
- generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=2048)
149
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
150
- thread.start()
151
- buffer = ""
152
- yield progress_bar_html("Processing with MiMo-VL-7B-RL Model")
153
- for new_text in streamer:
154
- buffer += new_text
155
- time.sleep(0.01)
156
- yield buffer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
- demo = gr.ChatInterface(
159
- fn=model_inference,
160
- description="# **MiMo-VL-7B-RL (2508) `@video-infer for video understanding`**",
161
- fill_height=True,
162
- textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple"),
163
- stop_btn="Stop Generation",
164
- multimodal=True,
165
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
- demo.launch(debug=True)
 
1
+ # modified from https://github.com/XiaomiMiMo/MiMo-VL/tree/main/app.py
2
+ import os
3
  import gradio as gr
4
+ from infer import MiMoVLInfer
 
 
 
 
5
  import spaces
 
 
 
6
 
7
+ infer = MiMoVLInfer(checkpoint_path=os.environ.get('CKPT_PATH'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ label_translations = {
10
+ "gr_chatinterface_ofl": {
11
+ "English": "Chatbot",
12
+ },
13
+ "gr_chatinterface_ol": {
14
+ "English": "Chatbot",
15
+ },
16
+ "gr_tab_ol": {
17
+ "English": "Online",
18
+ },
19
+ "gr_tab_ofl": {
20
+ "English": "Offline",
21
+ },
22
+ "gr_temperature": {
23
+ "English": "Temperature",
24
+ },
25
+ "gr_webcam_image": {
26
+ "English": "🤳 Open Webcam",
27
+ },
28
+ "gr_webcam_images": {
29
+ "English": "📹 Recorded Frames",
30
+ },
31
+ "gr_chatinterface_ofl.textbox.placeholder": {
32
+ "English":
33
+ "Ask me anything. You can also drop in images and .mp4 videos.",
34
+ },
35
+ "gr_chatinterface_ol.textbox.placeholder": {
36
+ "English": "Ask me anything...",
37
+ }
38
+ }
39
 
 
 
 
 
 
 
 
 
40
 
41
+ @spaces.GPU(duration=120) # bump if your requests take >60s
42
+ def offline_chat(gr_inputs: dict, gr_history: list, infer_history: list, temperature: float):
43
+ infer.to_device("cuda")
44
+ try:
45
+ yield [{"role": "assistant", "content": "⏳ Reserving GPU & preparing inference…"}], infer_history
46
+ for response_text, infer_history in infer(inputs=gr_inputs,
47
+ history=infer_history,
48
+ temperature=temperature):
49
+ if response_text.startswith('<think>') and '</think>' not in response_text:
50
+ reasoning_text = response_text.lstrip('<think>')
51
+ response_message = [{
52
+ "role": "assistant",
53
+ "content": reasoning_text,
54
+ 'metadata': {'title': '🤔 Thinking'}
55
+ }]
56
+ yield response_message, infer_history
57
+ elif '<think>' in response_text and '</think>' in response_text:
58
+ reasoning_text, response_text2 = response_text.split('</think>', 1)
59
+ reasoning_text = reasoning_text.lstrip('<think>')
60
+ response_message = [{
61
+ "role": "assistant",
62
+ "content": reasoning_text,
63
+ 'metadata': {'title': '🤔 Thinking'}
64
+ }, {
65
+ "role": "assistant",
66
+ "content": response_text2
67
+ }]
68
+ yield response_message, infer_history
69
+ else:
70
+ yield [{"role": "assistant", "content": response_text}], infer_history
71
+ finally:
72
+ infer.to_device("cpu")
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
+ @spaces.GPU(duration=120)
76
+ def online_record_chat(text: str, gr_history: list, gr_webcam_images: list, gr_counter: int,
77
+ infer_history: list, temperature: float):
78
+ infer.to_device("cuda")
79
+ try:
80
+ if not gr_webcam_images:
81
+ gr_webcam_images = []
82
+ gr_webcam_images = gr_webcam_images[gr_counter:]
83
+ inputs = {'text': text, 'files': [webp for webp, _ in gr_webcam_images]}
84
+ # send an immediate chunk
85
+ yield f'received {len(gr_webcam_images)} new frames, processing…', gr_counter + len(gr_webcam_images), infer_history
86
+ for response_message, infer_history in offline_chat(
87
+ inputs, gr_history, infer_history, temperature):
88
+ yield response_message, gr.skip(), infer_history
89
+ finally:
90
+ infer.to_device("cpu")
91
 
 
 
 
 
 
 
92
 
93
+ with gr.Blocks() as demo:
94
+ gr.Markdown("""<center><font size=8>MiMo-7b-VL</center>""")
95
+ with gr.Column():
96
+ # gr_title = gr.Markdown('# MiMo-VL')
97
+
98
+ with gr.Row():
99
+ gr_lang_selector = gr.Dropdown(choices=["English"],
100
+ value="English",
101
+ label="🌐 Interface",
102
+ interactive=True,
103
+ min_width=250,
104
+ scale=0)
105
+ with gr.Tabs():
106
+ with gr.Tab("Offline") as gr_tab_ofl:
107
+ gr_infer_history = gr.State([])
108
+ gr_temperature_hidden = gr.Slider(minimum=0.0,
109
+ maximum=2.0,
110
+ step=0.1,
111
+ value=1.0,
112
+ interactive=True,
113
+ visible=False)
114
+ gr_chatinterface_ofl = gr.ChatInterface(
115
+ fn=offline_chat,
116
+ type="messages",
117
+ multimodal=True,
118
+ chatbot=gr.Chatbot(height=800),
119
+ textbox=gr.MultimodalTextbox(
120
+ file_count="multiple",
121
+ file_types=["image", ".mp4"],
122
+ sources=["upload"],
123
+ stop_btn=True,
124
+ placeholder=label_translations[
125
+ 'gr_chatinterface_ofl.textbox.placeholder']['English'],
126
+ ),
127
+ additional_inputs=[
128
+ gr_infer_history, gr_temperature_hidden
129
+ ],
130
+ additional_outputs=[gr_infer_history],
131
+ )
132
+ gr.on(triggers=[gr_chatinterface_ofl.chatbot.clear],
133
+ fn=lambda: [],
134
+ outputs=[gr_infer_history])
135
+ with gr.Row():
136
+ with gr.Column(scale=1, min_width=200):
137
+ gr_temperature_ofl = gr.Slider(
138
+ minimum=0.0,
139
+ maximum=2.0,
140
+ step=0.1,
141
+ value=0.4,
142
+ label=label_translations['gr_temperature']['English'],
143
+ interactive=True)
144
+ gr_temperature_ofl.change(lambda x: x,
145
+ inputs=gr_temperature_ofl,
146
+ outputs=gr_temperature_hidden)
147
+ with gr.Column(scale=8):
148
+ with gr.Column(visible=True) as gr_examples_en:
149
+ gr.Examples(
150
+ examples=[
151
+ {
152
+ "text": "Who are you?",
153
+ "files": []
154
+ },
155
+ {
156
+ "text": "OCR and return markdown",
157
+ "files": ["examples/24-25-pl.png"]
158
+ },
159
+ {
160
+ "text":
161
+ """describe the video""",
162
+ "files":
163
+ ["examples/hitting_baseball.mp4"]
164
+ },
165
+ {
166
+ "text":
167
+ "For the model ranked first on WebSRC, what is its score on MathVision?",
168
+ "files": [
169
+ "examples/mimovl_gui.png",
170
+ "examples/mimovl_reason.png"
171
+ ]
172
+ },
173
+ ],
174
+ inputs=[gr_chatinterface_ofl.textbox],
175
+ )
176
+ with gr.Tab("Online") as gr_tab_ol:
177
+ with gr.Row():
178
+ with gr.Column(scale=1):
179
+ gr_infer_history = gr.State([])
180
+ gr_temperature_hidden = gr.Slider(minimum=0.0,
181
+ maximum=2.0,
182
+ step=0.1,
183
+ value=1.0,
184
+ interactive=True,
185
+ visible=False)
186
+ with gr.Row():
187
+ with gr.Column(scale=1):
188
+ gr_webcam_image = gr.Image(
189
+ label=label_translations['gr_webcam_image']
190
+ ['English'],
191
+ sources="webcam",
192
+ height=250,
193
+ type='filepath')
194
+ gr_webcam_images = gr.Gallery(
195
+ label=label_translations['gr_webcam_images']
196
+ ['English'],
197
+ show_label=True,
198
+ format='webp',
199
+ columns=1,
200
+ height=250,
201
+ preview=True,
202
+ interactive=False)
203
+ gr_counter = gr.Number(value=0, visible=False)
204
+ with gr.Column(scale=3):
205
+ gr_chatinterface_ol = gr.ChatInterface(
206
+ fn=online_record_chat,
207
+ type="messages",
208
+ multimodal=False,
209
+ chatbot=gr.Chatbot(height=800),
210
+ textbox=gr.
211
+ Textbox(placeholder=label_translations[
212
+ 'gr_chatinterface_ol.textbox.placeholder']
213
+ ['English'],
214
+ submit_btn=True,
215
+ stop_btn=True),
216
+ additional_inputs=[
217
+ gr_webcam_images, gr_counter,
218
+ gr_infer_history, gr_temperature_hidden
219
+ ],
220
+ additional_outputs=[
221
+ gr_counter, gr_infer_history
222
+ ],
223
+ )
224
 
225
+ def cache_webcam(recorded_image: str,
226
+ recorded_images: list):
227
+ if not recorded_images:
228
+ recorded_images = []
229
+ return recorded_images + [recorded_image]
230
+
231
+ gr_webcam_image.stream(
232
+ fn=cache_webcam,
233
+ inputs=[gr_webcam_image, gr_webcam_images],
234
+ outputs=[gr_webcam_images],
235
+ stream_every=1,
236
+ concurrency_limit=30,
237
+ )
238
+ with gr.Row():
239
+ gr_temperature_ol = gr.Slider(
240
+ minimum=0.0,
241
+ maximum=2.0,
242
+ step=0.1,
243
+ value=0.4,
244
+ label=label_translations['gr_temperature']
245
+ ['English'],
246
+ interactive=True)
247
+ gr_temperature_ol.change(
248
+ lambda x: x,
249
+ inputs=gr_temperature_ol,
250
+ outputs=gr_temperature_hidden)
251
+
252
+ def update_lang(lang: str):
253
+ return (
254
+ gr.update(label=label_translations['gr_chatinterface_ofl'][lang]),
255
+ gr.update(label=label_translations['gr_chatinterface_ol'][lang]),
256
+ gr.update(placeholder=label_translations[
257
+ 'gr_chatinterface_ofl.textbox.placeholder'][lang]),
258
+ gr.update(placeholder=label_translations[
259
+ 'gr_chatinterface_ol.textbox.placeholder'][lang]),
260
+ gr.update(label=label_translations['gr_tab_ofl'][lang]),
261
+ gr.update(label=label_translations['gr_tab_ol'][lang]),
262
+ gr.update(label=label_translations['gr_temperature'][lang]),
263
+ gr.update(label=label_translations['gr_temperature'][lang]),
264
+ gr.update(visible=lang == 'English'),
265
+ gr.update(visible=lang != 'English'),
266
+ gr.update(label=label_translations['gr_webcam_image'][lang]),
267
+ gr.update(label=label_translations['gr_webcam_images'][lang]),
268
+ )
269
+
270
+ gr_lang_selector.change(fn=update_lang,
271
+ inputs=[gr_lang_selector],
272
+ outputs=[
273
+ gr_chatinterface_ofl.chatbot,
274
+ gr_chatinterface_ol.chatbot,
275
+ gr_chatinterface_ofl.textbox,
276
+ gr_chatinterface_ol.textbox,
277
+ gr_tab_ofl,
278
+ gr_tab_ol,
279
+ gr_temperature_ofl,
280
+ gr_temperature_ol,
281
+ gr_examples_en,
282
+ gr_webcam_image,
283
+ gr_webcam_images,
284
+ ])
285
+ demo.queue(default_concurrency_limit=2, max_size=50)
286
+
287
+ if __name__ == "__main__":
288
+ demo.launch()
289
 
 
infer.py CHANGED
@@ -1,4 +1,6 @@
1
- # modified from https://github.com/ByteDance-Seed/Seed1.5-VL/blob/main/GradioDemo/infer.py
 
 
2
  from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer
3
  from transformers.generation.stopping_criteria import EosTokenCriteria, StoppingCriteriaList
4
  from qwen_vl_utils import process_vision_info
@@ -6,67 +8,73 @@ from threading import Thread
6
 
7
 
8
  class MiMoVLInfer:
9
- def __init__(self, checkpoint_path, device='cuda', **kwargs):
 
10
  self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
11
- checkpoint_path, torch_dtype='auto', device_map=device, attn_implementation='flash_attention_2',
12
- )
13
- self.processor = AutoProcessor.from_pretrained(checkpoint_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  def __call__(self, inputs: dict, history: list = [], temperature: float = 1.0):
16
  messages = self.construct_messages(inputs)
17
  updated_history = history + messages
18
  text = self.processor.apply_chat_template(updated_history, tokenize=False, add_generation_prompt=True)
19
  image_inputs, video_inputs = process_vision_info(updated_history)
 
20
  model_inputs = self.processor(
21
  text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors='pt'
22
  ).to(self.model.device)
 
23
  tokenizer = self.processor.tokenizer
24
- streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
 
 
25
  gen_kwargs = {
26
- 'max_new_tokens': 16000,
 
 
 
27
  'streamer': streamer,
28
  'stopping_criteria': StoppingCriteriaList([EosTokenCriteria(eos_token_id=self.model.config.eos_token_id)]),
29
  'pad_token_id': self.model.config.eos_token_id,
30
  **model_inputs
31
  }
32
- thread = Thread(target=self.model.generate, kwargs=gen_kwargs)
 
33
  thread.start()
34
  partial_response = ""
35
  for new_text in streamer:
36
  partial_response += new_text
37
  yield partial_response, updated_history + [{
38
  'role': 'assistant',
39
- 'content': [{
40
- 'type': 'text',
41
- 'text': partial_response
42
- }]
43
  }]
44
 
45
  def _is_video_file(self, filename):
46
- video_extensions = ['.mp4', '.avi', '.mkv', '.mov', '.wmv', '.flv', '.webm', '.mpeg']
47
- return any(filename.lower().endswith(ext) for ext in video_extensions)
48
 
49
  def construct_messages(self, inputs: dict) -> list:
50
  content = []
51
- for i, path in enumerate(inputs.get('files', [])):
52
  if self._is_video_file(path):
53
- content.append({
54
- "type": "video",
55
- "video": f'file://{path}'
56
- })
57
  else:
58
- content.append({
59
- "type": "image",
60
- "image": f'file://{path}'
61
- })
62
  query = inputs.get('text', '')
63
  if query:
64
- content.append({
65
- "type": "text",
66
- "text": query,
67
- })
68
- messages = [{
69
- "role": "user",
70
- "content": content,
71
- }]
72
- return messages
 
1
+ # modified from https://github.com/XiaomiMiMo/MiMo-VL/tree/main/infer.py
2
+ import os
3
+ import torch
4
  from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer
5
  from transformers.generation.stopping_criteria import EosTokenCriteria, StoppingCriteriaList
6
  from qwen_vl_utils import process_vision_info
 
8
 
9
 
10
  class MiMoVLInfer:
11
+ def __init__(self, checkpoint_path, **kwargs):
12
+ dtype = torch.float16
13
  self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
14
+ checkpoint_path,
15
+ torch_dtype=dtype,
16
+ device_map={"": "cpu"},
17
+ attn_implementation="eager",
18
+ trust_remote_code=True,
19
+ ).eval()
20
+ self.processor = AutoProcessor.from_pretrained(checkpoint_path, trust_remote_code=True)
21
+ self._on_cuda = False
22
+
23
+ def to_device(self, device: str):
24
+ if device == "cuda" and not self._on_cuda:
25
+ self.model.to("cuda")
26
+ self._on_cuda = True
27
+ elif device == "cpu" and self._on_cuda:
28
+ self.model.to("cpu")
29
+ self._on_cuda = False
30
 
31
  def __call__(self, inputs: dict, history: list = [], temperature: float = 1.0):
32
  messages = self.construct_messages(inputs)
33
  updated_history = history + messages
34
  text = self.processor.apply_chat_template(updated_history, tokenize=False, add_generation_prompt=True)
35
  image_inputs, video_inputs = process_vision_info(updated_history)
36
+
37
  model_inputs = self.processor(
38
  text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors='pt'
39
  ).to(self.model.device)
40
+
41
  tokenizer = self.processor.tokenizer
42
+ streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
43
+
44
+ max_new = int(os.getenv("MAX_NEW_TOKENS", "1024"))
45
  gen_kwargs = {
46
+ 'max_new_tokens': max_new,
47
+ 'do_sample': True,
48
+ 'temperature': max(0.0, float(temperature)),
49
+ 'top_p': 0.95,
50
  'streamer': streamer,
51
  'stopping_criteria': StoppingCriteriaList([EosTokenCriteria(eos_token_id=self.model.config.eos_token_id)]),
52
  'pad_token_id': self.model.config.eos_token_id,
53
  **model_inputs
54
  }
55
+
56
+ thread = Thread(target=self.model.generate, kwargs=gen_kwargs, daemon=True)
57
  thread.start()
58
  partial_response = ""
59
  for new_text in streamer:
60
  partial_response += new_text
61
  yield partial_response, updated_history + [{
62
  'role': 'assistant',
63
+ 'content': [{'type': 'text', 'text': partial_response}]
 
 
 
64
  }]
65
 
66
  def _is_video_file(self, filename):
67
+ return any(filename.lower().endswith(ext) for ext in
68
+ ['.mp4', '.avi', '.mkv', '.mov', '.wmv', '.flv', '.webm', '.mpeg'])
69
 
70
  def construct_messages(self, inputs: dict) -> list:
71
  content = []
72
+ for path in inputs.get('files', []):
73
  if self._is_video_file(path):
74
+ content.append({"type": "video", "video": f'file://{path}'})
 
 
 
75
  else:
76
+ content.append({"type": "image", "image": f'file://{path}'})
 
 
 
77
  query = inputs.get('text', '')
78
  if query:
79
+ content.append({"type": "text", "text": query})
80
+ return [{"role": "user", "content": content}]