|
import torch
|
|
import random
|
|
import bisect
|
|
import json
|
|
import re
|
|
from config import *
|
|
from transformers import GPT2Model, GPT2LMHeadModel, PreTrainedModel, BitsAndBytesConfig
|
|
from samplings import top_p_sampling, top_k_sampling, temperature_sampling
|
|
from tokenizers import Tokenizer
|
|
|
|
quantization_config = BitsAndBytesConfig(
|
|
load_in_8bit=True,
|
|
llm_int8_skip_modules=["patch_embedding"]
|
|
)
|
|
|
|
class Patchilizer:
|
|
def __init__(self, stream=PATCH_STREAM):
|
|
self.stream = stream
|
|
self.delimiters = ["|:", "::", ":|", "[|", "||", "|]", "|"]
|
|
self.regexPattern = '(' + '|'.join(map(re.escape, self.delimiters)) + ')'
|
|
self.bos_token_id = 1
|
|
self.eos_token_id = 2
|
|
self.special_token_id = 0
|
|
|
|
def split_bars(self, body_lines):
|
|
"""
|
|
Split a body of music into individual bars.
|
|
"""
|
|
new_bars = []
|
|
try:
|
|
for line in body_lines:
|
|
line_bars = re.split(self.regexPattern, line)
|
|
line_bars = list(filter(None, line_bars))
|
|
new_line_bars = []
|
|
|
|
if len(line_bars) == 1:
|
|
new_line_bars = line_bars
|
|
else:
|
|
if line_bars[0] in self.delimiters:
|
|
new_line_bars = [line_bars[i] + line_bars[i + 1] for i in range(0, len(line_bars), 2)]
|
|
else:
|
|
new_line_bars = [line_bars[0]] + [line_bars[i] + line_bars[i + 1] for i in range(1, len(line_bars), 2)]
|
|
if 'V' not in new_line_bars[-1]:
|
|
new_line_bars[-2] += new_line_bars[-1]
|
|
new_line_bars = new_line_bars[:-1]
|
|
new_bars += new_line_bars
|
|
except:
|
|
pass
|
|
|
|
return new_bars
|
|
|
|
def split_patches(self, abc_text, patch_size=PATCH_SIZE, generate_last=False):
|
|
if not generate_last and len(abc_text) % patch_size != 0:
|
|
abc_text += chr(self.eos_token_id)
|
|
patches = [abc_text[i : i + patch_size] for i in range(0, len(abc_text), patch_size)]
|
|
return patches
|
|
|
|
def patch2chars(self, patch):
|
|
"""
|
|
Convert a patch into a bar.
|
|
"""
|
|
bytes = ''
|
|
for idx in patch:
|
|
if idx == self.eos_token_id:
|
|
break
|
|
if idx < self.eos_token_id:
|
|
pass
|
|
bytes += chr(idx)
|
|
return bytes
|
|
|
|
|
|
def patchilize_metadata(self, metadata_lines):
|
|
|
|
metadata_patches = []
|
|
for line in metadata_lines:
|
|
metadata_patches += self.split_patches(line)
|
|
|
|
return metadata_patches
|
|
|
|
def patchilize_tunebody(self, tunebody_lines, encode_mode='train'):
|
|
|
|
tunebody_patches = []
|
|
bars = self.split_bars(tunebody_lines)
|
|
if encode_mode == 'train':
|
|
for bar in bars:
|
|
tunebody_patches += self.split_patches(bar)
|
|
elif encode_mode == 'generate':
|
|
for bar in bars[:-1]:
|
|
tunebody_patches += self.split_patches(bar)
|
|
tunebody_patches += self.split_patches(bars[-1], generate_last=True)
|
|
|
|
return tunebody_patches
|
|
|
|
def encode(self, abc_text, patch_length=PATCH_LENGTH, patch_size=PATCH_SIZE, add_special_patches=True, cut=True):
|
|
|
|
lines = abc_text.split('\n')
|
|
lines = list(filter(None, lines))
|
|
lines = [line + '\n' for line in lines]
|
|
|
|
tunebody_index = -1
|
|
for i, line in enumerate(lines):
|
|
if line.startswith('[r:'):
|
|
tunebody_index = i
|
|
break
|
|
|
|
metadata_lines = lines[: tunebody_index]
|
|
tunebody_lines = lines[tunebody_index:]
|
|
|
|
metadata_patches = self.patchilize_metadata(metadata_lines)
|
|
tunebody_patches = self.patchilize_tunebody(tunebody_lines, encode_mode='train')
|
|
|
|
if add_special_patches:
|
|
bos_patch = chr(self.bos_token_id) * (patch_size - 1) + chr(self.eos_token_id)
|
|
eos_patch = chr(self.bos_token_id) + chr(self.eos_token_id) * (patch_size - 1)
|
|
|
|
metadata_patches = [bos_patch] + metadata_patches
|
|
tunebody_patches = tunebody_patches + [eos_patch]
|
|
|
|
if self.stream:
|
|
if len(metadata_patches) + len(tunebody_patches) > patch_length:
|
|
available_cut_indexes = [0] + [index + 1 for index, patch in enumerate(tunebody_patches) if
|
|
'\n' in patch]
|
|
line_index_for_cut_index = list(range(len(available_cut_indexes)))
|
|
end_index = len(metadata_patches) + len(tunebody_patches) - patch_length
|
|
biggest_index = bisect.bisect_left(available_cut_indexes, end_index)
|
|
available_cut_indexes = available_cut_indexes[:biggest_index + 1]
|
|
|
|
if len(available_cut_indexes) == 1:
|
|
choices = ['head']
|
|
elif len(available_cut_indexes) == 2:
|
|
choices = ['head', 'tail']
|
|
else:
|
|
choices = ['head', 'tail', 'middle']
|
|
choice = random.choice(choices)
|
|
if choice == 'head':
|
|
patches = metadata_patches + tunebody_patches[0:]
|
|
else:
|
|
if choice == 'tail':
|
|
cut_index = len(available_cut_indexes) - 1
|
|
else:
|
|
cut_index = random.choice(range(1, len(available_cut_indexes) - 1))
|
|
|
|
line_index = line_index_for_cut_index[cut_index]
|
|
stream_tunebody_lines = tunebody_lines[line_index:]
|
|
|
|
stream_tunebody_patches = self.patchilize_tunebody(stream_tunebody_lines, encode_mode='train')
|
|
if add_special_patches:
|
|
stream_tunebody_patches = stream_tunebody_patches + [eos_patch]
|
|
patches = metadata_patches + stream_tunebody_patches
|
|
else:
|
|
patches = metadata_patches + tunebody_patches
|
|
else:
|
|
patches = metadata_patches + tunebody_patches
|
|
|
|
patches = patches[: patch_length]
|
|
|
|
|
|
id_patches = []
|
|
for patch in patches:
|
|
id_patch = [ord(c) for c in patch] + [self.special_token_id] * (patch_size - len(patch))
|
|
id_patches.append(id_patch)
|
|
|
|
return id_patches
|
|
|
|
def encode_generate(self, abc_code, patch_length=PATCH_LENGTH, patch_size=PATCH_SIZE, add_special_patches=True):
|
|
|
|
lines = abc_code.split('\n')
|
|
lines = list(filter(None, lines))
|
|
|
|
tunebody_index = None
|
|
for i, line in enumerate(lines):
|
|
if line.startswith('[V:') or line.startswith('[r:'):
|
|
tunebody_index = i
|
|
break
|
|
|
|
metadata_lines = lines[ : tunebody_index]
|
|
tunebody_lines = lines[tunebody_index : ]
|
|
|
|
metadata_lines = [line + '\n' for line in metadata_lines]
|
|
if self.stream:
|
|
if not abc_code.endswith('\n'):
|
|
tunebody_lines = [tunebody_lines[i] + '\n' for i in range(len(tunebody_lines) - 1)] + [tunebody_lines[-1]]
|
|
else:
|
|
tunebody_lines = [tunebody_lines[i] + '\n' for i in range(len(tunebody_lines))]
|
|
else:
|
|
tunebody_lines = [line + '\n' for line in tunebody_lines]
|
|
|
|
metadata_patches = self.patchilize_metadata(metadata_lines)
|
|
tunebody_patches = self.patchilize_tunebody(tunebody_lines, encode_mode='generate')
|
|
|
|
if add_special_patches:
|
|
bos_patch = chr(self.bos_token_id) * (patch_size - 1) + chr(self.eos_token_id)
|
|
|
|
metadata_patches = [bos_patch] + metadata_patches
|
|
|
|
patches = metadata_patches + tunebody_patches
|
|
patches = patches[ : patch_length]
|
|
|
|
|
|
id_patches = []
|
|
for patch in patches:
|
|
if len(patch) < PATCH_SIZE and patch[-1] != chr(self.eos_token_id):
|
|
id_patch = [ord(c) for c in patch]
|
|
else:
|
|
id_patch = [ord(c) for c in patch] + [self.special_token_id] * (patch_size - len(patch))
|
|
id_patches.append(id_patch)
|
|
|
|
return id_patches
|
|
|
|
def decode(self, patches):
|
|
"""
|
|
Decode patches into music.
|
|
"""
|
|
return ''.join(self.patch2chars(patch) for patch in patches)
|
|
|
|
|
|
|
|
|
|
class PatchLevelDecoder(PreTrainedModel):
|
|
"""
|
|
A Patch-level Decoder model for generating patch features in an auto-regressive manner.
|
|
It inherits PreTrainedModel from transformers.
|
|
"""
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.patch_embedding = torch.nn.Linear(PATCH_SIZE * 128, config.n_embd).to(torch.float16)
|
|
torch.nn.init.normal_(self.patch_embedding.weight, std=0.02)
|
|
self.base = GPT2Model(config)
|
|
|
|
def forward(self,
|
|
patches: torch.Tensor,
|
|
masks=None) -> torch.Tensor:
|
|
"""
|
|
The forward pass of the patch-level decoder model.
|
|
:param patches: the patches to be encoded
|
|
:param masks: the masks for the patches
|
|
:return: the encoded patches
|
|
"""
|
|
patches = torch.nn.functional.one_hot(patches, num_classes=128).to(self.dtype)
|
|
patches = patches.reshape(len(patches), -1, PATCH_SIZE * (128))
|
|
patches = self.patch_embedding(patches.to(self.device))
|
|
|
|
if masks==None:
|
|
return self.base(inputs_embeds=patches)
|
|
else:
|
|
return self.base(inputs_embeds=patches,
|
|
attention_mask=masks)
|
|
|
|
|
|
class CharLevelDecoder(PreTrainedModel):
|
|
"""
|
|
A Char-level Decoder model for generating the chars within each patch in an auto-regressive manner
|
|
based on the encoded patch features. It inherits PreTrainedModel from transformers.
|
|
"""
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.special_token_id = 0
|
|
self.bos_token_id = 1
|
|
|
|
self.base = GPT2LMHeadModel(config)
|
|
|
|
def forward(self,
|
|
encoded_patches: torch.Tensor,
|
|
target_patches: torch.Tensor):
|
|
"""
|
|
The forward pass of the char-level decoder model.
|
|
:param encoded_patches: the encoded patches
|
|
:param target_patches: the target patches
|
|
:return: the output of the model
|
|
"""
|
|
target_patches = torch.cat((torch.ones_like(target_patches[:, 0:1]) * self.bos_token_id,
|
|
target_patches), dim=1)
|
|
|
|
target_masks = target_patches == self.special_token_id
|
|
labels = target_patches.clone().masked_fill_(target_masks, -100)
|
|
|
|
target_masks = torch.ones_like(labels)
|
|
target_masks = target_masks.masked_fill_(labels == -100, 0)
|
|
|
|
input_embeds = torch.nn.functional.embedding(target_patches, self.base.transformer.wte.weight)
|
|
input_embeds = torch.cat((encoded_patches.unsqueeze(1), input_embeds[:, 1:, :]), dim=1)
|
|
logits = self.base(inputs_embeds=input_embeds,
|
|
attention_mask=target_masks).logits
|
|
logits = logits[:, :-1, :]
|
|
token_logps = torch.gather(logits.log_softmax(-1), dim=-1, index=target_patches[:, 1:].unsqueeze(-1)).squeeze(-1)
|
|
token_logps = token_logps[target_masks[:, 1:] == 1]
|
|
all_logps = token_logps.sum()
|
|
|
|
return all_logps
|
|
|
|
def generate(self,
|
|
encoded_patch: torch.Tensor,
|
|
tokens: torch.Tensor):
|
|
"""
|
|
The generate function for generating a patch based on the encoded patch and already generated tokens.
|
|
:param encoded_patch: the encoded patch
|
|
:param tokens: already generated tokens in the patch
|
|
:return: the probability distribution of next token
|
|
"""
|
|
encoded_patch = encoded_patch.reshape(1, 1, -1)
|
|
tokens = tokens.reshape(1, -1)
|
|
|
|
|
|
tokens = torch.nn.functional.embedding(tokens, self.base.transformer.wte.weight)
|
|
|
|
|
|
tokens = torch.cat((encoded_patch, tokens[:,1:,:]), dim=1)
|
|
|
|
|
|
outputs = self.base(inputs_embeds=tokens)
|
|
|
|
|
|
probs = torch.nn.functional.softmax(outputs.logits.squeeze(0)[-1], dim=-1)
|
|
|
|
return probs
|
|
|
|
class NotaGenLMHeadModel(PreTrainedModel):
|
|
"""
|
|
NotaGen is a language model with a hierarchical structure.
|
|
It includes a patch-level decoder and a char-level decoder.
|
|
The patch-level decoder is used to generate patch features in an auto-regressive manner.
|
|
The char-level decoder is used to generate the chars within each patch in an auto-regressive manner.
|
|
It inherits PreTrainedModel from transformers.
|
|
"""
|
|
def __init__(self, encoder_config, decoder_config):
|
|
super().__init__(encoder_config)
|
|
self.special_token_id = 0
|
|
self.bos_token_id = 1
|
|
self.eos_token_id = 2
|
|
self.patch_level_decoder = PatchLevelDecoder(encoder_config)
|
|
self.char_level_decoder = CharLevelDecoder(decoder_config)
|
|
|
|
def forward(self,
|
|
patches: torch.Tensor,
|
|
masks: torch.Tensor):
|
|
"""
|
|
The forward pass of the bGPT model.
|
|
:param patches: the patches to be encoded
|
|
:param masks: the masks for the patches
|
|
:return: the decoded patches
|
|
"""
|
|
patches = patches.reshape(len(patches), -1, PATCH_SIZE)
|
|
encoded_patches = self.patch_level_decoder(patches, masks)["last_hidden_state"]
|
|
|
|
left_shift_masks = masks * (masks.flip(1).cumsum(1).flip(1) > 1)
|
|
masks[:, 0] = 0
|
|
|
|
encoded_patches = encoded_patches[left_shift_masks == 1]
|
|
patches = patches[masks == 1]
|
|
|
|
return self.char_level_decoder(encoded_patches, patches)
|
|
|
|
def generate(self,
|
|
patches: torch.Tensor,
|
|
top_k=0,
|
|
top_p=1,
|
|
temperature=1.0):
|
|
"""
|
|
The generate function for generating patches based on patches.
|
|
:param patches: the patches to be encoded
|
|
:param top_k: the top k for sampling
|
|
:param top_p: the top p for sampling
|
|
:param temperature: the temperature for sampling
|
|
:return: the generated patches
|
|
"""
|
|
if patches.shape[-1] % PATCH_SIZE != 0:
|
|
tokens = patches[:,:,-(patches.shape[-1]%PATCH_SIZE):].squeeze(0, 1)
|
|
tokens = torch.cat((torch.tensor([self.bos_token_id], device=self.device), tokens), dim=-1)
|
|
patches = patches[:,:,:-(patches.shape[-1]%PATCH_SIZE)]
|
|
else:
|
|
tokens = torch.tensor([self.bos_token_id], device=self.device)
|
|
|
|
patches = patches.reshape(len(patches), -1, PATCH_SIZE)
|
|
encoded_patches = self.patch_level_decoder(patches)["last_hidden_state"]
|
|
generated_patch = []
|
|
|
|
while True:
|
|
prob = self.char_level_decoder.generate(encoded_patches[0][-1], tokens).cpu().detach().numpy()
|
|
prob = top_k_sampling(prob, top_k=top_k, return_probs=True)
|
|
prob = top_p_sampling(prob, top_p=top_p, return_probs=True)
|
|
token = temperature_sampling(prob, temperature=temperature)
|
|
char = chr(token)
|
|
generated_patch.append(token)
|
|
|
|
if len(tokens) >= PATCH_SIZE:
|
|
break
|
|
else:
|
|
tokens = torch.cat((tokens, torch.tensor([token], device=self.device)), dim=0)
|
|
|
|
return generated_patch
|
|
|
|
|
|
|
|
|