import torch from transformers import AutoModelForCausalLM, AutoTokenizer from .config import ME_LLAMA_MODEL, FALLBACK_MODEL import logging class ModelManager: def __init__(self): self.model = None self.tokenizer = None def load(self): if self.model is not None and self.tokenizer is not None: return try: logging.info(f"ME_LLAMA_MODEL type: {type(ME_LLAMA_MODEL)}, value: {ME_LLAMA_MODEL}") self.tokenizer = AutoTokenizer.from_pretrained(ME_LLAMA_MODEL, trust_remote_code=True) self.model = AutoModelForCausalLM.from_pretrained( ME_LLAMA_MODEL, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True ) except Exception as e: print(f"Error loading model: {e}") print(f"Falling back to {FALLBACK_MODEL}...") self.tokenizer = AutoTokenizer.from_pretrained(FALLBACK_MODEL) self.model = AutoModelForCausalLM.from_pretrained( FALLBACK_MODEL, torch_dtype=torch.float16, device_map="auto" ) def generate(self, prompt, max_new_tokens=1000, temperature=0.5, top_p=0.9): self.load() inputs = self.tokenizer(prompt, return_tensors="pt") if torch.cuda.is_available(): inputs = {k: v.to(self.model.device) for k, v in inputs.items()} with torch.no_grad(): outputs = self.model.generate( inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=True, pad_token_id=self.tokenizer.eos_token_id ) return self.tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)