import threading import os from jina import Executor, requests from docarray import BaseDoc, DocList # transformers imports are done lazily in _ensure_model to prevent heavy import on module load class CodeInput(BaseDoc): code: str class CodeOutput(BaseDoc): result: str class CodeDebugger(Executor): """ Jina Executor that lazy-loads a Hugging Face seq2seq model on first request. Use environment variable JINA_SKIP_MODEL_LOAD=1 to skip model loading (useful in CI/builds). """ def __init__(self, model_name: str = "Girinath11/aiml_code_debug_model", **kwargs): super().__init__(**kwargs) self.model_name = model_name self._lock = threading.Lock() self.tokenizer = None self.model = None # optional: allow overriding max_new_tokens via env var self.max_new_tokens = int(os.environ.get("MAX_NEW_TOKENS", "256")) def _ensure_model(self): """ Load tokenizer & model once in a thread-safe manner. If JINA_SKIP_MODEL_LOAD is set to "1", skip loading (helpful for hub builds). """ skip = os.environ.get("JINA_SKIP_MODEL_LOAD", "0") == "1" if skip: self.logger.warning("JINA_SKIP_MODEL_LOAD=1 set — skipping HF model load.") return if self.model is None or self.tokenizer is None: from transformers import AutoTokenizer, AutoModelForSeq2SeqLM # lazy import with self._lock: if self.model is None or self.tokenizer is None: self.logger.info(f"Loading model {self.model_name} ...") # If HF_TOKEN is set, transformers will use it automatically via huggingface-cli login self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name) self.logger.info("Model loaded successfully.") @requests def debug(self, docs: DocList[CodeInput], **kwargs) -> DocList[CodeOutput]: # Lazy load model at request time self._ensure_model() results = [] if self.model is None or self.tokenizer is None: # If model was skipped, return a helpful message for _ in docs: results.append(CodeOutput(result="Model not loaded (JINA_SKIP_MODEL_LOAD=1).")) return DocList[CodeOutput](results) for doc in docs: # make sure input is string code_text = doc.code if isinstance(doc.code, str) else str(doc.code) inputs = self.tokenizer(code_text, return_tensors="pt", padding=True, truncation=True) outputs = self.model.generate(**inputs, max_new_tokens=self.max_new_tokens) result = self.tokenizer.decode(outputs[0], skip_special_tokens=True) results.append(CodeOutput(result=result)) return DocList[CodeOutput](results)