app.py
CHANGED
@@ -66,9 +66,22 @@ class LocalModelManager:
|
|
66 |
self.models = {}
|
67 |
self.tokenizers = {}
|
68 |
self.pipelines = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
-
|
71 |
-
"""
|
72 |
if model_path not in self.models:
|
73 |
logger.info(f"Loading model: {model_path}")
|
74 |
try:
|
@@ -102,6 +115,11 @@ class LocalModelManager:
|
|
102 |
logger.error(f"Error loading model {model_path}: {str(e)}")
|
103 |
raise
|
104 |
|
|
|
|
|
|
|
|
|
|
|
105 |
@spaces.GPU()
|
106 |
def _generate_text_sync(self, pipeline, text: str) -> str:
|
107 |
"""同期的なテキスト生成の実行"""
|
@@ -157,6 +175,12 @@ class ModelManager:
|
|
157 |
model.model_id,
|
158 |
token=True # これによりHFトークンを使用
|
159 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
|
161 |
async def run_text_generation(self, text: str, selected_types: List[str]) -> List[str]:
|
162 |
"""テキスト生成モデルの実行"""
|
@@ -288,6 +312,7 @@ class ToxicityApp:
|
|
288 |
def __init__(self):
|
289 |
self.ui = UIComponents()
|
290 |
self.model_manager = ModelManager()
|
|
|
291 |
|
292 |
def update_model_visibility(self, selected_types: List[str]) -> List[gr.update]:
|
293 |
"""モデルの表示状態を更新"""
|
@@ -311,40 +336,74 @@ class ToxicityApp:
|
|
311 |
class_results.extend([""] * (len(CLASSIFICATION_MODELS) - len(class_results)))
|
312 |
|
313 |
return gen_results + class_results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
314 |
|
315 |
def create_ui(self):
|
316 |
"""UIの作成"""
|
317 |
with gr.Blocks() as demo:
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
334 |
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
343 |
)
|
344 |
|
345 |
return demo
|
346 |
|
347 |
def main():
|
|
|
348 |
app = ToxicityApp()
|
349 |
demo = app.create_ui()
|
350 |
demo.launch()
|
|
|
66 |
self.models = {}
|
67 |
self.tokenizers = {}
|
68 |
self.pipelines = {}
|
69 |
+
|
70 |
+
def preload_models(self):
|
71 |
+
"""起動時にすべてのローカルモデルを事前にロード"""
|
72 |
+
logger.info("Preloading all local models...")
|
73 |
+
for model in TEXT_GENERATION_MODELS:
|
74 |
+
if model.type == ModelType.LOCAL and model.model_path:
|
75 |
+
self.load_model_sync(model.model_path, "text-generation")
|
76 |
+
|
77 |
+
for model in CLASSIFICATION_MODELS:
|
78 |
+
if model.type == ModelType.LOCAL and model.model_path:
|
79 |
+
self.load_model_sync(model.model_path, "text-classification")
|
80 |
+
|
81 |
+
logger.info("All local models preloaded successfully")
|
82 |
|
83 |
+
def load_model_sync(self, model_path: str, task: str = "text-generation"):
|
84 |
+
"""モデルの同期ロード"""
|
85 |
if model_path not in self.models:
|
86 |
logger.info(f"Loading model: {model_path}")
|
87 |
try:
|
|
|
115 |
logger.error(f"Error loading model {model_path}: {str(e)}")
|
116 |
raise
|
117 |
|
118 |
+
async def load_model(self, model_path: str, task: str = "text-generation"):
|
119 |
+
"""モデルの遅延ロード(バックワードコンパティビリティのために維持)"""
|
120 |
+
if model_path not in self.models:
|
121 |
+
self.load_model_sync(model_path, task)
|
122 |
+
|
123 |
@spaces.GPU()
|
124 |
def _generate_text_sync(self, pipeline, text: str) -> str:
|
125 |
"""同期的なテキスト生成の実行"""
|
|
|
175 |
model.model_id,
|
176 |
token=True # これによりHFトークンを使用
|
177 |
)
|
178 |
+
|
179 |
+
def preload_models(self):
|
180 |
+
"""起動時にすべてのモデルを事前にロード"""
|
181 |
+
logger.info("Preloading models...")
|
182 |
+
self.local_manager.preload_models()
|
183 |
+
logger.info("Model preloading complete")
|
184 |
|
185 |
async def run_text_generation(self, text: str, selected_types: List[str]) -> List[str]:
|
186 |
"""テキスト生成モデルの実行"""
|
|
|
312 |
def __init__(self):
|
313 |
self.ui = UIComponents()
|
314 |
self.model_manager = ModelManager()
|
315 |
+
self.models_loaded = False
|
316 |
|
317 |
def update_model_visibility(self, selected_types: List[str]) -> List[gr.update]:
|
318 |
"""モデルの表示状態を更新"""
|
|
|
336 |
class_results.extend([""] * (len(CLASSIFICATION_MODELS) - len(class_results)))
|
337 |
|
338 |
return gen_results + class_results
|
339 |
+
|
340 |
+
def load_models_and_update_ui(self):
|
341 |
+
"""モデルをロードしUIを更新する"""
|
342 |
+
try:
|
343 |
+
# モデルのロード
|
344 |
+
self.model_manager.preload_models()
|
345 |
+
self.models_loaded = True
|
346 |
+
logger.info("Models loaded successfully")
|
347 |
+
# ロード完了メッセージを返して、UIのロード中表示を非表示にする
|
348 |
+
return gr.update(visible=False), gr.update(visible=True)
|
349 |
+
except Exception as e:
|
350 |
+
logger.error(f"Error loading models: {e}")
|
351 |
+
return gr.update(value=f"Error loading models: {e}"), gr.update(visible=False)
|
352 |
|
353 |
def create_ui(self):
|
354 |
"""UIの作成"""
|
355 |
with gr.Blocks() as demo:
|
356 |
+
# ロード中コンポーネント
|
357 |
+
with gr.Group(visible=True) as loading_group:
|
358 |
+
gr.Markdown("""
|
359 |
+
# Toxic Eye
|
360 |
+
|
361 |
+
### Loading models... This may take a few minutes.
|
362 |
+
|
363 |
+
The application is initializing and preloading all models.
|
364 |
+
Please wait while the models are being loaded...
|
365 |
+
""")
|
366 |
+
|
367 |
+
# メインUIコンポーネント(初期状態では非表示)
|
368 |
+
with gr.Group(visible=False) as main_ui_group:
|
369 |
+
self.ui.create_header()
|
370 |
+
self.ui.create_input_section()
|
371 |
+
self.ui.create_filter_section()
|
372 |
+
self.ui.create_invoke_button()
|
373 |
+
self.ui.create_model_tabs()
|
374 |
+
|
375 |
+
# イベントハンドラの設定
|
376 |
+
self.ui.filter_checkboxes.change(
|
377 |
+
fn=self.update_model_visibility,
|
378 |
+
inputs=[self.ui.filter_checkboxes],
|
379 |
+
outputs=[
|
380 |
+
output["group"]
|
381 |
+
for outputs in [self.ui.gen_model_outputs, self.ui.class_model_outputs]
|
382 |
+
for output in outputs
|
383 |
+
]
|
384 |
+
)
|
385 |
|
386 |
+
self.ui.invoke_button.click(
|
387 |
+
fn=self.handle_invoke,
|
388 |
+
inputs=[self.ui.input_text, self.ui.filter_checkboxes],
|
389 |
+
outputs=[
|
390 |
+
output["output"]
|
391 |
+
for outputs in [self.ui.gen_model_outputs, self.ui.class_model_outputs]
|
392 |
+
for output in outputs
|
393 |
+
]
|
394 |
+
)
|
395 |
+
|
396 |
+
# 起動時にモデルロード処理を実行
|
397 |
+
demo.load(
|
398 |
+
fn=self.load_models_and_update_ui,
|
399 |
+
inputs=None,
|
400 |
+
outputs=[loading_group, main_ui_group]
|
401 |
)
|
402 |
|
403 |
return demo
|
404 |
|
405 |
def main():
|
406 |
+
logger.info("Starting Toxic Eye application")
|
407 |
app = ToxicityApp()
|
408 |
demo = app.create_ui()
|
409 |
demo.launch()
|