laurelix commited on
Commit
3883928
·
verified ·
1 Parent(s): f3b48d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -77
app.py CHANGED
@@ -1,80 +1,35 @@
1
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
2
- import gradio as gr
 
 
 
 
 
3
 
4
  model_id = "GoToCompany/gemma2-9b-cpt-sahabatai-v1-instruct"
5
  tokenizer = AutoTokenizer.from_pretrained(model_id)
6
- model = AutoModelForCausalLM.from_pretrained(model_id)
7
-
8
- pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
9
-
10
- def chat_fn(prompt):
11
- result = pipe(prompt, max_new_tokens=200, temperature=0.7)[0]["generated_text"]
12
- return result
13
-
14
- gr.Interface(fn=chat_fn, inputs="text", outputs="text").launch()
15
-
16
-
17
- # import gradio as gr
18
- # from huggingface_hub import InferenceClient
19
-
20
- # """
21
- # For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
22
- # """
23
- # client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
24
-
25
-
26
- # def respond(
27
- # message,
28
- # history: list[tuple[str, str]],
29
- # system_message,
30
- # max_tokens,
31
- # temperature,
32
- # top_p,
33
- # ):
34
- # messages = [{"role": "system", "content": system_message}]
35
-
36
- # for val in history:
37
- # if val[0]:
38
- # messages.append({"role": "user", "content": val[0]})
39
- # if val[1]:
40
- # messages.append({"role": "assistant", "content": val[1]})
41
-
42
- # messages.append({"role": "user", "content": message})
43
-
44
- # response = ""
45
-
46
- # for message in client.chat_completion(
47
- # messages,
48
- # max_tokens=max_tokens,
49
- # stream=True,
50
- # temperature=temperature,
51
- # top_p=top_p,
52
- # ):
53
- # token = message.choices[0].delta.content
54
-
55
- # response += token
56
- # yield response
57
-
58
-
59
- # """
60
- # For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
61
- # """
62
- # demo = gr.ChatInterface(
63
- # respond,
64
- # additional_inputs=[
65
- # gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
66
- # gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
67
- # gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
68
- # gr.Slider(
69
- # minimum=0.1,
70
- # maximum=1.0,
71
- # value=0.95,
72
- # step=0.05,
73
- # label="Top-p (nucleus sampling)",
74
- # ),
75
- # ],
76
- # )
77
-
78
-
79
- # if __name__ == "__main__":
80
- # demo.launch()
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ import torch
5
+ import uvicorn
6
+
7
+ app = FastAPI()
8
 
9
  model_id = "GoToCompany/gemma2-9b-cpt-sahabatai-v1-instruct"
10
  tokenizer = AutoTokenizer.from_pretrained(model_id)
11
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
12
+
13
+ class ChatRequest(BaseModel):
14
+ prompt: str
15
+ max_new_tokens: int = 256
16
+ temperature: float = 0.7
17
+ top_p: float = 0.95
18
+
19
+ @app.post("/chat")
20
+ async def chat(request: ChatRequest):
21
+ inputs = tokenizer(request.prompt, return_tensors="pt").to(model.device)
22
+ outputs = model.generate(
23
+ **inputs,
24
+ max_new_tokens=request.max_new_tokens,
25
+ temperature=request.temperature,
26
+ top_p=request.top_p,
27
+ do_sample=True,
28
+ pad_token_id=tokenizer.eos_token_id,
29
+ )
30
+ result = tokenizer.decode(outputs[0], skip_special_tokens=True)
31
+ return {"response": result}
32
+
33
+ # This will only run locally or in Spaces, not if you import this module
34
+ if __name__ == "__main__":
35
+ uvicorn.run(app, host="0.0.0.0", port=7860)