Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,951 Bytes
031a3f5 21c51e5 031a3f5 21c51e5 031a3f5 004c8e7 031a3f5 79da5f0 031a3f5 004c8e7 031a3f5 dde83e1 031a3f5 004c8e7 031a3f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from .config import ME_LLAMA_MODEL, FALLBACK_MODEL
import logging
class ModelManager:
def __init__(self):
self.model = None
self.tokenizer = None
def load(self):
if self.model is not None and self.tokenizer is not None:
return
try:
logging.info(f"ME_LLAMA_MODEL type: {type(ME_LLAMA_MODEL)}, value: {ME_LLAMA_MODEL}")
self.tokenizer = AutoTokenizer.from_pretrained(ME_LLAMA_MODEL, trust_remote_code=True)
self.model = AutoModelForCausalLM.from_pretrained(
ME_LLAMA_MODEL,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
except Exception as e:
print(f"Error loading model: {e}")
print(f"Falling back to {FALLBACK_MODEL}...")
self.tokenizer = AutoTokenizer.from_pretrained(FALLBACK_MODEL)
self.model = AutoModelForCausalLM.from_pretrained(
FALLBACK_MODEL,
torch_dtype=torch.float16,
device_map="auto"
)
def generate(self, prompt, max_new_tokens=1000, temperature=0.5, top_p=0.9):
self.load()
inputs = self.tokenizer(prompt, return_tensors="pt")
if torch.cuda.is_available():
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id
)
return self.tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True) |