File size: 5,463 Bytes
e73d95b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import re
from typing import Dict, List, Optional, Sequence, Union

from vllm import LLM, SamplingParams
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    apply_hf_chat_template,
    apply_mistral_chat_template,
    parse_chat_messages,
)
from vllm.inputs import PromptInputs, TextPrompt
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.transformers_utils.tokenizer import MistralTokenizer
from vllm.utils import is_list_of


_TAIL_WS_RE = re.compile(r"(?:\r?\n|\s)+$") 

def needs_newline(text: str) -> bool:
    """Return True when *text* does NOT already end with whitespace/newline."""
    return _TAIL_WS_RE.search(text[-8:]) is None  # inspect last few chars


def add_prefix(prompt: str, prefix: str, eos_token: str) -> str:
    """Insert *prefix* before the first generated token.

    Keeps EOS token at the very end if the template already appended it.
    """
    if prompt.endswith(eos_token):
        return prompt[:-len(eos_token)] + prefix + eos_token
    return prompt + prefix


class PrefixLLM(LLM):
    """vLLM LLM subclass that conditionally prepends *trigger_word*."""

    def route_chat(
        self,
        messages: Union[
            List[ChatCompletionMessageParam],
            List[List[ChatCompletionMessageParam]],
        ],
        sampling_params_route: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        sampling_params_force_think: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
        add_generation_prompt: bool = True,
        tools: Optional[List[Dict[str, any]]] = None,
        *,
        trigger_word: Optional[str] = None,
    ) -> List[RequestOutput]:
        """Drop-in replacement for `LLM.chat` with one extra keyword:

        Parameters
        ----------
        trigger_word : str | None, default None
            The prefix to inject.  If ``None`` → no prefix injection.
        """

        tokenizer = self.get_tokenizer()
        model_config = self.llm_engine.get_model_config()
        eos_token   = tokenizer.eos_token

        orig_prompts: List[Union[TokensPrompt, TextPrompt]] = []
        pref_prompts: List[Union[TokensPrompt, TextPrompt]] = []
        mm_payloads: List[Optional[Dict[str, Any]]] = []

        list_of_messages: List[List[ChatCompletionMessageParam]]

        # Handle multi and single conversations
        if is_list_of(messages, list):
            # messages is List[List[...]]
            list_of_messages = messages
        else:
            # messages is List[...]
            list_of_messages = [messages]

        prompts: List[Union[TokensPrompt, TextPrompt]] = []

        for msgs in list_of_messages:
            # ---- render chat template exactly once ----
            if isinstance(tokenizer, MistralTokenizer):
                prompt_data: Union[str, List[int]] = apply_mistral_chat_template(
                    tokenizer,
                    messages=msgs,
                    chat_template=chat_template,
                    add_generation_prompt=add_generation_prompt,
                    tools=tools,
                )
                mm_data = None  # mistral util returns already embedded image tokens
            else:
                conversation, mm_data = parse_chat_messages(msgs, model_config, tokenizer)
                prompt_data = apply_hf_chat_template(
                    tokenizer,
                    conversation=conversation,
                    chat_template=chat_template,
                    add_generation_prompt=add_generation_prompt,
                    tools=tools,
                )

            if is_list_of(prompt_data, int):
                raise NotImplementedError
            else:
                orig_prompt = TextPrompt(prompt=prompt_data)

                if trigger_word is None:
                    raise ValueError("trigger_word must be provided when using force_think logic")

                need_nl = needs_newline(prompt_data)
                prefix   = trigger_word + ("\n" if need_nl else "")
                pref_txt = add_prefix(prompt_data, prefix, eos_token)
                pref_prompt = TextPrompt(prompt=pref_txt)
            
            if mm_data is not None:
                orig_prompt["multi_modal_data"] = mm_data
                pref_prompt["multi_modal_data"] = copy.deepcopy(mm_data)
            
            orig_prompts.append(orig_prompt)
            pref_prompts.append(pref_prompt)

        results = self.generate(
            orig_prompts,
            sampling_params=sampling_params_route,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

        need_force = [i for i, out in enumerate(results) if "<specialLong>" in out.outputs[0].text[:100]]
        

        if len(need_force) == 0:
            return results  # early exit, nothing to redo

        prompts_force = [pref_prompts[i] for i in need_force]

        results_force = self.generate(
            prompts_force,
            sampling_params=sampling_params_force_think,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )
        
        for idx, new_out in zip(need_force, results_force):
            results[idx] = new_out

        return results