ginipick commited on
Commit
e507147
Β·
verified Β·
1 Parent(s): 7f56bf7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +245 -34
app.py CHANGED
@@ -1,48 +1,259 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
 
 
 
 
 
4
  import spaces
5
 
6
- # εŠ θ½½ζ¨‘εž‹ε’Œεˆ†θ―ε™¨
7
- model_name = "XiaomiMiMo/MiMo-7B-RL"
8
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
9
- model = AutoModelForCausalLM.from_pretrained(
10
- model_name,
11
- torch_dtype=torch.float16,
12
- device_map="auto",
13
- trust_remote_code=True
14
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  @spaces.GPU()
16
  def predict(message, history):
17
- # ζž„ε»ΊθΎ“ε…₯
 
 
 
 
 
 
 
18
  history_text = ""
19
- for human, assistant in history:
20
- history_text += f"Human: {human}\nAssistant: {assistant}\n"
 
 
 
 
21
  prompt = f"{history_text}Human: {message}\nAssistant:"
22
-
23
- # η”Ÿζˆε›žε€
24
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
25
- outputs = model.generate(
26
- **inputs,
27
- max_new_tokens=10000,
28
- do_sample=True,
29
- temperature=0.7,
30
- top_p=0.9,
31
- repetition_penalty=1.1,
32
- pad_token_id=tokenizer.eos_token_id
33
- )
34
- response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
35
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  return response.strip()
37
 
38
- # εˆ›ε»ΊGradioη•Œι’
 
 
 
 
 
 
 
 
 
 
39
  demo = gr.ChatInterface(
40
- predict,
41
- title="MiMo-7B-RL θŠε€©ζœΊε™¨δΊΊ",
42
- description="θΏ™ζ˜―δΈ€δΈͺ基于小米 MiMo-7B-RL ζ¨‘εž‹ηš„θŠε€©ζœΊε™¨δΊΊγ€‚",
43
- examples=["δ½ ε₯½οΌ", "请介绍一下你θ‡ͺε·±", "δ½ θƒ½εšδ»€δΉˆοΌŸ"],
44
- theme=gr.themes.Soft()
 
 
 
 
 
 
45
  )
46
 
 
47
  if __name__ == "__main__":
48
- demo.launch(share=True)
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
2
  import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ import gc
5
+ import os
6
+ import datetime
7
+ import time
8
  import spaces
9
 
10
+ # --- μ„€μ • ---
11
+ MODEL_ID = "XiaomiMiMo/MiMo-7B-RL"
12
+ MAX_NEW_TOKENS = 512
13
+ CPU_THREAD_COUNT = 4 # ν•„μš”μ‹œ 쑰절
14
+
15
+ # --- 선택 사항: CPU μŠ€λ ˆλ“œ μ„€μ • ---
16
+ # torch.set_num_threads(CPU_THREAD_COUNT)
17
+ # os.environ["OMP_NUM_THREADS"] = str(CPU_THREAD_COUNT)
18
+ # os.environ["MKL_NUM_THREADS"] = str(CPU_THREAD_COUNT)
19
+
20
+ print("--- ν™˜κ²½ μ„€μ • ---")
21
+ print(f"PyTorch 버전: {torch.__version__}")
22
+ print(f"μ‹€ν–‰ μž₯치: {torch.device('cuda' if torch.cuda.is_available() else 'cpu')}")
23
+ print(f"Torch μŠ€λ ˆλ“œ: {torch.get_num_threads()}")
24
+
25
+ # --- λͺ¨λΈ 및 ν† ν¬λ‚˜μ΄μ € λ‘œλ”© ---
26
+ print(f"--- λͺ¨λΈ λ‘œλ”© 쀑: {MODEL_ID} ---")
27
+ print("첫 μ‹€ν–‰ μ‹œ λͺ‡ λΆ„ 정도 μ†Œμš”λ  수 μžˆμŠ΅λ‹ˆλ‹€...")
28
+
29
+ model = None
30
+ tokenizer = None
31
+ load_successful = False
32
+ stop_token_ids_list = [] # stop_token_ids_list μ΄ˆκΈ°ν™”
33
+
34
+ try:
35
+ start_load_time = time.time()
36
+ # μžμ›μ— 따라 device_map μ„€μ •
37
+ device_map = "auto" if torch.cuda.is_available() else "cpu"
38
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
39
+
40
+ tokenizer = AutoTokenizer.from_pretrained(
41
+ MODEL_ID,
42
+ trust_remote_code=True
43
+ )
44
+
45
+ model = AutoModelForCausalLM.from_pretrained(
46
+ MODEL_ID,
47
+ torch_dtype=dtype,
48
+ device_map=device_map,
49
+ trust_remote_code=True
50
+ )
51
+
52
+ model.eval()
53
+ load_time = time.time() - start_load_time
54
+ print(f"--- λͺ¨λΈ 및 ν† ν¬λ‚˜μ΄μ € λ‘œλ”© μ™„λ£Œ: {load_time:.2f}초 μ†Œμš” ---")
55
+ load_successful = True
56
+
57
+ # --- 쀑지 토큰 μ„€μ • ---
58
+ stop_token_strings = ["</s>", "<|endoftext|>"]
59
+ temp_stop_ids = [tokenizer.convert_tokens_to_ids(token) for token in stop_token_strings]
60
+
61
+ if tokenizer.eos_token_id is not None and tokenizer.eos_token_id not in temp_stop_ids:
62
+ temp_stop_ids.append(tokenizer.eos_token_id)
63
+ elif tokenizer.eos_token_id is None:
64
+ print("κ²½κ³ : tokenizer.eos_token_idκ°€ Noneμž…λ‹ˆλ‹€. 쀑지 토큰에 μΆ”κ°€ν•  수 μ—†μŠ΅λ‹ˆλ‹€.")
65
+
66
+ stop_token_ids_list = [tid for tid in temp_stop_ids if tid is not None]
67
+
68
+ if not stop_token_ids_list:
69
+ print("κ²½κ³ : 쀑지 토큰 IDλ₯Ό 찾을 수 μ—†μŠ΅λ‹ˆλ‹€. κ°€λŠ₯ν•˜λ©΄ κΈ°λ³Έ EOSλ₯Ό μ‚¬μš©ν•˜κ³ , κ·Έλ ‡μ§€ μ•ŠμœΌλ©΄ 생성이 μ˜¬λ°”λ₯΄κ²Œ μ€‘μ§€λ˜μ§€ μ•Šμ„ 수 μžˆμŠ΅λ‹ˆλ‹€.")
70
+ if tokenizer.eos_token_id is not None:
71
+ stop_token_ids_list = [tokenizer.eos_token_id]
72
+ else:
73
+ print("였λ₯˜: κΈ°λ³Έ EOSλ₯Ό ν¬ν•¨ν•˜μ—¬ 쀑지 토큰을 찾을 수 μ—†μŠ΅λ‹ˆλ‹€. 생성이 λ¬΄ν•œμ • 싀행될 수 μžˆμŠ΅λ‹ˆλ‹€.")
74
+
75
+ print(f"μ‚¬μš©ν•  쀑지 토큰 ID: {stop_token_ids_list}")
76
+
77
+ except Exception as e:
78
+ print(f"!!! λͺ¨λΈ λ‘œλ”© 였λ₯˜: {e}")
79
+ if 'model' in locals() and model is not None: del model
80
+ if 'tokenizer' in locals() and tokenizer is not None: del tokenizer
81
+ gc.collect()
82
+ raise gr.Error(f"λͺ¨λΈ {MODEL_ID} λ‘œλ”©μ— μ‹€νŒ¨ν–ˆμŠ΅λ‹ˆλ‹€. μ• ν”Œλ¦¬μΌ€μ΄μ…˜μ„ μ‹œμž‘ν•  수 μ—†μŠ΅λ‹ˆλ‹€. 였λ₯˜: {e}")
83
+
84
+ # --- μ‹œμŠ€ν…œ ν”„λ‘¬ν”„νŠΈ μ •μ˜ ---
85
+ def get_system_prompt():
86
+ current_date = datetime.datetime.now().strftime("%Y-%m-%d (%A)")
87
+ return (
88
+ f"- AI μ–Έμ–΄λͺ¨λΈμ˜ 이름은 \"MiMo\"이며 XiaomiMiMoμ—μ„œ λ§Œλ“€μ—ˆμŠ΅λ‹ˆλ‹€.\n"
89
+ f"- μ˜€λŠ˜μ€ {current_date}μž…λ‹ˆλ‹€.\n"
90
+ f"- μ‚¬μš©μžμ˜ μ§ˆλ¬Έμ— λŒ€ν•΄ μΉœμ ˆν•˜κ³  μžμ„Έν•˜κ²Œ ν•œκ΅­μ–΄λ‘œ λ‹΅λ³€ν•΄μ•Ό ν•©λ‹ˆλ‹€."
91
+ )
92
+
93
+ # --- μ›œμ—… ν•¨μˆ˜ ---
94
+ def warmup_model():
95
+ if not load_successful or model is None or tokenizer is None:
96
+ print("μ›œμ—… κ±΄λ„ˆλ›°κΈ°: λͺ¨λΈμ΄ μ„±κ³΅μ μœΌλ‘œ λ‘œλ“œλ˜μ§€ μ•Šμ•˜μŠ΅λ‹ˆλ‹€.")
97
+ return
98
+
99
+ print("--- λͺ¨λΈ μ›œμ—… μ‹œμž‘ ---")
100
+ try:
101
+ start_warmup_time = time.time()
102
+ warmup_message = "μ•ˆλ…•ν•˜μ„Έμš”"
103
+
104
+ # λͺ¨λΈμ— λ§žλŠ” ν˜•μ‹μœΌλ‘œ μž…λ ₯ ꡬ성
105
+ system_prompt = get_system_prompt()
106
+
107
+ # MiMo λͺ¨λΈμ˜ ν”„λ‘¬ν”„νŠΈ ν˜•μ‹μ— 맞게 μ‘°μ •
108
+ prompt = f"Human: {warmup_message}\nAssistant:"
109
+
110
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
111
+
112
+ # 쀑지 토큰이 λΉ„μ–΄ μžˆλŠ”μ§€ ν™•μΈν•˜κ³  적절히 처리
113
+ gen_kwargs = {
114
+ "max_new_tokens": 10,
115
+ "pad_token_id": tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.pad_token_id,
116
+ "do_sample": False
117
+ }
118
+
119
+ if stop_token_ids_list:
120
+ gen_kwargs["eos_token_id"] = stop_token_ids_list
121
+ else:
122
+ print("μ›œοΏ½οΏ½ κ²½κ³ : 생성에 μ •μ˜λœ 쀑지 토큰이 μ—†μŠ΅λ‹ˆλ‹€.")
123
+
124
+ with torch.no_grad():
125
+ output_ids = model.generate(**inputs, **gen_kwargs)
126
+
127
+ del inputs
128
+ del output_ids
129
+ gc.collect()
130
+ warmup_time = time.time() - start_warmup_time
131
+ print(f"--- λͺ¨λΈ μ›œμ—… μ™„λ£Œ: {warmup_time:.2f}초 μ†Œμš” ---")
132
+
133
+ except Exception as e:
134
+ print(f"!!! λͺ¨λΈ μ›œμ—… 쀑 였λ₯˜ λ°œμƒ: {e}")
135
+ finally:
136
+ gc.collect()
137
+
138
+ # --- μΆ”λ‘  ν•¨μˆ˜ ---
139
  @spaces.GPU()
140
  def predict(message, history):
141
+ """
142
+ XiaomiMiMo/MiMo-7B-RL λͺ¨λΈμ„ μ‚¬μš©ν•˜μ—¬ 응닡을 μƒμ„±ν•©λ‹ˆλ‹€.
143
+ 'history'λŠ” Gradio 'messages' ν˜•μ‹μ„ κ°€μ •ν•©λ‹ˆλ‹€: List[Dict].
144
+ """
145
+ if model is None or tokenizer is None:
146
+ return "였λ₯˜: λͺ¨λΈμ΄ λ‘œλ“œλ˜μ§€ μ•Šμ•˜μŠ΅λ‹ˆλ‹€."
147
+
148
+ # λŒ€ν™” 기둝 처리
149
  history_text = ""
150
+ if isinstance(history, list):
151
+ for turn in history:
152
+ if isinstance(turn, tuple) and len(turn) == 2:
153
+ history_text += f"Human: {turn[0]}\nAssistant: {turn[1]}\n"
154
+
155
+ # MiMo λͺ¨λΈ μž…λ ₯ ν˜•μ‹μ— 맞게 ν”„λ‘¬ν”„νŠΈ ꡬ성
156
  prompt = f"{history_text}Human: {message}\nAssistant:"
157
+
158
+ inputs = None
159
+ output_ids = None
160
+
161
+ try:
162
+ # μž…λ ₯ μ€€λΉ„
163
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
164
+ input_length = inputs.input_ids.shape[1]
165
+ print(f"\nμž…λ ₯ 토큰 수: {input_length}")
166
+
167
+ except Exception as e:
168
+ print(f"!!! μž…λ ₯ 처리 쀑 였λ₯˜ λ°œμƒ: {e}")
169
+ return f"였λ₯˜: μž…λ ₯ ν˜•μ‹μ„ μ²˜λ¦¬ν•˜λŠ” 쀑 λ¬Έμ œκ°€ λ°œμƒν–ˆμŠ΅λ‹ˆλ‹€. ({e})"
170
+
171
+ try:
172
+ print("응닡 생성 쀑...")
173
+ generation_start_time = time.time()
174
+
175
+ # 생성 인수 μ€€λΉ„, λΉ„μ–΄ μžˆλŠ” stop_token_ids_list 처리
176
+ gen_kwargs = {
177
+ "max_new_tokens": MAX_NEW_TOKENS,
178
+ "pad_token_id": tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.pad_token_id,
179
+ "do_sample": True,
180
+ "temperature": 0.7,
181
+ "top_p": 0.9,
182
+ "repetition_penalty": 1.1
183
+ }
184
+
185
+ if stop_token_ids_list:
186
+ gen_kwargs["eos_token_id"] = stop_token_ids_list
187
+ else:
188
+ print("생성 κ²½κ³ : μ •μ˜λœ 쀑지 토큰이 μ—†μŠ΅λ‹ˆλ‹€.")
189
+
190
+ with torch.no_grad():
191
+ output_ids = model.generate(**inputs, **gen_kwargs)
192
+
193
+ generation_time = time.time() - generation_start_time
194
+ print(f"생성 μ™„λ£Œ: {generation_time:.2f}초 μ†Œμš”.")
195
+
196
+ except Exception as e:
197
+ print(f"!!! λͺ¨λΈ 생성 쀑 였λ₯˜ λ°œμƒ: {e}")
198
+ if inputs is not None: del inputs
199
+ if output_ids is not None: del output_ids
200
+ gc.collect()
201
+ return f"였λ₯˜: 응닡을 μƒμ„±ν•˜λŠ” 쀑 λ¬Έμ œκ°€ λ°œμƒν–ˆμŠ΅λ‹ˆλ‹€. ({e})"
202
+
203
+ # 응닡 λ””μ½”λ”©
204
+ response = "였λ₯˜: 응닡 생성에 μ‹€νŒ¨ν–ˆμŠ΅λ‹ˆλ‹€."
205
+ if output_ids is not None:
206
+ try:
207
+ new_tokens = output_ids[0, input_length:]
208
+ response = tokenizer.decode(new_tokens, skip_special_tokens=True)
209
+ print(f"좜λ ₯ 토큰 수: {len(new_tokens)}")
210
+ del new_tokens
211
+ except Exception as e:
212
+ print(f"!!! 응닡 λ””μ½”λ”© 쀑 였λ₯˜ λ°œμƒ: {e}")
213
+ response = "였λ₯˜: 응닡을 λ””μ½”λ”©ν•˜λŠ” 쀑 λ¬Έμ œκ°€ λ°œμƒν–ˆμŠ΅λ‹ˆλ‹€."
214
+
215
+ # λ©”λͺ¨λ¦¬ 정리
216
+ if inputs is not None: del inputs
217
+ if output_ids is not None: del output_ids
218
+ gc.collect()
219
+ print("λ©”λͺ¨λ¦¬ 정리 μ™„λ£Œ.")
220
+
221
  return response.strip()
222
 
223
+ # --- Gradio μΈν„°νŽ˜μ΄μŠ€ μ„€μ • ---
224
+ print("--- Gradio μΈν„°νŽ˜μ΄μŠ€ μ„€μ • 쀑 ---")
225
+
226
+ examples = [
227
+ ["μ•ˆλ…•ν•˜μ„Έμš”! μžκΈ°μ†Œκ°œ μ’€ ν•΄μ£Όμ„Έμš”."],
228
+ ["인곡지λŠ₯κ³Ό λ¨Έμ‹ λŸ¬λ‹μ˜ 차이점은 λ¬΄μ—‡μΈκ°€μš”?"],
229
+ ["λ”₯λŸ¬λ‹ λͺ¨λΈ ν•™μŠ΅ 과정을 λ‹¨κ³„λ³„λ‘œ μ•Œλ €μ£Όμ„Έμš”."],
230
+ ["μ œμ£Όλ„ μ—¬ν–‰ κ³„νšμ„ μ„Έμš°κ³  μžˆλŠ”λ°, 3λ°• 4일 μΆ”μ²œ μ½”μŠ€ μ’€ μ•Œλ €μ£Όμ„Έμš”."],
231
+ ]
232
+
233
+ # ChatInterfaceλ₯Ό μ‚¬μš©ν•˜μ—¬ 자체 Chatbot μ»΄ν¬λ„ŒνŠΈ 관리
234
  demo = gr.ChatInterface(
235
+ fn=predict,
236
+ title="πŸ€– XiaomiMiMo/MiMo-7B-RL ν•œκ΅­μ–΄ 데λͺ¨",
237
+ description=(
238
+ f"**λͺ¨λΈ:** {MODEL_ID}\n"
239
+ f"**ν™˜κ²½:** {'GPU' if torch.cuda.is_available() else 'CPU'}\n"
240
+ f"**주의:** {'GPUμ—μ„œ μ‹€ν–‰ μ€‘μž…λ‹ˆλ‹€.' if torch.cuda.is_available() else 'CPUμ—μ„œ μ‹€ν–‰λ˜λ―€λ‘œ 응닡 생성에 λ‹€μ†Œ μ‹œκ°„μ΄ 걸릴 수 μžˆμŠ΅λ‹ˆλ‹€.'}\n"
241
+ f"μ΅œλŒ€ 생성 토큰 μˆ˜λŠ” {MAX_NEW_TOKENS}개둜 μ œν•œλ©λ‹ˆλ‹€."
242
+ ),
243
+ examples=examples,
244
+ cache_examples=False,
245
+ theme=gr.themes.Soft(),
246
  )
247
 
248
+ # --- μ• ν”Œλ¦¬μΌ€μ΄μ…˜ μ‹€ν–‰ ---
249
  if __name__ == "__main__":
250
+ if load_successful:
251
+ warmup_model()
252
+ else:
253
+ print("λͺ¨λΈ λ‘œλ”©μ— μ‹€νŒ¨ν•˜μ—¬ μ›œμ—…μ„ κ±΄λ„ˆλœλ‹ˆλ‹€.")
254
+
255
+ print("--- Gradio μ•± μ‹€ν–‰ 쀑 ---")
256
+ demo.queue().launch(
257
+ # share=True # 곡개 링크λ₯Ό μ›ν•˜λ©΄ 주석 ν•΄μ œ
258
+ # server_name="0.0.0.0" # 둜컬 λ„€νŠΈμ›Œν¬ 접근을 μ›ν•˜λ©΄ 주석 ν•΄μ œ
259
+ )