Staticaliza commited on
Commit
957abbb
·
verified ·
1 Parent(s): d2ef006

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -58
app.py CHANGED
@@ -47,79 +47,54 @@ filetypes = {
47
  "Audio": [".wav", ".mp3", ".flac", ".aac"],
48
  }
49
 
 
50
  def infer_filetype(ext):
51
  return next((k for k, v in filetypes.items() if ext in v), None)
52
 
53
-
54
  def uniform_sample(seq, n):
55
  step = max(len(seq) // n, 1)
56
  return seq[::step][:n]
57
 
58
-
59
  def frames_from_video(path):
60
  vr = VideoReader(path, ctx = cpu(0))
61
  idx = uniform_sample(range(len(vr)), MAX_FRAMES)
62
  batch = vr.get_batch(idx).asnumpy()
63
  return [Image.fromarray(frame.astype("uint8")) for frame in batch]
64
 
65
-
66
  def audio_from_video(path):
67
  clip = VideoFileClip(path)
68
- audio = clip.audio.to_soundarray(fps = AUDIO_SR)
69
  clip.close()
70
- return librosa.to_mono(audio.T)
71
-
72
 
73
  def load_audio(path):
74
  audio_np, _ = librosa.load(path, sr = AUDIO_SR, mono = True)
75
  return audio_np
76
 
77
-
78
  def build_video_omni(path, prefix, instruction):
79
  frames = frames_from_video(path)
80
  audio = audio_from_video(path)
81
- return processor.build_omni_input(
82
- frames = frames,
83
- audio = audio,
84
- prefix = prefix,
85
- instruction = instruction,
86
- max_frames = MAX_FRAMES,
87
- sr = AUDIO_SR
88
- )
89
-
90
 
91
  def build_image_omni(path, prefix, instruction):
92
  image = Image.open(path).convert("RGB")
93
- return processor.build_omni_input(
94
- frames = [image],
95
- audio = None,
96
- prefix = prefix,
97
- instruction = instruction
98
- )
99
-
100
 
101
  def build_gif_omni(path, prefix, instruction):
102
- img = Image.open(path)
103
- frames = [frame.copy().convert("RGB") for frame in ImageSequence.Iterator(img)]
104
  frames = uniform_sample(frames, MAX_FRAMES)
105
- return processor.build_omni_input(
106
- frames = frames,
107
- audio = None,
108
- prefix = prefix,
109
- instruction = instruction
110
- )
111
-
112
 
113
  def build_audio_omni(path, prefix, instruction):
114
  audio = load_audio(path)
115
- return processor.build_omni_input(
116
- frames = None,
117
- audio = audio,
118
- prefix = prefix,
119
- instruction = instruction,
120
- sr = AUDIO_SR
121
- )
122
-
123
 
124
  @spaces.GPU(duration = 60)
125
  def generate(input,
@@ -136,30 +111,32 @@ def generate(input,
136
  filetype = infer_filetype(extension)
137
  if not filetype:
138
  return "unsupported file type."
139
- filename = os.path.basename(input)
140
- prefix = input_prefixes[filetype].replace("█", filename)
141
- builder_map = {
142
  "Image": build_image_omni,
143
  "GIF" : build_gif_omni,
144
  "Video": build_video_omni,
145
  "Audio": build_audio_omni
146
  }
147
- omni_content = builder_map[filetype](input, prefix, instruction)
148
- sys_msg = repo.get_sys_prompt(mode = "omni", language = "en")
149
- msgs = [sys_msg, { "role": "user", "content": omni_content }]
150
- output = repo.chat(
151
- msgs = msgs,
152
- tokenizer = tokenizer,
153
- sampling = sampling,
154
- temperature = temperature,
155
- top_p = top_p,
156
- top_k = top_k,
157
- repetition_penalty = repetition_penalty,
158
- max_new_tokens = max_tokens,
159
- omni_input = True,
160
- use_image_id = False,
161
- max_slice_nums = 2
162
  )
 
 
163
  return output
164
 
165
  def cloud():
 
47
  "Audio": [".wav", ".mp3", ".flac", ".aac"],
48
  }
49
 
50
+ # Functions
51
  def infer_filetype(ext):
52
  return next((k for k, v in filetypes.items() if ext in v), None)
53
 
 
54
  def uniform_sample(seq, n):
55
  step = max(len(seq) // n, 1)
56
  return seq[::step][:n]
57
 
 
58
  def frames_from_video(path):
59
  vr = VideoReader(path, ctx = cpu(0))
60
  idx = uniform_sample(range(len(vr)), MAX_FRAMES)
61
  batch = vr.get_batch(idx).asnumpy()
62
  return [Image.fromarray(frame.astype("uint8")) for frame in batch]
63
 
 
64
  def audio_from_video(path):
65
  clip = VideoFileClip(path)
66
+ wav = clip.audio.to_soundarray(fps = AUDIO_SR)
67
  clip.close()
68
+ return librosa.to_mono(wav.T)
 
69
 
70
  def load_audio(path):
71
  audio_np, _ = librosa.load(path, sr = AUDIO_SR, mono = True)
72
  return audio_np
73
 
 
74
  def build_video_omni(path, prefix, instruction):
75
  frames = frames_from_video(path)
76
  audio = audio_from_video(path)
77
+ contents = [prefix + instruction]
78
+ total = max(len(frames), math.ceil(len(audio) / AUDIO_SR))
79
+ for i in range(total):
80
+ frame = frames[i] if i < len(frames) else frames[-1]
81
+ chunk = audio[AUDIO_SR * i : AUDIO_SR * (i + 1)]
82
+ contents.extend(["<unit>", frame, chunk])
83
+ return contents
 
 
84
 
85
  def build_image_omni(path, prefix, instruction):
86
  image = Image.open(path).convert("RGB")
87
+ return [prefix + instruction, image]
 
 
 
 
 
 
88
 
89
  def build_gif_omni(path, prefix, instruction):
90
+ img = Image.open(path)
91
+ frames = [f.copy().convert("RGB") for f in ImageSequence.Iterator(img)]
92
  frames = uniform_sample(frames, MAX_FRAMES)
93
+ return [prefix + instruction, *frames]
 
 
 
 
 
 
94
 
95
  def build_audio_omni(path, prefix, instruction):
96
  audio = load_audio(path)
97
+ return [prefix + instruction, audio]
 
 
 
 
 
 
 
98
 
99
  @spaces.GPU(duration = 60)
100
  def generate(input,
 
111
  filetype = infer_filetype(extension)
112
  if not filetype:
113
  return "unsupported file type."
114
+ filename = os.path.basename(input)
115
+ prefix = input_prefixes[filetype].replace("█", filename)
116
+ builder_map = {
117
  "Image": build_image_omni,
118
  "GIF" : build_gif_omni,
119
  "Video": build_video_omni,
120
  "Audio": build_audio_omni
121
  }
122
+ omni_content = builder_map[filetype](input, prefix, instruction)
123
+ sys_msg = repo.get_sys_prompt(mode = "omni", language = "en")
124
+ msgs = [sys_msg, { "role": "user", "content": omni_content }]
125
+ output = repo.chat(
126
+ msgs = msgs,
127
+ tokenizer = tokenizer,
128
+ sampling = sampling,
129
+ temperature = temperature,
130
+ top_p = top_p,
131
+ top_k = top_k,
132
+ repetition_penalty = repetition_penalty,
133
+ max_new_tokens = max_tokens,
134
+ omni_input = True,
135
+ use_image_id = False,
136
+ max_slice_nums = 2
137
  )
138
+ torch.cuda.empty_cache()
139
+ gc.collect()
140
  return output
141
 
142
  def cloud():