medbot_2 / medbot /model.py
Thanush
Increase max_new_tokens in generate method to 1000 for enhanced output capacity
dde83e1
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from .config import ME_LLAMA_MODEL, FALLBACK_MODEL
import logging
class ModelManager:
def __init__(self):
self.model = None
self.tokenizer = None
def load(self):
if self.model is not None and self.tokenizer is not None:
return
try:
logging.info(f"ME_LLAMA_MODEL type: {type(ME_LLAMA_MODEL)}, value: {ME_LLAMA_MODEL}")
self.tokenizer = AutoTokenizer.from_pretrained(ME_LLAMA_MODEL, trust_remote_code=True)
self.model = AutoModelForCausalLM.from_pretrained(
ME_LLAMA_MODEL,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
except Exception as e:
print(f"Error loading model: {e}")
print(f"Falling back to {FALLBACK_MODEL}...")
self.tokenizer = AutoTokenizer.from_pretrained(FALLBACK_MODEL)
self.model = AutoModelForCausalLM.from_pretrained(
FALLBACK_MODEL,
torch_dtype=torch.float16,
device_map="auto"
)
def generate(self, prompt, max_new_tokens=1000, temperature=0.5, top_p=0.9):
self.load()
inputs = self.tokenizer(prompt, return_tensors="pt")
if torch.cuda.is_available():
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id
)
return self.tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)