howtomakepplragequit commited on
Commit
0f9dc03
·
verified ·
1 Parent(s): b592524

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +24 -26
main.py CHANGED
@@ -1,26 +1,24 @@
1
- import os
2
- os.environ["HF_HOME"] = "/tmp"
3
-
4
- from transformers import AutoTokenizer, AutoModelForCausalLM
5
- from peft import PeftModel
6
- from fastapi import FastAPI
7
- from pydantic import BaseModel
8
-
9
- model_name = "microsoft/phi-2"
10
- adapter_path = "howtomakepplragequit/phi2-lora-instruct"
11
-
12
- tokenizer = AutoTokenizer.from_pretrained(model_name)
13
- base_model = AutoModelForCausalLM.from_pretrained(model_name)
14
- model = PeftModel.from_pretrained(base_model, adapter_path)
15
-
16
- app = FastAPI()
17
-
18
- class Prompt(BaseModel):
19
- input: str
20
-
21
- @app.post("/chat")
22
- def chat(prompt: Prompt):
23
- inputs = tokenizer(prompt.input, return_tensors="pt")
24
- output = model.generate(**inputs, max_new_tokens=50)
25
- response = tokenizer.decode(output[0], skip_special_tokens=True)
26
- return {"response": response}
 
1
+ from fastapi import FastAPI, Request
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ from peft import PeftModel, PeftConfig
4
+ import torch
5
+
6
+ app = FastAPI()
7
+
8
+ model_name = "microsoft/phi-2"
9
+ peft_model_id = "howtomakepplragequit/phi2-lora-instruct"
10
+
11
+ # Load tokenizer and model with LoRA
12
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
13
+ base_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
14
+ model = PeftModel.from_pretrained(base_model, peft_model_id)
15
+ model.eval()
16
+
17
+ @app.post("/generate")
18
+ async def generate(request: Request):
19
+ data = await request.json()
20
+ prompt = data.get("prompt", "")
21
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
22
+ outputs = model.generate(**inputs, max_new_tokens=100)
23
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
24
+ return {"response": response}