# inference.py | |
from transformers import pipeline | |
# This will be called once, at container startup | |
def init(): | |
global generator | |
generator = pipeline( | |
"text2text-generation", | |
model=".", | |
tokenizer=".", | |
device=0, # GPU 0 | |
max_length=128, | |
do_sample=True, | |
top_p=0.9, | |
temperature=0.7 | |
) | |
# This will be called for every request | |
def run(request: dict): | |
""" | |
Expects: { "inputs": "<your-prompt>" } | |
Returns: { "generated_text": "..." } | |
""" | |
prompt = request.get("inputs", "") | |
outputs = generator(prompt) | |
return { "generated_text": outputs[0]["generated_text"] } | |