junxiliu commited on
Commit
ac87bd7
·
1 Parent(s): 0f2d92b

update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -14
app.py CHANGED
@@ -32,7 +32,7 @@ if torch.cuda.is_available():
32
  setup_eval_logging()
33
  OUTPUT_DIR = Path("./output/gradio")
34
  OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
35
- NUM_SAMPLE=8
36
  snapshot_download(repo_id="google/flan-t5-large")
37
  a=AutoModel.from_pretrained('bert-base-uncased')
38
  b=AutoModel.from_pretrained('roberta-base')
@@ -193,8 +193,6 @@ def generate_audio_gradio(
193
  log.info("Using FlowMatching for generation.")
194
  generation_func = generate_fm
195
  sampler_arg_name = "fm"
196
- all_audios=[]
197
- all_scores=[]
198
  audios = generation_func(
199
  [prompt]*NUM_SAMPLE,
200
  negative_text=[negative_prompt]*NUM_SAMPLE,
@@ -204,17 +202,12 @@ def generate_audio_gradio(
204
  cfg_strength=cfg_strength,
205
  **{sampler_arg_name: sampler},
206
  )
207
- for i in range(NUM_SAMPLE):
208
- audio = audios.float().cpu()[i]
209
- text_embed = laion_clap_model.get_text_embedding(prompt, use_tensor=True).squeeze()
210
- audio_embed = laion_clap_model.get_audio_embedding_from_data(audio, use_tensor=True).squeeze()
211
- score = torch.cosine_similarity(text_embed,
212
- audio_embed,
213
- dim=-1).mean()
214
- all_audios.append(audio)
215
- all_scores.append(score)
216
- winner_idx = torch.argmax(torch.tensor(all_scores)).item()
217
- audio=all_audios[winner_idx]
218
  safe_prompt = (
219
  "".join(c for c in prompt if c.isalnum() or c in (" ", "_"))
220
  .rstrip()
 
32
  setup_eval_logging()
33
  OUTPUT_DIR = Path("./output/gradio")
34
  OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
35
+ NUM_SAMPLE=5
36
  snapshot_download(repo_id="google/flan-t5-large")
37
  a=AutoModel.from_pretrained('bert-base-uncased')
38
  b=AutoModel.from_pretrained('roberta-base')
 
193
  log.info("Using FlowMatching for generation.")
194
  generation_func = generate_fm
195
  sampler_arg_name = "fm"
 
 
196
  audios = generation_func(
197
  [prompt]*NUM_SAMPLE,
198
  negative_text=[negative_prompt]*NUM_SAMPLE,
 
202
  cfg_strength=cfg_strength,
203
  **{sampler_arg_name: sampler},
204
  )
205
+ text_embed = laion_clap_model.get_text_embedding(prompt, use_tensor=True).squeeze()
206
+ audio_embed = laion_clap_model.get_audio_embedding_from_data(audios[:,0,:].float().cpu(), use_tensor=True).squeeze()
207
+ scores = torch.cosine_similarity(text_embed.expand(audio_embed.shape[0], -1),
208
+ audio_embed,
209
+ dim=-1)
210
+ audio=audios[torch.argmax(scores).item()].float().cpu()
 
 
 
 
 
211
  safe_prompt = (
212
  "".join(c for c in prompt if c.isalnum() or c in (" ", "_"))
213
  .rstrip()