import re from typing import Dict, List, Optional, Sequence, Union from vllm import LLM, SamplingParams from vllm.entrypoints.chat_utils import ( ChatCompletionMessageParam, apply_hf_chat_template, apply_mistral_chat_template, parse_chat_messages, ) from vllm.inputs import PromptInputs, TextPrompt from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput from vllm.transformers_utils.tokenizer import MistralTokenizer from vllm.utils import is_list_of _TAIL_WS_RE = re.compile(r"(?:\r?\n|\s)+$") def needs_newline(text: str) -> bool: """Return True when *text* does NOT already end with whitespace/newline.""" return _TAIL_WS_RE.search(text[-8:]) is None # inspect last few chars def add_prefix(prompt: str, prefix: str, eos_token: str) -> str: """Insert *prefix* before the first generated token. Keeps EOS token at the very end if the template already appended it. """ if prompt.endswith(eos_token): return prompt[:-len(eos_token)] + prefix + eos_token return prompt + prefix class PrefixLLM(LLM): """vLLM LLM subclass that conditionally prepends *trigger_word*.""" def route_chat( self, messages: Union[ List[ChatCompletionMessageParam], List[List[ChatCompletionMessageParam]], ], sampling_params_route: Optional[Union[SamplingParams, List[SamplingParams]]] = None, sampling_params_force_think: Optional[Union[SamplingParams, List[SamplingParams]]] = None, use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, chat_template: Optional[str] = None, add_generation_prompt: bool = True, tools: Optional[List[Dict[str, any]]] = None, *, trigger_word: Optional[str] = None, ) -> List[RequestOutput]: """Drop-in replacement for `LLM.chat` with one extra keyword: Parameters ---------- trigger_word : str | None, default None The prefix to inject. If ``None`` → no prefix injection. """ tokenizer = self.get_tokenizer() model_config = self.llm_engine.get_model_config() eos_token = tokenizer.eos_token orig_prompts: List[Union[TokensPrompt, TextPrompt]] = [] pref_prompts: List[Union[TokensPrompt, TextPrompt]] = [] mm_payloads: List[Optional[Dict[str, Any]]] = [] list_of_messages: List[List[ChatCompletionMessageParam]] # Handle multi and single conversations if is_list_of(messages, list): # messages is List[List[...]] list_of_messages = messages else: # messages is List[...] list_of_messages = [messages] prompts: List[Union[TokensPrompt, TextPrompt]] = [] for msgs in list_of_messages: # ---- render chat template exactly once ---- if isinstance(tokenizer, MistralTokenizer): prompt_data: Union[str, List[int]] = apply_mistral_chat_template( tokenizer, messages=msgs, chat_template=chat_template, add_generation_prompt=add_generation_prompt, tools=tools, ) mm_data = None # mistral util returns already embedded image tokens else: conversation, mm_data = parse_chat_messages(msgs, model_config, tokenizer) prompt_data = apply_hf_chat_template( tokenizer, conversation=conversation, chat_template=chat_template, add_generation_prompt=add_generation_prompt, tools=tools, ) if is_list_of(prompt_data, int): raise NotImplementedError else: orig_prompt = TextPrompt(prompt=prompt_data) if trigger_word is None: raise ValueError("trigger_word must be provided when using force_think logic") need_nl = needs_newline(prompt_data) prefix = trigger_word + ("\n" if need_nl else "") pref_txt = add_prefix(prompt_data, prefix, eos_token) pref_prompt = TextPrompt(prompt=pref_txt) if mm_data is not None: orig_prompt["multi_modal_data"] = mm_data pref_prompt["multi_modal_data"] = copy.deepcopy(mm_data) orig_prompts.append(orig_prompt) pref_prompts.append(pref_prompt) results = self.generate( orig_prompts, sampling_params=sampling_params_route, use_tqdm=use_tqdm, lora_request=lora_request, ) need_force = [i for i, out in enumerate(results) if "" in out.outputs[0].text[:100]] if len(need_force) == 0: return results # early exit, nothing to redo prompts_force = [pref_prompts[i] for i in need_force] results_force = self.generate( prompts_force, sampling_params=sampling_params_force_think, use_tqdm=use_tqdm, lora_request=lora_request, ) for idx, new_out in zip(need_force, results_force): results[idx] = new_out return results