jina-code-debugger / model_wrapper.py
Girinath11's picture
Upload 7 files
87ce049 verified
import os
import threading
class CodeDebuggerWrapper:
"""
Simple wrapper that loads the same HF model and exposes debug(code: str) -> str
This is used by app.py (Gradio).
"""
def __init__(self, model_name: str = "Girinath11/aiml_code_debug_model"):
self.model_name = model_name
self._lock = threading.Lock()
self.tokenizer = None
self.model = None
self.max_new_tokens = int(os.environ.get("MAX_NEW_TOKENS", "256"))
self._ensure_model()
def _ensure_model(self):
# allow skipping in environments where you don't want to download weights
skip = os.environ.get("SKIP_MODEL_LOAD", "0") == "1"
if skip:
print("SKIP_MODEL_LOAD=1 -> not loading model.")
return
if self.model is None or self.tokenizer is None:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
with self._lock:
if self.model is None or self.tokenizer is None:
print(f"Loading model {self.model_name} ...")
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
print("Model loaded.")
def debug(self, code: str) -> str:
if self.model is None or self.tokenizer is None:
return "Model not loaded. Set SKIP_MODEL_LOAD=0 and ensure HF token is available if model is private."
inputs = self.tokenizer(code, return_tensors="pt", padding=True, truncation=True)
outputs = self.model.generate(**inputs, max_new_tokens=self.max_new_tokens)
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)