Spaces:
Runtime error
Runtime error
import threading | |
import os | |
from jina import Executor, requests | |
from docarray import BaseDoc, DocList | |
# transformers imports are done lazily in _ensure_model to prevent heavy import on module load | |
class CodeInput(BaseDoc): | |
code: str | |
class CodeOutput(BaseDoc): | |
result: str | |
class CodeDebugger(Executor): | |
""" | |
Jina Executor that lazy-loads a Hugging Face seq2seq model on first request. | |
Use environment variable JINA_SKIP_MODEL_LOAD=1 to skip model loading (useful in CI/builds). | |
""" | |
def __init__(self, model_name: str = "Girinath11/aiml_code_debug_model", **kwargs): | |
super().__init__(**kwargs) | |
self.model_name = model_name | |
self._lock = threading.Lock() | |
self.tokenizer = None | |
self.model = None | |
# optional: allow overriding max_new_tokens via env var | |
self.max_new_tokens = int(os.environ.get("MAX_NEW_TOKENS", "256")) | |
def _ensure_model(self): | |
""" | |
Load tokenizer & model once in a thread-safe manner. | |
If JINA_SKIP_MODEL_LOAD is set to "1", skip loading (helpful for hub builds). | |
""" | |
skip = os.environ.get("JINA_SKIP_MODEL_LOAD", "0") == "1" | |
if skip: | |
self.logger.warning("JINA_SKIP_MODEL_LOAD=1 set β skipping HF model load.") | |
return | |
if self.model is None or self.tokenizer is None: | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM # lazy import | |
with self._lock: | |
if self.model is None or self.tokenizer is None: | |
self.logger.info(f"Loading model {self.model_name} ...") | |
# If HF_TOKEN is set, transformers will use it automatically via huggingface-cli login | |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name) | |
self.logger.info("Model loaded successfully.") | |
def debug(self, docs: DocList[CodeInput], **kwargs) -> DocList[CodeOutput]: | |
# Lazy load model at request time | |
self._ensure_model() | |
results = [] | |
if self.model is None or self.tokenizer is None: | |
# If model was skipped, return a helpful message | |
for _ in docs: | |
results.append(CodeOutput(result="Model not loaded (JINA_SKIP_MODEL_LOAD=1).")) | |
return DocList[CodeOutput](results) | |
for doc in docs: | |
# make sure input is string | |
code_text = doc.code if isinstance(doc.code, str) else str(doc.code) | |
inputs = self.tokenizer(code_text, return_tensors="pt", padding=True, truncation=True) | |
outputs = self.model.generate(**inputs, max_new_tokens=self.max_new_tokens) | |
result = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
results.append(CodeOutput(result=result)) | |
return DocList[CodeOutput](results) | |