File size: 4,018 Bytes
543fe4a
 
de0bfd0
543fe4a
 
 
 
 
de0bfd0
 
 
 
 
 
543fe4a
 
 
de0bfd0
32cbc00
de0bfd0
 
 
543fe4a
de0bfd0
 
543fe4a
de0bfd0
 
 
 
543fe4a
 
 
de0bfd0
 
 
 
 
 
 
 
543fe4a
de0bfd0
 
 
4804703
 
f715cab
4804703
 
 
 
 
543fe4a
4804703
 
 
 
 
 
 
543fe4a
4804703
 
f715cab
4804703
 
 
 
de0bfd0
543fe4a
 
 
 
de0bfd0
 
 
 
 
 
 
 
 
 
 
 
543fe4a
 
f715cab
 
543fe4a
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from typing import Dict, List, Any

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.generation.logits_process import LogitsProcessorList, InfNanRemoveLogitsProcessor
from transformers_gad.grammar_utils import IncrementalGrammarConstraint
from transformers_gad.generation.logits_process import GrammarAlignedOracleLogitsProcessor

def safe_int_cast(str, default):
    try:
        return int(str)
    except ValueError:
        return default

class EndpointHandler():
    def __init__(self, path=""):
        # Preload
        DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
        DTYPE = torch.float32

        self.device = torch.device(DEVICE)

        self.tokenizer = AutoTokenizer.from_pretrained(path)
        self.tokenizer.pad_token = self.tokenizer.eos_token

        self.model = AutoModelForCausalLM.from_pretrained(path)
        self.model.to(self.device)
        self.model.to(dtype=DTYPE)
        self.model.resize_token_embeddings(len(self.tokenizer))
        self.model = torch.compile(self.model, mode='reduce-overhead', fullgraph=True)

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        # do it!
        MAX_NEW_TOKENS=512
        MAX_TIME=30
        TEMPERATURE = 1.0
        REPETITION_PENALTY = 1.0
        TOP_P = 1.0
        TOP_K = 0

        inputs = data.get("inputs", data)
        grammar_str = data.get("grammar", "")
        max_new_tokens = safe_int_cast(data.get("max-new-tokens"), MAX_NEW_TOKENS)
        max_time = safe_int_cast(data.get("max-time"), MAX_TIME)

        if grammar_str is None or len(grammar_str) == 0 or grammar_str.isspace():
            logits_processors = None
            gad_oracle_processor = None
        else:
            print("=== GOT GRAMMAR ===")
            print(grammar_str)
            print("===================")
            grammar = IncrementalGrammarConstraint(grammar_str, "root", self.tokenizer)

            # Initialize logits processor for the grammar
            gad_oracle_processor = GrammarAlignedOracleLogitsProcessor(grammar)
            inf_nan_remove_processor = InfNanRemoveLogitsProcessor()
            logits_processors = LogitsProcessorList([
                inf_nan_remove_processor,
                gad_oracle_processor,
            ])

        #input_ids = self.tokenizer([inputs], add_special_tokens=False, return_tensors="pt", padding=True)["input_ids"]
        input_ids = self.tokenizer.apply_chat_template(
            [{"role": "user", "content": inputs}],
            tokenize=True,
            add_generation_prompt=True,
            return_tensors="pt"
        )
        input_ids = input_ids.to(self.model.device)

        output = self.model.generate(
                    input_ids,
                    do_sample=True,
                    pad_token_id=self.tokenizer.eos_token_id,
                    eos_token_id=self.tokenizer.eos_token_id,
                    max_time=max_time,
                    max_new_tokens=max_new_tokens,
                    top_p=TOP_P,
                    top_k=TOP_K,
                    repetition_penalty=REPETITION_PENALTY,
                    temperature=TEMPERATURE,
                    logits_processor=logits_processors,
                    num_return_sequences=1,
                    return_dict_in_generate=True,
                    output_scores=True
                )
        
        if gad_oracle_processor is not None:
            gad_oracle_processor.reset()

        # Detokenize generated output
        input_length = 1 if self.model.config.is_encoder_decoder else input_ids.shape[1]
        if (hasattr(output, "sequences")):
            generated_tokens = output.sequences[:, input_length:]
        else:
            generated_tokens = output[:, input_length:]

        generations = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
        return generations