PyTorch
music
text-to-music
symbolic-music
dorienh commited on
Commit
ee1cbd9
·
verified ·
1 Parent(s): 70e9e4d

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +16 -1
README.md CHANGED
@@ -43,7 +43,14 @@ model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
43
  # Download the vocab_remi.pkl file
44
  tokenizer_path = hf_hub_download(repo_id=repo_id, filename="vocab_remi.pkl")
45
 
46
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
 
 
 
 
 
 
47
 
48
  # Load the tokenizer dictionary
49
  with open(tokenizer_path, "rb") as f:
@@ -57,12 +64,20 @@ model.load_state_dict(torch.load(model_path, map_location=device))
57
  model.eval()
58
  tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
59
 
 
 
 
 
60
  src = "A melodic electronic song with ambient elements, featuring piano, acoustic guitar, alto saxophone, string ensemble, and electric bass. Set in G minor with a 4/4 time signature, it moves at a lively Presto tempo. The composition evokes a blend of relaxation and darkness, with hints of happiness and a meditative quality."
 
 
61
  inputs = tokenizer(src, return_tensors='pt', padding=True, truncation=True)
62
  input_ids = nn.utils.rnn.pad_sequence(inputs.input_ids, batch_first=True, padding_value=0)
63
  input_ids = input_ids.to(device)
64
  attention_mask =nn.utils.rnn.pad_sequence(inputs.attention_mask, batch_first=True, padding_value=0)
65
  attention_mask = attention_mask.to(device)
 
 
66
  output = model.generate(input_ids, attention_mask, max_len=2000,temperature = 1.0)
67
  output_list = output[0].tolist()
68
  generated_midi = r_tokenizer.decode(output_list)
 
43
  # Download the vocab_remi.pkl file
44
  tokenizer_path = hf_hub_download(repo_id=repo_id, filename="vocab_remi.pkl")
45
 
46
+ if torch.cuda.is_available():
47
+ device = 'cuda'
48
+ elif torch.backends.mps.is_available():
49
+ device = 'mps'
50
+ else:
51
+ device = 'cpu'
52
+
53
+ print(f"Using device: {device}")
54
 
55
  # Load the tokenizer dictionary
56
  with open(tokenizer_path, "rb") as f:
 
64
  model.eval()
65
  tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
66
 
67
+ print('Model loaded.')
68
+
69
+
70
+ # Enter the text prompt and tokenize it
71
  src = "A melodic electronic song with ambient elements, featuring piano, acoustic guitar, alto saxophone, string ensemble, and electric bass. Set in G minor with a 4/4 time signature, it moves at a lively Presto tempo. The composition evokes a blend of relaxation and darkness, with hints of happiness and a meditative quality."
72
+ print('Generating for prompt: ' + src)
73
+
74
  inputs = tokenizer(src, return_tensors='pt', padding=True, truncation=True)
75
  input_ids = nn.utils.rnn.pad_sequence(inputs.input_ids, batch_first=True, padding_value=0)
76
  input_ids = input_ids.to(device)
77
  attention_mask =nn.utils.rnn.pad_sequence(inputs.attention_mask, batch_first=True, padding_value=0)
78
  attention_mask = attention_mask.to(device)
79
+
80
+ # Generate the midi
81
  output = model.generate(input_ids, attention_mask, max_len=2000,temperature = 1.0)
82
  output_list = output[0].tolist()
83
  generated_midi = r_tokenizer.decode(output_list)