88ggg commited on
Commit
daf7c22
·
verified ·
1 Parent(s): b5d062d

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -0
app.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.system("pip install git+https://github.com/shumingma/transformers.git")
4
+
5
+ import threading
6
+ import torch
7
+ import torch._dynamo
8
+ torch._dynamo.config.suppress_errors = True
9
+
10
+ from transformers import (
11
+ AutoModelForCausalLM,
12
+ AutoTokenizer,
13
+ TextIteratorStreamer,
14
+ )
15
+ import gradio as gr
16
+ import spaces
17
+
18
+ model_id = "microsoft/bitnet-b1.58-2B-4T"
19
+
20
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
21
+ model = AutoModelForCausalLM.from_pretrained(
22
+ model_id,
23
+ torch_dtype=torch.bfloat16,
24
+ device_map="auto"
25
+ )
26
+ print(model.device)
27
+
28
+ @spaces.GPU
29
+ def respond(
30
+ message: str,
31
+ history: list[tuple[str, str]],
32
+ system_message: str,
33
+ max_tokens: int,
34
+ temperature: float,
35
+ top_p: float,
36
+ ):
37
+ """
38
+ Generate a chat response using streaming with TextIteratorStreamer.
39
+
40
+ Args:
41
+ message: User's current message.
42
+ history: List of (user, assistant) tuples from previous turns.
43
+ system_message: Initial system prompt guiding the assistant.
44
+ max_tokens: Maximum number of tokens to generate.
45
+ temperature: Sampling temperature.
46
+ top_p: Nucleus sampling probability.
47
+
48
+ Yields:
49
+ The growing response text as new tokens are generated.
50
+ """
51
+ messages = [{"role": "system", "content": system_message}]
52
+ for user_msg, bot_msg in history:
53
+ if user_msg:
54
+ messages.append({"role": "user", "content": user_msg})
55
+ if bot_msg:
56
+ messages.append({"role": "assistant", "content": bot_msg})
57
+ messages.append({"role": "user", "content": message})
58
+
59
+ prompt = tokenizer.apply_chat_template(
60
+ messages, tokenize=False, add_generation_prompt=True
61
+ )
62
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
63
+
64
+ streamer = TextIteratorStreamer(
65
+ tokenizer, skip_prompt=True, skip_special_tokens=True
66
+ )
67
+ generate_kwargs = dict(
68
+ **inputs,
69
+ streamer=streamer,
70
+ max_new_tokens=max_tokens,
71
+ temperature=temperature,
72
+ top_p=top_p,
73
+ do_sample=True,
74
+ )
75
+ thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
76
+ thread.start()
77
+
78
+ response = ""
79
+ for new_text in streamer:
80
+ response += new_text
81
+ yield response
82
+
83
+ demo = gr.ChatInterface(
84
+ fn=respond,
85
+ title="Bitnet-b1.58-2B-4T Chatbot",
86
+ description="This chat application is powered by Microsoft's SOTA Bitnet-b1.58-2B-4T and designed for natural and fast conversations.",
87
+ examples=[
88
+ [
89
+ "Hello! How are you?",
90
+ "You are a helpful AI assistant for everyday tasks.",
91
+ 512,
92
+ 0.7,
93
+ 0.95,
94
+ ],
95
+ [
96
+ "Can you code a snake game in Python?",
97
+ "You are a helpful AI assistant for coding.",
98
+ 2048,
99
+ 0.7,
100
+ 0.95,
101
+ ],
102
+ ],
103
+ additional_inputs=[
104
+ gr.Textbox(
105
+ value="You are a helpful AI assistant.",
106
+ label="System message"
107
+ ),
108
+ gr.Slider(
109
+ minimum=1,
110
+ maximum=8192,
111
+ value=2048,
112
+ step=1,
113
+ label="Max new tokens"
114
+ ),
115
+ gr.Slider(
116
+ minimum=0.1,
117
+ maximum=4.0,
118
+ value=0.7,
119
+ step=0.1,
120
+ label="Temperature"
121
+ ),
122
+ gr.Slider(
123
+ minimum=0.1,
124
+ maximum=1.0,
125
+ value=0.95,
126
+ step=0.05,
127
+ label="Top-p (nucleus sampling)"
128
+ ),
129
+ ],
130
+ )
131
+
132
+ if __name__ == "__main__":
133
+ demo.launch()