nyasukun commited on
Commit
b9ebd64
·
1 Parent(s): de5e4b0
Files changed (1) hide show
  1. app.py +85 -26
app.py CHANGED
@@ -66,9 +66,22 @@ class LocalModelManager:
66
  self.models = {}
67
  self.tokenizers = {}
68
  self.pipelines = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
- async def load_model(self, model_path: str, task: str = "text-generation"):
71
- """モデルの遅延ロード"""
72
  if model_path not in self.models:
73
  logger.info(f"Loading model: {model_path}")
74
  try:
@@ -102,6 +115,11 @@ class LocalModelManager:
102
  logger.error(f"Error loading model {model_path}: {str(e)}")
103
  raise
104
 
 
 
 
 
 
105
  @spaces.GPU()
106
  def _generate_text_sync(self, pipeline, text: str) -> str:
107
  """同期的なテキスト生成の実行"""
@@ -157,6 +175,12 @@ class ModelManager:
157
  model.model_id,
158
  token=True # これによりHFトークンを使用
159
  )
 
 
 
 
 
 
160
 
161
  async def run_text_generation(self, text: str, selected_types: List[str]) -> List[str]:
162
  """テキスト生成モデルの実行"""
@@ -288,6 +312,7 @@ class ToxicityApp:
288
  def __init__(self):
289
  self.ui = UIComponents()
290
  self.model_manager = ModelManager()
 
291
 
292
  def update_model_visibility(self, selected_types: List[str]) -> List[gr.update]:
293
  """モデルの表示状態を更新"""
@@ -311,40 +336,74 @@ class ToxicityApp:
311
  class_results.extend([""] * (len(CLASSIFICATION_MODELS) - len(class_results)))
312
 
313
  return gen_results + class_results
 
 
 
 
 
 
 
 
 
 
 
 
 
314
 
315
  def create_ui(self):
316
  """UIの作成"""
317
  with gr.Blocks() as demo:
318
- self.ui.create_header()
319
- self.ui.create_input_section()
320
- self.ui.create_filter_section()
321
- self.ui.create_invoke_button()
322
- self.ui.create_model_tabs()
323
-
324
- # イベントハンドラの設定
325
- self.ui.filter_checkboxes.change(
326
- fn=self.update_model_visibility,
327
- inputs=[self.ui.filter_checkboxes],
328
- outputs=[
329
- output["group"]
330
- for outputs in [self.ui.gen_model_outputs, self.ui.class_model_outputs]
331
- for output in outputs
332
- ]
333
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
334
 
335
- self.ui.invoke_button.click(
336
- fn=self.handle_invoke,
337
- inputs=[self.ui.input_text, self.ui.filter_checkboxes],
338
- outputs=[
339
- output["output"]
340
- for outputs in [self.ui.gen_model_outputs, self.ui.class_model_outputs]
341
- for output in outputs
342
- ]
 
 
 
 
 
 
 
343
  )
344
 
345
  return demo
346
 
347
  def main():
 
348
  app = ToxicityApp()
349
  demo = app.create_ui()
350
  demo.launch()
 
66
  self.models = {}
67
  self.tokenizers = {}
68
  self.pipelines = {}
69
+
70
+ def preload_models(self):
71
+ """起動時にすべてのローカルモデルを事前にロード"""
72
+ logger.info("Preloading all local models...")
73
+ for model in TEXT_GENERATION_MODELS:
74
+ if model.type == ModelType.LOCAL and model.model_path:
75
+ self.load_model_sync(model.model_path, "text-generation")
76
+
77
+ for model in CLASSIFICATION_MODELS:
78
+ if model.type == ModelType.LOCAL and model.model_path:
79
+ self.load_model_sync(model.model_path, "text-classification")
80
+
81
+ logger.info("All local models preloaded successfully")
82
 
83
+ def load_model_sync(self, model_path: str, task: str = "text-generation"):
84
+ """モデルの同期ロード"""
85
  if model_path not in self.models:
86
  logger.info(f"Loading model: {model_path}")
87
  try:
 
115
  logger.error(f"Error loading model {model_path}: {str(e)}")
116
  raise
117
 
118
+ async def load_model(self, model_path: str, task: str = "text-generation"):
119
+ """モデルの遅延ロード(バックワードコンパティビリティのために維持)"""
120
+ if model_path not in self.models:
121
+ self.load_model_sync(model_path, task)
122
+
123
  @spaces.GPU()
124
  def _generate_text_sync(self, pipeline, text: str) -> str:
125
  """同期的なテキスト生成の実行"""
 
175
  model.model_id,
176
  token=True # これによりHFトークンを使用
177
  )
178
+
179
+ def preload_models(self):
180
+ """起動時にすべてのモデルを事前にロード"""
181
+ logger.info("Preloading models...")
182
+ self.local_manager.preload_models()
183
+ logger.info("Model preloading complete")
184
 
185
  async def run_text_generation(self, text: str, selected_types: List[str]) -> List[str]:
186
  """テキスト生成モデルの実行"""
 
312
  def __init__(self):
313
  self.ui = UIComponents()
314
  self.model_manager = ModelManager()
315
+ self.models_loaded = False
316
 
317
  def update_model_visibility(self, selected_types: List[str]) -> List[gr.update]:
318
  """モデルの表示状態を更新"""
 
336
  class_results.extend([""] * (len(CLASSIFICATION_MODELS) - len(class_results)))
337
 
338
  return gen_results + class_results
339
+
340
+ def load_models_and_update_ui(self):
341
+ """モデルをロードしUIを更新する"""
342
+ try:
343
+ # モデルのロード
344
+ self.model_manager.preload_models()
345
+ self.models_loaded = True
346
+ logger.info("Models loaded successfully")
347
+ # ロード完了メッセージを返して、UIのロード中表示を非表示にする
348
+ return gr.update(visible=False), gr.update(visible=True)
349
+ except Exception as e:
350
+ logger.error(f"Error loading models: {e}")
351
+ return gr.update(value=f"Error loading models: {e}"), gr.update(visible=False)
352
 
353
  def create_ui(self):
354
  """UIの作成"""
355
  with gr.Blocks() as demo:
356
+ # ロード中コンポーネント
357
+ with gr.Group(visible=True) as loading_group:
358
+ gr.Markdown("""
359
+ # Toxic Eye
360
+
361
+ ### Loading models... This may take a few minutes.
362
+
363
+ The application is initializing and preloading all models.
364
+ Please wait while the models are being loaded...
365
+ """)
366
+
367
+ # メインUIコンポーネント(初期状態では非表示)
368
+ with gr.Group(visible=False) as main_ui_group:
369
+ self.ui.create_header()
370
+ self.ui.create_input_section()
371
+ self.ui.create_filter_section()
372
+ self.ui.create_invoke_button()
373
+ self.ui.create_model_tabs()
374
+
375
+ # イベントハンドラの設定
376
+ self.ui.filter_checkboxes.change(
377
+ fn=self.update_model_visibility,
378
+ inputs=[self.ui.filter_checkboxes],
379
+ outputs=[
380
+ output["group"]
381
+ for outputs in [self.ui.gen_model_outputs, self.ui.class_model_outputs]
382
+ for output in outputs
383
+ ]
384
+ )
385
 
386
+ self.ui.invoke_button.click(
387
+ fn=self.handle_invoke,
388
+ inputs=[self.ui.input_text, self.ui.filter_checkboxes],
389
+ outputs=[
390
+ output["output"]
391
+ for outputs in [self.ui.gen_model_outputs, self.ui.class_model_outputs]
392
+ for output in outputs
393
+ ]
394
+ )
395
+
396
+ # 起動時にモデルロード処理を実行
397
+ demo.load(
398
+ fn=self.load_models_and_update_ui,
399
+ inputs=None,
400
+ outputs=[loading_group, main_ui_group]
401
  )
402
 
403
  return demo
404
 
405
  def main():
406
+ logger.info("Starting Toxic Eye application")
407
  app = ToxicityApp()
408
  demo = app.create_ui()
409
  demo.launch()