eat2fit / model.py
DurgaDeepak's picture
Create model.py
54619a0 verified
raw
history blame
401 Bytes
from transformers import pipeline
def load_llm():
return pipeline(
"text-generation",
model="mistralai/Mistral-7B-Instruct-v0.1",
device_map="auto",
trust_remote_code=True
)
def get_response(pipe, prompt, max_new_tokens=256):
out = pipe(prompt, max_new_tokens=max_new_tokens, do_sample=True)
return out[0]["generated_text"].split("User:")[-1].strip()