update
Browse files
app.py
CHANGED
@@ -36,17 +36,17 @@ setup_eval_logging()
|
|
36 |
|
37 |
OUTPUT_DIR = Path("./output/gradio")
|
38 |
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
39 |
-
NUM_SAMPLE=1
|
40 |
|
41 |
-
snapshot_download(repo_id="google/flan-t5-large")
|
42 |
-
a = AutoModel.from_pretrained('bert-base-uncased')
|
43 |
-
b = AutoModel.from_pretrained('roberta-base')
|
44 |
|
45 |
-
snapshot_download(repo_id="AndreasXi/MeanAudio", local_dir="./weights",allow_patterns=["*.pt", "*.pth"] )
|
46 |
-
_clap_ckpt_path='./weights/music_speech_audioset_epoch_15_esc_89.98.pt'
|
47 |
-
laion_clap_model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base').cuda().eval()
|
48 |
|
49 |
-
laion_clap_model.load_ckpt(_clap_ckpt_path, verbose=False)
|
50 |
|
51 |
|
52 |
@spaces.GPU(duration=10)
|
@@ -116,14 +116,15 @@ def generate_audio_gradio(
|
|
116 |
cfg_strength=cfg_strength,
|
117 |
**{sampler_arg_name: sampler},
|
118 |
)
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
log.info(
|
126 |
-
|
|
|
127 |
safe_prompt = (
|
128 |
"".join(c for c in prompt if c.isalnum() or c in (" ", "_"))
|
129 |
.rstrip()
|
|
|
36 |
|
37 |
OUTPUT_DIR = Path("./output/gradio")
|
38 |
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
39 |
+
NUM_SAMPLE = 1
|
40 |
|
41 |
+
# snapshot_download(repo_id="google/flan-t5-large")
|
42 |
+
# a = AutoModel.from_pretrained('bert-base-uncased')
|
43 |
+
# b = AutoModel.from_pretrained('roberta-base')
|
44 |
|
45 |
+
# snapshot_download(repo_id="AndreasXi/MeanAudio", local_dir="./weights",allow_patterns=["*.pt", "*.pth"] )
|
46 |
+
# _clap_ckpt_path='./weights/music_speech_audioset_epoch_15_esc_89.98.pt'
|
47 |
+
# laion_clap_model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base').cuda().eval()
|
48 |
|
49 |
+
# laion_clap_model.load_ckpt(_clap_ckpt_path, verbose=False)
|
50 |
|
51 |
|
52 |
@spaces.GPU(duration=10)
|
|
|
116 |
cfg_strength=cfg_strength,
|
117 |
**{sampler_arg_name: sampler},
|
118 |
)
|
119 |
+
audio = audios[0].float().cpu()
|
120 |
+
# text_embed = laion_clap_model.get_text_embedding(prompt, use_tensor=True).squeeze()
|
121 |
+
# audio_embed = laion_clap_model.get_audio_embedding_from_data(audios[:,0,:].float().cpu(), use_tensor=True).squeeze()
|
122 |
+
# scores = torch.cosine_similarity(text_embed.expand(audio_embed.shape[0], -1),
|
123 |
+
# audio_embed,
|
124 |
+
# dim=-1)
|
125 |
+
# log.info(scores)
|
126 |
+
# log.info(torch.argmax(scores).item())
|
127 |
+
# audio = audios[torch.argmax(scores).item()].float().cpu()
|
128 |
safe_prompt = (
|
129 |
"".join(c for c in prompt if c.isalnum() or c in (" ", "_"))
|
130 |
.rstrip()
|