File size: 647 Bytes
1b8d8e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28

from transformers import pipeline

def inference(text, model, tokenizer, args={}):
    generator = pipeline(
        "text-generation", 
        model=model, 
        tokenizer=tokenizer, 
        device_map="auto"
    )
    
    # Default parameters that can be overridden by args
    params = {
        "max_new_tokens": 256,
        "temperature": 0.7,
        "top_p": 0.9,
        "top_k": 50,
        "do_sample": True,
        "repetition_penalty": 1.1
    }
    
    # Update with any user-provided parameters
    params.update(args)
    
    # Run generation
    result = generator(text, **params)
    return result[0]["generated_text"]