|
from typing import Dict, List, Any, Optional, Union |
|
import os |
|
import json |
|
import time |
|
import torch |
|
from threading import Thread |
|
import logging |
|
from transformers import ( |
|
AutoTokenizer, |
|
AutoModelForCausalLM, |
|
TextIteratorStreamer, |
|
StoppingCriteriaList, |
|
StoppingCriteria, |
|
BitsAndBytesConfig |
|
) |
|
from peft import PeftModel |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s', |
|
handlers=[ |
|
logging.FileHandler("lora_inference.log"), |
|
logging.StreamHandler() |
|
] |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
class ImprovedJSONStoppingCriteria(StoppingCriteria): |
|
""" |
|
Stopping criteria that ensures JSON is complete before stopping. |
|
Only stops generation when a valid, complete JSON object is detected. |
|
""" |
|
def __init__(self, tokenizer): |
|
self.tokenizer = tokenizer |
|
self.generated = "" |
|
self.json_complete = False |
|
|
|
def __call__(self, input_ids, scores, **kwargs): |
|
|
|
if self.json_complete: |
|
return True |
|
|
|
|
|
text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True) |
|
|
|
|
|
if '{' not in text: |
|
return False |
|
|
|
|
|
if '}' not in text: |
|
return False |
|
|
|
|
|
try: |
|
|
|
start_pos = text.find('{') |
|
|
|
|
|
stack = [] |
|
end_pos = -1 |
|
|
|
for i, char in enumerate(text[start_pos:], start_pos): |
|
if char == '{': |
|
stack.append('{') |
|
elif char == '}': |
|
if stack: |
|
stack.pop() |
|
if not stack: |
|
end_pos = i |
|
potential_json = text[start_pos:end_pos+1] |
|
|
|
|
|
|
|
try: |
|
|
|
parsed = json.loads(potential_json) |
|
|
|
|
|
|
|
if "calls" in parsed: |
|
for call in parsed.get("calls", []): |
|
|
|
if "arguments" in call: |
|
args = call.get("arguments", "") |
|
|
|
|
|
if isinstance(args, str) and args.startswith("{"): |
|
|
|
|
|
if not args.endswith("}"): |
|
return False |
|
|
|
|
|
try: |
|
json.loads(args) |
|
except: |
|
|
|
return False |
|
|
|
|
|
self.json_complete = True |
|
return True |
|
except: |
|
|
|
continue |
|
|
|
|
|
open_count = text.count('{') |
|
close_count = text.count('}') |
|
if close_count > open_count: |
|
|
|
fixed_text = text[start_pos:] |
|
stack = [] |
|
for i, char in enumerate(fixed_text): |
|
if char == '{': |
|
stack.append('{') |
|
elif char == '}': |
|
if stack: |
|
stack.pop() |
|
if not stack: |
|
try: |
|
potential_json = fixed_text[:i+1] |
|
parsed = json.loads(potential_json) |
|
self.json_complete = True |
|
return True |
|
except: |
|
pass |
|
except Exception: |
|
|
|
pass |
|
|
|
return False |
|
|
|
class ExcessBraceStoppingCriteria(StoppingCriteria): |
|
"""Stop generation if we're generating excessive closing braces""" |
|
def __init__(self, tokenizer): |
|
self.tokenizer = tokenizer |
|
|
|
def __call__(self, input_ids, scores, **kwargs): |
|
text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True) |
|
|
|
|
|
if '{' in text and '}' in text: |
|
|
|
open_count = text.count('{') |
|
close_count = text.count('}') |
|
|
|
|
|
if close_count > open_count + 3: |
|
return True |
|
|
|
return False |
|
|
|
def fix_json_output(text): |
|
"""Fix malformed JSON with excessive closing braces.""" |
|
if '{' not in text or '}' not in text: |
|
return text |
|
|
|
|
|
open_count = text.count('{') |
|
close_count = text.count('}') |
|
|
|
|
|
if open_count >= close_count: |
|
return text |
|
|
|
|
|
start_pos = text.find('{') |
|
depth = 0 |
|
for i, char in enumerate(text[start_pos:], start_pos): |
|
if char == '{': |
|
depth += 1 |
|
elif char == '}': |
|
depth -= 1 |
|
if depth == 0: |
|
|
|
return text[:i+1] |
|
|
|
|
|
return text[:start_pos + text[start_pos:].find('}')+1] |
|
|
|
def create_stopping_criteria(tokenizer, stop_tokens): |
|
"""Create stopping criteria from tokens""" |
|
stop_token_ids = [] |
|
for stop_token in stop_tokens: |
|
token_ids = tokenizer.encode(stop_token, add_special_tokens=False) |
|
if len(token_ids) > 0: |
|
stop_token_ids.append(token_ids[-1]) |
|
|
|
return StoppingCriteriaList([StopOnTokens(tokenizer, stop_token_ids)]) |
|
|
|
class StopOnTokens(StoppingCriteria): |
|
"""Custom stopping criteria for text generation.""" |
|
def __init__(self, tokenizer, stop_token_ids): |
|
self.tokenizer = tokenizer |
|
self.stop_token_ids = stop_token_ids |
|
|
|
def __call__(self, input_ids, scores, **kwargs): |
|
for stop_id in self.stop_token_ids: |
|
if input_ids[0][-1] == stop_id: |
|
return True |
|
return False |
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
""" |
|
Initialize the handler by loading model and tokenizer |
|
|
|
Args: |
|
path (str): Path to the model directory (uses environment variable if not provided) |
|
""" |
|
|
|
model_path = path if path else os.environ.get("MODEL_PATH", "") |
|
adapter_path = os.environ.get("ADAPTER_PATH", None) |
|
logger.info(f"Loading model from {model_path}") |
|
|
|
|
|
use_8bit = os.environ.get("USE_8BIT", "False").lower() == "true" |
|
use_4bit = os.environ.get("USE_4BIT", "False").lower() == "true" |
|
device = os.environ.get("DEVICE", "auto") |
|
|
|
|
|
logger.info(f"Loading tokenizer from {model_path}") |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
if self.tokenizer.pad_token is None: |
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
|
|
if use_4bit: |
|
logger.info("Using 4-bit quantization for inference...") |
|
quantization_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_compute_dtype=torch.float16, |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_quant_type="nf4" |
|
) |
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
model_path, |
|
quantization_config=quantization_config, |
|
device_map=device, |
|
low_cpu_mem_usage=True |
|
) |
|
elif use_8bit: |
|
logger.info("Using 8-bit quantization for inference...") |
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
model_path, |
|
load_in_8bit=True, |
|
device_map=device, |
|
low_cpu_mem_usage=True |
|
) |
|
else: |
|
logger.info("Loading model in float16 precision...") |
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
model_path, |
|
torch_dtype=torch.float16, |
|
device_map=device, |
|
low_cpu_mem_usage=True |
|
) |
|
|
|
|
|
if adapter_path: |
|
logger.info(f"Loading LoRA adapter from {adapter_path}") |
|
self.model = PeftModel.from_pretrained(base_model, adapter_path) |
|
else: |
|
self.model = base_model |
|
logger.info("No adapter path provided, using base model only") |
|
|
|
self.model.eval() |
|
|
|
|
|
if torch.__version__ >= "2.0.0" and os.environ.get("USE_COMPILE", "False").lower() == "true": |
|
try: |
|
logger.info("Applying torch.compile for additional optimization...") |
|
self.model = torch.compile(self.model) |
|
logger.info("Model successfully compiled!") |
|
except Exception as e: |
|
logger.warning(f"Could not compile model: {e}") |
|
|
|
logger.info("Model and tokenizer loaded successfully!") |
|
|
|
def format_conversation(self, messages, add_generation_prompt=True): |
|
"""Format a conversation using the tokenizer's chat template""" |
|
return self.tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=False, |
|
add_generation_prompt=add_generation_prompt |
|
) |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
""" |
|
Process inference request |
|
|
|
Args: |
|
data (Dict[str, Any]): Request data containing inputs and parameters |
|
|
|
Returns: |
|
List[Dict[str, Any]]: List of response dictionaries |
|
""" |
|
start_time = time.time() |
|
|
|
|
|
inputs = data.get("inputs", []) |
|
parameters = data.get("parameters", {}) |
|
|
|
|
|
max_new_tokens = parameters.get("max_new_tokens", 512) |
|
temperature = parameters.get("temperature", 0.7) |
|
top_p = parameters.get("top_p", 0.95) |
|
do_sample = parameters.get("do_sample", temperature > 0.1) |
|
stream = parameters.get("stream", False) |
|
json_mode = parameters.get("json_mode", False) |
|
system_prompt = """ |
|
<|begin_of_text|><|start_header_id|>system<|end_header_id|> |
|
|
|
Cutting Knowledge Date: December 2023 |
|
Today Date: 16 March 2025 |
|
|
|
When you receive a tool call response, use the output to format an answer to the orginal user question. |
|
|
|
You are a helpful assistant with tool calling capabilities.<|eot_id|><|start_header_id|>user<|end_header_id|> |
|
|
|
Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. |
|
|
|
Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. Do not use variables. |
|
{ |
|
"type": "function", |
|
"function": { |
|
"name": "llm", |
|
"description": "Access your internal knowledge as an LLM to provide general information, explanations, and guidance without searching the web.", |
|
}, |
|
|
|
"type": "function", |
|
"function": { |
|
"name": "search_web", |
|
"description": "Fetch up-to-date, specific, or contextual information that may not be stable or broadly known.", |
|
"arguments": { |
|
"type": "object", |
|
"properties": { |
|
"query": { |
|
"type": "string", |
|
"description": "A search query used to find relevant information on the web" |
|
}, |
|
"required": ["query"] |
|
} |
|
}, |
|
|
|
"type": "function", |
|
"function": { |
|
"name": "calculate", |
|
"description": "used for precise mathematical computations", |
|
"arguments": { |
|
"type": "object", |
|
"properties": { |
|
"expression": { |
|
"type": "string", |
|
"description": "An executable mathmatical javascript expression" |
|
}, |
|
"required": ["expression"] |
|
} |
|
}, |
|
|
|
"type": "function", |
|
"function": { |
|
"name": "open_url", |
|
"description": "Opens or shows a website to the user with the specified URL", |
|
"arguments": { |
|
"type": "object", |
|
"properties": { |
|
"url": { |
|
"type": "string", |
|
"description": "The URL of the website to open or show to the user" |
|
}, |
|
"required": ["url"] |
|
} |
|
}, |
|
|
|
"type": "function", |
|
"function": { |
|
"name": "fetch_web_content", |
|
"description": "The URL or webiste of the content to fetch, get, summarize, or analyze" |
|
"arguments": { |
|
"type": "object", |
|
"properties": { |
|
"url": { |
|
"type": "string", |
|
"description": "The URL of the content to fetch, get, summarize, or analyze" |
|
}, |
|
"required": ["url"] |
|
} |
|
}, |
|
|
|
"type": "function", |
|
"function": { |
|
"name": "unsupported_capability", |
|
"description": "Use this function to indicate that the requested action is not supported or not possible.", |
|
"arguments": { |
|
"type": "object", |
|
"properties": { |
|
"capability": { |
|
"type": "string", |
|
"description": "The capability requested by the user that is not supported" |
|
}, |
|
"required": ["capability"] |
|
} |
|
}, |
|
} |
|
|
|
Question: |
|
|
|
""" |
|
|
|
|
|
if isinstance(inputs, str): |
|
|
|
messages = [{"role": "user", "content": inputs}] |
|
elif isinstance(inputs, dict) and "messages" in inputs: |
|
|
|
messages = inputs["messages"] |
|
elif isinstance(inputs, list): |
|
|
|
messages = inputs |
|
else: |
|
|
|
return [{"error": "Invalid input format. Please provide a string, a list of messages, or a dict with 'messages' key."}] |
|
|
|
|
|
conversation = [] |
|
if system_prompt: |
|
conversation.append({"role": "system", "content": system_prompt}) |
|
conversation.extend(messages) |
|
|
|
|
|
prompt = self.format_conversation(conversation) |
|
|
|
|
|
inputs_dict = self.tokenizer(prompt, return_tensors="pt") |
|
inputs_dict = {k: v.to(self.model.device) for k, v in inputs_dict.items()} |
|
|
|
|
|
generation_config = { |
|
"max_new_tokens": max_new_tokens, |
|
"temperature": temperature, |
|
"top_p": top_p, |
|
"do_sample": do_sample, |
|
"pad_token_id": self.tokenizer.pad_token_id, |
|
} |
|
|
|
|
|
if json_mode: |
|
stop_tokens = ["\n\n", "\n}", "}\n", "}}", "} }", "}\n]", "}\n{"] |
|
stopping_criteria = create_stopping_criteria(self.tokenizer, stop_tokens) |
|
generation_config["stopping_criteria"] = stopping_criteria |
|
|
|
|
|
|
|
temperature = min(temperature, 0.1) |
|
do_sample = False |
|
generation_config["do_sample"] = do_sample |
|
generation_config["temperature"] = temperature |
|
|
|
|
|
input_length = inputs_dict["input_ids"].shape[1] |
|
|
|
generated_text = "" |
|
stream = False |
|
if stream: |
|
|
|
streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
generation_config["streamer"] = streamer |
|
|
|
|
|
thread = Thread(target=self.model.generate, kwargs={**inputs_dict, **generation_config}) |
|
thread.start() |
|
|
|
|
|
for text in streamer: |
|
generated_text += text |
|
|
|
|
|
if json_mode and '{' in generated_text and '}' in generated_text: |
|
if generated_text.count('}') > generated_text.count('{'): |
|
fixed_text = fix_json_output(generated_text) |
|
if fixed_text != generated_text: |
|
logger.info("Fixed malformed JSON in response") |
|
generated_text = fixed_text |
|
else: |
|
|
|
with torch.no_grad(): |
|
outputs = self.model.generate(**inputs_dict, **generation_config) |
|
|
|
|
|
generated_ids = outputs[0][input_length:] |
|
generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True) |
|
|
|
|
|
if json_mode and '{' in generated_text and '}' in generated_text: |
|
if generated_text.count('}') > generated_text.count('{'): |
|
fixed_text = fix_json_output(generated_text) |
|
if fixed_text != generated_text: |
|
logger.info("Fixed malformed JSON in response") |
|
generated_text = fixed_text |
|
|
|
|
|
end_time = time.time() |
|
processing_time = end_time - start_time |
|
|
|
|
|
response = { |
|
"generated_text": generated_text, |
|
"processing_time": processing_time |
|
} |
|
|
|
|
|
if parameters.get("return_token_count", False): |
|
response["input_token_count"] = input_length |
|
response["output_token_count"] = len(generated_text.split()) |
|
|
|
return [response] |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
model_path = os.environ.get("MODEL_PATH", "./model") |
|
handler = EndpointHandler(model_path) |
|
|
|
|
|
test_data = { |
|
"inputs": "Explain the concept of machine learning in simple terms.", |
|
"parameters": { |
|
"max_new_tokens": 100, |
|
"temperature": 0.7, |
|
"system_prompt": "You are a helpful AI assistant." |
|
} |
|
} |
|
|
|
response = handler(test_data) |
|
print("\nTest Response:") |
|
print(json.dumps(response, indent=2)) |
|
|
|
|
|
test_chat_data = { |
|
"inputs": { |
|
"messages": [ |
|
{"role": "user", "content": "Create a JSON object with information about the solar system. Include at least 3 planets with their name, diameter, and distance from the sun."} |
|
] |
|
}, |
|
"parameters": { |
|
"max_new_tokens": 512, |
|
"temperature": 0.1, |
|
"json_mode": True, |
|
"system_prompt": "You are a helpful AI assistant that responds in JSON format." |
|
} |
|
} |
|
|
|
chat_response = handler(test_chat_data) |
|
print("\nJSON Format Response:") |
|
print(json.dumps(chat_response, indent=2)) |