Staticaliza commited on
Commit
d7a2675
Β·
verified Β·
1 Parent(s): 0350ec7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -84
app.py CHANGED
@@ -9,7 +9,6 @@ import librosa
9
  import tempfile
10
  from PIL import Image, ImageSequence
11
  from decord import VideoReader, cpu
12
- from moviepy.editor import VideoFileClip
13
  from transformers import AutoModel, AutoTokenizer, AutoProcessor
14
 
15
  # Variables
@@ -39,10 +38,10 @@ footer {
39
  global_instruction = "You will analyze video, audio and text input and output your description of the given content with as much keywords and always take a guess."
40
 
41
  input_prefixes = {
42
- "Image": "A image file called β–ˆ has been attached, describe the image content.",
43
- "GIF": "A GIF file called β–ˆ has been attached, describe the GIF content.",
44
- "Video": "A audio video file called β–ˆ has been attached, describe the video content and the audio content.",
45
- "Audio": "A audio file called β–ˆ has been attached, describe the audio content.",
46
  }
47
 
48
  filetypes = {
@@ -53,109 +52,81 @@ filetypes = {
53
  }
54
 
55
  # Functions
56
- def infer_filetype(ext):
57
- return next((k for k, v in filetypes.items() if ext in v), None)
58
-
59
  def uniform_sample(seq, n):
60
  step = max(len(seq) // n, 1)
61
  return seq[::step][:n]
62
 
63
- def frames_from_video(path):
64
  vr = VideoReader(path, ctx = cpu(0))
65
  idx = uniform_sample(range(len(vr)), MAX_FRAMES)
66
  batch = vr.get_batch(idx).asnumpy()
67
- return [Image.fromarray(frame.astype("uint8")) for frame in batch]
68
-
69
- def audio_from_video(path):
70
- clip = VideoFileClip(path)
71
- with tempfile.NamedTemporaryFile(suffix = ".wav", delete = True) as tmp:
72
- clip.audio.write_audiofile(tmp.name,
73
- codec = "pcm_s16le",
74
- fps = AUDIO_SR,
75
- verbose = False,
76
- logger = None)
77
- audio_np, _ = librosa.load(tmp.name, sr = AUDIO_SR, mono = True)
78
- clip.close()
79
- return audio_np
80
-
81
- def load_audio(path):
82
- audio_np, _ = librosa.load(path, sr = AUDIO_SR, mono = True)
83
- return audio_np
84
-
85
- def build_video_omni(path, instruction):
86
- frames = frames_from_video(path)
87
- audio = audio_from_video(path)
88
- contents = [instruction]
89
-
90
- audio_secs = math.ceil(len(audio) / AUDIO_SR)
91
  total_units = max(1, min(len(frames), audio_secs))
92
 
 
93
  for i in range(total_units):
94
  frame = frames[i] if i < len(frames) else frames[-1]
95
  start = i * AUDIO_SR
96
- end = min((i + 1) * AUDIO_SR, len(audio))
97
  chunk = audio[start:end]
98
  if chunk.size == 0: break
99
  contents.extend(["<unit>", frame, chunk])
100
 
101
  return contents
102
-
103
- def build_image_omni(path, instruction):
104
  image = Image.open(path).convert("RGB")
105
- return [instruction, image]
106
 
107
- def build_gif_omni(path, instruction):
108
- img = Image.open(path)
109
- frames = [f.copy().convert("RGB") for f in ImageSequence.Iterator(img)]
110
  frames = uniform_sample(frames, MAX_FRAMES)
111
- return [instruction, *frames]
112
 
113
- def build_audio_omni(path, instruction):
114
- audio = load_audio(path)
115
- return [instruction, audio]
116
 
117
  @spaces.GPU(duration=30)
118
- def generate(input,
119
- instruction = DEFAULT_INPUT,
120
- sampling = False,
121
- temperature = 0.7,
122
- top_p = 0.8,
123
- top_k = 100,
124
- repetition_penalty = 1.05,
125
- max_tokens = 512):
126
- if not input: return "no input provided."
127
-
128
  extension = os.path.splitext(input)[1].lower()
129
- filetype = infer_filetype(extension)
130
- if not filetype: return "unsupported file type."
131
-
132
- filename = os.path.basename(input)
133
- prefix = input_prefixes[filetype].replace("β–ˆ", filename)
134
- builder_map = {
135
- "Image": build_image_omni,
136
- "GIF" : build_gif_omni,
137
- "Video": build_video_omni,
138
- "Audio": build_audio_omni
139
  }
140
 
141
- instruction = f"{prefix}\n{instruction}"
142
- omni_content = builder_map[filetype](input, instruction)
143
- msgs = [{ "role": "user", "content": global_instruction }, { "role": "user", "content": omni_content }]
144
 
145
  print(msgs)
146
 
147
- output = repo.chat(
148
- msgs = msgs,
149
- tokenizer = tokenizer,
150
- sampling = sampling,
151
- temperature = temperature,
152
- top_p = top_p,
153
- top_k = top_k,
154
- repetition_penalty = repetition_penalty,
155
- max_new_tokens = max_tokens,
156
- omni_input = True,
157
- use_image_id = False,
158
- max_slice_nums = 2
159
  )
160
 
161
  torch.cuda.empty_cache()
@@ -171,11 +142,11 @@ with gr.Blocks(css=css) as main:
171
  with gr.Column():
172
  input = gr.File(label="Input", file_types=["image", "video", "audio"], type="filepath")
173
  instruction = gr.Textbox(lines=1, value=DEFAULT_INPUT, label="Instruction")
174
- sampling = gr.Checkbox(value=False, label="Sampling")
175
- temperature = gr.Slider(minimum=0.01, maximum=1.99, step=0.01, value=0.7, label="Temperature")
176
- top_p = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.8, label="Top P")
177
- top_k = gr.Slider(minimum=0, maximum=1000, step=1, value=100, label="Top K")
178
- repetition_penalty = gr.Slider(minimum=0.01, maximum=1.99, step=0.01, value=1.05, label="Repetition Penalty")
179
  max_tokens = gr.Slider(minimum=1, maximum=4096, step=1, value=512, label="Max Tokens")
180
  submit = gr.Button("β–Ά")
181
  maintain = gr.Button("☁️")
 
9
  import tempfile
10
  from PIL import Image, ImageSequence
11
  from decord import VideoReader, cpu
 
12
  from transformers import AutoModel, AutoTokenizer, AutoProcessor
13
 
14
  # Variables
 
38
  global_instruction = "You will analyze video, audio and text input and output your description of the given content with as much keywords and always take a guess."
39
 
40
  input_prefixes = {
41
+ "Image": "Analyze the 'β–ˆ' image.",
42
+ "GIF": "Analyze the 'β–ˆ' GIF.",
43
+ "Video": "Analyze the 'β–ˆ' video including the audio associated with the video.",
44
+ "Audio": "Analyze the 'β–ˆ' audio.",
45
  }
46
 
47
  filetypes = {
 
52
  }
53
 
54
  # Functions
 
 
 
55
  def uniform_sample(seq, n):
56
  step = max(len(seq) // n, 1)
57
  return seq[::step][:n]
58
 
59
+ def build_video_omni(path):
60
  vr = VideoReader(path, ctx = cpu(0))
61
  idx = uniform_sample(range(len(vr)), MAX_FRAMES)
62
  batch = vr.get_batch(idx).asnumpy()
63
+ frames = [Image.fromarray(frame.astype("uint8")) for frame in batch]
64
+
65
+ audio = build_audio(path)
66
+
67
+ audio_secs = math.ceil(len(audio) / AUDIO_SR)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  total_units = max(1, min(len(frames), audio_secs))
69
 
70
+ contents = []
71
  for i in range(total_units):
72
  frame = frames[i] if i < len(frames) else frames[-1]
73
  start = i * AUDIO_SR
74
+ end = min((i + 1) * AUDIO_SR, len(audio))
75
  chunk = audio[start:end]
76
  if chunk.size == 0: break
77
  contents.extend(["<unit>", frame, chunk])
78
 
79
  return contents
80
+
81
+ def build_image(path):
82
  image = Image.open(path).convert("RGB")
83
+ return image
84
 
85
+ def build_gif(path):
86
+ image = Image.open(path)
87
+ frames = [f.copy().convert("RGB") for f in ImageSequence.Iterator(image)]
88
  frames = uniform_sample(frames, MAX_FRAMES)
89
+ return *frames
90
 
91
+ def build_audio(path):
92
+ audio, _ = librosa.load(path, sr=AUDIO_SR, mono=True)
93
+ return audio
94
 
95
  @spaces.GPU(duration=30)
96
+ def generate(input, instruction=DEFAULT_INPUT, sampling=False, temperature=0.7, top_p=0.8, top_k=100, repetition_penalty=1.05, max_tokens=512):
97
+ if not input: return "No input provided."
98
+
 
 
 
 
 
 
 
99
  extension = os.path.splitext(input)[1].lower()
100
+ filetype = next((k for k, v in filetypes.items() if extension in v), None)
101
+ if not filetype: return "Unsupported file type."
102
+
103
+ filename = os.path.basename(input)
104
+ prefix = input_prefixes[filetype].replace("β–ˆ", filename)
105
+ builder_map = {
106
+ "Image": build_image,
107
+ "GIF" : build_gif,
108
+ "Video": build_video,
109
+ "Audio": build_audio,
110
  }
111
 
112
+ instruction = f"{global_instruction}\n{prefix}\n{instruction}"
113
+ omni_content = builder_map[filetype](input)
114
+ msgs = [{ "role": "user", "content": [omni_content, instruction] }]
115
 
116
  print(msgs)
117
 
118
+ output = repo.chat(
119
+ msgs=msgs,
120
+ tokenizer=tokenizer,
121
+ sampling=sampling,
122
+ temperature= temperature,
123
+ top_p=top_p,
124
+ top_k=top_k,
125
+ repetition_penalty=repetition_penalty,
126
+ max_new_tokens=max_tokens,
127
+ omni_input=True,
128
+ use_image_id=False,
129
+ max_slice_nums=9
130
  )
131
 
132
  torch.cuda.empty_cache()
 
142
  with gr.Column():
143
  input = gr.File(label="Input", file_types=["image", "video", "audio"], type="filepath")
144
  instruction = gr.Textbox(lines=1, value=DEFAULT_INPUT, label="Instruction")
145
+ sampling = gr.Checkbox(value=True, label="Sampling")
146
+ temperature = gr.Slider(minimum=0, maximum=2, step=0.01, value=1, label="Temperature")
147
+ top_p = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.95, label="Top P")
148
+ top_k = gr.Slider(minimum=0, maximum=1000, step=1, value=50, label="Top K")
149
+ repetition_penalty = gr.Slider(minimum=0, maximum=2, step=0.01, value=1.05, label="Repetition Penalty")
150
  max_tokens = gr.Slider(minimum=1, maximum=4096, step=1, value=512, label="Max Tokens")
151
  submit = gr.Button("β–Ά")
152
  maintain = gr.Button("☁️")