File size: 2,125 Bytes
543fe4a dfffec6 543fe4a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
from typing import Dict, List, Any
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.generation.logits_process import LogitsProcessorList, InfNanRemoveLogitsProcessor
from transformers_gad.grammar_utils import IncrementalGrammarConstraint
from transformers_gad.generation.logits_process import GrammarAlignedOracleLogitsProcessor
class EndpointHandler():
def __init__(self, path=""):
# Preload
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = AutoModelForCausalLM.from_pretrained(path)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
# do it!
inputs = data.get("inputs",data)
grammar_str = data.get("grammar", "")
MAX_NEW_TOKENS=4096
MAX_TIME=300
print(grammar_str)
grammar = IncrementalGrammarConstraint(grammar_str, "root", self.tokenizer)
# Initialize logits processor for the grammar
gad_oracle_processor = GrammarAlignedOracleLogitsProcessor(grammar)
inf_nan_remove_processor = InfNanRemoveLogitsProcessor()
logits_processors = LogitsProcessorList([
inf_nan_remove_processor,
gad_oracle_processor,
])
input_ids = self.tokenizer([inputs], add_special_tokens=False, return_tensors="pt")["input_ids"]
output = self.model.generate(
input_ids,
do_sample=True,
max_time=MAX_TIME,
max_new_tokens=MAX_NEW_TOKENS,
logits_processor=logits_processors
)
gad_oracle_processor.reset()
# Detokenize generated output
input_length = 1 if self.model.config.is_encoder_decoder else input_ids.shape[1]
if (hasattr(output, "sequences")):
generated_tokens = output.sequences[:, input_length:]
else:
generated_tokens = output[:, input_length:]
generations = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
return generations |