John Graham Reynolds
remove incorrect apex package that comes from pypi
d6f96cd
raw
history blame
1.76 kB
import mlflow
import torch
import streamlit as st
from transformers import T5ForConditionalGeneration, T5Tokenizer
class InferenceBuilder:
def __init__(self):
# Load the necessary configuration from yaml
self.model_config = mlflow.models.ModelConfig(development_config="model_config.yaml")
self.cybersolve_config = self.model_config.get("cybersolve_config")
def load_tokenizer(self):
tokenizer_name = self.cybersolve_config.get("tokenizer_name")
# make sure we cache this so that it doesnt redownload each time
# cannot directly use @st.cache_resource on a method (function within a class) that has a self argument
@st.cache_resource # https://docs.streamlit.io/develop/concepts/architecture/caching
def load_and_cache_tokenizer(tokenizer_name):
tokenizer = T5Tokenizer.from_pretrained(tokenizer_name) # CyberSolve uses the same tokenizer as the base FLAN-T5 model
return tokenizer
return load_and_cache_tokenizer(tokenizer_name)
def load_model(self):
model_name = self.cybersolve_config.get("model_name")
# make sure we cache this so that it doesnt redownload each time
# cannot directly use @st.cache_resource on a method (function within a class) that has a self argument
@st.cache_resource # https://docs.streamlit.io/develop/concepts/architecture/caching
def load_and_cache_model(model_name):
model = T5ForConditionalGeneration.from_pretrained(model_name).to("cuda") # put the model on our Space's GPU
# model = T5ForConditionalGeneration.from_pretrained(model_name) # testing on CPU
return model
return load_and_cache_model(model_name)