from typing import Dict, List, Any, Optional, Union import os import json import time import torch from threading import Thread import logging from transformers import ( AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria, BitsAndBytesConfig ) from peft import PeftModel # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler("lora_inference.log"), logging.StreamHandler() ] ) logger = logging.getLogger(__name__) class ImprovedJSONStoppingCriteria(StoppingCriteria): """ Stopping criteria that ensures JSON is complete before stopping. Only stops generation when a valid, complete JSON object is detected. """ def __init__(self, tokenizer): self.tokenizer = tokenizer self.generated = "" self.json_complete = False def __call__(self, input_ids, scores, **kwargs): # If we already found complete JSON, stop immediately if self.json_complete: return True # Decode current text text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True) # Skip early if no JSON structure detected if '{' not in text: return False # Don't stop if we don't have at least one closing brace if '}' not in text: return False # Check for complete JSON structure try: # First, try to find a valid JSON object start_pos = text.find('{') # Progressively validate from the first opening brace stack = [] end_pos = -1 for i, char in enumerate(text[start_pos:], start_pos): if char == '{': stack.append('{') elif char == '}': if stack: stack.pop() if not stack: # We have balanced braces end_pos = i potential_json = text[start_pos:end_pos+1] # Make sure this is actually valid JSON # and not just balanced braces try: # Parse JSON to validate parsed = json.loads(potential_json) # We need to make sure we have all required fields # For search_web or tool calls, verify arguments are complete if "calls" in parsed: for call in parsed.get("calls", []): # If we have a call with arguments, make sure they're complete if "arguments" in call: args = call.get("arguments", "") # If arguments is a string, it might be JSON itself if isinstance(args, str) and args.startswith("{"): # If the argument string starts with { but doesn't have a # closing }, it's incomplete if not args.endswith("}"): return False # Try to parse the arguments as JSON try: json.loads(args) except: # If we can't parse, the JSON is incomplete return False # All checks passed - we have valid, complete JSON self.json_complete = True return True except: # Not valid JSON, continue looking continue # Only stop with excessive braces if we already have a valid structure open_count = text.count('{') close_count = text.count('}') if close_count > open_count: # Check if we have a valid JSON by balancing fixed_text = text[start_pos:] stack = [] for i, char in enumerate(fixed_text): if char == '{': stack.append('{') elif char == '}': if stack: stack.pop() if not stack: try: potential_json = fixed_text[:i+1] parsed = json.loads(potential_json) self.json_complete = True return True except: pass except Exception: # Error in parsing or validation, don't stop pass return False class ExcessBraceStoppingCriteria(StoppingCriteria): """Stop generation if we're generating excessive closing braces""" def __init__(self, tokenizer): self.tokenizer = tokenizer def __call__(self, input_ids, scores, **kwargs): text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True) # Only trigger if we have JSON content if '{' in text and '}' in text: # Check if we're generating excessive closing braces open_count = text.count('{') close_count = text.count('}') # If we have more closing than opening braces, stop generation if close_count > open_count + 3: # Allow a small buffer return True return False def fix_json_output(text): """Fix malformed JSON with excessive closing braces.""" if '{' not in text or '}' not in text: return text # Count opening and closing braces open_count = text.count('{') close_count = text.count('}') # If balanced or too few closing braces, return as-is if open_count >= close_count: return text # Track JSON depth to find valid JSON object start_pos = text.find('{') depth = 0 for i, char in enumerate(text[start_pos:], start_pos): if char == '{': depth += 1 elif char == '}': depth -= 1 if depth == 0: # Found balanced JSON, return up to this point return text[:i+1] # If we can't balance it with depth tracking, simply truncate return text[:start_pos + text[start_pos:].find('}')+1] def create_stopping_criteria(tokenizer, stop_tokens): """Create stopping criteria from tokens""" stop_token_ids = [] for stop_token in stop_tokens: token_ids = tokenizer.encode(stop_token, add_special_tokens=False) if len(token_ids) > 0: stop_token_ids.append(token_ids[-1]) return StoppingCriteriaList([StopOnTokens(tokenizer, stop_token_ids)]) class StopOnTokens(StoppingCriteria): """Custom stopping criteria for text generation.""" def __init__(self, tokenizer, stop_token_ids): self.tokenizer = tokenizer self.stop_token_ids = stop_token_ids def __call__(self, input_ids, scores, **kwargs): for stop_id in self.stop_token_ids: if input_ids[0][-1] == stop_id: return True return False class EndpointHandler: def __init__(self, path=""): """ Initialize the handler by loading model and tokenizer Args: path (str): Path to the model directory (uses environment variable if not provided) """ # Get model path from environment or from argument model_path = path if path else os.environ.get("MODEL_PATH", "") adapter_path = os.environ.get("ADAPTER_PATH", None) logger.info(f"Loading model from {model_path}") # Determine quantization settings from environment use_8bit = os.environ.get("USE_8BIT", "False").lower() == "true" use_4bit = os.environ.get("USE_4BIT", "False").lower() == "true" device = os.environ.get("DEVICE", "auto") # Load tokenizer logger.info(f"Loading tokenizer from {model_path}") self.tokenizer = AutoTokenizer.from_pretrained(model_path) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token # Load model with appropriate configuration if use_4bit: logger.info("Using 4-bit quantization for inference...") quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4" ) base_model = AutoModelForCausalLM.from_pretrained( model_path, quantization_config=quantization_config, device_map=device, low_cpu_mem_usage=True ) elif use_8bit: logger.info("Using 8-bit quantization for inference...") base_model = AutoModelForCausalLM.from_pretrained( model_path, load_in_8bit=True, device_map=device, low_cpu_mem_usage=True ) else: logger.info("Loading model in float16 precision...") base_model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=torch.float16, device_map=device, low_cpu_mem_usage=True ) # Apply adapter if specified if adapter_path: logger.info(f"Loading LoRA adapter from {adapter_path}") self.model = PeftModel.from_pretrained(base_model, adapter_path) else: self.model = base_model logger.info("No adapter path provided, using base model only") self.model.eval() # Try to use torch.compile for additional performance if available if torch.__version__ >= "2.0.0" and os.environ.get("USE_COMPILE", "False").lower() == "true": try: logger.info("Applying torch.compile for additional optimization...") self.model = torch.compile(self.model) logger.info("Model successfully compiled!") except Exception as e: logger.warning(f"Could not compile model: {e}") logger.info("Model and tokenizer loaded successfully!") def format_conversation(self, messages, add_generation_prompt=True): """Format a conversation using the tokenizer's chat template""" return self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=add_generation_prompt ) def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Process inference request Args: data (Dict[str, Any]): Request data containing inputs and parameters Returns: List[Dict[str, Any]]: List of response dictionaries """ start_time = time.time() # Extract input data and parameters inputs = data.get("inputs", []) parameters = data.get("parameters", {}) # Parse generation parameters with defaults max_new_tokens = parameters.get("max_new_tokens", 512) temperature = parameters.get("temperature", 0.7) top_p = parameters.get("top_p", 0.95) do_sample = parameters.get("do_sample", temperature > 0.1) stream = parameters.get("stream", False) json_mode = parameters.get("json_mode", False) system_prompt = """ <|begin_of_text|><|start_header_id|>system<|end_header_id|> Cutting Knowledge Date: December 2023 Today Date: 16 March 2025 When you receive a tool call response, use the output to format an answer to the orginal user question. You are a helpful assistant with tool calling capabilities.<|eot_id|><|start_header_id|>user<|end_header_id|> Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. Do not use variables. { "type": "function", "function": { "name": "llm", "description": "Access your internal knowledge as an LLM to provide general information, explanations, and guidance without searching the web.", }, "type": "function", "function": { "name": "search_web", "description": "Fetch up-to-date, specific, or contextual information that may not be stable or broadly known.", "arguments": { "type": "object", "properties": { "query": { "type": "string", "description": "A search query used to find relevant information on the web" }, "required": ["query"] } }, "type": "function", "function": { "name": "calculate", "description": "used for precise mathematical computations", "arguments": { "type": "object", "properties": { "expression": { "type": "string", "description": "An executable mathmatical javascript expression" }, "required": ["expression"] } }, "type": "function", "function": { "name": "open_url", "description": "Opens or shows a website to the user with the specified URL", "arguments": { "type": "object", "properties": { "url": { "type": "string", "description": "The URL of the website to open or show to the user" }, "required": ["url"] } }, "type": "function", "function": { "name": "fetch_web_content", "description": "The URL or webiste of the content to fetch, get, summarize, or analyze" "arguments": { "type": "object", "properties": { "url": { "type": "string", "description": "The URL of the content to fetch, get, summarize, or analyze" }, "required": ["url"] } }, "type": "function", "function": { "name": "unsupported_capability", "description": "Use this function to indicate that the requested action is not supported or not possible.", "arguments": { "type": "object", "properties": { "capability": { "type": "string", "description": "The capability requested by the user that is not supported" }, "required": ["capability"] } }, } Question: """ # Check if input is in various formats and normalize to messages format if isinstance(inputs, str): # Create simple chat with user message messages = [{"role": "user", "content": inputs}] elif isinstance(inputs, dict) and "messages" in inputs: # Input is already in chat format messages = inputs["messages"] elif isinstance(inputs, list): # Assume this is a list of message dicts messages = inputs else: # Invalid input format return [{"error": "Invalid input format. Please provide a string, a list of messages, or a dict with 'messages' key."}] # Prepare conversation with system prompt if provided conversation = [] if system_prompt: conversation.append({"role": "system", "content": system_prompt}) conversation.extend(messages) # Format the conversation prompt = self.format_conversation(conversation) # Tokenize the prompt inputs_dict = self.tokenizer(prompt, return_tensors="pt") inputs_dict = {k: v.to(self.model.device) for k, v in inputs_dict.items()} # Configure generation parameters generation_config = { "max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "do_sample": do_sample, "pad_token_id": self.tokenizer.pad_token_id, } # Add JSON-specific settings if needed if json_mode: stop_tokens = ["\n\n", "\n}", "}\n", "}}", "} }", "}\n]", "}\n{"] stopping_criteria = create_stopping_criteria(self.tokenizer, stop_tokens) generation_config["stopping_criteria"] = stopping_criteria # Lower temperature for JSON mode to get more reliable outputs # but don't set to 0 as that might cause truncation issues temperature = min(temperature, 0.1) do_sample = False generation_config["do_sample"] = do_sample generation_config["temperature"] = temperature # Record input length for proper decoding input_length = inputs_dict["input_ids"].shape[1] generated_text = "" stream = False if stream: # Use streaming for interactive responses streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True) generation_config["streamer"] = streamer # Start generation in a thread thread = Thread(target=self.model.generate, kwargs={**inputs_dict, **generation_config}) thread.start() # Stream the output (for local testing) for text in streamer: generated_text += text # Apply JSON cleaning if needed and json_mode is enabled if json_mode and '{' in generated_text and '}' in generated_text: if generated_text.count('}') > generated_text.count('{'): fixed_text = fix_json_output(generated_text) if fixed_text != generated_text: logger.info("Fixed malformed JSON in response") generated_text = fixed_text else: # Non-streaming generation with torch.no_grad(): outputs = self.model.generate(**inputs_dict, **generation_config) # Decode the output generated_ids = outputs[0][input_length:] generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True) # Apply JSON cleaning if needed and json_mode is enabled if json_mode and '{' in generated_text and '}' in generated_text: if generated_text.count('}') > generated_text.count('{'): fixed_text = fix_json_output(generated_text) if fixed_text != generated_text: logger.info("Fixed malformed JSON in response") generated_text = fixed_text # Calculate processing time end_time = time.time() processing_time = end_time - start_time # Create response dictionary response = { "generated_text": generated_text, "processing_time": processing_time } # Include input token count if requested if parameters.get("return_token_count", False): response["input_token_count"] = input_length response["output_token_count"] = len(generated_text.split()) return [response] # For local testing if __name__ == "__main__": # Test the handler model_path = os.environ.get("MODEL_PATH", "./model") handler = EndpointHandler(model_path) # Test with a simple query test_data = { "inputs": "Explain the concept of machine learning in simple terms.", "parameters": { "max_new_tokens": 100, "temperature": 0.7, "system_prompt": "You are a helpful AI assistant." } } response = handler(test_data) print("\nTest Response:") print(json.dumps(response, indent=2)) # Test with chat format and JSON mode test_chat_data = { "inputs": { "messages": [ {"role": "user", "content": "Create a JSON object with information about the solar system. Include at least 3 planets with their name, diameter, and distance from the sun."} ] }, "parameters": { "max_new_tokens": 512, "temperature": 0.1, "json_mode": True, "system_prompt": "You are a helpful AI assistant that responds in JSON format." } } chat_response = handler(test_chat_data) print("\nJSON Format Response:") print(json.dumps(chat_response, indent=2))