Yuchan5386 commited on
Commit
7d6a990
·
verified ·
1 Parent(s): bbe88ff

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +21 -7
api.py CHANGED
@@ -124,14 +124,16 @@ _ = model(dummy_input)
124
  model.load_weights("InteractGPT.weights.h5")
125
  print("모델 가중치 로드 완료!")
126
 
127
- async def generate_text_mirostat_top_p(model, prompt, max_len=100, max_gen=98,
 
128
  temperature=1.0, min_len=20,
129
- repetition_penalty=1.2, eta=0.1, m=100, p=0.9):
130
  model_input = text_to_ids(f"<start> {prompt} <sep>")
131
  model_input = model_input[:max_len]
132
  generated = list(model_input)
133
 
134
  tau = 5.0 # 초기 목표 surprise
 
135
 
136
  for step in range(max_gen):
137
  pad_length = max(0, max_len - len(generated))
@@ -177,18 +179,30 @@ async def generate_text_mirostat_top_p(model, prompt, max_len=100, max_gen=98,
177
 
178
  final_token = np.random.choice(filtered_indices, p=filtered_probs)
179
 
180
- # 특수 토큰 처리
181
  if final_token == end_id:
 
 
 
 
 
 
182
  break
 
183
  if final_token in [start_id, pad_id] or sp.id_to_piece(final_token) == "<sep>":
184
  continue
185
 
186
- # 정상 토큰만 추가 및 yield
187
  generated.append(int(final_token))
188
- decoded = sp.decode([final_token])
189
- yield decoded
 
 
 
 
 
 
 
190
 
191
  @app.get("/generate")
192
  async def generate(request: Request):
193
  prompt = request.query_params.get("prompt", "안녕하세요")
194
- return StreamingResponse(generate_text_stream(prompt), media_type="text/plain")
 
124
  model.load_weights("InteractGPT.weights.h5")
125
  print("모델 가중치 로드 완료!")
126
 
127
+
128
+ async def generate_text_mirostat_top_p_with_buffer(model, prompt, max_len=100, max_gen=98,
129
  temperature=1.0, min_len=20,
130
+ repetition_penalty=1.2, eta=0.1, m=100, p=0.9, buffer_size=3):
131
  model_input = text_to_ids(f"<start> {prompt} <sep>")
132
  model_input = model_input[:max_len]
133
  generated = list(model_input)
134
 
135
  tau = 5.0 # 초기 목표 surprise
136
+ buffer_tokens = []
137
 
138
  for step in range(max_gen):
139
  pad_length = max(0, max_len - len(generated))
 
179
 
180
  final_token = np.random.choice(filtered_indices, p=filtered_probs)
181
 
 
182
  if final_token == end_id:
183
+ # 버퍼에 남은 거 다 출력
184
+ if buffer_tokens:
185
+ decoded = sp.decode(buffer_tokens)
186
+ for token in ["<start>", "<sep>", "<end>"]:
187
+ decoded = decoded.replace(token, "")
188
+ yield decoded.strip()
189
  break
190
+
191
  if final_token in [start_id, pad_id] or sp.id_to_piece(final_token) == "<sep>":
192
  continue
193
 
 
194
  generated.append(int(final_token))
195
+ buffer_tokens.append(final_token)
196
+
197
+ if len(buffer_tokens) >= buffer_size or sp.id_to_piece(final_token).endswith("▁"):
198
+ # 띄어쓰기 있는 토큰 나오거나 버퍼 꽉 찼으면 출력
199
+ decoded = sp.decode(buffer_tokens)
200
+ for token in ["<start>", "<sep>", "<end>"]:
201
+ decoded = decoded.replace(token, "")
202
+ yield decoded.strip()
203
+ buffer_tokens = []
204
 
205
  @app.get("/generate")
206
  async def generate(request: Request):
207
  prompt = request.query_params.get("prompt", "안녕하세요")
208
+ return StreamingResponse(generate_text_mirostat_top_p_with_buffer(prompt), media_type="text/plain")