Spaces:
Paused
Paused
| # A mirror to gradio launch stream | |
| # By Xuan Phi Nguyen at DAMO Academy, Alibaba Group | |
| """ | |
| Load FasterLlama with original VLLM codebase, | |
| require changing config names to LlamaForCausalLM | |
| tensor_parallel must == 1 | |
| """ | |
| import os | |
| import numpy as np | |
| import argparse | |
| import torch | |
| import gradio as gr | |
| from typing import Any, Iterator | |
| from typing import Iterator, List, Optional, Tuple | |
| import filelock | |
| import glob | |
| import json | |
| from gradio_client.documentation import document, set_documentation_group | |
| from typing import List, Optional, Union, Dict, Tuple | |
| from tqdm.auto import tqdm | |
| from huggingface_hub import snapshot_download | |
| DEBUG = True | |
| if not DEBUG: | |
| # vllm import | |
| from vllm import LLM, SamplingParams | |
| from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast | |
| from vllm.engine.arg_utils import EngineArgs | |
| from vllm.engine.llm_engine import LLMEngine | |
| from vllm.outputs import RequestOutput | |
| from vllm.sampling_params import SamplingParams | |
| from vllm.utils import Counter | |
| from vllm.sequence import (Sequence, SequenceData, SequenceGroup, | |
| SequenceGroupMetadata, SequenceOutputs, | |
| SequenceStatus) | |
| # ! reconfigure vllm to faster llama | |
| from vllm.model_executor.model_loader import _MODEL_REGISTRY | |
| from vllm.model_executor.models import LlamaForCausalLM | |
| _MODEL_REGISTRY['FasterLlamaForCausalLM'] = LlamaForCausalLM | |
| def hf_model_weights_iterator( | |
| model_name_or_path: str, | |
| cache_dir: Optional[str] = None, | |
| use_np_cache: bool = False, | |
| ) -> Iterator[Tuple[str, torch.Tensor]]: | |
| from vllm.model_executor.weight_utils import Disabledtqdm | |
| # Prepare file lock directory to prevent multiple processes from | |
| # downloading the same model weights at the same time. | |
| lock_dir = cache_dir if cache_dir is not None else "/tmp" | |
| lock_file_name = model_name_or_path.replace("/", "-") + ".lock" | |
| lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name)) | |
| # Download model weights from huggingface. | |
| is_local = os.path.isdir(model_name_or_path) | |
| if not is_local: | |
| with lock: | |
| hf_folder = snapshot_download(model_name_or_path, | |
| allow_patterns="*.bin", | |
| cache_dir=cache_dir, | |
| local_files_only=True, | |
| tqdm_class=Disabledtqdm) | |
| else: | |
| hf_folder = model_name_or_path | |
| hf_bin_files = [ | |
| # x for x in glob.glob(os.path.join(hf_folder, "*.bin")) | |
| x for x in glob.glob(os.path.join(hf_folder, "*model*.bin")) | |
| if not x.endswith("training_args.bin") | |
| ] | |
| hf_safetensors_files = [ | |
| x for x in glob.glob(os.path.join(hf_folder, "*model*.safetensors")) | |
| if not x.endswith("training_args.bin") | |
| ] | |
| # print(F'Load bin files: {hf_bin_files} // safetensors: {hf_safetensors_files}') | |
| if use_np_cache: | |
| # Convert the model weights from torch tensors to numpy arrays for | |
| # faster loading. | |
| np_folder = os.path.join(hf_folder, "np") | |
| os.makedirs(np_folder, exist_ok=True) | |
| weight_names_file = os.path.join(np_folder, "weight_names.json") | |
| with lock: | |
| if not os.path.exists(weight_names_file): | |
| weight_names = [] | |
| for bin_file in hf_bin_files: | |
| state = torch.load(bin_file, map_location="cpu") | |
| for name, param in state.items(): | |
| param_path = os.path.join(np_folder, name) | |
| with open(param_path, "wb") as f: | |
| np.save(f, param.cpu().detach().numpy()) | |
| weight_names.append(name) | |
| with open(weight_names_file, "w") as f: | |
| json.dump(weight_names, f) | |
| with open(weight_names_file, "r") as f: | |
| weight_names = json.load(f) | |
| for name in weight_names: | |
| param_path = os.path.join(np_folder, name) | |
| with open(param_path, "rb") as f: | |
| param = np.load(f) | |
| yield name, torch.from_numpy(param) | |
| else: | |
| if len(hf_bin_files) > 0: | |
| print(F'Load bin files: {hf_bin_files}') | |
| for bin_file in hf_bin_files: | |
| state = torch.load(bin_file, map_location="cpu") | |
| for name, param in state.items(): | |
| yield name, param | |
| del state | |
| torch.cuda.empty_cache() | |
| elif len(hf_safetensors_files) > 0: | |
| print(F'Load safetensor files: {hf_safetensors_files}') | |
| from safetensors.torch import load_file | |
| for safe_file in hf_safetensors_files: | |
| # state = torch.load(bin_file, map_location="cpu") | |
| state = load_file(safe_file) | |
| for name, param in state.items(): | |
| yield name, param | |
| del state | |
| torch.cuda.empty_cache() | |
| else: | |
| raise ValueError(f'no files available either bin or safe') | |
| def convert_pyslice_to_tensor(x: Any) -> torch.Tensor: | |
| """convert PySafeSlice object from safetensors to torch.Tensor | |
| PySafeSlice object supports indexing, which is done before loading the | |
| actual tensor and can reduce the amount of memory being read into the | |
| memory. However, it does not support more advanced functionalities | |
| like `.view()` or `.t()`. Therefore, if we need to modify the loaded | |
| tensor with these more complicated operators, we need to convert to | |
| tensor first. | |
| """ | |
| if not isinstance(x, torch.Tensor): | |
| x = x[:] | |
| return x | |
| def load_padded_tensor_parallel_vocab( | |
| param: torch.Tensor, | |
| loaded_weight: Any, # `torch.Tensor` or `PySafeSlice` | |
| tensor_model_parallel_rank: int, | |
| ) -> None: | |
| shard_size = param.shape[0] | |
| start_idx = tensor_model_parallel_rank * shard_size | |
| end_idx = (tensor_model_parallel_rank + 1) * shard_size | |
| loaded_weight = loaded_weight[start_idx:end_idx] | |
| loaded_weight = convert_pyslice_to_tensor(loaded_weight) | |
| param[:loaded_weight.shape[0]].copy_(loaded_weight) | |
| def llama_load_weights( | |
| self, | |
| model_name_or_path: str, | |
| cache_dir: Optional[str] = None, | |
| use_np_cache: bool = False, | |
| load_format: str = "auto", | |
| # load_format: str = "pt", | |
| revision: Optional[str] = None | |
| ): | |
| from vllm.model_executor.weight_utils import ( | |
| load_tensor_parallel_weights | |
| ) | |
| from vllm.model_executor.parallel_utils.parallel_state import ( | |
| get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) | |
| tp_size = get_tensor_model_parallel_world_size() | |
| tensor_model_parallel_rank = get_tensor_model_parallel_rank() | |
| q_proj_shard_size = (self.config.hidden_size // tp_size) | |
| kv_proj_shard_size = (self.config.hidden_size // | |
| self.config.num_attention_heads * | |
| getattr(self.config, "num_key_value_heads", self.config.num_attention_heads) // tp_size) | |
| attention_weight_specs = [ | |
| # (weight_name, shard_size, offset) | |
| ("q_proj", q_proj_shard_size, 0), | |
| ("k_proj", kv_proj_shard_size, q_proj_shard_size), | |
| ("v_proj", kv_proj_shard_size, | |
| q_proj_shard_size + kv_proj_shard_size), | |
| ] | |
| state_dict = self.state_dict() | |
| need_to_load = len(state_dict) | |
| loaded = 0 | |
| # try: | |
| # iterator = hf_model_weights_iterator(model_name_or_path, cache_dir, use_np_cache) | |
| # except Exception as e: | |
| # iterator = hf_model_weights_iterator(model_name_or_path, cache_dir, load_format, revision) | |
| iterator = hf_model_weights_iterator(model_name_or_path, cache_dir, use_np_cache) | |
| # for name, loaded_weight in hf_model_weights_iterator( | |
| # model_name_or_path, cache_dir, load_format, revision): | |
| # model_name_or_path, cache_dir, use_np_cache): | |
| for name, loaded_weight in iterator: | |
| if "rotary_emb.inv_freq" in name: | |
| continue | |
| # if "embed_tokens" in name or "lm_head" in name: | |
| # param = state_dict[name] | |
| # # Consider padding in the vocab size. | |
| # padded_vocab_size = (param.shape[0] * tp_size) | |
| # # num_extra_rows = padded_vocab_size - self.config.vocab_size | |
| # num_extra_rows = padded_vocab_size - loaded_weight.size(0) | |
| # load_size = loaded_weight.size() | |
| # extra_rows = torch.empty(num_extra_rows, | |
| # loaded_weight.shape[1]) | |
| # extra_rows = extra_rows.to(loaded_weight) | |
| # loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0) | |
| # if num_extra_rows > 0: | |
| # print(f'Add empty to {num_extra_rows} extra row for {name}') | |
| # print(f'Load: {name} | {padded_vocab_size=} | {self.config.vocab_size=} | {num_extra_rows=} | {param.size()=} | {loaded_weight.size()=} | {load_size=}') | |
| if "embed_tokens" in name or "lm_head" in name: | |
| param = state_dict[name] | |
| load_padded_tensor_parallel_vocab(param, loaded_weight, tensor_model_parallel_rank) | |
| loaded += 1 | |
| continue | |
| is_attention_weight = False | |
| for weight_name, shard_size, offset in attention_weight_specs: | |
| if weight_name not in name or "qkv_proj" in name: | |
| continue | |
| param = state_dict[name.replace(weight_name, "qkv_proj")] | |
| loaded_weight = loaded_weight[ | |
| shard_size * tensor_model_parallel_rank:shard_size * | |
| (tensor_model_parallel_rank + 1)] | |
| param_slice = param.data[offset:offset + shard_size] | |
| assert param_slice.shape == loaded_weight.shape | |
| param_slice.copy_(loaded_weight) | |
| loaded += 1.0 / 3 | |
| is_attention_weight = True | |
| break | |
| if is_attention_weight: | |
| continue | |
| # ! qkv_proj is sharded differently if concatenated into qkv | |
| # qkv: qqqq kkkk vvvv | |
| # lweight: qq0qq1 kk0kk1 vv0vv1 | |
| # q_shard_size: hidden_size // tp_size = qq | |
| # qkv_s0: qq0_kk0_vv0 | |
| # qkv_s1: qq1_kk1_vv1 | |
| if "qkv_proj" in name: | |
| param = state_dict[name] | |
| # loaded_weight | |
| qsize = self.config.hidden_size | |
| kvsize = self.config.hidden_size // self.config.num_attention_heads * getattr(self.config, "num_key_value_heads", self.config.num_attention_heads) | |
| q_offsets = ( | |
| q_proj_shard_size * tensor_model_parallel_rank, | |
| q_proj_shard_size * (tensor_model_parallel_rank + 1) | |
| ) | |
| k_offsets = ( | |
| qsize + kv_proj_shard_size * tensor_model_parallel_rank, | |
| qsize + kv_proj_shard_size * (tensor_model_parallel_rank + 1) | |
| ) | |
| v_offsets = ( | |
| qsize + kvsize + kv_proj_shard_size * tensor_model_parallel_rank, | |
| qsize + kvsize + kv_proj_shard_size * (tensor_model_parallel_rank + 1) | |
| ) | |
| _loaded_weight = torch.cat( | |
| [ | |
| loaded_weight[q_offsets[0]:q_offsets[1]], | |
| loaded_weight[k_offsets[0]:k_offsets[1]], | |
| loaded_weight[v_offsets[0]:v_offsets[1]], | |
| ], 0 | |
| ) | |
| # print(f'{name} | {q_offsets} | {k_offsets} | {v_offsets}') | |
| assert param.shape == _loaded_weight.shape, f'{param.shape=} != {_loaded_weight.shape=}' | |
| param.data.copy_(_loaded_weight) | |
| loaded += 1.0 | |
| is_attention_weight = True | |
| if is_attention_weight: | |
| continue | |
| is_gate_up_weight = False | |
| for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]): | |
| if weight_name not in name or "gate_up_proj" in name: | |
| continue | |
| param = state_dict[name.replace(weight_name, "gate_up_proj")] | |
| shard_size = param.shape[0] // 2 | |
| loaded_weight = loaded_weight[ | |
| shard_size * tensor_model_parallel_rank:shard_size * | |
| (tensor_model_parallel_rank + 1)] | |
| param_slice = param.data[shard_size * stride_id:shard_size * | |
| (stride_id + 1)] | |
| assert param_slice.shape == loaded_weight.shape | |
| param_slice.copy_(loaded_weight) | |
| loaded += 1.0 / 2 | |
| is_gate_up_weight = True | |
| break | |
| if is_gate_up_weight: | |
| continue | |
| if "gate_up_proj" in name: | |
| param = state_dict[name] | |
| shard_size = param.shape[0] // 2 | |
| intermediate_size = self.config.intermediate_size | |
| g_offsets = ( | |
| shard_size * tensor_model_parallel_rank, | |
| shard_size * (tensor_model_parallel_rank + 1) | |
| ) | |
| u_offsets = ( | |
| intermediate_size + shard_size * tensor_model_parallel_rank, | |
| intermediate_size + shard_size * (tensor_model_parallel_rank + 1) | |
| ) | |
| # print(f'{name} {param.size()} | {g_offsets} | {u_offsets}') | |
| _loaded_weight = torch.cat( | |
| [ | |
| loaded_weight[g_offsets[0]:g_offsets[1]], | |
| loaded_weight[u_offsets[0]:u_offsets[1]], | |
| ], 0 | |
| ) | |
| assert param.shape == _loaded_weight.shape | |
| param.data.copy_(_loaded_weight) | |
| loaded += 1.0 | |
| is_gate_up_weight = True | |
| if is_gate_up_weight: | |
| continue | |
| param = state_dict[name] | |
| load_tensor_parallel_weights(param, loaded_weight, name, | |
| self._column_parallel_weights, | |
| self._row_parallel_weights, | |
| tensor_model_parallel_rank) | |
| loaded += 1 | |
| if np.abs(loaded - need_to_load) < 0.01: | |
| print(f'WARNING: only {loaded} params loaded out of {need_to_load}') | |
| else: | |
| print(f'Loaded all {loaded} params loaded out of {need_to_load}') | |
| # Reassign LlamaForCausalLM.load_weights with llama_load_weights | |
| if not DEBUG: | |
| LlamaForCausalLM.load_weights = llama_load_weights | |
| # ! ================================================================== | |
| set_documentation_group("component") | |
| DATA_ROOT = os.environ.get("dataroot", "/mnt/workspace/workgroup/phi") | |
| MODEL_CACHE_DIR = os.path.join(DATA_ROOT, "pret_models") | |
| DTYPES = { | |
| 'float16': torch.float16, | |
| 'bfloat16': torch.bfloat16 | |
| } | |
| llm = None | |
| demo = None | |
| RELOAD_SIGNAL = '<<<reload:' | |
| BOS_TOKEN = '<s>' | |
| EOS_TOKEN = '</s>' | |
| B_INST, E_INST = "[INST]", "[/INST]" | |
| B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n" | |
| SYSTEM_PROMPT_1 = """You are a multilingual, helpful, respectful and honest assistant. Your name is SeaL and you are built by DAMO Academy, Alibaba Group. Always answer as helpfully as possible, while being safe. Your \ | |
| answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\ | |
| that your responses are socially unbiased and positive in nature. | |
| If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \ | |
| correct. If you don't know the answer to a question, please don't share false information. | |
| As a multilingual assistant, you must respond and follow instructions in the native language of the user by default, unless told otherwise. \ | |
| Your response should adapt to the norms and customs of the respective language and culture. | |
| """ | |
| RES_PRINTED = False | |
| def llama_chat_sys_input_seq_constructor(text, sys_prompt=SYSTEM_PROMPT_1, bos_token=BOS_TOKEN, eos_token=EOS_TOKEN): | |
| return f"{bos_token}{B_INST} {B_SYS} {sys_prompt} {E_SYS} {text} {E_INST}" | |
| def llama_chat_multiturn_sys_input_seq_constructor( | |
| message: str, | |
| history: List[Tuple[str, str]], | |
| sys_prompt=SYSTEM_PROMPT_1, | |
| bos_token=BOS_TOKEN, | |
| eos_token=EOS_TOKEN, | |
| ): | |
| """ | |
| ``` | |
| <bos>[INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer <eos> | |
| <bos>[INST] Prompt [/INST] Answer <eos> | |
| <bos>[INST] Prompt [/INST] | |
| ``` | |
| """ | |
| text = '' | |
| for i, (prompt, res) in enumerate(history): | |
| if i == 0: | |
| text += f"{bos_token}{B_INST} {B_SYS} {sys_prompt} {E_SYS} {prompt} {E_INST}" | |
| else: | |
| text += f"{bos_token}{B_INST} {prompt} {E_INST}" | |
| if res is not None: | |
| text += f" {res} {eos_token} " | |
| if len(history) == 0 or text.strip() == '': | |
| text = f"{bos_token}{B_INST} {B_SYS} {sys_prompt} {E_SYS} {message} {E_INST}" | |
| else: | |
| text += f"{bos_token}{B_INST} {message} {E_INST}" | |
| return text | |
| class ChatBot(gr.Chatbot): | |
| def _postprocess_chat_messages( | |
| self, chat_message | |
| ): | |
| x = super()._postprocess_chat_messages(chat_message) | |
| if isinstance(x, str): | |
| x = x.replace("\n", "<br>") | |
| return x | |
| def load_ckpt(ckpt_file: str) -> str: | |
| global llm | |
| status = "Failed" | |
| if not os.path.exists(ckpt_file): | |
| status = f"Failed - file not found: {ckpt_file}" | |
| elif not ckpt_file.endswith(".bin"): | |
| status = f"Failed - file not .bin: {ckpt_file}" | |
| else: | |
| try: | |
| state_dict = torch.load(ckpt_file, map_location='cpu') | |
| print(f'loaded state_dict: {ckpt_file}') | |
| llm.llm_engine.workers[0].model.load_state_dict(state_dict) | |
| status = f'Success. Loaded {ckpt_file}' | |
| except Exception as e: | |
| status = f'Failed - {str(e)}' | |
| return status | |
| def chat_response(message, history, temperature: float, max_tokens: int, system_prompt: str = '') -> str: | |
| global llm | |
| assert llm is not None | |
| temperature = float(temperature) | |
| max_tokens = int(max_tokens) | |
| if system_prompt.strip() != '': | |
| # chat version, add system prompt | |
| message = llama_chat_sys_input_seq_constructor( | |
| message.strip(), | |
| sys_prompt=system_prompt | |
| ) | |
| sampling_params = SamplingParams(temperature=temperature, max_tokens=max_tokens) | |
| gen = llm.generate(message, sampling_params) | |
| out = gen[0].outputs[0].text | |
| # print(f'{message}<<<{out}>>>') | |
| return f'{out}' | |
| def vllm_abort(self: Any): | |
| scheduler = self.llm_engine.scheduler | |
| for state_queue in [scheduler.waiting, scheduler.running, scheduler.swapped]: | |
| for seq_group in state_queue: | |
| # if seq_group.request_id == request_id: | |
| # Remove the sequence group from the state queue. | |
| state_queue.remove(seq_group) | |
| for seq in seq_group.seqs: | |
| if seq.is_finished(): | |
| continue | |
| scheduler.free_seq(seq, SequenceStatus.FINISHED_ABORTED) | |
| # def _vllm_run_engine(self: LLM, use_tqdm: bool = False) -> Dict[str, RequestOutput]: | |
| def _vllm_run_engine(self: Any, use_tqdm: bool = False) -> Dict[str, Any]: | |
| # Initialize tqdm. | |
| if use_tqdm: | |
| num_requests = self.llm_engine.get_num_unfinished_requests() | |
| pbar = tqdm(total=num_requests, desc="Processed prompts") | |
| # Run the engine. | |
| outputs: Dict[str, RequestOutput] = {} | |
| while self.llm_engine.has_unfinished_requests(): | |
| step_outputs = self.llm_engine.step() | |
| for output in step_outputs: | |
| # if output.finished: | |
| # outputs.append(output) | |
| # if use_tqdm: | |
| # pbar.update(1) | |
| outputs[output.request_id] = output | |
| # outputs = sorted(outputs, key=lambda x: int(x.request_id)) | |
| if len(outputs) > 0: | |
| yield outputs | |
| # if use_tqdm: | |
| # pbar.close() | |
| # Sort the outputs by request ID. | |
| # This is necessary because some requests may be finished earlier than | |
| # its previous requests. | |
| # outputs = sorted(outputs, key=lambda x: int(x.request_id)) | |
| # return outputs | |
| def vllm_generate_stream( | |
| self: Any, | |
| prompts: Optional[Union[str, List[str]]] = None, | |
| sampling_params: Optional[Any] = None, | |
| prompt_token_ids: Optional[List[List[int]]] = None, | |
| use_tqdm: bool = False, | |
| ) -> Dict[str, Any]: | |
| """Generates the completions for the input prompts. | |
| NOTE: This class automatically batches the given prompts, considering | |
| the memory constraint. For the best performance, put all of your prompts | |
| into a single list and pass it to this method. | |
| Args: | |
| prompts: A list of prompts to generate completions for. | |
| sampling_params: The sampling parameters for text generation. If | |
| None, we use the default sampling parameters. | |
| prompt_token_ids: A list of token IDs for the prompts. If None, we | |
| use the tokenizer to convert the prompts to token IDs. | |
| use_tqdm: Whether to use tqdm to display the progress bar. | |
| Returns: | |
| A list of `RequestOutput` objects containing the generated | |
| completions in the same order as the input prompts. | |
| """ | |
| if prompts is None and prompt_token_ids is None: | |
| raise ValueError("Either prompts or prompt_token_ids must be " | |
| "provided.") | |
| if isinstance(prompts, str): | |
| # Convert a single prompt to a list. | |
| prompts = [prompts] | |
| if prompts is not None and prompt_token_ids is not None: | |
| if len(prompts) != len(prompt_token_ids): | |
| raise ValueError("The lengths of prompts and prompt_token_ids " | |
| "must be the same.") | |
| if sampling_params is None: | |
| # Use default sampling params. | |
| sampling_params = SamplingParams() | |
| # Add requests to the engine. | |
| if prompts is not None: | |
| num_requests = len(prompts) | |
| else: | |
| num_requests = len(prompt_token_ids) | |
| for i in range(num_requests): | |
| prompt = prompts[i] if prompts is not None else None | |
| if prompt_token_ids is None: | |
| token_ids = None | |
| else: | |
| token_ids = prompt_token_ids[i] | |
| self._add_request(prompt, sampling_params, token_ids) | |
| # return self._run_engine(use_tqdm) | |
| yield from _vllm_run_engine(self, use_tqdm) | |
| def chat_response_stream( | |
| message: str, | |
| history: List[Tuple[str, str]], | |
| temperature: float, | |
| max_tokens: int, | |
| frequency_penalty: float, | |
| system_prompt: str | |
| ) -> str: | |
| global llm, RES_PRINTED | |
| assert llm is not None | |
| # force removing all | |
| vllm_abort(llm) | |
| temperature = float(temperature) | |
| frequency_penalty = float(frequency_penalty) | |
| max_tokens = int(max_tokens) | |
| if system_prompt.strip() != '': | |
| # chat version, add system prompt | |
| message = llama_chat_sys_input_seq_constructor( | |
| message.strip(), | |
| sys_prompt=system_prompt | |
| ) | |
| sampling_params = SamplingParams( | |
| temperature=temperature, max_tokens=max_tokens, | |
| frequency_penalty=frequency_penalty, | |
| ) | |
| cur_out = None | |
| for gen in vllm_generate_stream(llm, message, sampling_params): | |
| if cur_out is not None: | |
| yield cur_out | |
| assert len(gen) == 1, f'{gen}' | |
| item = next(iter(gen.values())) | |
| cur_out = item.outputs[0].text | |
| if not RES_PRINTED: | |
| print(f'{message}<<<{cur_out}>>>') | |
| RES_PRINTED = True | |
| if cur_out is not None: | |
| yield cur_out | |
| def chat_response_stream_multiturn( | |
| message: str, | |
| history: List[Tuple[str, str]], | |
| temperature: float, | |
| max_tokens: int, | |
| frequency_penalty: float, | |
| system_prompt: str | |
| ) -> str: | |
| """Build multi turn | |
| <bos>[INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer <eos> | |
| <bos>[INST] Prompt [/INST] Answer <eos> | |
| <bos>[INST] Prompt [/INST] | |
| message is incoming prompt | |
| history don't have the current messauge | |
| """ | |
| global llm, RES_PRINTED | |
| assert llm is not None | |
| assert system_prompt.strip() != '', f'system prompt is empty' | |
| # force removing all | |
| vllm_abort(llm) | |
| temperature = float(temperature) | |
| frequency_penalty = float(frequency_penalty) | |
| max_tokens = int(max_tokens) | |
| # history.append([message, None]) | |
| # history will be appended with message later on | |
| full_prompt = llama_chat_multiturn_sys_input_seq_constructor( | |
| message, history, sys_prompt=system_prompt | |
| ) | |
| sampling_params = SamplingParams( | |
| temperature=temperature, max_tokens=max_tokens, | |
| frequency_penalty=frequency_penalty, | |
| ) | |
| cur_out = None | |
| for gen in vllm_generate_stream(llm, full_prompt, sampling_params): | |
| if cur_out is not None: | |
| yield cur_out | |
| assert len(gen) == 1, f'{gen}' | |
| item = next(iter(gen.values())) | |
| cur_out = item.outputs[0].text | |
| if not RES_PRINTED: | |
| print(f'{full_prompt}<<<{cur_out}>>>') | |
| RES_PRINTED = True | |
| if cur_out is not None: | |
| yield cur_out | |
| def debug_chat_response_echo( | |
| message: str, | |
| history: List[Tuple[str, str]], | |
| temperature: float = 0.0, | |
| max_tokens: int = 4096, | |
| frequency_penalty: float = 0.4, | |
| system_prompt: str = SYSTEM_PROMPT_1, | |
| ) -> str: | |
| yield f"repeat: {message}" | |
| # ============ CONSTANT ============ | |
| MODEL_NAME = "DAMO-SeaL-13B" | |
| MODEL_TITLE = "DAMO-SeaL-13B - An Assistant for South East Asian Languages" | |
| MODEL_DESC = """ | |
| This is a 13B DAMO-SeaL-Chat assistant model built by DAMO Academy, Alibaba Group. It can produce helpful responses in English, Vietnamese, Indonesian and Thai. | |
| <br> | |
| #### Citation | |
| If you find our project useful, hope you can star our repo and cite our paper as follows: | |
| ``` | |
| @article{damonlpsg2023seallm, | |
| author = {???}, | |
| title = {SeaL: A language model for South East Asian Languages}, | |
| year = 2023, | |
| } | |
| ``` | |
| """.strip() | |
| cite_markdown = """ | |
| """ | |
| # journal = {arXiv preprint arXiv:2306.02858} | |
| # url = {https://arxiv.org/abs/2306.02858} | |
| TENSOR_PARALLEL = int(os.environ.get("TENSOR_PARALLEL", "1")) | |
| DTYPE = 'bfloat16' | |
| DTYPE = 'float16' | |
| MODEL_PATH = os.environ.get("MODEL_PATH", "notfound, please set `export MODEL_PATH=`") | |
| def launch(): | |
| global demo, llm, DEBUG | |
| model_desc = MODEL_DESC | |
| model_path = MODEL_PATH | |
| model_title = MODEL_TITLE | |
| tensor_parallel = TENSOR_PARALLEL | |
| assert tensor_parallel > 0 , f'{tensor_parallel} invalid' | |
| dtype = DTYPE | |
| sys_prompt = SYSTEM_PROMPT_1 | |
| max_tokens = 4096 | |
| if DEBUG: | |
| model_desc += "\n<br>!!!!! This is in debug mode, responses will be copy original" | |
| response_fn = debug_chat_response_echo | |
| else: | |
| # ! load the model | |
| assert os.path.exists(model_path), f'{model_path} not found' | |
| llm = LLM(model=model_path, dtype=dtype, tensor_parallel_size=tensor_parallel) | |
| print(f'Use system prompt:\n{sys_prompt}') | |
| # response_fn = chat_response_stream_multiturn if args.multiturn else chat_response_stream | |
| response_fn = chat_response_stream_multiturn | |
| print(F'respond: {response_fn}') | |
| demo = gr.ChatInterface( | |
| response_fn, | |
| chatbot=ChatBot( | |
| # value=MODEL_NAME, | |
| bubble_full_width=False, | |
| latex_delimiters=[ | |
| { "left": "$", "right": "$", "display": False}, | |
| { "left": "$$", "right": "$$", "display": True}, | |
| ] | |
| ), | |
| textbox=gr.Textbox(placeholder='Type message', lines=8, max_lines=128, min_width=200), | |
| submit_btn=gr.Button(value='Submit', variant="primary", scale=0), | |
| # stop_btn=None, | |
| title=f"{model_title}", | |
| description=f"{model_desc}", | |
| # ! decide if can change the system prompt. | |
| additional_inputs=[ | |
| gr.Number(value=0, label='Temperature (higher -> more random)'), | |
| gr.Number(value=max_tokens, label='Max generated tokens (increase if want more generation)'), | |
| gr.Number(value=0.4, label='Frequency penalty (> 0 encourage new tokens)'), | |
| gr.Textbox(value=sys_prompt, label='System prompt', lines=8)], | |
| ) | |
| # with gr.Blocks() as demo: | |
| # gr.ChatInterface( | |
| # response_fn, | |
| # chatbot=ChatBot( | |
| # bubble_full_width=False, | |
| # latex_delimiters=[ | |
| # { "left": "$", "right": "$", "display": False}, | |
| # { "left": "$$", "right": "$$", "display": True}, | |
| # ] | |
| # ), | |
| # textbox=gr.Textbox(placeholder='Type message', lines=8, max_lines=128, min_width=200), | |
| # submit_btn=gr.Button(value='Submit', variant="primary", scale=0), | |
| # # stop_btn=None, | |
| # title=f"{model_title}", | |
| # description=f"{model_desc}", | |
| # # ! decide if can change the system prompt. | |
| # additional_inputs=[ | |
| # gr.Number(value=0, label='Temperature (higher -> more random)'), | |
| # gr.Number(value=max_tokens, label='Max generated tokens (increase if want more generation)'), | |
| # gr.Number(value=0.4, label='Frequency penalty (> 0 encourage new tokens)'), | |
| # gr.Textbox(value=sys_prompt, label='System prompt', lines=8) | |
| # ], | |
| # ) | |
| # gr.Markdown(cite_markdown) | |
| demo.queue() | |
| # demo.launch(server_port=args.port) | |
| demo.launch() | |
| def main(): | |
| # launch(parser.parse_args()) | |
| launch() | |
| if __name__ == "__main__": | |
| main() | |
| """ | |
| export CUDA_VISIBLE_DEVICES=0 | |
| export MODEL_PATH=${dataroot}/hf_train/pretrain_lm/swpn/merlion13s108Hi8kPretFlCW8k.LMFromHf.a.gc.t5k0.vizhthid.mean_std.TrainTask.NLNL.Multi.Vi.FSePlCq13M.FSePlCq13M.m4k.b8.lr1e5.linear.wa0k.ms858k.grac1.se1.8g.v4c.zfsdp/step_4000 | |
| export MODEL_PATH=${dataroot}/llama-2-7b-lxxp-faster | |
| export MODEL_PATH=${dataroot}/llama-2-7b-chat-xp | |
| python app.py | |
| """ |