import logging import os import uuid import torch from .constants import ( AUD_END_TOKEN, AUD_START_TOKEN, AUD_TAG_TOKEN, BOX_END_TOKEN, BOX_START_TOKEN, IMG_CONTEXT_TOKEN, IMG_END_TOKEN, IMG_START_TOKEN, IMG_TAG_TOKEN, PATCH_CONTEXT_TOKEN, PATCH_END_TOKEN, PATCH_START_TOKEN, QUAD_END_TOKEN, QUAD_START_TOKEN, REF_END_TOKEN, REF_START_TOKEN, VID_CONTEXT_TOKEN, VID_END_TOKEN, VID_START_TOKEN, VID_TAG_TOKEN, ) logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) def update_tokenizer_for_cosyvoice2(tokenizer): token_list = [ IMG_START_TOKEN, IMG_END_TOKEN, IMG_CONTEXT_TOKEN, VID_START_TOKEN, VID_END_TOKEN, VID_CONTEXT_TOKEN, PATCH_START_TOKEN, PATCH_END_TOKEN, PATCH_CONTEXT_TOKEN, AUD_START_TOKEN, AUD_END_TOKEN, QUAD_START_TOKEN, QUAD_END_TOKEN, REF_START_TOKEN, REF_END_TOKEN, BOX_START_TOKEN, BOX_END_TOKEN, IMG_TAG_TOKEN, VID_TAG_TOKEN, AUD_TAG_TOKEN, ] num_new_tokens = tokenizer.add_tokens(token_list, special_tokens=True) token_list = [f"<|audio_{i}|>" for i in range(6561)] num_new_tokens = tokenizer.add_tokens(token_list, special_tokens=False) # logger.info(f"tokenizer {tokenizer}") return tokenizer class CosyVoice2Tokenizer: def __init__(self, model_name_or_path, rank=None): self.model_name_or_path = model_name_or_path if rank is None and torch.distributed.is_initialized(): rank = torch.distributed.get_rank() self.rank = rank % 8 else: self.rank = rank logger.info(f"{self.rank=}") self.is_discrete = True self.is_contiguous = False # T A text_audio_interval_ratio = [13, 26] self.text_audio_interval_ratio = text_audio_interval_ratio def load_model(self): if hasattr(self, "cosyvoice"): return logger.info("Loading CosyVoice2Tokenizer") from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 from cosyvoice.utils.file_utils import load_wav if self.rank is not None: torch.cuda.set_device(self.rank) else: import os os.environ["CUDA_VISIBLE_DEVICES"] = "" print(f"{self.rank}") self.cosyvoice = CosyVoice2( self.model_name_or_path, load_jit=False, load_trt=False, fp16=True ) del self.cosyvoice.model.llm self.load_wav = load_wav def encode(self, audio_path, **kwargs): if not hasattr(self, "cosyvoice"): self.load_model() speech_16k = self.load_wav(audio_path, 16000) try: speech_token, speech_token_len = self.cosyvoice.frontend._extract_speech_token( speech_16k ) speech_token = speech_token[0].cpu().tolist() except Exception as error: # logger.info("error", error) speech_token = [] # logger.info(f"speech_token {speech_token}") return speech_token def decode(self, prompt_speech_token, source_speech_16k=None): if not hasattr(self, "cosyvoice"): self.load_model() prompt_speech_token = torch.tensor(prompt_speech_token).unsqueeze(0) flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32) prompt_speech_feat = torch.zeros(1, 0, 80) if source_speech_16k is None: flow_embedding = torch.zeros(1, 192) else: flow_embedding = self.cosyvoice.frontend._extract_spk_embedding(source_speech_16k) this_uuid = str(uuid.uuid1()) this_uuid = "abc" self.cosyvoice.model.hift_cache_dict[this_uuid] = None token_offset = 0 tts_speech = self.cosyvoice.model.token2wav( token=prompt_speech_token, prompt_token=flow_prompt_speech_token, prompt_feat=prompt_speech_feat, embedding=flow_embedding, uuid=this_uuid, token_offset=token_offset, finalize=True, ) tts_speech = tts_speech.squeeze().cpu() return tts_speech def apply_to_role(self, role, **kwargs): is_discrete = kwargs.get("is_discrete", False) if is_discrete: return True is_contiguous = kwargs.get("is_contiguous", False) if is_contiguous: return False return True