matt-bcny's picture
Update handler.py
fdfb021 verified
from typing import Dict, List, Any, Optional, Union
import os
import json
import time
import torch
from threading import Thread
import logging
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TextIteratorStreamer,
StoppingCriteriaList,
StoppingCriteria,
BitsAndBytesConfig
)
from peft import PeftModel
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("lora_inference.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
class ImprovedJSONStoppingCriteria(StoppingCriteria):
"""
Stopping criteria that ensures JSON is complete before stopping.
Only stops generation when a valid, complete JSON object is detected.
"""
def __init__(self, tokenizer):
self.tokenizer = tokenizer
self.generated = ""
self.json_complete = False
def __call__(self, input_ids, scores, **kwargs):
# If we already found complete JSON, stop immediately
if self.json_complete:
return True
# Decode current text
text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
# Skip early if no JSON structure detected
if '{' not in text:
return False
# Don't stop if we don't have at least one closing brace
if '}' not in text:
return False
# Check for complete JSON structure
try:
# First, try to find a valid JSON object
start_pos = text.find('{')
# Progressively validate from the first opening brace
stack = []
end_pos = -1
for i, char in enumerate(text[start_pos:], start_pos):
if char == '{':
stack.append('{')
elif char == '}':
if stack:
stack.pop()
if not stack: # We have balanced braces
end_pos = i
potential_json = text[start_pos:end_pos+1]
# Make sure this is actually valid JSON
# and not just balanced braces
try:
# Parse JSON to validate
parsed = json.loads(potential_json)
# We need to make sure we have all required fields
# For search_web or tool calls, verify arguments are complete
if "calls" in parsed:
for call in parsed.get("calls", []):
# If we have a call with arguments, make sure they're complete
if "arguments" in call:
args = call.get("arguments", "")
# If arguments is a string, it might be JSON itself
if isinstance(args, str) and args.startswith("{"):
# If the argument string starts with { but doesn't have a
# closing }, it's incomplete
if not args.endswith("}"):
return False
# Try to parse the arguments as JSON
try:
json.loads(args)
except:
# If we can't parse, the JSON is incomplete
return False
# All checks passed - we have valid, complete JSON
self.json_complete = True
return True
except:
# Not valid JSON, continue looking
continue
# Only stop with excessive braces if we already have a valid structure
open_count = text.count('{')
close_count = text.count('}')
if close_count > open_count:
# Check if we have a valid JSON by balancing
fixed_text = text[start_pos:]
stack = []
for i, char in enumerate(fixed_text):
if char == '{':
stack.append('{')
elif char == '}':
if stack:
stack.pop()
if not stack:
try:
potential_json = fixed_text[:i+1]
parsed = json.loads(potential_json)
self.json_complete = True
return True
except:
pass
except Exception:
# Error in parsing or validation, don't stop
pass
return False
class ExcessBraceStoppingCriteria(StoppingCriteria):
"""Stop generation if we're generating excessive closing braces"""
def __init__(self, tokenizer):
self.tokenizer = tokenizer
def __call__(self, input_ids, scores, **kwargs):
text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
# Only trigger if we have JSON content
if '{' in text and '}' in text:
# Check if we're generating excessive closing braces
open_count = text.count('{')
close_count = text.count('}')
# If we have more closing than opening braces, stop generation
if close_count > open_count + 3: # Allow a small buffer
return True
return False
def fix_json_output(text):
"""Fix malformed JSON with excessive closing braces."""
if '{' not in text or '}' not in text:
return text
# Count opening and closing braces
open_count = text.count('{')
close_count = text.count('}')
# If balanced or too few closing braces, return as-is
if open_count >= close_count:
return text
# Track JSON depth to find valid JSON object
start_pos = text.find('{')
depth = 0
for i, char in enumerate(text[start_pos:], start_pos):
if char == '{':
depth += 1
elif char == '}':
depth -= 1
if depth == 0:
# Found balanced JSON, return up to this point
return text[:i+1]
# If we can't balance it with depth tracking, simply truncate
return text[:start_pos + text[start_pos:].find('}')+1]
def create_stopping_criteria(tokenizer, stop_tokens):
"""Create stopping criteria from tokens"""
stop_token_ids = []
for stop_token in stop_tokens:
token_ids = tokenizer.encode(stop_token, add_special_tokens=False)
if len(token_ids) > 0:
stop_token_ids.append(token_ids[-1])
return StoppingCriteriaList([StopOnTokens(tokenizer, stop_token_ids)])
class StopOnTokens(StoppingCriteria):
"""Custom stopping criteria for text generation."""
def __init__(self, tokenizer, stop_token_ids):
self.tokenizer = tokenizer
self.stop_token_ids = stop_token_ids
def __call__(self, input_ids, scores, **kwargs):
for stop_id in self.stop_token_ids:
if input_ids[0][-1] == stop_id:
return True
return False
class EndpointHandler:
def __init__(self, path=""):
"""
Initialize the handler by loading model and tokenizer
Args:
path (str): Path to the model directory (uses environment variable if not provided)
"""
# Get model path from environment or from argument
model_path = path if path else os.environ.get("MODEL_PATH", "")
adapter_path = os.environ.get("ADAPTER_PATH", None)
logger.info(f"Loading model from {model_path}")
# Determine quantization settings from environment
use_8bit = os.environ.get("USE_8BIT", "False").lower() == "true"
use_4bit = os.environ.get("USE_4BIT", "False").lower() == "true"
device = os.environ.get("DEVICE", "auto")
# Load tokenizer
logger.info(f"Loading tokenizer from {model_path}")
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Load model with appropriate configuration
if use_4bit:
logger.info("Using 4-bit quantization for inference...")
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
base_model = AutoModelForCausalLM.from_pretrained(
model_path,
quantization_config=quantization_config,
device_map=device,
low_cpu_mem_usage=True
)
elif use_8bit:
logger.info("Using 8-bit quantization for inference...")
base_model = AutoModelForCausalLM.from_pretrained(
model_path,
load_in_8bit=True,
device_map=device,
low_cpu_mem_usage=True
)
else:
logger.info("Loading model in float16 precision...")
base_model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16,
device_map=device,
low_cpu_mem_usage=True
)
# Apply adapter if specified
if adapter_path:
logger.info(f"Loading LoRA adapter from {adapter_path}")
self.model = PeftModel.from_pretrained(base_model, adapter_path)
else:
self.model = base_model
logger.info("No adapter path provided, using base model only")
self.model.eval()
# Try to use torch.compile for additional performance if available
if torch.__version__ >= "2.0.0" and os.environ.get("USE_COMPILE", "False").lower() == "true":
try:
logger.info("Applying torch.compile for additional optimization...")
self.model = torch.compile(self.model)
logger.info("Model successfully compiled!")
except Exception as e:
logger.warning(f"Could not compile model: {e}")
logger.info("Model and tokenizer loaded successfully!")
def format_conversation(self, messages, add_generation_prompt=True):
"""Format a conversation using the tokenizer's chat template"""
return self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=add_generation_prompt
)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Process inference request
Args:
data (Dict[str, Any]): Request data containing inputs and parameters
Returns:
List[Dict[str, Any]]: List of response dictionaries
"""
start_time = time.time()
# Extract input data and parameters
inputs = data.get("inputs", [])
parameters = data.get("parameters", {})
# Parse generation parameters with defaults
max_new_tokens = parameters.get("max_new_tokens", 512)
temperature = parameters.get("temperature", 0.7)
top_p = parameters.get("top_p", 0.95)
do_sample = parameters.get("do_sample", temperature > 0.1)
stream = parameters.get("stream", False)
json_mode = parameters.get("json_mode", False)
system_prompt = """
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
Cutting Knowledge Date: December 2023
Today Date: 16 March 2025
When you receive a tool call response, use the output to format an answer to the orginal user question.
You are a helpful assistant with tool calling capabilities.<|eot_id|><|start_header_id|>user<|end_header_id|>
Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.
Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. Do not use variables.
{
"type": "function",
"function": {
"name": "llm",
"description": "Access your internal knowledge as an LLM to provide general information, explanations, and guidance without searching the web.",
},
"type": "function",
"function": {
"name": "search_web",
"description": "Fetch up-to-date, specific, or contextual information that may not be stable or broadly known.",
"arguments": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "A search query used to find relevant information on the web"
},
"required": ["query"]
}
},
"type": "function",
"function": {
"name": "calculate",
"description": "used for precise mathematical computations",
"arguments": {
"type": "object",
"properties": {
"expression": {
"type": "string",
"description": "An executable mathmatical javascript expression"
},
"required": ["expression"]
}
},
"type": "function",
"function": {
"name": "open_url",
"description": "Opens or shows a website to the user with the specified URL",
"arguments": {
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "The URL of the website to open or show to the user"
},
"required": ["url"]
}
},
"type": "function",
"function": {
"name": "fetch_web_content",
"description": "The URL or webiste of the content to fetch, get, summarize, or analyze"
"arguments": {
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "The URL of the content to fetch, get, summarize, or analyze"
},
"required": ["url"]
}
},
"type": "function",
"function": {
"name": "unsupported_capability",
"description": "Use this function to indicate that the requested action is not supported or not possible.",
"arguments": {
"type": "object",
"properties": {
"capability": {
"type": "string",
"description": "The capability requested by the user that is not supported"
},
"required": ["capability"]
}
},
}
Question:
"""
# Check if input is in various formats and normalize to messages format
if isinstance(inputs, str):
# Create simple chat with user message
messages = [{"role": "user", "content": inputs}]
elif isinstance(inputs, dict) and "messages" in inputs:
# Input is already in chat format
messages = inputs["messages"]
elif isinstance(inputs, list):
# Assume this is a list of message dicts
messages = inputs
else:
# Invalid input format
return [{"error": "Invalid input format. Please provide a string, a list of messages, or a dict with 'messages' key."}]
# Prepare conversation with system prompt if provided
conversation = []
if system_prompt:
conversation.append({"role": "system", "content": system_prompt})
conversation.extend(messages)
# Format the conversation
prompt = self.format_conversation(conversation)
# Tokenize the prompt
inputs_dict = self.tokenizer(prompt, return_tensors="pt")
inputs_dict = {k: v.to(self.model.device) for k, v in inputs_dict.items()}
# Configure generation parameters
generation_config = {
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"top_p": top_p,
"do_sample": do_sample,
"pad_token_id": self.tokenizer.pad_token_id,
}
# Add JSON-specific settings if needed
if json_mode:
stop_tokens = ["\n\n", "\n}", "}\n", "}}", "} }", "}\n]", "}\n{"]
stopping_criteria = create_stopping_criteria(self.tokenizer, stop_tokens)
generation_config["stopping_criteria"] = stopping_criteria
# Lower temperature for JSON mode to get more reliable outputs
# but don't set to 0 as that might cause truncation issues
temperature = min(temperature, 0.1)
do_sample = False
generation_config["do_sample"] = do_sample
generation_config["temperature"] = temperature
# Record input length for proper decoding
input_length = inputs_dict["input_ids"].shape[1]
generated_text = ""
stream = False
if stream:
# Use streaming for interactive responses
streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_config["streamer"] = streamer
# Start generation in a thread
thread = Thread(target=self.model.generate, kwargs={**inputs_dict, **generation_config})
thread.start()
# Stream the output (for local testing)
for text in streamer:
generated_text += text
# Apply JSON cleaning if needed and json_mode is enabled
if json_mode and '{' in generated_text and '}' in generated_text:
if generated_text.count('}') > generated_text.count('{'):
fixed_text = fix_json_output(generated_text)
if fixed_text != generated_text:
logger.info("Fixed malformed JSON in response")
generated_text = fixed_text
else:
# Non-streaming generation
with torch.no_grad():
outputs = self.model.generate(**inputs_dict, **generation_config)
# Decode the output
generated_ids = outputs[0][input_length:]
generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
# Apply JSON cleaning if needed and json_mode is enabled
if json_mode and '{' in generated_text and '}' in generated_text:
if generated_text.count('}') > generated_text.count('{'):
fixed_text = fix_json_output(generated_text)
if fixed_text != generated_text:
logger.info("Fixed malformed JSON in response")
generated_text = fixed_text
# Calculate processing time
end_time = time.time()
processing_time = end_time - start_time
# Create response dictionary
response = {
"generated_text": generated_text,
"processing_time": processing_time
}
# Include input token count if requested
if parameters.get("return_token_count", False):
response["input_token_count"] = input_length
response["output_token_count"] = len(generated_text.split())
return [response]
# For local testing
if __name__ == "__main__":
# Test the handler
model_path = os.environ.get("MODEL_PATH", "./model")
handler = EndpointHandler(model_path)
# Test with a simple query
test_data = {
"inputs": "Explain the concept of machine learning in simple terms.",
"parameters": {
"max_new_tokens": 100,
"temperature": 0.7,
"system_prompt": "You are a helpful AI assistant."
}
}
response = handler(test_data)
print("\nTest Response:")
print(json.dumps(response, indent=2))
# Test with chat format and JSON mode
test_chat_data = {
"inputs": {
"messages": [
{"role": "user", "content": "Create a JSON object with information about the solar system. Include at least 3 planets with their name, diameter, and distance from the sun."}
]
},
"parameters": {
"max_new_tokens": 512,
"temperature": 0.1,
"json_mode": True,
"system_prompt": "You are a helpful AI assistant that responds in JSON format."
}
}
chat_response = handler(test_chat_data)
print("\nJSON Format Response:")
print(json.dumps(chat_response, indent=2))