update app.py
Browse files
app.py
CHANGED
@@ -32,7 +32,7 @@ if torch.cuda.is_available():
|
|
32 |
setup_eval_logging()
|
33 |
OUTPUT_DIR = Path("./output/gradio")
|
34 |
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
35 |
-
NUM_SAMPLE=
|
36 |
snapshot_download(repo_id="google/flan-t5-large")
|
37 |
a=AutoModel.from_pretrained('bert-base-uncased')
|
38 |
b=AutoModel.from_pretrained('roberta-base')
|
@@ -193,8 +193,6 @@ def generate_audio_gradio(
|
|
193 |
log.info("Using FlowMatching for generation.")
|
194 |
generation_func = generate_fm
|
195 |
sampler_arg_name = "fm"
|
196 |
-
all_audios=[]
|
197 |
-
all_scores=[]
|
198 |
audios = generation_func(
|
199 |
[prompt]*NUM_SAMPLE,
|
200 |
negative_text=[negative_prompt]*NUM_SAMPLE,
|
@@ -204,17 +202,12 @@ def generate_audio_gradio(
|
|
204 |
cfg_strength=cfg_strength,
|
205 |
**{sampler_arg_name: sampler},
|
206 |
)
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
dim=-1).mean()
|
214 |
-
all_audios.append(audio)
|
215 |
-
all_scores.append(score)
|
216 |
-
winner_idx = torch.argmax(torch.tensor(all_scores)).item()
|
217 |
-
audio=all_audios[winner_idx]
|
218 |
safe_prompt = (
|
219 |
"".join(c for c in prompt if c.isalnum() or c in (" ", "_"))
|
220 |
.rstrip()
|
|
|
32 |
setup_eval_logging()
|
33 |
OUTPUT_DIR = Path("./output/gradio")
|
34 |
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
35 |
+
NUM_SAMPLE=5
|
36 |
snapshot_download(repo_id="google/flan-t5-large")
|
37 |
a=AutoModel.from_pretrained('bert-base-uncased')
|
38 |
b=AutoModel.from_pretrained('roberta-base')
|
|
|
193 |
log.info("Using FlowMatching for generation.")
|
194 |
generation_func = generate_fm
|
195 |
sampler_arg_name = "fm"
|
|
|
|
|
196 |
audios = generation_func(
|
197 |
[prompt]*NUM_SAMPLE,
|
198 |
negative_text=[negative_prompt]*NUM_SAMPLE,
|
|
|
202 |
cfg_strength=cfg_strength,
|
203 |
**{sampler_arg_name: sampler},
|
204 |
)
|
205 |
+
text_embed = laion_clap_model.get_text_embedding(prompt, use_tensor=True).squeeze()
|
206 |
+
audio_embed = laion_clap_model.get_audio_embedding_from_data(audios[:,0,:].float().cpu(), use_tensor=True).squeeze()
|
207 |
+
scores = torch.cosine_similarity(text_embed.expand(audio_embed.shape[0], -1),
|
208 |
+
audio_embed,
|
209 |
+
dim=-1)
|
210 |
+
audio=audios[torch.argmax(scores).item()].float().cpu()
|
|
|
|
|
|
|
|
|
|
|
211 |
safe_prompt = (
|
212 |
"".join(c for c in prompt if c.isalnum() or c in (" ", "_"))
|
213 |
.rstrip()
|