ElectricAlexis commited on
Commit
c86d7b7
·
verified ·
1 Parent(s): d900b7e

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +388 -388
inference.py CHANGED
@@ -1,388 +1,388 @@
1
- import os
2
- import time
3
- import torch
4
- import re
5
- import difflib
6
- from utils import *
7
- from config import *
8
- from transformers import GPT2Config
9
- from abctoolkit.utils import Exclaim_re, Quote_re, SquareBracket_re, Barline_regexPattern
10
- from abctoolkit.transpose import Note_list, Pitch_sign_list
11
- from abctoolkit.duration import calculate_bartext_duration
12
- import requests
13
- import torch
14
- from huggingface_hub import hf_hub_download
15
- import logging
16
-
17
- # Setup logging
18
- logging.basicConfig(level=logging.INFO)
19
- logger = logging.getLogger(__name__)
20
-
21
- Note_list = Note_list + ['z', 'x']
22
-
23
- if torch.cuda.is_available():
24
- device = torch.device("cuda")
25
- else:
26
- device = torch.device("cpu")
27
-
28
- patchilizer = Patchilizer()
29
-
30
- patch_config = GPT2Config(num_hidden_layers=PATCH_NUM_LAYERS,
31
- max_length=PATCH_LENGTH,
32
- max_position_embeddings=PATCH_LENGTH,
33
- n_embd=HIDDEN_SIZE,
34
- num_attention_heads=HIDDEN_SIZE // 64,
35
- vocab_size=1)
36
- byte_config = GPT2Config(num_hidden_layers=CHAR_NUM_LAYERS,
37
- max_length=PATCH_SIZE + 1,
38
- max_position_embeddings=PATCH_SIZE + 1,
39
- hidden_size=HIDDEN_SIZE,
40
- num_attention_heads=HIDDEN_SIZE // 64,
41
- vocab_size=128)
42
-
43
- model = NotaGenLMHeadModel(encoder_config=patch_config, decoder_config=byte_config).to(device)
44
-
45
-
46
- def download_model_weights():
47
- weights_path = "weights_notagenx_p_size_16_p_length_1024_p_layers_20_h_size_1280.pth"
48
- local_weights_path = os.path.join(os.getcwd(), weights_path)
49
-
50
- # Check if weights already exist locally
51
- if os.path.exists(local_weights_path):
52
- logger.info(f"Model weights already exist at {local_weights_path}")
53
- return local_weights_path
54
-
55
- logger.info("Downloading model weights from HuggingFace Hub...")
56
- try:
57
- # Download from HuggingFace
58
- downloaded_path = hf_hub_download(
59
- repo_id="ElectricAlexis/NotaGen",
60
- filename=weights_path,
61
- local_dir=os.getcwd(),
62
- local_dir_use_symlinks=False
63
- )
64
- logger.info(f"Model weights downloaded successfully to {downloaded_path}")
65
- return downloaded_path
66
- except Exception as e:
67
- logger.error(f"Error downloading model weights: {str(e)}")
68
- raise RuntimeError(f"Failed to download model weights: {str(e)}")
69
-
70
-
71
- def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True):
72
- """
73
- Prepare model for k-bit training.
74
- Features include:
75
- 1. Convert model to mixed precision (FP16).
76
- 2. Disable unnecessary gradient computations.
77
- 3. Enable gradient checkpointing (optional).
78
- """
79
- # Convert model to mixed precision
80
- model = model.to(dtype=torch.float16)
81
-
82
- # Disable gradients for embedding layers
83
- for param in model.parameters():
84
- if param.dtype == torch.float32:
85
- param.requires_grad = False
86
-
87
- # Enable gradient checkpointing
88
- if use_gradient_checkpointing:
89
- model.gradient_checkpointing_enable()
90
-
91
- return model
92
-
93
-
94
- model = prepare_model_for_kbit_training(
95
- model,
96
- use_gradient_checkpointing=False
97
- )
98
-
99
- print("Parameter Number: " + str(sum(p.numel() for p in model.parameters() if p.requires_grad)))
100
-
101
- # Download weights at startup
102
- model_weights_path = download_model_weights()
103
- checkpoint = torch.load(model_weights_path, map_location=torch.device(device))
104
- model.load_state_dict(checkpoint['model'], strict=False)
105
-
106
- model = model.to(device)
107
- model.eval()
108
-
109
-
110
- def postprocess_inst_names(abc_text):
111
- with open('standard_inst_names.txt', 'r', encoding='utf-8') as f:
112
- standard_instruments_list = [line.strip() for line in f if line.strip()]
113
-
114
- with open('instrument_mapping.json', 'r', encoding='utf-8') as f:
115
- instrument_mapping = json.load(f)
116
-
117
- abc_lines = abc_text.split('\n')
118
- abc_lines = list(filter(None, abc_lines))
119
- abc_lines = [line + '\n' for line in abc_lines]
120
-
121
- for i, line in enumerate(abc_lines):
122
- if line.startswith('V:') and 'nm=' in line:
123
- match = re.search(r'nm="([^"]*)"', line)
124
- if match:
125
- inst_name = match.group(1)
126
-
127
- # Check if the instrument name is already standard
128
- if inst_name in standard_instruments_list:
129
- continue
130
-
131
- # Find the most similar key in instrument_mapping
132
- matching_key = difflib.get_close_matches(inst_name, list(instrument_mapping.keys()), n=1, cutoff=0.6)
133
-
134
- if matching_key:
135
- # Replace the instrument name with the standardized version
136
- replacement = instrument_mapping[matching_key[0]]
137
- new_line = line.replace(f'nm="{inst_name}"', f'nm="{replacement}"')
138
- abc_lines[i] = new_line
139
-
140
- # Combine the lines back into a single string
141
- processed_abc_text = ''.join(abc_lines)
142
- return processed_abc_text
143
-
144
-
145
- def complete_brackets(s):
146
- stack = []
147
- bracket_map = {'{': '}', '[': ']', '(': ')'}
148
-
149
- # Iterate through each character, handle bracket matching
150
- for char in s:
151
- if char in bracket_map:
152
- stack.append(char)
153
- elif char in bracket_map.values():
154
- # Find the corresponding left bracket
155
- for key, value in bracket_map.items():
156
- if value == char:
157
- if stack and stack[-1] == key:
158
- stack.pop()
159
- break # Found matching right bracket, process next character
160
-
161
- # Complete missing right brackets (in reverse order of remaining left brackets in stack)
162
- completion = ''.join(bracket_map[c] for c in reversed(stack))
163
- return s + completion
164
-
165
-
166
- def rest_unreduce(abc_lines):
167
- tunebody_index = None
168
- for i in range(len(abc_lines)):
169
- if abc_lines[i].startswith('%%score'):
170
- abc_lines[i] = complete_brackets(abc_lines[i])
171
- if '[V:' in abc_lines[i]:
172
- tunebody_index = i
173
- break
174
-
175
- metadata_lines = abc_lines[: tunebody_index]
176
- tunebody_lines = abc_lines[tunebody_index:]
177
-
178
- part_symbol_list = []
179
- voice_group_list = []
180
- for line in metadata_lines:
181
- if line.startswith('%%score'):
182
- for round_bracket_match in re.findall(r'\((.*?)\)', line):
183
- voice_group_list.append(round_bracket_match.split())
184
- existed_voices = [item for sublist in voice_group_list for item in sublist]
185
- if line.startswith('V:'):
186
- symbol = line.split()[0]
187
- part_symbol_list.append(symbol)
188
- if symbol[2:] not in existed_voices:
189
- voice_group_list.append([symbol[2:]])
190
- z_symbol_list = [] # voices that use z as rest
191
- x_symbol_list = [] # voices that use x as rest
192
- for voice_group in voice_group_list:
193
- z_symbol_list.append('V:' + voice_group[0])
194
- for j in range(1, len(voice_group)):
195
- x_symbol_list.append('V:' + voice_group[j])
196
-
197
- part_symbol_list.sort(key=lambda x: int(x[2:]))
198
-
199
- unreduced_tunebody_lines = []
200
-
201
- for i, line in enumerate(tunebody_lines):
202
- unreduced_line = ''
203
-
204
- line = re.sub(r'^\[r:[^\]]*\]', '', line)
205
-
206
- pattern = r'\[V:(\d+)\](.*?)(?=\[V:|$)'
207
- matches = re.findall(pattern, line)
208
-
209
- line_bar_dict = {}
210
- for match in matches:
211
- key = f'V:{match[0]}'
212
- value = match[1]
213
- line_bar_dict[key] = value
214
-
215
- # calculate duration and collect barline
216
- dur_dict = {}
217
- for symbol, bartext in line_bar_dict.items():
218
- right_barline = ''.join(re.split(Barline_regexPattern, bartext)[-2:])
219
- bartext = bartext[:-len(right_barline)]
220
- try:
221
- bar_dur = calculate_bartext_duration(bartext)
222
- except:
223
- bar_dur = None
224
- if bar_dur is not None:
225
- if bar_dur not in dur_dict.keys():
226
- dur_dict[bar_dur] = 1
227
- else:
228
- dur_dict[bar_dur] += 1
229
-
230
- try:
231
- ref_dur = max(dur_dict, key=dur_dict.get)
232
- except:
233
- pass # use last ref_dur
234
-
235
- if i == 0:
236
- prefix_left_barline = line.split('[V:')[0]
237
- else:
238
- prefix_left_barline = ''
239
-
240
- for symbol in part_symbol_list:
241
- if symbol in line_bar_dict.keys():
242
- symbol_bartext = line_bar_dict[symbol]
243
- else:
244
- if symbol in z_symbol_list:
245
- symbol_bartext = prefix_left_barline + 'z' + str(ref_dur) + right_barline
246
- elif symbol in x_symbol_list:
247
- symbol_bartext = prefix_left_barline + 'x' + str(ref_dur) + right_barline
248
- unreduced_line += '[' + symbol + ']' + symbol_bartext
249
-
250
- unreduced_tunebody_lines.append(unreduced_line + '\n')
251
-
252
- unreduced_lines = metadata_lines + unreduced_tunebody_lines
253
-
254
- return unreduced_lines
255
-
256
-
257
- def inference_patch(period, composer, instrumentation):
258
- prompt_lines = [
259
- '%' + period + '\n',
260
- '%' + composer + '\n',
261
- '%' + instrumentation + '\n']
262
-
263
- while True:
264
-
265
- failure_flag = False
266
-
267
- bos_patch = [patchilizer.bos_token_id] * (PATCH_SIZE - 1) + [patchilizer.eos_token_id]
268
-
269
- start_time = time.time()
270
-
271
- prompt_patches = patchilizer.patchilize_metadata(prompt_lines)
272
- byte_list = list(''.join(prompt_lines))
273
- context_tunebody_byte_list = []
274
- metadata_byte_list = []
275
-
276
- print(''.join(byte_list), end='')
277
-
278
- prompt_patches = [[ord(c) for c in patch] + [patchilizer.special_token_id] * (PATCH_SIZE - len(patch)) for patch
279
- in prompt_patches]
280
- prompt_patches.insert(0, bos_patch)
281
-
282
- input_patches = torch.tensor(prompt_patches, device=device).reshape(1, -1)
283
-
284
- end_flag = False
285
- cut_index = None
286
-
287
- tunebody_flag = False
288
-
289
- with torch.inference_mode():
290
-
291
- while True:
292
- with torch.autocast(device_type='cuda', dtype=torch.float16):
293
- predicted_patch = model.generate(input_patches.unsqueeze(0),
294
- top_k=TOP_K,
295
- top_p=TOP_P,
296
- temperature=TEMPERATURE)
297
- if not tunebody_flag and patchilizer.decode([predicted_patch]).startswith(
298
- '[r:'): # 初次进入tunebody,必须以[r:0/开头
299
- tunebody_flag = True
300
- r0_patch = torch.tensor([ord(c) for c in '[r:0/']).unsqueeze(0).to(device)
301
- temp_input_patches = torch.concat([input_patches, r0_patch], axis=-1)
302
- predicted_patch = model.generate(temp_input_patches.unsqueeze(0),
303
- top_k=TOP_K,
304
- top_p=TOP_P,
305
- temperature=TEMPERATURE)
306
- predicted_patch = [ord(c) for c in '[r:0/'] + predicted_patch
307
- if predicted_patch[0] == patchilizer.bos_token_id and predicted_patch[1] == patchilizer.eos_token_id:
308
- end_flag = True
309
- break
310
- next_patch = patchilizer.decode([predicted_patch])
311
-
312
- for char in next_patch:
313
- byte_list.append(char)
314
- if tunebody_flag:
315
- context_tunebody_byte_list.append(char)
316
- else:
317
- metadata_byte_list.append(char)
318
- print(char, end='')
319
-
320
- patch_end_flag = False
321
- for j in range(len(predicted_patch)):
322
- if patch_end_flag:
323
- predicted_patch[j] = patchilizer.special_token_id
324
- if predicted_patch[j] == patchilizer.eos_token_id:
325
- patch_end_flag = True
326
-
327
- predicted_patch = torch.tensor([predicted_patch], device=device) # (1, 16)
328
- input_patches = torch.cat([input_patches, predicted_patch], dim=1) # (1, 16 * patch_len)
329
-
330
- if len(byte_list) > 102400:
331
- failure_flag = True
332
- break
333
- if time.time() - start_time > 10 * 60:
334
- failure_flag = True
335
- break
336
-
337
- if input_patches.shape[1] >= PATCH_LENGTH * PATCH_SIZE and not end_flag:
338
- print('Stream generating...')
339
-
340
- metadata = ''.join(metadata_byte_list)
341
- context_tunebody = ''.join(context_tunebody_byte_list)
342
-
343
- if '\n' not in context_tunebody:
344
- break # Generated content is all metadata, abandon
345
-
346
- context_tunebody_lines = context_tunebody.strip().split('\n')
347
-
348
- if not context_tunebody.endswith('\n'):
349
- context_tunebody_lines = [context_tunebody_lines[i] + '\n' for i in
350
- range(len(context_tunebody_lines) - 1)] + [context_tunebody_lines[-1]]
351
- else:
352
- context_tunebody_lines = [context_tunebody_lines[i] + '\n' for i in
353
- range(len(context_tunebody_lines))]
354
-
355
- cut_index = len(context_tunebody_lines) // 2
356
- abc_code_slice = metadata + ''.join(context_tunebody_lines[-cut_index:])
357
-
358
- input_patches = patchilizer.encode_generate(abc_code_slice)
359
-
360
- input_patches = [item for sublist in input_patches for item in sublist]
361
- input_patches = torch.tensor([input_patches], device=device)
362
- input_patches = input_patches.reshape(1, -1)
363
-
364
- context_tunebody_byte_list = list(''.join(context_tunebody_lines[-cut_index:]))
365
-
366
- if not failure_flag:
367
- abc_text = ''.join(byte_list)
368
-
369
- # unreduce
370
- abc_lines = abc_text.split('\n')
371
- abc_lines = list(filter(None, abc_lines))
372
- abc_lines = [line + '\n' for line in abc_lines]
373
- try:
374
- unreduced_abc_lines = rest_unreduce(abc_lines)
375
- except:
376
- failure_flag = True
377
- pass
378
- else:
379
- unreduced_abc_lines = [line for line in unreduced_abc_lines if
380
- not (line.startswith('%') and not line.startswith('%%'))]
381
- unreduced_abc_lines = ['X:1\n'] + unreduced_abc_lines
382
- unreduced_abc_text = ''.join(unreduced_abc_lines)
383
- return unreduced_abc_text
384
-
385
-
386
- if __name__ == '__main__':
387
- inference_patch('Classical', 'Beethoven, Ludwig van', 'Orchestral')
388
-
 
1
+ import os
2
+ import time
3
+ import torch
4
+ import re
5
+ import difflib
6
+ from utils import *
7
+ from config import *
8
+ from transformers import GPT2Config
9
+ from abctoolkit.utils import Exclaim_re, Quote_re, SquareBracket_re, Barline_regexPattern
10
+ from abctoolkit.transpose import Note_list, Pitch_sign_list
11
+ from abctoolkit.duration import calculate_bartext_duration
12
+ import requests
13
+ import torch
14
+ from huggingface_hub import hf_hub_download
15
+ import logging
16
+
17
+ # Setup logging
18
+ logging.basicConfig(level=logging.INFO)
19
+ logger = logging.getLogger(__name__)
20
+
21
+ Note_list = Note_list + ['z', 'x']
22
+
23
+ if torch.cuda.is_available():
24
+ device = torch.device("cuda")
25
+ else:
26
+ device = torch.device("cpu")
27
+
28
+ patchilizer = Patchilizer()
29
+
30
+ patch_config = GPT2Config(num_hidden_layers=PATCH_NUM_LAYERS,
31
+ max_length=PATCH_LENGTH,
32
+ max_position_embeddings=PATCH_LENGTH,
33
+ n_embd=HIDDEN_SIZE,
34
+ num_attention_heads=HIDDEN_SIZE // 64,
35
+ vocab_size=1)
36
+ byte_config = GPT2Config(num_hidden_layers=CHAR_NUM_LAYERS,
37
+ max_length=PATCH_SIZE + 1,
38
+ max_position_embeddings=PATCH_SIZE + 1,
39
+ hidden_size=HIDDEN_SIZE,
40
+ num_attention_heads=HIDDEN_SIZE // 64,
41
+ vocab_size=128)
42
+
43
+ model = NotaGenLMHeadModel(encoder_config=patch_config, decoder_config=byte_config).to(device)
44
+
45
+
46
+ def download_model_weights():
47
+ weights_path = "weights_notagenx_p_size_16_p_length_1024_p_layers_20_h_size_1280.pth"
48
+ local_weights_path = os.path.join(os.getcwd(), weights_path)
49
+
50
+ # Check if weights already exist locally
51
+ if os.path.exists(local_weights_path):
52
+ logger.info(f"Model weights already exist at {local_weights_path}")
53
+ return local_weights_path
54
+
55
+ logger.info("Downloading model weights from HuggingFace Hub...")
56
+ try:
57
+ # Download from HuggingFace
58
+ downloaded_path = hf_hub_download(
59
+ repo_id="ElectricAlexis/NotaGen",
60
+ filename=weights_path,
61
+ local_dir=os.getcwd(),
62
+ local_dir_use_symlinks=False
63
+ )
64
+ logger.info(f"Model weights downloaded successfully to {downloaded_path}")
65
+ return downloaded_path
66
+ except Exception as e:
67
+ logger.error(f"Error downloading model weights: {str(e)}")
68
+ raise RuntimeError(f"Failed to download model weights: {str(e)}")
69
+
70
+
71
+ def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True):
72
+ """
73
+ Prepare model for k-bit training.
74
+ Features include:
75
+ 1. Convert model to mixed precision (FP16).
76
+ 2. Disable unnecessary gradient computations.
77
+ 3. Enable gradient checkpointing (optional).
78
+ """
79
+ # Convert model to mixed precision
80
+ model = model.to(dtype=torch.float16)
81
+
82
+ # Disable gradients for embedding layers
83
+ for param in model.parameters():
84
+ if param.dtype == torch.float32:
85
+ param.requires_grad = False
86
+
87
+ # Enable gradient checkpointing
88
+ if use_gradient_checkpointing:
89
+ model.gradient_checkpointing_enable()
90
+
91
+ return model
92
+
93
+
94
+ model = prepare_model_for_kbit_training(
95
+ model,
96
+ use_gradient_checkpointing=False
97
+ )
98
+
99
+ print("Parameter Number: " + str(sum(p.numel() for p in model.parameters() if p.requires_grad)))
100
+
101
+ # Download weights at startup
102
+ model_weights_path = download_model_weights()
103
+ checkpoint = torch.load(model_weights_path, weights_only=True, map_location=torch.device(device))
104
+ model.load_state_dict(checkpoint['model'], strict=False)
105
+
106
+ model = model.to(device)
107
+ model.eval()
108
+
109
+
110
+ def postprocess_inst_names(abc_text):
111
+ with open('standard_inst_names.txt', 'r', encoding='utf-8') as f:
112
+ standard_instruments_list = [line.strip() for line in f if line.strip()]
113
+
114
+ with open('instrument_mapping.json', 'r', encoding='utf-8') as f:
115
+ instrument_mapping = json.load(f)
116
+
117
+ abc_lines = abc_text.split('\n')
118
+ abc_lines = list(filter(None, abc_lines))
119
+ abc_lines = [line + '\n' for line in abc_lines]
120
+
121
+ for i, line in enumerate(abc_lines):
122
+ if line.startswith('V:') and 'nm=' in line:
123
+ match = re.search(r'nm="([^"]*)"', line)
124
+ if match:
125
+ inst_name = match.group(1)
126
+
127
+ # Check if the instrument name is already standard
128
+ if inst_name in standard_instruments_list:
129
+ continue
130
+
131
+ # Find the most similar key in instrument_mapping
132
+ matching_key = difflib.get_close_matches(inst_name, list(instrument_mapping.keys()), n=1, cutoff=0.6)
133
+
134
+ if matching_key:
135
+ # Replace the instrument name with the standardized version
136
+ replacement = instrument_mapping[matching_key[0]]
137
+ new_line = line.replace(f'nm="{inst_name}"', f'nm="{replacement}"')
138
+ abc_lines[i] = new_line
139
+
140
+ # Combine the lines back into a single string
141
+ processed_abc_text = ''.join(abc_lines)
142
+ return processed_abc_text
143
+
144
+
145
+ def complete_brackets(s):
146
+ stack = []
147
+ bracket_map = {'{': '}', '[': ']', '(': ')'}
148
+
149
+ # Iterate through each character, handle bracket matching
150
+ for char in s:
151
+ if char in bracket_map:
152
+ stack.append(char)
153
+ elif char in bracket_map.values():
154
+ # Find the corresponding left bracket
155
+ for key, value in bracket_map.items():
156
+ if value == char:
157
+ if stack and stack[-1] == key:
158
+ stack.pop()
159
+ break # Found matching right bracket, process next character
160
+
161
+ # Complete missing right brackets (in reverse order of remaining left brackets in stack)
162
+ completion = ''.join(bracket_map[c] for c in reversed(stack))
163
+ return s + completion
164
+
165
+
166
+ def rest_unreduce(abc_lines):
167
+ tunebody_index = None
168
+ for i in range(len(abc_lines)):
169
+ if abc_lines[i].startswith('%%score'):
170
+ abc_lines[i] = complete_brackets(abc_lines[i])
171
+ if '[V:' in abc_lines[i]:
172
+ tunebody_index = i
173
+ break
174
+
175
+ metadata_lines = abc_lines[: tunebody_index]
176
+ tunebody_lines = abc_lines[tunebody_index:]
177
+
178
+ part_symbol_list = []
179
+ voice_group_list = []
180
+ for line in metadata_lines:
181
+ if line.startswith('%%score'):
182
+ for round_bracket_match in re.findall(r'\((.*?)\)', line):
183
+ voice_group_list.append(round_bracket_match.split())
184
+ existed_voices = [item for sublist in voice_group_list for item in sublist]
185
+ if line.startswith('V:'):
186
+ symbol = line.split()[0]
187
+ part_symbol_list.append(symbol)
188
+ if symbol[2:] not in existed_voices:
189
+ voice_group_list.append([symbol[2:]])
190
+ z_symbol_list = [] # voices that use z as rest
191
+ x_symbol_list = [] # voices that use x as rest
192
+ for voice_group in voice_group_list:
193
+ z_symbol_list.append('V:' + voice_group[0])
194
+ for j in range(1, len(voice_group)):
195
+ x_symbol_list.append('V:' + voice_group[j])
196
+
197
+ part_symbol_list.sort(key=lambda x: int(x[2:]))
198
+
199
+ unreduced_tunebody_lines = []
200
+
201
+ for i, line in enumerate(tunebody_lines):
202
+ unreduced_line = ''
203
+
204
+ line = re.sub(r'^\[r:[^\]]*\]', '', line)
205
+
206
+ pattern = r'\[V:(\d+)\](.*?)(?=\[V:|$)'
207
+ matches = re.findall(pattern, line)
208
+
209
+ line_bar_dict = {}
210
+ for match in matches:
211
+ key = f'V:{match[0]}'
212
+ value = match[1]
213
+ line_bar_dict[key] = value
214
+
215
+ # calculate duration and collect barline
216
+ dur_dict = {}
217
+ for symbol, bartext in line_bar_dict.items():
218
+ right_barline = ''.join(re.split(Barline_regexPattern, bartext)[-2:])
219
+ bartext = bartext[:-len(right_barline)]
220
+ try:
221
+ bar_dur = calculate_bartext_duration(bartext)
222
+ except:
223
+ bar_dur = None
224
+ if bar_dur is not None:
225
+ if bar_dur not in dur_dict.keys():
226
+ dur_dict[bar_dur] = 1
227
+ else:
228
+ dur_dict[bar_dur] += 1
229
+
230
+ try:
231
+ ref_dur = max(dur_dict, key=dur_dict.get)
232
+ except:
233
+ pass # use last ref_dur
234
+
235
+ if i == 0:
236
+ prefix_left_barline = line.split('[V:')[0]
237
+ else:
238
+ prefix_left_barline = ''
239
+
240
+ for symbol in part_symbol_list:
241
+ if symbol in line_bar_dict.keys():
242
+ symbol_bartext = line_bar_dict[symbol]
243
+ else:
244
+ if symbol in z_symbol_list:
245
+ symbol_bartext = prefix_left_barline + 'z' + str(ref_dur) + right_barline
246
+ elif symbol in x_symbol_list:
247
+ symbol_bartext = prefix_left_barline + 'x' + str(ref_dur) + right_barline
248
+ unreduced_line += '[' + symbol + ']' + symbol_bartext
249
+
250
+ unreduced_tunebody_lines.append(unreduced_line + '\n')
251
+
252
+ unreduced_lines = metadata_lines + unreduced_tunebody_lines
253
+
254
+ return unreduced_lines
255
+
256
+
257
+ def inference_patch(period, composer, instrumentation):
258
+ prompt_lines = [
259
+ '%' + period + '\n',
260
+ '%' + composer + '\n',
261
+ '%' + instrumentation + '\n']
262
+
263
+ while True:
264
+
265
+ failure_flag = False
266
+
267
+ bos_patch = [patchilizer.bos_token_id] * (PATCH_SIZE - 1) + [patchilizer.eos_token_id]
268
+
269
+ start_time = time.time()
270
+
271
+ prompt_patches = patchilizer.patchilize_metadata(prompt_lines)
272
+ byte_list = list(''.join(prompt_lines))
273
+ context_tunebody_byte_list = []
274
+ metadata_byte_list = []
275
+
276
+ print(''.join(byte_list), end='')
277
+
278
+ prompt_patches = [[ord(c) for c in patch] + [patchilizer.special_token_id] * (PATCH_SIZE - len(patch)) for patch
279
+ in prompt_patches]
280
+ prompt_patches.insert(0, bos_patch)
281
+
282
+ input_patches = torch.tensor(prompt_patches, device=device).reshape(1, -1)
283
+
284
+ end_flag = False
285
+ cut_index = None
286
+
287
+ tunebody_flag = False
288
+
289
+ with torch.inference_mode():
290
+
291
+ while True:
292
+ with torch.autocast(device_type='cuda', dtype=torch.float16):
293
+ predicted_patch = model.generate(input_patches.unsqueeze(0),
294
+ top_k=TOP_K,
295
+ top_p=TOP_P,
296
+ temperature=TEMPERATURE)
297
+ if not tunebody_flag and patchilizer.decode([predicted_patch]).startswith(
298
+ '[r:'): # 初次进入tunebody,必须以[r:0/开头
299
+ tunebody_flag = True
300
+ r0_patch = torch.tensor([ord(c) for c in '[r:0/']).unsqueeze(0).to(device)
301
+ temp_input_patches = torch.concat([input_patches, r0_patch], axis=-1)
302
+ predicted_patch = model.generate(temp_input_patches.unsqueeze(0),
303
+ top_k=TOP_K,
304
+ top_p=TOP_P,
305
+ temperature=TEMPERATURE)
306
+ predicted_patch = [ord(c) for c in '[r:0/'] + predicted_patch
307
+ if predicted_patch[0] == patchilizer.bos_token_id and predicted_patch[1] == patchilizer.eos_token_id:
308
+ end_flag = True
309
+ break
310
+ next_patch = patchilizer.decode([predicted_patch])
311
+
312
+ for char in next_patch:
313
+ byte_list.append(char)
314
+ if tunebody_flag:
315
+ context_tunebody_byte_list.append(char)
316
+ else:
317
+ metadata_byte_list.append(char)
318
+ print(char, end='')
319
+
320
+ patch_end_flag = False
321
+ for j in range(len(predicted_patch)):
322
+ if patch_end_flag:
323
+ predicted_patch[j] = patchilizer.special_token_id
324
+ if predicted_patch[j] == patchilizer.eos_token_id:
325
+ patch_end_flag = True
326
+
327
+ predicted_patch = torch.tensor([predicted_patch], device=device) # (1, 16)
328
+ input_patches = torch.cat([input_patches, predicted_patch], dim=1) # (1, 16 * patch_len)
329
+
330
+ if len(byte_list) > 102400:
331
+ failure_flag = True
332
+ break
333
+ if time.time() - start_time > 10 * 60:
334
+ failure_flag = True
335
+ break
336
+
337
+ if input_patches.shape[1] >= PATCH_LENGTH * PATCH_SIZE and not end_flag:
338
+ print('Stream generating...')
339
+
340
+ metadata = ''.join(metadata_byte_list)
341
+ context_tunebody = ''.join(context_tunebody_byte_list)
342
+
343
+ if '\n' not in context_tunebody:
344
+ break # Generated content is all metadata, abandon
345
+
346
+ context_tunebody_lines = context_tunebody.strip().split('\n')
347
+
348
+ if not context_tunebody.endswith('\n'):
349
+ context_tunebody_lines = [context_tunebody_lines[i] + '\n' for i in
350
+ range(len(context_tunebody_lines) - 1)] + [context_tunebody_lines[-1]]
351
+ else:
352
+ context_tunebody_lines = [context_tunebody_lines[i] + '\n' for i in
353
+ range(len(context_tunebody_lines))]
354
+
355
+ cut_index = len(context_tunebody_lines) // 2
356
+ abc_code_slice = metadata + ''.join(context_tunebody_lines[-cut_index:])
357
+
358
+ input_patches = patchilizer.encode_generate(abc_code_slice)
359
+
360
+ input_patches = [item for sublist in input_patches for item in sublist]
361
+ input_patches = torch.tensor([input_patches], device=device)
362
+ input_patches = input_patches.reshape(1, -1)
363
+
364
+ context_tunebody_byte_list = list(''.join(context_tunebody_lines[-cut_index:]))
365
+
366
+ if not failure_flag:
367
+ abc_text = ''.join(byte_list)
368
+
369
+ # unreduce
370
+ abc_lines = abc_text.split('\n')
371
+ abc_lines = list(filter(None, abc_lines))
372
+ abc_lines = [line + '\n' for line in abc_lines]
373
+ try:
374
+ unreduced_abc_lines = rest_unreduce(abc_lines)
375
+ except:
376
+ failure_flag = True
377
+ pass
378
+ else:
379
+ unreduced_abc_lines = [line for line in unreduced_abc_lines if
380
+ not (line.startswith('%') and not line.startswith('%%'))]
381
+ unreduced_abc_lines = ['X:1\n'] + unreduced_abc_lines
382
+ unreduced_abc_text = ''.join(unreduced_abc_lines)
383
+ return unreduced_abc_text
384
+
385
+
386
+ if __name__ == '__main__':
387
+ inference_patch('Classical', 'Beethoven, Ludwig van', 'Orchestral')
388
+