Staticaliza commited on
Commit
b5f3a95
·
verified ·
1 Parent(s): cb40697

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -63
app.py CHANGED
@@ -52,8 +52,8 @@ def uniform_sample(idxs, n):
52
  gap = len(idxs) / n
53
  return [idxs[int(i * gap + gap / 2)] for i in range(n)]
54
 
55
- def build_omni_chunks(path, sr=16000, seconds_per_unit=1):
56
- clip = VideoFileClip(path)
57
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
58
  clip.audio.write_audiofile(tmp.name, fps=sr, codec="pcm_s16le", verbose=False, logger=None)
59
  audio_np, _ = librosa.load(tmp.name, sr=sr, mono=True)
@@ -61,19 +61,17 @@ def build_omni_chunks(path, sr=16000, seconds_per_unit=1):
61
  content = []
62
  for i in range(total_units):
63
  t = min(i * seconds_per_unit, clip.duration - 1e-3)
64
- frame = Image.fromarray(clip.get_frame(t).astype("uint8"))
65
  audio_chunk = audio_np[sr * i * seconds_per_unit : sr * (i + 1) * seconds_per_unit]
66
  content.extend(["<unit>", frame, audio_chunk])
 
 
 
67
  return content
68
-
69
- def encode_video(path):
70
- vr = VideoReader(path, ctx=cpu(0))
71
- fps = round(vr.get_avg_fps())
72
- idxs = list(range(0, len(vr), fps))
73
- if len(idxs) > MAX_FRAMES:
74
- idxs = uniform_sample(idxs, MAX_FRAMES)
75
- frames = vr.get_batch(idxs).asnumpy()
76
- return [Image.fromarray(f.astype("uint8")) for f in frames]
77
 
78
  def encode_gif(path):
79
  img = Image.open(path)
@@ -81,59 +79,44 @@ def encode_gif(path):
81
  if len(frames) > MAX_FRAMES:
82
  frames = uniform_sample(frames, MAX_FRAMES)
83
  return frames
84
-
85
- @spaces.GPU(duration=60)
86
- def generate(input, instruction=DEFAULT_INPUT, sampling=False, temperature=0.7, top_p=0.8, top_k=100, repetition_penalty=1.05, max_tokens=512):
87
- print(input)
88
- print(instruction)
89
 
90
- if not input:
91
- return "No input provided."
 
 
 
 
 
92
 
93
- extension = os.path.splitext(input)[1].lower()
94
- filetype = None
95
- for category, extensions in filetypes.items():
96
- if extension in extensions:
97
- filetype = category
98
- break
99
 
100
- content = []
101
- if filetype == "Image":
102
- image = Image.open(input).convert("RGB")
103
- content.append(image)
104
- elif filetype == "GIF":
105
- frames = encode_gif(input)
106
- content.extend(frames)
107
- elif filetype == "Video":
108
- omni_content = build_omni_chunks(input) + [instruction]
109
- sys_msg = repo.get_sys_prompt(mode="omni", language="en")
110
- msgs = [sys_msg, {"role": "user", "content": omni_content}]
111
- print(msgs)
112
- elif filetype == "Audio":
113
- audio_np, sample_rate = librosa.load(input, sr=16000, mono=True)
114
- chunk_tensor = torch.from_numpy(audio_np).float().to(DEVICE)
115
- content.append({"array": chunk_tensor, "sampling_rate": sample_rate})
116
-
117
- """
118
- elif filetype == "Video":
119
- frames = encode_video(input)
120
- content.extend(frames)
121
- audio, _ = librosa.load(input, sr=16000, mono=True)
122
- content.append(audio)
123
- elif filetype == "Audio":
124
- audio, _ = librosa.load(input, sr=16000, mono=True)
125
- content.append(audio)
126
- else:
127
- return "Unsupported file type."
128
- """
129
 
130
  filename = os.path.basename(input)
131
  prefix = input_prefixes[filetype].replace("█", filename)
132
- content.append(prefix + instruction)
133
- inputs_payload = [{"role": "user", "content": content}]
134
-
 
 
 
 
 
 
 
 
 
135
  params = {
136
- "msgs": msgs or inputs_payload,
137
  "tokenizer": tokenizer,
138
  "sampling": sampling,
139
  "temperature": temperature,
@@ -141,13 +124,12 @@ def generate(input, instruction=DEFAULT_INPUT, sampling=False, temperature=0.7,
141
  "top_k": top_k,
142
  "repetition_penalty": repetition_penalty,
143
  "max_new_tokens": max_tokens,
144
- "omni_input": filetype == "Video",
145
  }
146
-
147
- output = repo.chat(**params)
148
 
149
- print(output)
150
-
 
151
  return output
152
 
153
  def cloud():
 
52
  gap = len(idxs) / n
53
  return [idxs[int(i * gap + gap / 2)] for i in range(n)]
54
 
55
+ def build_omni_chunks(path, prefix, instruction, sr=AUDIO_SR, seconds_per_unit=1):
56
+ clip = VideoFileClip(path, audio_fps=sr)
57
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
58
  clip.audio.write_audiofile(tmp.name, fps=sr, codec="pcm_s16le", verbose=False, logger=None)
59
  audio_np, _ = librosa.load(tmp.name, sr=sr, mono=True)
 
61
  content = []
62
  for i in range(total_units):
63
  t = min(i * seconds_per_unit, clip.duration - 1e-3)
64
+ frame = Image.fromarray(clip.get_frame(t).astype("uint8")).convert("RGB")
65
  audio_chunk = audio_np[sr * i * seconds_per_unit : sr * (i + 1) * seconds_per_unit]
66
  content.extend(["<unit>", frame, audio_chunk])
67
+ clip.close()
68
+ os.remove(tmp.name)
69
+ content.append(prefix + instruction)
70
  return content
71
+
72
+ def build_image_omni(path, prefix, instruction):
73
+ image = Image.open(path).convert("RGB")
74
+ return ["<unit>", image, prefix + instruction]
 
 
 
 
 
75
 
76
  def encode_gif(path):
77
  img = Image.open(path)
 
79
  if len(frames) > MAX_FRAMES:
80
  frames = uniform_sample(frames, MAX_FRAMES)
81
  return frames
 
 
 
 
 
82
 
83
+ def build_gif_omni(path, prefix, instruction):
84
+ frames = encode_gif(path)
85
+ content = []
86
+ for f in frames:
87
+ content.extend(["<unit>", f])
88
+ content.append(prefix + instruction)
89
+ return content
90
 
91
+ def build_audio_omni(path, prefix, instruction, sr=AUDIO_SR):
92
+ audio_np, _ = librosa.load(path, sr=sr, mono=True)
93
+ return ["<unit>", audio_np, prefix + instruction]
 
 
 
94
 
95
+ @spaces.GPU(duration=60)
96
+ def generate(input, instruction=DEFAULT_INPUT, sampling=False, temperature=0.7, top_p=0.8, top_k=100, repetition_penalty=1.05, max_tokens=512):
97
+ if not input: return "No input provided."
98
+
99
+ extension = os.path.splitext(input)[1].lower()
100
+ filetype = next((k for k, v in filetypes.items() if extension in v), None)
101
+
102
+ if not filetype: return "Unsupported file type."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  filename = os.path.basename(input)
105
  prefix = input_prefixes[filetype].replace("█", filename)
106
+ if filetype == "Video":
107
+ omni_content = build_omni_chunks(input, prefix, instruction)
108
+ elif filetype == "Image":
109
+ omni_content = build_image_omni(input, prefix, instruction)
110
+ elif filetype == "GIF":
111
+ omni_content = build_gif_omni(input, prefix, instruction)
112
+ elif filetype == "Audio":
113
+ omni_content = build_audio_omni(input, prefix, instruction)
114
+
115
+ sys_msg = repo.get_sys_prompt(mode="omni", language="en")
116
+ msgs = [sys_msg, {"role": "user", "content": omni_content}]
117
+
118
  params = {
119
+ "msgs": msgs,
120
  "tokenizer": tokenizer,
121
  "sampling": sampling,
122
  "temperature": temperature,
 
124
  "top_k": top_k,
125
  "repetition_penalty": repetition_penalty,
126
  "max_new_tokens": max_tokens,
127
+ "omni_input": True,
128
  }
 
 
129
 
130
+ output = repo.chat(**params)
131
+ torch.cuda.empty_cache()
132
+ gc.collect()
133
  return output
134
 
135
  def cloud():