Staticaliza commited on
Commit
c81c545
Β·
verified Β·
1 Parent(s): 38e087a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -20
app.py CHANGED
@@ -36,11 +36,13 @@ footer {
36
  }
37
  '''
38
 
 
 
39
  input_prefixes = {
40
- "Image": "(A image file called β–ˆ has been attached, describe the image content) ",
41
- "GIF": "(A GIF file called β–ˆ has been attached, describe the GIF content) ",
42
- "Video": "(A audio video file called β–ˆ has been attached, describe the video content and the audio content) ",
43
- "Audio": "(A audio file called β–ˆ has been attached, describe the audio content) ",
44
  }
45
 
46
  filetypes = {
@@ -67,8 +69,7 @@ def frames_from_video(path):
67
  def audio_from_video(path):
68
  clip = VideoFileClip(path)
69
  with tempfile.NamedTemporaryFile(suffix = ".wav", delete = True) as tmp:
70
- clip.audio.write_audiofile(tmp.name, codec = "pcm_s16le",
71
- fps = AUDIO_SR, verbose = False, logger = None)
72
  audio_np, _ = librosa.load(tmp.name, sr = AUDIO_SR, mono = True)
73
  clip.close()
74
  return audio_np
@@ -77,10 +78,10 @@ def load_audio(path):
77
  audio_np, _ = librosa.load(path, sr = AUDIO_SR, mono = True)
78
  return audio_np
79
 
80
- def build_video_omni(path, prefix, instruction):
81
  frames = frames_from_video(path)
82
  audio = audio_from_video(path)
83
- contents = [prefix + instruction]
84
  total = max(len(frames), math.ceil(len(audio) / AUDIO_SR))
85
  for i in range(total):
86
  frame = frames[i] if i < len(frames) else frames[-1]
@@ -88,21 +89,21 @@ def build_video_omni(path, prefix, instruction):
88
  contents.extend(["<unit>", frame, chunk])
89
  return contents
90
 
91
- def build_image_omni(path, prefix, instruction):
92
  image = Image.open(path).convert("RGB")
93
- return [prefix + instruction, image]
94
 
95
- def build_gif_omni(path, prefix, instruction):
96
  img = Image.open(path)
97
  frames = [f.copy().convert("RGB") for f in ImageSequence.Iterator(img)]
98
  frames = uniform_sample(frames, MAX_FRAMES)
99
- return [prefix + instruction, *frames]
100
 
101
- def build_audio_omni(path, prefix, instruction):
102
  audio = load_audio(path)
103
- return [prefix + instruction, audio]
104
 
105
- @spaces.GPU(duration = 60)
106
  def generate(input,
107
  instruction = DEFAULT_INPUT,
108
  sampling = False,
@@ -111,12 +112,12 @@ def generate(input,
111
  top_k = 100,
112
  repetition_penalty = 1.05,
113
  max_tokens = 512):
114
- if not input:
115
- return "no input provided."
116
  extension = os.path.splitext(input)[1].lower()
117
  filetype = infer_filetype(extension)
118
- if not filetype:
119
- return "unsupported file type."
120
  filename = os.path.basename(input)
121
  prefix = input_prefixes[filetype].replace("β–ˆ", filename)
122
  builder_map = {
@@ -125,9 +126,14 @@ def generate(input,
125
  "Video": build_video_omni,
126
  "Audio": build_audio_omni
127
  }
128
- omni_content = builder_map[filetype](input, prefix, instruction)
 
 
129
  sys_msg = repo.get_sys_prompt(mode = "omni", language = "en")
130
  msgs = [sys_msg, { "role": "user", "content": omni_content }]
 
 
 
131
  output = repo.chat(
132
  msgs = msgs,
133
  tokenizer = tokenizer,
@@ -141,8 +147,10 @@ def generate(input,
141
  use_image_id = False,
142
  max_slice_nums = 2
143
  )
 
144
  torch.cuda.empty_cache()
145
  gc.collect()
 
146
  return output
147
 
148
  def cloud():
 
36
  }
37
  '''
38
 
39
+ global_instruction = "Describe the given content with as much keywords and always take a guess."
40
+
41
  input_prefixes = {
42
+ "Image": "A image file called β–ˆ has been attached, describe the image content.",
43
+ "GIF": "A GIF file called β–ˆ has been attached, describe the GIF content.",
44
+ "Video": "A audio video file called β–ˆ has been attached, describe the video content and the audio content.",
45
+ "Audio": "A audio file called β–ˆ has been attached, describe the audio content.",
46
  }
47
 
48
  filetypes = {
 
69
  def audio_from_video(path):
70
  clip = VideoFileClip(path)
71
  with tempfile.NamedTemporaryFile(suffix = ".wav", delete = True) as tmp:
72
+ clip.audio.write_audiofile(tmp.name, codec = "pcm_s16le", fps = AUDIO_SR, verbose = False, logger = None)
 
73
  audio_np, _ = librosa.load(tmp.name, sr = AUDIO_SR, mono = True)
74
  clip.close()
75
  return audio_np
 
78
  audio_np, _ = librosa.load(path, sr = AUDIO_SR, mono = True)
79
  return audio_np
80
 
81
+ def build_video_omni(path, instruction):
82
  frames = frames_from_video(path)
83
  audio = audio_from_video(path)
84
+ contents = [instruction]
85
  total = max(len(frames), math.ceil(len(audio) / AUDIO_SR))
86
  for i in range(total):
87
  frame = frames[i] if i < len(frames) else frames[-1]
 
89
  contents.extend(["<unit>", frame, chunk])
90
  return contents
91
 
92
+ def build_image_omni(path, instruction):
93
  image = Image.open(path).convert("RGB")
94
+ return [instruction, image]
95
 
96
+ def build_gif_omni(path, instruction):
97
  img = Image.open(path)
98
  frames = [f.copy().convert("RGB") for f in ImageSequence.Iterator(img)]
99
  frames = uniform_sample(frames, MAX_FRAMES)
100
+ return [instruction, *frames]
101
 
102
+ def build_audio_omni(path, instruction):
103
  audio = load_audio(path)
104
+ return [instruction, audio]
105
 
106
+ @spaces.GPU(duration=30)
107
  def generate(input,
108
  instruction = DEFAULT_INPUT,
109
  sampling = False,
 
112
  top_k = 100,
113
  repetition_penalty = 1.05,
114
  max_tokens = 512):
115
+ if not input: return "no input provided."
116
+
117
  extension = os.path.splitext(input)[1].lower()
118
  filetype = infer_filetype(extension)
119
+ if not filetype: return "unsupported file type."
120
+
121
  filename = os.path.basename(input)
122
  prefix = input_prefixes[filetype].replace("β–ˆ", filename)
123
  builder_map = {
 
126
  "Video": build_video_omni,
127
  "Audio": build_audio_omni
128
  }
129
+
130
+ instruction = f"{global_instruction}\n{prefix}\n{instruction}"
131
+ omni_content = builder_map[filetype](input, instruction)
132
  sys_msg = repo.get_sys_prompt(mode = "omni", language = "en")
133
  msgs = [sys_msg, { "role": "user", "content": omni_content }]
134
+
135
+ print(msgs)
136
+
137
  output = repo.chat(
138
  msgs = msgs,
139
  tokenizer = tokenizer,
 
147
  use_image_id = False,
148
  max_slice_nums = 2
149
  )
150
+
151
  torch.cuda.empty_cache()
152
  gc.collect()
153
+
154
  return output
155
 
156
  def cloud():