Yuchan5386 commited on
Commit
4675638
Β·
verified Β·
1 Parent(s): 8266e1d

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +3 -27
api.py CHANGED
@@ -146,27 +146,6 @@ _ = model(dummy_input) # λͺ¨λΈμ΄ λΉŒλ“œλ¨
146
  model.load_weights("Flexi.weights.h5")
147
  print("λͺ¨λΈ κ°€μ€‘μΉ˜ λ‘œλ“œ μ™„λ£Œ!")
148
 
149
-
150
- def is_greedy_response_acceptable(text):
151
- text = text.strip()
152
-
153
- # λ„ˆλ¬΄ 짧은 λ¬Έμž₯ κ±°λ₯΄κΈ°
154
- if len(text) < 5:
155
- return False
156
-
157
- # 단어 수 λ„ˆλ¬΄ 적은 것도 거름
158
- if len(text.split()) < 3:
159
- return False
160
-
161
- # γ…‹γ…‹γ…‹ 같은 자λͺ¨ μ—°μ†λ§Œ 있으면 거름 (단, 'γ…‹γ…‹' ν¬ν•¨λ˜λ©΄ ν—ˆμš©)
162
- if re.search(r'[γ„±-γ…Žγ…-γ…£]{3,}', text) and 'γ…‹γ…‹' not in text:
163
- return False
164
-
165
- # λ¬Έμž₯ 끝이 μ–΄μƒ‰ν•œ 경우 (λ‹€/μš”/μ£  λ“± 일반적 ν˜•νƒœλ‘œ λλ‚˜μ§€ μ•ŠμœΌλ©΄ 거름)
166
- if not re.search(r'(λ‹€|μš”|μ£ |λ‹€\.|μš”\.|μ£ \.|λ‹€!|μš”!|μ£ !|\!|\?|\.)$', text):
167
- return False
168
-
169
- return True
170
 
171
  def generate_text_sample(model, prompt, max_len=100, max_gen=98,
172
  temperature=0.8, top_k=55, top_p=0.95, min_len=12):
@@ -197,10 +176,9 @@ def generate_text_sample(model, prompt, max_len=100, max_gen=98,
197
  sorted_indices = np.argsort(probs)[::-1]
198
  sorted_probs = probs[sorted_indices]
199
  cumulative_probs = np.cumsum(sorted_probs)
200
- # λˆ„μ  ν™•λ₯ μ΄ top_p μ΄ˆκ³Όν•˜λŠ” 토큰듀은 제거
201
  cutoff_index = np.searchsorted(cumulative_probs, top_p, side='right')
202
  probs_to_keep = sorted_indices[:cutoff_index+1]
203
-
204
  mask = np.ones_like(probs, dtype=bool)
205
  mask[probs_to_keep] = False
206
  probs[mask] = 0
@@ -217,16 +195,14 @@ def generate_text_sample(model, prompt, max_len=100, max_gen=98,
217
  decoded = decoded.strip()
218
 
219
  if len(generated) >= min_len and (next_token == end_id or decoded.endswith(('μš”', 'λ‹€', '.', '!', '?'))):
220
- if is_greedy_response_acceptable(decoded):
221
- return decoded
222
- else:
223
- continue
224
 
225
  decoded = sp.decode(generated)
226
  for t in ["<start>", "<sep>", "<end>"]:
227
  decoded = decoded.replace(t, "")
228
  return decoded.strip()
229
 
 
230
  from sklearn.feature_extraction.text import TfidfVectorizer
231
  from sklearn.decomposition import TruncatedSVD
232
  from sklearn.metrics.pairwise import cosine_similarity
 
146
  model.load_weights("Flexi.weights.h5")
147
  print("λͺ¨λΈ κ°€μ€‘μΉ˜ λ‘œλ“œ μ™„λ£Œ!")
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
  def generate_text_sample(model, prompt, max_len=100, max_gen=98,
151
  temperature=0.8, top_k=55, top_p=0.95, min_len=12):
 
176
  sorted_indices = np.argsort(probs)[::-1]
177
  sorted_probs = probs[sorted_indices]
178
  cumulative_probs = np.cumsum(sorted_probs)
 
179
  cutoff_index = np.searchsorted(cumulative_probs, top_p, side='right')
180
  probs_to_keep = sorted_indices[:cutoff_index+1]
181
+
182
  mask = np.ones_like(probs, dtype=bool)
183
  mask[probs_to_keep] = False
184
  probs[mask] = 0
 
195
  decoded = decoded.strip()
196
 
197
  if len(generated) >= min_len and (next_token == end_id or decoded.endswith(('μš”', 'λ‹€', '.', '!', '?'))):
198
+ return decoded
 
 
 
199
 
200
  decoded = sp.decode(generated)
201
  for t in ["<start>", "<sep>", "<end>"]:
202
  decoded = decoded.replace(t, "")
203
  return decoded.strip()
204
 
205
+
206
  from sklearn.feature_extraction.text import TfidfVectorizer
207
  from sklearn.decomposition import TruncatedSVD
208
  from sklearn.metrics.pairwise import cosine_similarity