File size: 4,018 Bytes
543fe4a de0bfd0 543fe4a de0bfd0 543fe4a de0bfd0 32cbc00 de0bfd0 543fe4a de0bfd0 543fe4a de0bfd0 543fe4a de0bfd0 543fe4a de0bfd0 4804703 f715cab 4804703 543fe4a 4804703 543fe4a 4804703 f715cab 4804703 de0bfd0 543fe4a de0bfd0 543fe4a f715cab 543fe4a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
from typing import Dict, List, Any
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.generation.logits_process import LogitsProcessorList, InfNanRemoveLogitsProcessor
from transformers_gad.grammar_utils import IncrementalGrammarConstraint
from transformers_gad.generation.logits_process import GrammarAlignedOracleLogitsProcessor
def safe_int_cast(str, default):
try:
return int(str)
except ValueError:
return default
class EndpointHandler():
def __init__(self, path=""):
# Preload
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float32
self.device = torch.device(DEVICE)
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = AutoModelForCausalLM.from_pretrained(path)
self.model.to(self.device)
self.model.to(dtype=DTYPE)
self.model.resize_token_embeddings(len(self.tokenizer))
self.model = torch.compile(self.model, mode='reduce-overhead', fullgraph=True)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
# do it!
MAX_NEW_TOKENS=512
MAX_TIME=30
TEMPERATURE = 1.0
REPETITION_PENALTY = 1.0
TOP_P = 1.0
TOP_K = 0
inputs = data.get("inputs", data)
grammar_str = data.get("grammar", "")
max_new_tokens = safe_int_cast(data.get("max-new-tokens"), MAX_NEW_TOKENS)
max_time = safe_int_cast(data.get("max-time"), MAX_TIME)
if grammar_str is None or len(grammar_str) == 0 or grammar_str.isspace():
logits_processors = None
gad_oracle_processor = None
else:
print("=== GOT GRAMMAR ===")
print(grammar_str)
print("===================")
grammar = IncrementalGrammarConstraint(grammar_str, "root", self.tokenizer)
# Initialize logits processor for the grammar
gad_oracle_processor = GrammarAlignedOracleLogitsProcessor(grammar)
inf_nan_remove_processor = InfNanRemoveLogitsProcessor()
logits_processors = LogitsProcessorList([
inf_nan_remove_processor,
gad_oracle_processor,
])
#input_ids = self.tokenizer([inputs], add_special_tokens=False, return_tensors="pt", padding=True)["input_ids"]
input_ids = self.tokenizer.apply_chat_template(
[{"role": "user", "content": inputs}],
tokenize=True,
add_generation_prompt=True,
return_tensors="pt"
)
input_ids = input_ids.to(self.model.device)
output = self.model.generate(
input_ids,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
max_time=max_time,
max_new_tokens=max_new_tokens,
top_p=TOP_P,
top_k=TOP_K,
repetition_penalty=REPETITION_PENALTY,
temperature=TEMPERATURE,
logits_processor=logits_processors,
num_return_sequences=1,
return_dict_in_generate=True,
output_scores=True
)
if gad_oracle_processor is not None:
gad_oracle_processor.reset()
# Detokenize generated output
input_length = 1 if self.model.config.is_encoder_decoder else input_ids.shape[1]
if (hasattr(output, "sequences")):
generated_tokens = output.sequences[:, input_length:]
else:
generated_tokens = output[:, input_length:]
generations = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
return generations |