File size: 394 Bytes
124b5b5 |
1 2 3 4 5 6 7 8 9 10 11 12 |
from functools import lru_cache
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
@lru_cache(maxsize=2)
def get_model(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
return tokenizer, model, device
|