File size: 5,463 Bytes
e73d95b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import re
from typing import Dict, List, Optional, Sequence, Union
from vllm import LLM, SamplingParams
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
apply_hf_chat_template,
apply_mistral_chat_template,
parse_chat_messages,
)
from vllm.inputs import PromptInputs, TextPrompt
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.transformers_utils.tokenizer import MistralTokenizer
from vllm.utils import is_list_of
_TAIL_WS_RE = re.compile(r"(?:\r?\n|\s)+$")
def needs_newline(text: str) -> bool:
"""Return True when *text* does NOT already end with whitespace/newline."""
return _TAIL_WS_RE.search(text[-8:]) is None # inspect last few chars
def add_prefix(prompt: str, prefix: str, eos_token: str) -> str:
"""Insert *prefix* before the first generated token.
Keeps EOS token at the very end if the template already appended it.
"""
if prompt.endswith(eos_token):
return prompt[:-len(eos_token)] + prefix + eos_token
return prompt + prefix
class PrefixLLM(LLM):
"""vLLM LLM subclass that conditionally prepends *trigger_word*."""
def route_chat(
self,
messages: Union[
List[ChatCompletionMessageParam],
List[List[ChatCompletionMessageParam]],
],
sampling_params_route: Optional[Union[SamplingParams,
List[SamplingParams]]] = None,
sampling_params_force_think: Optional[Union[SamplingParams,
List[SamplingParams]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
chat_template: Optional[str] = None,
add_generation_prompt: bool = True,
tools: Optional[List[Dict[str, any]]] = None,
*,
trigger_word: Optional[str] = None,
) -> List[RequestOutput]:
"""Drop-in replacement for `LLM.chat` with one extra keyword:
Parameters
----------
trigger_word : str | None, default None
The prefix to inject. If ``None`` → no prefix injection.
"""
tokenizer = self.get_tokenizer()
model_config = self.llm_engine.get_model_config()
eos_token = tokenizer.eos_token
orig_prompts: List[Union[TokensPrompt, TextPrompt]] = []
pref_prompts: List[Union[TokensPrompt, TextPrompt]] = []
mm_payloads: List[Optional[Dict[str, Any]]] = []
list_of_messages: List[List[ChatCompletionMessageParam]]
# Handle multi and single conversations
if is_list_of(messages, list):
# messages is List[List[...]]
list_of_messages = messages
else:
# messages is List[...]
list_of_messages = [messages]
prompts: List[Union[TokensPrompt, TextPrompt]] = []
for msgs in list_of_messages:
# ---- render chat template exactly once ----
if isinstance(tokenizer, MistralTokenizer):
prompt_data: Union[str, List[int]] = apply_mistral_chat_template(
tokenizer,
messages=msgs,
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
tools=tools,
)
mm_data = None # mistral util returns already embedded image tokens
else:
conversation, mm_data = parse_chat_messages(msgs, model_config, tokenizer)
prompt_data = apply_hf_chat_template(
tokenizer,
conversation=conversation,
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
tools=tools,
)
if is_list_of(prompt_data, int):
raise NotImplementedError
else:
orig_prompt = TextPrompt(prompt=prompt_data)
if trigger_word is None:
raise ValueError("trigger_word must be provided when using force_think logic")
need_nl = needs_newline(prompt_data)
prefix = trigger_word + ("\n" if need_nl else "")
pref_txt = add_prefix(prompt_data, prefix, eos_token)
pref_prompt = TextPrompt(prompt=pref_txt)
if mm_data is not None:
orig_prompt["multi_modal_data"] = mm_data
pref_prompt["multi_modal_data"] = copy.deepcopy(mm_data)
orig_prompts.append(orig_prompt)
pref_prompts.append(pref_prompt)
results = self.generate(
orig_prompts,
sampling_params=sampling_params_route,
use_tqdm=use_tqdm,
lora_request=lora_request,
)
need_force = [i for i, out in enumerate(results) if "<specialLong>" in out.outputs[0].text[:100]]
if len(need_force) == 0:
return results # early exit, nothing to redo
prompts_force = [pref_prompts[i] for i in need_force]
results_force = self.generate(
prompts_force,
sampling_params=sampling_params_force_think,
use_tqdm=use_tqdm,
lora_request=lora_request,
)
for idx, new_out in zip(need_force, results_force):
results[idx] = new_out
return results |