File size: 664 Bytes
8e44141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# inference.py

from transformers import pipeline

# This will be called once, at container startup
def init():
    global generator
    generator = pipeline(
        "text2text-generation",
        model=".",
        tokenizer=".",
        device=0,           # GPU 0
        max_length=128,
        do_sample=True,
        top_p=0.9,
        temperature=0.7
    )

# This will be called for every request
def run(request: dict):
    """
    Expects: { "inputs": "<your-prompt>" }
    Returns: { "generated_text": "..." }
    """
    prompt = request.get("inputs", "")
    outputs = generator(prompt)
    return { "generated_text": outputs[0]["generated_text"] }