from transformers import pipeline | |
def inference(text, model, tokenizer, args={}): | |
generator = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
device_map="auto" | |
) | |
# Default parameters that can be overridden by args | |
params = { | |
"max_new_tokens": 256, | |
"temperature": 0.7, | |
"top_p": 0.9, | |
"top_k": 50, | |
"do_sample": True, | |
"repetition_penalty": 1.1 | |
} | |
# Update with any user-provided parameters | |
params.update(args) | |
# Run generation | |
result = generator(text, **params) | |
return result[0]["generated_text"] | |