|
import re |
|
from typing import Dict, List, Optional, Sequence, Union |
|
|
|
from vllm import LLM, SamplingParams |
|
from vllm.entrypoints.chat_utils import ( |
|
ChatCompletionMessageParam, |
|
apply_hf_chat_template, |
|
apply_mistral_chat_template, |
|
parse_chat_messages, |
|
) |
|
from vllm.inputs import PromptInputs, TextPrompt |
|
from vllm.lora.request import LoRARequest |
|
from vllm.outputs import RequestOutput |
|
from vllm.transformers_utils.tokenizer import MistralTokenizer |
|
from vllm.utils import is_list_of |
|
|
|
|
|
_TAIL_WS_RE = re.compile(r"(?:\r?\n|\s)+$") |
|
|
|
def needs_newline(text: str) -> bool: |
|
"""Return True when *text* does NOT already end with whitespace/newline.""" |
|
return _TAIL_WS_RE.search(text[-8:]) is None |
|
|
|
|
|
def add_prefix(prompt: str, prefix: str, eos_token: str) -> str: |
|
"""Insert *prefix* before the first generated token. |
|
|
|
Keeps EOS token at the very end if the template already appended it. |
|
""" |
|
if prompt.endswith(eos_token): |
|
return prompt[:-len(eos_token)] + prefix + eos_token |
|
return prompt + prefix |
|
|
|
|
|
class PrefixLLM(LLM): |
|
"""vLLM LLM subclass that conditionally prepends *trigger_word*.""" |
|
|
|
def route_chat( |
|
self, |
|
messages: Union[ |
|
List[ChatCompletionMessageParam], |
|
List[List[ChatCompletionMessageParam]], |
|
], |
|
sampling_params_route: Optional[Union[SamplingParams, |
|
List[SamplingParams]]] = None, |
|
sampling_params_force_think: Optional[Union[SamplingParams, |
|
List[SamplingParams]]] = None, |
|
use_tqdm: bool = True, |
|
lora_request: Optional[LoRARequest] = None, |
|
chat_template: Optional[str] = None, |
|
add_generation_prompt: bool = True, |
|
tools: Optional[List[Dict[str, any]]] = None, |
|
*, |
|
trigger_word: Optional[str] = None, |
|
) -> List[RequestOutput]: |
|
"""Drop-in replacement for `LLM.chat` with one extra keyword: |
|
|
|
Parameters |
|
---------- |
|
trigger_word : str | None, default None |
|
The prefix to inject. If ``None`` → no prefix injection. |
|
""" |
|
|
|
tokenizer = self.get_tokenizer() |
|
model_config = self.llm_engine.get_model_config() |
|
eos_token = tokenizer.eos_token |
|
|
|
orig_prompts: List[Union[TokensPrompt, TextPrompt]] = [] |
|
pref_prompts: List[Union[TokensPrompt, TextPrompt]] = [] |
|
mm_payloads: List[Optional[Dict[str, Any]]] = [] |
|
|
|
list_of_messages: List[List[ChatCompletionMessageParam]] |
|
|
|
|
|
if is_list_of(messages, list): |
|
|
|
list_of_messages = messages |
|
else: |
|
|
|
list_of_messages = [messages] |
|
|
|
prompts: List[Union[TokensPrompt, TextPrompt]] = [] |
|
|
|
for msgs in list_of_messages: |
|
|
|
if isinstance(tokenizer, MistralTokenizer): |
|
prompt_data: Union[str, List[int]] = apply_mistral_chat_template( |
|
tokenizer, |
|
messages=msgs, |
|
chat_template=chat_template, |
|
add_generation_prompt=add_generation_prompt, |
|
tools=tools, |
|
) |
|
mm_data = None |
|
else: |
|
conversation, mm_data = parse_chat_messages(msgs, model_config, tokenizer) |
|
prompt_data = apply_hf_chat_template( |
|
tokenizer, |
|
conversation=conversation, |
|
chat_template=chat_template, |
|
add_generation_prompt=add_generation_prompt, |
|
tools=tools, |
|
) |
|
|
|
if is_list_of(prompt_data, int): |
|
raise NotImplementedError |
|
else: |
|
orig_prompt = TextPrompt(prompt=prompt_data) |
|
|
|
if trigger_word is None: |
|
raise ValueError("trigger_word must be provided when using force_think logic") |
|
|
|
need_nl = needs_newline(prompt_data) |
|
prefix = trigger_word + ("\n" if need_nl else "") |
|
pref_txt = add_prefix(prompt_data, prefix, eos_token) |
|
pref_prompt = TextPrompt(prompt=pref_txt) |
|
|
|
if mm_data is not None: |
|
orig_prompt["multi_modal_data"] = mm_data |
|
pref_prompt["multi_modal_data"] = copy.deepcopy(mm_data) |
|
|
|
orig_prompts.append(orig_prompt) |
|
pref_prompts.append(pref_prompt) |
|
|
|
results = self.generate( |
|
orig_prompts, |
|
sampling_params=sampling_params_route, |
|
use_tqdm=use_tqdm, |
|
lora_request=lora_request, |
|
) |
|
|
|
need_force = [i for i, out in enumerate(results) if "<specialLong>" in out.outputs[0].text[:100]] |
|
|
|
|
|
if len(need_force) == 0: |
|
return results |
|
|
|
prompts_force = [pref_prompts[i] for i in need_force] |
|
|
|
results_force = self.generate( |
|
prompts_force, |
|
sampling_params=sampling_params_force_think, |
|
use_tqdm=use_tqdm, |
|
lora_request=lora_request, |
|
) |
|
|
|
for idx, new_out in zip(need_force, results_force): |
|
results[idx] = new_out |
|
|
|
return results |