File size: 2,947 Bytes
87ce049
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import threading
import os
from jina import Executor, requests
from docarray import BaseDoc, DocList

# transformers imports are done lazily in _ensure_model to prevent heavy import on module load
class CodeInput(BaseDoc):
    code: str

class CodeOutput(BaseDoc):
    result: str

class CodeDebugger(Executor):
    """
    Jina Executor that lazy-loads a Hugging Face seq2seq model on first request.
    Use environment variable JINA_SKIP_MODEL_LOAD=1 to skip model loading (useful in CI/builds).
    """
    def __init__(self, model_name: str = "Girinath11/aiml_code_debug_model", **kwargs):
        super().__init__(**kwargs)
        self.model_name = model_name
        self._lock = threading.Lock()
        self.tokenizer = None
        self.model = None
        # optional: allow overriding max_new_tokens via env var
        self.max_new_tokens = int(os.environ.get("MAX_NEW_TOKENS", "256"))

    def _ensure_model(self):
        """
        Load tokenizer & model once in a thread-safe manner.
        If JINA_SKIP_MODEL_LOAD is set to "1", skip loading (helpful for hub builds).
        """
        skip = os.environ.get("JINA_SKIP_MODEL_LOAD", "0") == "1"
        if skip:
            self.logger.warning("JINA_SKIP_MODEL_LOAD=1 set β€” skipping HF model load.")
            return

        if self.model is None or self.tokenizer is None:
            from transformers import AutoTokenizer, AutoModelForSeq2SeqLM  # lazy import
            with self._lock:
                if self.model is None or self.tokenizer is None:
                    self.logger.info(f"Loading model {self.model_name} ...")
                    # If HF_TOKEN is set, transformers will use it automatically via huggingface-cli login
                    self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
                    self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
                    self.logger.info("Model loaded successfully.")

    @requests
    def debug(self, docs: DocList[CodeInput], **kwargs) -> DocList[CodeOutput]:
        # Lazy load model at request time
        self._ensure_model()

        results = []
        if self.model is None or self.tokenizer is None:
            # If model was skipped, return a helpful message
            for _ in docs:
                results.append(CodeOutput(result="Model not loaded (JINA_SKIP_MODEL_LOAD=1)."))
            return DocList[CodeOutput](results)

        for doc in docs:
            # make sure input is string
            code_text = doc.code if isinstance(doc.code, str) else str(doc.code)
            inputs = self.tokenizer(code_text, return_tensors="pt", padding=True, truncation=True)
            outputs = self.model.generate(**inputs, max_new_tokens=self.max_new_tokens)
            result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            results.append(CodeOutput(result=result))

        return DocList[CodeOutput](results)