prithivMLmods commited on
Commit
17b06a6
Β·
verified Β·
1 Parent(s): bb456cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -13
app.py CHANGED
@@ -1,5 +1,11 @@
1
  import gradio as gr
2
- from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer
 
 
 
 
 
 
3
  from transformers.image_utils import load_image
4
  from threading import Thread
5
  import time
@@ -9,6 +15,12 @@ import cv2
9
  import numpy as np
10
  from PIL import Image
11
 
 
 
 
 
 
 
12
  def progress_bar_html(label: str) -> str:
13
  """
14
  Returns an HTML snippet for a thin progress bar with a label.
@@ -29,6 +41,9 @@ def progress_bar_html(label: str) -> str:
29
  </style>
30
  '''
31
 
 
 
 
32
  def downsample_video(video_path):
33
  """
34
  Downsamples the video to 10 evenly spaced frames.
@@ -54,19 +69,40 @@ def downsample_video(video_path):
54
  vidcap.release()
55
  return frames
56
 
57
- MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct" # Alternatively: "Qwen/Qwen2.5-VL-3B-Instruct"
58
- processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
59
- model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
60
- MODEL_ID,
 
 
 
61
  trust_remote_code=True,
62
  torch_dtype=torch.bfloat16
63
  ).to("cuda").eval()
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  @spaces.GPU
66
  def model_inference(input_dict, history):
67
  text = input_dict["text"]
68
- files = input_dict["files"]
69
 
 
 
 
70
  if text.strip().lower().startswith("@video-infer"):
71
  # Remove the tag from the query.
72
  text = text[len("@video-infer"):].strip()
@@ -103,7 +139,7 @@ def model_inference(input_dict, history):
103
  # Set up streaming generation.
104
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
105
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
106
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
107
  thread.start()
108
  buffer = ""
109
  yield progress_bar_html("Processing video with Qwen2.5VL Model")
@@ -113,6 +149,46 @@ def model_inference(input_dict, history):
113
  yield buffer
114
  return
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  if len(files) > 1:
117
  images = [load_image(image) for image in files]
118
  elif len(files) == 1:
@@ -120,9 +196,6 @@ def model_inference(input_dict, history):
120
  else:
121
  images = []
122
 
123
- if text == "" and not images:
124
- gr.Error("Please input a query and optionally image(s).")
125
- return
126
  if text == "" and images:
127
  gr.Error("Please input a text query along with the image(s).")
128
  return
@@ -145,7 +218,7 @@ def model_inference(input_dict, history):
145
  ).to("cuda")
146
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
147
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
148
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
149
  thread.start()
150
  buffer = ""
151
  yield progress_bar_html("Processing with Qwen2.5VL Model")
@@ -154,11 +227,15 @@ def model_inference(input_dict, history):
154
  time.sleep(0.01)
155
  yield buffer
156
 
 
 
 
157
  examples = [
158
  [{"text": "Describe the Image?", "files": ["example_images/document.jpg"]}],
 
159
  [{"text": "@video-infer Explain the content of the Advertisement", "files": ["example_images/videoplayback.mp4"]}],
160
  [{"text": "@video-infer Explain the content of the video in detail", "files": ["example_images/breakfast.mp4"]}],
161
- [{"text": "@video-infer Explain the content of the video.", "files": ["example_images/sky.mp4"]}],
162
  ]
163
 
164
  demo = gr.ChatInterface(
@@ -172,4 +249,5 @@ demo = gr.ChatInterface(
172
  cache_examples=False,
173
  )
174
 
175
- demo.launch(debug=True)
 
 
1
  import gradio as gr
2
+ from transformers import (
3
+ AutoProcessor,
4
+ Qwen2_5_VLForConditionalGeneration,
5
+ TextIteratorStreamer,
6
+ AutoModelForCausalLM,
7
+ AutoTokenizer,
8
+ )
9
  from transformers.image_utils import load_image
10
  from threading import Thread
11
  import time
 
15
  import numpy as np
16
  from PIL import Image
17
 
18
+ # A constant for token length limit
19
+ MAX_INPUT_TOKEN_LENGTH = 4096
20
+
21
+ # -----------------------
22
+ # Progress Bar Helper
23
+ # -----------------------
24
  def progress_bar_html(label: str) -> str:
25
  """
26
  Returns an HTML snippet for a thin progress bar with a label.
 
41
  </style>
42
  '''
43
 
44
+ # -----------------------
45
+ # Video Downsampling Helper
46
+ # -----------------------
47
  def downsample_video(video_path):
48
  """
49
  Downsamples the video to 10 evenly spaced frames.
 
69
  vidcap.release()
70
  return frames
71
 
72
+ # -----------------------
73
+ # Qwen2.5-VL Multimodal Setup
74
+ # -----------------------
75
+ MODEL_ID_QWEN = "Qwen/Qwen2.5-VL-7B-Instruct" # Alternatively: "Qwen/Qwen2.5-VL-3B-Instruct"
76
+ processor = AutoProcessor.from_pretrained(MODEL_ID_QWEN, trust_remote_code=True)
77
+ qwen_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
78
+ MODEL_ID_QWEN,
79
  trust_remote_code=True,
80
  torch_dtype=torch.bfloat16
81
  ).to("cuda").eval()
82
 
83
+ # -----------------------
84
+ # DeepHermes Text Generation Setup
85
+ # -----------------------
86
+ text_model_id = "prithivMLmods/DeepHermes-3-Llama-3-3B-Preview-abliterated"
87
+ text_tokenizer = AutoTokenizer.from_pretrained(text_model_id)
88
+ text_model = AutoModelForCausalLM.from_pretrained(
89
+ text_model_id,
90
+ device_map="auto",
91
+ torch_dtype=torch.bfloat16,
92
+ )
93
+ text_model.eval()
94
+
95
+ # -----------------------
96
+ # Main Inference Function
97
+ # -----------------------
98
  @spaces.GPU
99
  def model_inference(input_dict, history):
100
  text = input_dict["text"]
101
+ files = input_dict.get("files", [])
102
 
103
+ # -----------------------
104
+ # Video Inference Branch
105
+ # -----------------------
106
  if text.strip().lower().startswith("@video-infer"):
107
  # Remove the tag from the query.
108
  text = text[len("@video-infer"):].strip()
 
139
  # Set up streaming generation.
140
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
141
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
142
+ thread = Thread(target=qwen_model.generate, kwargs=generation_kwargs)
143
  thread.start()
144
  buffer = ""
145
  yield progress_bar_html("Processing video with Qwen2.5VL Model")
 
149
  yield buffer
150
  return
151
 
152
+ # -----------------------
153
+ # Text-Only Inference Branch (using DeepHermes text generation)
154
+ # -----------------------
155
+ if not files:
156
+ # Prepare a simple conversation for text-only input.
157
+ conversation = [{"role": "user", "content": text}]
158
+ # Here we use the text tokenizer’s chat template method.
159
+ input_ids = text_tokenizer.apply_chat_template(
160
+ conversation, add_generation_prompt=True, return_tensors="pt"
161
+ )
162
+ # Trim if necessary.
163
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
164
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
165
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
166
+ input_ids = input_ids.to(text_model.device)
167
+ streamer = TextIteratorStreamer(text_tokenizer, skip_prompt=True, skip_special_tokens=True)
168
+ generation_kwargs = {
169
+ "input_ids": input_ids,
170
+ "streamer": streamer,
171
+ "max_new_tokens": 1024,
172
+ "do_sample": True,
173
+ "top_p": 0.9,
174
+ "top_k": 50,
175
+ "temperature": 0.6,
176
+ "num_beams": 1,
177
+ "repetition_penalty": 1.2,
178
+ }
179
+ thread = Thread(target=text_model.generate, kwargs=generation_kwargs)
180
+ thread.start()
181
+ buffer = ""
182
+ yield progress_bar_html("Processing with DeepHermes Text Generation Model")
183
+ for new_text in streamer:
184
+ buffer += new_text
185
+ time.sleep(0.01)
186
+ yield buffer
187
+ return
188
+
189
+ # -----------------------
190
+ # Multimodal (Image) Inference Branch with Qwen2.5-VL
191
+ # -----------------------
192
  if len(files) > 1:
193
  images = [load_image(image) for image in files]
194
  elif len(files) == 1:
 
196
  else:
197
  images = []
198
 
 
 
 
199
  if text == "" and images:
200
  gr.Error("Please input a text query along with the image(s).")
201
  return
 
218
  ).to("cuda")
219
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
220
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
221
+ thread = Thread(target=qwen_model.generate, kwargs=generation_kwargs)
222
  thread.start()
223
  buffer = ""
224
  yield progress_bar_html("Processing with Qwen2.5VL Model")
 
227
  time.sleep(0.01)
228
  yield buffer
229
 
230
+ # -----------------------
231
+ # Gradio Chat Interface
232
+ # -----------------------
233
  examples = [
234
  [{"text": "Describe the Image?", "files": ["example_images/document.jpg"]}],
235
+ [{"text": "Tell me a story about a brave knight in a faraway kingdom."}],
236
  [{"text": "@video-infer Explain the content of the Advertisement", "files": ["example_images/videoplayback.mp4"]}],
237
  [{"text": "@video-infer Explain the content of the video in detail", "files": ["example_images/breakfast.mp4"]}],
238
+
239
  ]
240
 
241
  demo = gr.ChatInterface(
 
249
  cache_examples=False,
250
  )
251
 
252
+ if __name__ == "__main__":
253
+ demo.launch(debug=True)