nyasukun commited on
Commit
d75daa2
·
1 Parent(s): aacf53d
Files changed (1) hide show
  1. app.py +136 -14
app.py CHANGED
@@ -6,6 +6,8 @@ from enum import Enum, auto
6
  import torch
7
  from transformers import AutoTokenizer, pipeline
8
  import spaces
 
 
9
 
10
  # ロガーの設定
11
  logging.basicConfig(
@@ -160,32 +162,152 @@ def classify_text_api(model_id, text):
160
  logger.error(f"Error in API classification with {model_id}: {str(e)}")
161
  return f"Error: {str(e)}"
162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  # Invokeボタンのハンドラ
164
  def handle_invoke(text, selected_types):
165
- """Invokeボタンのハンドラ"""
166
- results = []
 
167
 
168
- # テキスト生成モデルの実行
169
- for model in TEXT_GENERATION_MODELS:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  if model["type"] in selected_types:
171
  if model["type"] == LOCAL:
172
- result = generate_text_local(model["model_path"], text)
 
 
173
  else: # api
174
- result = generate_text_api(model["model_id"], text)
175
- results.append(f"{model['name']}: {result}")
176
 
177
- # 分類モデルの実行
178
- for model in CLASSIFICATION_MODELS:
 
179
  if model["type"] in selected_types:
180
  if model["type"] == LOCAL:
181
- result = classify_text_local(model["model_path"], text)
 
 
182
  else: # api
 
 
 
 
 
 
 
 
 
 
183
  result = classify_text_api(model["model_id"], text)
184
- results.append(f"{model['name']}: {result}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
- # 結果リストの長さを調整
187
- while len(results) < len(TEXT_GENERATION_MODELS) + len(CLASSIFICATION_MODELS):
188
- results.append("")
189
 
190
  return results
191
 
 
6
  import torch
7
  from transformers import AutoTokenizer, pipeline
8
  import spaces
9
+ import concurrent.futures
10
+ import time
11
 
12
  # ロガーの設定
13
  logging.basicConfig(
 
162
  logger.error(f"Error in API classification with {model_id}: {str(e)}")
163
  return f"Error: {str(e)}"
164
 
165
+ @spaces.GPU
166
+ def parallel_text_generation(model_paths, texts):
167
+ """複数のローカルモデルを一度のGPU割り当てで実行するための最適化関数"""
168
+ try:
169
+ logger.info(f"Running parallel text generation for {len(model_paths)} models")
170
+ results = {}
171
+
172
+ # 各モデルのパイプラインが既にロードされている前提で実行
173
+ for i, (model_path, text) in enumerate(zip(model_paths, texts)):
174
+ try:
175
+ logger.info(f"Processing model {i+1}/{len(model_paths)}: {model_path}")
176
+ outputs = pipelines[model_path](
177
+ text,
178
+ max_new_tokens=40,
179
+ do_sample=False,
180
+ num_return_sequences=1
181
+ )
182
+ results[model_path] = outputs[0]["generated_text"]
183
+ except Exception as e:
184
+ logger.error(f"Error in text generation with {model_path}: {str(e)}")
185
+ results[model_path] = f"Error: {str(e)}"
186
+
187
+ return results
188
+ except Exception as e:
189
+ logger.error(f"Error in parallel text generation: {str(e)}")
190
+ return {model_path: f"Error: {str(e)}" for model_path in model_paths}
191
+
192
+ @spaces.GPU
193
+ def parallel_text_classification(model_paths, texts):
194
+ """複数のローカル分類モデルを一度のGPU割り当てで実行するための最適化関数"""
195
+ try:
196
+ logger.info(f"Running parallel text classification for {len(model_paths)} models")
197
+ results = {}
198
+
199
+ # 各モデルのパイプラインが既にロードされている前提で実行
200
+ for i, (model_path, text) in enumerate(zip(model_paths, texts)):
201
+ try:
202
+ logger.info(f"Processing classification model {i+1}/{len(model_paths)}: {model_path}")
203
+ result = pipelines[model_path](text)
204
+ results[model_path] = str(result)
205
+ except Exception as e:
206
+ logger.error(f"Error in classification with {model_path}: {str(e)}")
207
+ results[model_path] = f"Error: {str(e)}"
208
+
209
+ return results
210
+ except Exception as e:
211
+ logger.error(f"Error in parallel text classification: {str(e)}")
212
+ return {model_path: f"Error: {str(e)}" for model_path in model_paths}
213
+
214
  # Invokeボタンのハンドラ
215
  def handle_invoke(text, selected_types):
216
+ """Invokeボタンのハンドラ - 並列処理版"""
217
+ start_time = time.time()
218
+ logger.info("Starting parallel model execution")
219
 
220
+ # 結果を格納する配列(順番を保持するため、最初に空の配列を作成)
221
+ results = [""] * (len(TEXT_GENERATION_MODELS) + len(CLASSIFICATION_MODELS))
222
+
223
+ # ローカルの生成モデルを一括処理するための準備
224
+ local_gen_models = []
225
+ local_gen_texts = []
226
+ local_gen_indices = []
227
+
228
+ # ローカルの分類モデルを一括処理するための準備
229
+ local_cls_models = []
230
+ local_cls_texts = []
231
+ local_cls_indices = []
232
+
233
+ # APIモデルとその他のタスク
234
+ api_tasks = []
235
+
236
+ # テキスト生成モデルの分類
237
+ for i, model in enumerate(TEXT_GENERATION_MODELS):
238
  if model["type"] in selected_types:
239
  if model["type"] == LOCAL:
240
+ local_gen_models.append(model["model_path"])
241
+ local_gen_texts.append(text)
242
+ local_gen_indices.append(i)
243
  else: # api
244
+ api_tasks.append((i, model, "gen_api"))
 
245
 
246
+ # 分類モデルの分類
247
+ for i, model in enumerate(CLASSIFICATION_MODELS):
248
+ idx = i + len(TEXT_GENERATION_MODELS)
249
  if model["type"] in selected_types:
250
  if model["type"] == LOCAL:
251
+ local_cls_models.append(model["model_path"])
252
+ local_cls_texts.append(text)
253
+ local_cls_indices.append(idx)
254
  else: # api
255
+ api_tasks.append((idx, model, "cls_api"))
256
+
257
+ # APIタスクを処理する関数
258
+ def process_api_task(task_data):
259
+ idx, model, task_type = task_data
260
+ try:
261
+ if task_type == "gen_api":
262
+ result = generate_text_api(model["model_id"], text)
263
+ return idx, f"{model['name']}: {result}"
264
+ elif task_type == "cls_api":
265
  result = classify_text_api(model["model_id"], text)
266
+ return idx, f"{model['name']}: {result}"
267
+ except Exception as e:
268
+ logger.error(f"Error in {model['name']}: {str(e)}")
269
+ return idx, f"{model['name']}: Error - {str(e)}"
270
+
271
+ # API処理を並列実行
272
+ futures = []
273
+ if api_tasks:
274
+ with concurrent.futures.ThreadPoolExecutor(max_workers=len(api_tasks)) as executor:
275
+ futures = [executor.submit(process_api_task, task) for task in api_tasks]
276
+
277
+ # ローカル生成モデルを並列処理
278
+ if local_gen_models:
279
+ try:
280
+ local_gen_results = parallel_text_generation(local_gen_models, local_gen_texts)
281
+ for model_path, idx in zip(local_gen_models, local_gen_indices):
282
+ model_name = next(m["name"] for m in TEXT_GENERATION_MODELS if m["model_path"] == model_path)
283
+ results[idx] = f"{model_name}: {local_gen_results[model_path]}"
284
+ except Exception as e:
285
+ logger.error(f"Error in parallel text generation: {str(e)}")
286
+ for model_path, idx in zip(local_gen_models, local_gen_indices):
287
+ model_name = next(m["name"] for m in TEXT_GENERATION_MODELS if m["model_path"] == model_path)
288
+ results[idx] = f"{model_name}: Error - {str(e)}"
289
+
290
+ # ローカル分類モデルを並列処理
291
+ if local_cls_models:
292
+ try:
293
+ local_cls_results = parallel_text_classification(local_cls_models, local_cls_texts)
294
+ for model_path, idx in zip(local_cls_models, local_cls_indices):
295
+ model_name = next(m["name"] for m in CLASSIFICATION_MODELS if m["model_path"] == model_path)
296
+ results[idx] = f"{model_name}: {local_cls_results[model_path]}"
297
+ except Exception as e:
298
+ logger.error(f"Error in parallel text classification: {str(e)}")
299
+ for model_path, idx in zip(local_cls_models, local_cls_indices):
300
+ model_name = next(m["name"] for m in CLASSIFICATION_MODELS if m["model_path"] == model_path)
301
+ results[idx] = f"{model_name}: Error - {str(e)}"
302
+
303
+ # APIタスクの結果を収集
304
+ for future in concurrent.futures.as_completed(futures):
305
+ idx, result = future.result()
306
+ results[idx] = result
307
 
308
+ # 実行時間を記録
309
+ elapsed_time = time.time() - start_time
310
+ logger.info(f"Parallel model execution completed in {elapsed_time:.2f} seconds")
311
 
312
  return results
313