AutoL2S-7b / examples /prefixLLM.py
Feng Luo
update example usage
e73d95b
import re
from typing import Dict, List, Optional, Sequence, Union
from vllm import LLM, SamplingParams
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
apply_hf_chat_template,
apply_mistral_chat_template,
parse_chat_messages,
)
from vllm.inputs import PromptInputs, TextPrompt
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.transformers_utils.tokenizer import MistralTokenizer
from vllm.utils import is_list_of
_TAIL_WS_RE = re.compile(r"(?:\r?\n|\s)+$")
def needs_newline(text: str) -> bool:
"""Return True when *text* does NOT already end with whitespace/newline."""
return _TAIL_WS_RE.search(text[-8:]) is None # inspect last few chars
def add_prefix(prompt: str, prefix: str, eos_token: str) -> str:
"""Insert *prefix* before the first generated token.
Keeps EOS token at the very end if the template already appended it.
"""
if prompt.endswith(eos_token):
return prompt[:-len(eos_token)] + prefix + eos_token
return prompt + prefix
class PrefixLLM(LLM):
"""vLLM LLM subclass that conditionally prepends *trigger_word*."""
def route_chat(
self,
messages: Union[
List[ChatCompletionMessageParam],
List[List[ChatCompletionMessageParam]],
],
sampling_params_route: Optional[Union[SamplingParams,
List[SamplingParams]]] = None,
sampling_params_force_think: Optional[Union[SamplingParams,
List[SamplingParams]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
chat_template: Optional[str] = None,
add_generation_prompt: bool = True,
tools: Optional[List[Dict[str, any]]] = None,
*,
trigger_word: Optional[str] = None,
) -> List[RequestOutput]:
"""Drop-in replacement for `LLM.chat` with one extra keyword:
Parameters
----------
trigger_word : str | None, default None
The prefix to inject. If ``None`` → no prefix injection.
"""
tokenizer = self.get_tokenizer()
model_config = self.llm_engine.get_model_config()
eos_token = tokenizer.eos_token
orig_prompts: List[Union[TokensPrompt, TextPrompt]] = []
pref_prompts: List[Union[TokensPrompt, TextPrompt]] = []
mm_payloads: List[Optional[Dict[str, Any]]] = []
list_of_messages: List[List[ChatCompletionMessageParam]]
# Handle multi and single conversations
if is_list_of(messages, list):
# messages is List[List[...]]
list_of_messages = messages
else:
# messages is List[...]
list_of_messages = [messages]
prompts: List[Union[TokensPrompt, TextPrompt]] = []
for msgs in list_of_messages:
# ---- render chat template exactly once ----
if isinstance(tokenizer, MistralTokenizer):
prompt_data: Union[str, List[int]] = apply_mistral_chat_template(
tokenizer,
messages=msgs,
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
tools=tools,
)
mm_data = None # mistral util returns already embedded image tokens
else:
conversation, mm_data = parse_chat_messages(msgs, model_config, tokenizer)
prompt_data = apply_hf_chat_template(
tokenizer,
conversation=conversation,
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
tools=tools,
)
if is_list_of(prompt_data, int):
raise NotImplementedError
else:
orig_prompt = TextPrompt(prompt=prompt_data)
if trigger_word is None:
raise ValueError("trigger_word must be provided when using force_think logic")
need_nl = needs_newline(prompt_data)
prefix = trigger_word + ("\n" if need_nl else "")
pref_txt = add_prefix(prompt_data, prefix, eos_token)
pref_prompt = TextPrompt(prompt=pref_txt)
if mm_data is not None:
orig_prompt["multi_modal_data"] = mm_data
pref_prompt["multi_modal_data"] = copy.deepcopy(mm_data)
orig_prompts.append(orig_prompt)
pref_prompts.append(pref_prompt)
results = self.generate(
orig_prompts,
sampling_params=sampling_params_route,
use_tqdm=use_tqdm,
lora_request=lora_request,
)
need_force = [i for i, out in enumerate(results) if "<specialLong>" in out.outputs[0].text[:100]]
if len(need_force) == 0:
return results # early exit, nothing to redo
prompts_force = [pref_prompts[i] for i in need_force]
results_force = self.generate(
prompts_force,
sampling_params=sampling_params_force_think,
use_tqdm=use_tqdm,
lora_request=lora_request,
)
for idx, new_out in zip(need_force, results_force):
results[idx] = new_out
return results