nyasukun commited on
Commit
b711b66
·
1 Parent(s): 8ea290d
Files changed (1) hide show
  1. app.py +14 -136
app.py CHANGED
@@ -6,8 +6,6 @@ from enum import Enum, auto
6
  import torch
7
  from transformers import AutoTokenizer, pipeline
8
  import spaces
9
- import concurrent.futures
10
- import time
11
 
12
  # ロガーの設定
13
  logging.basicConfig(
@@ -162,152 +160,32 @@ def classify_text_api(model_id, text):
162
  logger.error(f"Error in API classification with {model_id}: {str(e)}")
163
  return f"Error: {str(e)}"
164
 
165
- @spaces.GPU
166
- def parallel_text_generation(model_paths, texts):
167
- """複数のローカルモデルを一度のGPU割り当てで実行するための最適化関数"""
168
- try:
169
- logger.info(f"Running parallel text generation for {len(model_paths)} models")
170
- results = {}
171
-
172
- # 各モデルのパイプラインが既にロードされている前提で実行
173
- for i, (model_path, text) in enumerate(zip(model_paths, texts)):
174
- try:
175
- logger.info(f"Processing model {i+1}/{len(model_paths)}: {model_path}")
176
- outputs = pipelines[model_path](
177
- text,
178
- max_new_tokens=40,
179
- do_sample=False,
180
- num_return_sequences=1
181
- )
182
- results[model_path] = outputs[0]["generated_text"]
183
- except Exception as e:
184
- logger.error(f"Error in text generation with {model_path}: {str(e)}")
185
- results[model_path] = f"Error: {str(e)}"
186
-
187
- return results
188
- except Exception as e:
189
- logger.error(f"Error in parallel text generation: {str(e)}")
190
- return {model_path: f"Error: {str(e)}" for model_path in model_paths}
191
-
192
- @spaces.GPU
193
- def parallel_text_classification(model_paths, texts):
194
- """複数のローカル分類モデルを一度のGPU割り当てで実行するための最適化関数"""
195
- try:
196
- logger.info(f"Running parallel text classification for {len(model_paths)} models")
197
- results = {}
198
-
199
- # 各モデルのパイプラインが既にロードされている前提で実行
200
- for i, (model_path, text) in enumerate(zip(model_paths, texts)):
201
- try:
202
- logger.info(f"Processing classification model {i+1}/{len(model_paths)}: {model_path}")
203
- result = pipelines[model_path](text)
204
- results[model_path] = str(result)
205
- except Exception as e:
206
- logger.error(f"Error in classification with {model_path}: {str(e)}")
207
- results[model_path] = f"Error: {str(e)}"
208
-
209
- return results
210
- except Exception as e:
211
- logger.error(f"Error in parallel text classification: {str(e)}")
212
- return {model_path: f"Error: {str(e)}" for model_path in model_paths}
213
-
214
  # Invokeボタンのハンドラ
215
  def handle_invoke(text, selected_types):
216
- """Invokeボタンのハンドラ - 並列処理版"""
217
- start_time = time.time()
218
- logger.info("Starting parallel model execution")
219
-
220
- # 結果を格納する配列(順番を保持するため、最初に空の配列を作成)
221
- results = [""] * (len(TEXT_GENERATION_MODELS) + len(CLASSIFICATION_MODELS))
222
-
223
- # ローカルの生成モデルを一括処理するための準備
224
- local_gen_models = []
225
- local_gen_texts = []
226
- local_gen_indices = []
227
 
228
- # ローカルの分類モデルを一括処理するための準備
229
- local_cls_models = []
230
- local_cls_texts = []
231
- local_cls_indices = []
232
-
233
- # APIモデルとその他のタスク
234
- api_tasks = []
235
-
236
- # テキスト生成モデルの分類
237
- for i, model in enumerate(TEXT_GENERATION_MODELS):
238
  if model["type"] in selected_types:
239
  if model["type"] == LOCAL:
240
- local_gen_models.append(model["model_path"])
241
- local_gen_texts.append(text)
242
- local_gen_indices.append(i)
243
  else: # api
244
- api_tasks.append((i, model, "gen_api"))
 
245
 
246
- # 分類モデルの分類
247
- for i, model in enumerate(CLASSIFICATION_MODELS):
248
- idx = i + len(TEXT_GENERATION_MODELS)
249
  if model["type"] in selected_types:
250
  if model["type"] == LOCAL:
251
- local_cls_models.append(model["model_path"])
252
- local_cls_texts.append(text)
253
- local_cls_indices.append(idx)
254
  else: # api
255
- api_tasks.append((idx, model, "cls_api"))
256
-
257
- # APIタスクを処理する関数
258
- def process_api_task(task_data):
259
- idx, model, task_type = task_data
260
- try:
261
- if task_type == "gen_api":
262
- result = generate_text_api(model["model_id"], text)
263
- return idx, f"{model['name']}: {result}"
264
- elif task_type == "cls_api":
265
  result = classify_text_api(model["model_id"], text)
266
- return idx, f"{model['name']}: {result}"
267
- except Exception as e:
268
- logger.error(f"Error in {model['name']}: {str(e)}")
269
- return idx, f"{model['name']}: Error - {str(e)}"
270
-
271
- # API処理を並列実行
272
- futures = []
273
- if api_tasks:
274
- with concurrent.futures.ThreadPoolExecutor(max_workers=len(api_tasks)) as executor:
275
- futures = [executor.submit(process_api_task, task) for task in api_tasks]
276
-
277
- # ローカル生成モデルを並列処理
278
- if local_gen_models:
279
- try:
280
- local_gen_results = parallel_text_generation(local_gen_models, local_gen_texts)
281
- for model_path, idx in zip(local_gen_models, local_gen_indices):
282
- model_name = next(m["name"] for m in TEXT_GENERATION_MODELS if m["model_path"] == model_path)
283
- results[idx] = f"{model_name}: {local_gen_results[model_path]}"
284
- except Exception as e:
285
- logger.error(f"Error in parallel text generation: {str(e)}")
286
- for model_path, idx in zip(local_gen_models, local_gen_indices):
287
- model_name = next(m["name"] for m in TEXT_GENERATION_MODELS if m["model_path"] == model_path)
288
- results[idx] = f"{model_name}: Error - {str(e)}"
289
-
290
- # ローカル分類モデルを並列処理
291
- if local_cls_models:
292
- try:
293
- local_cls_results = parallel_text_classification(local_cls_models, local_cls_texts)
294
- for model_path, idx in zip(local_cls_models, local_cls_indices):
295
- model_name = next(m["name"] for m in CLASSIFICATION_MODELS if m["model_path"] == model_path)
296
- results[idx] = f"{model_name}: {local_cls_results[model_path]}"
297
- except Exception as e:
298
- logger.error(f"Error in parallel text classification: {str(e)}")
299
- for model_path, idx in zip(local_cls_models, local_cls_indices):
300
- model_name = next(m["name"] for m in CLASSIFICATION_MODELS if m["model_path"] == model_path)
301
- results[idx] = f"{model_name}: Error - {str(e)}"
302
-
303
- # APIタスクの結果を収集
304
- for future in concurrent.futures.as_completed(futures):
305
- idx, result = future.result()
306
- results[idx] = result
307
 
308
- # 実行時間を記録
309
- elapsed_time = time.time() - start_time
310
- logger.info(f"Parallel model execution completed in {elapsed_time:.2f} seconds")
311
 
312
  return results
313
 
 
6
  import torch
7
  from transformers import AutoTokenizer, pipeline
8
  import spaces
 
 
9
 
10
  # ロガーの設定
11
  logging.basicConfig(
 
160
  logger.error(f"Error in API classification with {model_id}: {str(e)}")
161
  return f"Error: {str(e)}"
162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  # Invokeボタンのハンドラ
164
  def handle_invoke(text, selected_types):
165
+ """Invokeボタンのハンドラ"""
166
+ results = []
 
 
 
 
 
 
 
 
 
167
 
168
+ # テキスト生成モデルの実行
169
+ for model in TEXT_GENERATION_MODELS:
 
 
 
 
 
 
 
 
170
  if model["type"] in selected_types:
171
  if model["type"] == LOCAL:
172
+ result = generate_text_local(model["model_path"], text)
 
 
173
  else: # api
174
+ result = generate_text_api(model["model_id"], text)
175
+ results.append(f"{model['name']}: {result}")
176
 
177
+ # 分類モデルの実行
178
+ for model in CLASSIFICATION_MODELS:
 
179
  if model["type"] in selected_types:
180
  if model["type"] == LOCAL:
181
+ result = classify_text_local(model["model_path"], text)
 
 
182
  else: # api
 
 
 
 
 
 
 
 
 
 
183
  result = classify_text_api(model["model_id"], text)
184
+ results.append(f"{model['name']}: {result}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
+ # 結果リストの長さを調整
187
+ while len(results) < len(TEXT_GENERATION_MODELS) + len(CLASSIFICATION_MODELS):
188
+ results.append("")
189
 
190
  return results
191