AndreasXi commited on
Commit
08a9c69
·
1 Parent(s): de50529
Files changed (1) hide show
  1. app.py +17 -16
app.py CHANGED
@@ -36,17 +36,17 @@ setup_eval_logging()
36
 
37
  OUTPUT_DIR = Path("./output/gradio")
38
  OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
39
- NUM_SAMPLE=1
40
 
41
- snapshot_download(repo_id="google/flan-t5-large")
42
- a = AutoModel.from_pretrained('bert-base-uncased')
43
- b = AutoModel.from_pretrained('roberta-base')
44
 
45
- snapshot_download(repo_id="AndreasXi/MeanAudio", local_dir="./weights",allow_patterns=["*.pt", "*.pth"] )
46
- _clap_ckpt_path='./weights/music_speech_audioset_epoch_15_esc_89.98.pt'
47
- laion_clap_model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base').cuda().eval()
48
 
49
- laion_clap_model.load_ckpt(_clap_ckpt_path, verbose=False)
50
 
51
 
52
  @spaces.GPU(duration=10)
@@ -116,14 +116,15 @@ def generate_audio_gradio(
116
  cfg_strength=cfg_strength,
117
  **{sampler_arg_name: sampler},
118
  )
119
- text_embed = laion_clap_model.get_text_embedding(prompt, use_tensor=True).squeeze()
120
- audio_embed = laion_clap_model.get_audio_embedding_from_data(audios[:,0,:].float().cpu(), use_tensor=True).squeeze()
121
- scores = torch.cosine_similarity(text_embed.expand(audio_embed.shape[0], -1),
122
- audio_embed,
123
- dim=-1)
124
- log.info(scores)
125
- log.info(torch.argmax(scores).item())
126
- audio = audios[torch.argmax(scores).item()].float().cpu()
 
127
  safe_prompt = (
128
  "".join(c for c in prompt if c.isalnum() or c in (" ", "_"))
129
  .rstrip()
 
36
 
37
  OUTPUT_DIR = Path("./output/gradio")
38
  OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
39
+ NUM_SAMPLE = 1
40
 
41
+ # snapshot_download(repo_id="google/flan-t5-large")
42
+ # a = AutoModel.from_pretrained('bert-base-uncased')
43
+ # b = AutoModel.from_pretrained('roberta-base')
44
 
45
+ # snapshot_download(repo_id="AndreasXi/MeanAudio", local_dir="./weights",allow_patterns=["*.pt", "*.pth"] )
46
+ # _clap_ckpt_path='./weights/music_speech_audioset_epoch_15_esc_89.98.pt'
47
+ # laion_clap_model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base').cuda().eval()
48
 
49
+ # laion_clap_model.load_ckpt(_clap_ckpt_path, verbose=False)
50
 
51
 
52
  @spaces.GPU(duration=10)
 
116
  cfg_strength=cfg_strength,
117
  **{sampler_arg_name: sampler},
118
  )
119
+ audio = audios[0].float().cpu()
120
+ # text_embed = laion_clap_model.get_text_embedding(prompt, use_tensor=True).squeeze()
121
+ # audio_embed = laion_clap_model.get_audio_embedding_from_data(audios[:,0,:].float().cpu(), use_tensor=True).squeeze()
122
+ # scores = torch.cosine_similarity(text_embed.expand(audio_embed.shape[0], -1),
123
+ # audio_embed,
124
+ # dim=-1)
125
+ # log.info(scores)
126
+ # log.info(torch.argmax(scores).item())
127
+ # audio = audios[torch.argmax(scores).item()].float().cpu()
128
  safe_prompt = (
129
  "".join(c for c in prompt if c.isalnum() or c in (" ", "_"))
130
  .rstrip()