Staticaliza commited on
Commit
bcbc1e7
Β·
verified Β·
1 Parent(s): 0bb8329

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -35
app.py CHANGED
@@ -39,7 +39,7 @@ footer {
39
  input_prefixes = {
40
  "Image": "(A image file called β–ˆ has been attached, describe the image content) ",
41
  "GIF": "(A GIF file called β–ˆ has been attached, describe the GIF content) ",
42
- "Video": "(A video with audio file called β–ˆ has been attached, describe the video content and the audio content embedded into the video) ",
43
  "Audio": "(A audio file called β–ˆ has been attached, describe the audio content) ",
44
  }
45
 
@@ -94,42 +94,119 @@ def build_audio_omni(path, prefix, instruction, sr=AUDIO_SR):
94
  audio_np, _ = librosa.load(path, sr=sr, mono=True)
95
  return ["<unit>", audio_np, prefix + instruction]
96
 
97
- @spaces.GPU(duration=60)
98
- def generate(input, instruction=DEFAULT_INPUT, sampling=False, temperature=0.7, top_p=0.8, top_k=100, repetition_penalty=1.05, max_tokens=512):
99
- if not input: return "No input provided."
100
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  extension = os.path.splitext(input)[1].lower()
102
- filetype = next((k for k, v in filetypes.items() if extension in v), None)
103
-
104
- if not filetype: return "Unsupported file type."
105
-
106
- filename = os.path.basename(input)
107
- prefix = input_prefixes[filetype].replace("β–ˆ", filename)
108
- if filetype == "Video":
109
- omni_content = build_omni_chunks(input, prefix, instruction)
110
- elif filetype == "Image":
111
- omni_content = build_image_omni(input, prefix, instruction)
112
- elif filetype == "GIF":
113
- omni_content = build_gif_omni(input, prefix, instruction)
114
- elif filetype == "Audio":
115
- omni_content = build_audio_omni(input, prefix, instruction)
116
-
117
- sys_msg = repo.get_sys_prompt(mode="omni", language="en")
118
- msgs = [sys_msg, {"role": "user", "content": omni_content}]
119
-
120
- params = {
121
- "msgs": msgs,
122
- "tokenizer": tokenizer,
123
- "sampling": sampling,
124
- "temperature": temperature,
125
- "top_p": top_p,
126
- "top_k": top_k,
127
- "repetition_penalty": repetition_penalty,
128
- "max_new_tokens": max_tokens,
129
- "omni_input": True,
130
  }
131
-
132
- output = repo.chat(**params)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  torch.cuda.empty_cache()
134
  gc.collect()
135
  return output
 
39
  input_prefixes = {
40
  "Image": "(A image file called β–ˆ has been attached, describe the image content) ",
41
  "GIF": "(A GIF file called β–ˆ has been attached, describe the GIF content) ",
42
+ "Video": "(A audio video file called β–ˆ has been attached, describe the video content and the audio content) ",
43
  "Audio": "(A audio file called β–ˆ has been attached, describe the audio content) ",
44
  }
45
 
 
94
  audio_np, _ = librosa.load(path, sr=sr, mono=True)
95
  return ["<unit>", audio_np, prefix + instruction]
96
 
97
+ ef infer_filetype(ext):
98
+ return next((k for k, v in filetypes.items() if ext in v), None)
99
+
100
+
101
+ def uniform_sample(seq, n):
102
+ step = max(len(seq) // n, 1)
103
+ return seq[::step][:n]
104
+
105
+
106
+ def frames_from_video(path):
107
+ vr = VideoReader(path, ctx = cpu(0))
108
+ idx = uniform_sample(range(len(vr)), MAX_FRAMES)
109
+ batch = vr.get_batch(idx).asnumpy()
110
+ return [Image.fromarray(frame.astype("uint8")) for frame in batch]
111
+
112
+
113
+ def audio_from_video(path):
114
+ clip = VideoFileClip(path)
115
+ audio = clip.audio.to_soundarray(fps = AUDIO_SR)
116
+ clip.close()
117
+ return librosa.to_mono(audio.T)
118
+
119
+
120
+ def load_audio(path):
121
+ audio_np, _ = librosa.load(path, sr = AUDIO_SR, mono = True)
122
+ return audio_np
123
+
124
+
125
+ def build_video_omni(path, prefix, instruction):
126
+ frames = frames_from_video(path)
127
+ audio = audio_from_video(path)
128
+ return processor.build_omni_input(
129
+ frames = frames,
130
+ audio = audio,
131
+ prefix = prefix,
132
+ instruction = instruction,
133
+ max_frames = MAX_FRAMES,
134
+ sr = AUDIO_SR
135
+ )
136
+
137
+
138
+ def build_image_omni(path, prefix, instruction):
139
+ image = Image.open(path).convert("RGB")
140
+ return processor.build_omni_input(
141
+ frames = [image],
142
+ audio = None,
143
+ prefix = prefix,
144
+ instruction = instruction
145
+ )
146
+
147
+
148
+ def build_gif_omni(path, prefix, instruction):
149
+ img = Image.open(path)
150
+ frames = [frame.copy().convert("RGB") for frame in ImageSequence.Iterator(img)]
151
+ frames = uniform_sample(frames, MAX_FRAMES)
152
+ return processor.build_omni_input(
153
+ frames = frames,
154
+ audio = None,
155
+ prefix = prefix,
156
+ instruction = instruction
157
+ )
158
+
159
+
160
+ def build_audio_omni(path, prefix, instruction):
161
+ audio = load_audio(path)
162
+ return processor.build_omni_input(
163
+ frames = None,
164
+ audio = audio,
165
+ prefix = prefix,
166
+ instruction = instruction,
167
+ sr = AUDIO_SR
168
+ )
169
+
170
+
171
+ @spaces.GPU(duration = 60)
172
+ def generate(input,
173
+ instruction = DEFAULT_INPUT,
174
+ sampling = False,
175
+ temperature = 0.7,
176
+ top_p = 0.8,
177
+ top_k = 100,
178
+ repetition_penalty = 1.05,
179
+ max_tokens = 512):
180
+ if not input:
181
+ return "no input provided."
182
  extension = os.path.splitext(input)[1].lower()
183
+ filetype = infer_filetype(extension)
184
+ if not filetype:
185
+ return "unsupported file type."
186
+ filename = os.path.basename(input)
187
+ prefix = input_prefixes[filetype].replace("β–ˆ", filename)
188
+ builder_map = {
189
+ "Image": build_image_omni,
190
+ "GIF" : build_gif_omni,
191
+ "Video": build_video_omni,
192
+ "Audio": build_audio_omni
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  }
194
+ omni_content = builder_map[filetype](input, prefix, instruction)
195
+ sys_msg = repo.get_sys_prompt(mode = "omni", language = "en")
196
+ msgs = [sys_msg, { "role": "user", "content": omni_content }]
197
+ output = repo.chat(
198
+ msgs = msgs,
199
+ tokenizer = tokenizer,
200
+ sampling = sampling,
201
+ temperature = temperature,
202
+ top_p = top_p,
203
+ top_k = top_k,
204
+ repetition_penalty = repetition_penalty,
205
+ max_new_tokens = max_tokens,
206
+ omni_input = True,
207
+ use_image_id = False,
208
+ max_slice_nums = 2
209
+ )
210
  torch.cuda.empty_cache()
211
  gc.collect()
212
  return output