Update README.md
Browse files
README.md
CHANGED
@@ -43,7 +43,14 @@ model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
|
|
43 |
# Download the vocab_remi.pkl file
|
44 |
tokenizer_path = hf_hub_download(repo_id=repo_id, filename="vocab_remi.pkl")
|
45 |
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
# Load the tokenizer dictionary
|
49 |
with open(tokenizer_path, "rb") as f:
|
@@ -57,12 +64,20 @@ model.load_state_dict(torch.load(model_path, map_location=device))
|
|
57 |
model.eval()
|
58 |
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
|
59 |
|
|
|
|
|
|
|
|
|
60 |
src = "A melodic electronic song with ambient elements, featuring piano, acoustic guitar, alto saxophone, string ensemble, and electric bass. Set in G minor with a 4/4 time signature, it moves at a lively Presto tempo. The composition evokes a blend of relaxation and darkness, with hints of happiness and a meditative quality."
|
|
|
|
|
61 |
inputs = tokenizer(src, return_tensors='pt', padding=True, truncation=True)
|
62 |
input_ids = nn.utils.rnn.pad_sequence(inputs.input_ids, batch_first=True, padding_value=0)
|
63 |
input_ids = input_ids.to(device)
|
64 |
attention_mask =nn.utils.rnn.pad_sequence(inputs.attention_mask, batch_first=True, padding_value=0)
|
65 |
attention_mask = attention_mask.to(device)
|
|
|
|
|
66 |
output = model.generate(input_ids, attention_mask, max_len=2000,temperature = 1.0)
|
67 |
output_list = output[0].tolist()
|
68 |
generated_midi = r_tokenizer.decode(output_list)
|
|
|
43 |
# Download the vocab_remi.pkl file
|
44 |
tokenizer_path = hf_hub_download(repo_id=repo_id, filename="vocab_remi.pkl")
|
45 |
|
46 |
+
if torch.cuda.is_available():
|
47 |
+
device = 'cuda'
|
48 |
+
elif torch.backends.mps.is_available():
|
49 |
+
device = 'mps'
|
50 |
+
else:
|
51 |
+
device = 'cpu'
|
52 |
+
|
53 |
+
print(f"Using device: {device}")
|
54 |
|
55 |
# Load the tokenizer dictionary
|
56 |
with open(tokenizer_path, "rb") as f:
|
|
|
64 |
model.eval()
|
65 |
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
|
66 |
|
67 |
+
print('Model loaded.')
|
68 |
+
|
69 |
+
|
70 |
+
# Enter the text prompt and tokenize it
|
71 |
src = "A melodic electronic song with ambient elements, featuring piano, acoustic guitar, alto saxophone, string ensemble, and electric bass. Set in G minor with a 4/4 time signature, it moves at a lively Presto tempo. The composition evokes a blend of relaxation and darkness, with hints of happiness and a meditative quality."
|
72 |
+
print('Generating for prompt: ' + src)
|
73 |
+
|
74 |
inputs = tokenizer(src, return_tensors='pt', padding=True, truncation=True)
|
75 |
input_ids = nn.utils.rnn.pad_sequence(inputs.input_ids, batch_first=True, padding_value=0)
|
76 |
input_ids = input_ids.to(device)
|
77 |
attention_mask =nn.utils.rnn.pad_sequence(inputs.attention_mask, batch_first=True, padding_value=0)
|
78 |
attention_mask = attention_mask.to(device)
|
79 |
+
|
80 |
+
# Generate the midi
|
81 |
output = model.generate(input_ids, attention_mask, max_len=2000,temperature = 1.0)
|
82 |
output_list = output[0].tolist()
|
83 |
generated_midi = r_tokenizer.decode(output_list)
|