- app.py +20 -20
- requirements.txt +0 -1
app.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import gradio as gr
|
2 |
-
from huggingface_hub import
|
3 |
from typing import List, Dict, Optional, Union
|
4 |
import logging
|
5 |
from dataclasses import dataclass
|
@@ -100,7 +100,7 @@ class LocalModelManager:
|
|
100 |
logger.error(f"Error preloading model {model_path}: {str(e)}")
|
101 |
# 続行するが、エラーをログに記録
|
102 |
|
103 |
-
|
104 |
"""モデルが既にロードされているか確認し、なければロード"""
|
105 |
if model_path not in self.pipelines:
|
106 |
logger.info(f"Loading model on demand: {model_path}")
|
@@ -142,10 +142,10 @@ class LocalModelManager:
|
|
142 |
)
|
143 |
return outputs[0]["generated_text"]
|
144 |
|
145 |
-
|
146 |
-
"""
|
147 |
-
if model_path not in self.
|
148 |
-
|
149 |
|
150 |
try:
|
151 |
return self._generate_text_sync(self.pipelines[model_path], text)
|
@@ -159,10 +159,10 @@ class LocalModelManager:
|
|
159 |
result = pipeline(text)
|
160 |
return str(result)
|
161 |
|
162 |
-
|
163 |
-
"""
|
164 |
-
if model_path not in self.
|
165 |
-
|
166 |
|
167 |
try:
|
168 |
return self._classify_text_sync(self.pipelines[model_path], text)
|
@@ -181,7 +181,7 @@ class ModelManager:
|
|
181 |
"""Inference APIクライアントの初期化"""
|
182 |
for model in TEXT_GENERATION_MODELS + CLASSIFICATION_MODELS:
|
183 |
if model.type == ModelType.INFERENCE_API and model.model_id:
|
184 |
-
self.api_clients[model.model_id] =
|
185 |
model.model_id,
|
186 |
token=True # これによりHFトークンを使用
|
187 |
)
|
@@ -206,7 +206,7 @@ class ModelManager:
|
|
206 |
# 事前ロード実行
|
207 |
self.local_manager.preload_models(models_to_preload, tasks)
|
208 |
|
209 |
-
|
210 |
"""テキスト生成モデルの実行"""
|
211 |
results = []
|
212 |
for model in TEXT_GENERATION_MODELS:
|
@@ -214,20 +214,20 @@ class ModelManager:
|
|
214 |
try:
|
215 |
if model.type == ModelType.INFERENCE_API:
|
216 |
logger.info(f"Running API text generation: {model.name}")
|
217 |
-
response =
|
218 |
text, max_new_tokens=100, temperature=0.7
|
219 |
)
|
220 |
results.append(f"{model.name}: {response}")
|
221 |
else:
|
222 |
logger.info(f"Running local text generation: {model.name}")
|
223 |
-
response =
|
224 |
results.append(f"{model.name}: {response}")
|
225 |
except Exception as e:
|
226 |
logger.error(f"Error in {model.name}: {str(e)}")
|
227 |
results.append(f"{model.name}: Error - {str(e)}")
|
228 |
return results
|
229 |
|
230 |
-
|
231 |
"""分類モデルの実行"""
|
232 |
results = []
|
233 |
for model in CLASSIFICATION_MODELS:
|
@@ -235,11 +235,11 @@ class ModelManager:
|
|
235 |
try:
|
236 |
if model.type == ModelType.INFERENCE_API:
|
237 |
logger.info(f"Running API classification: {model.name}")
|
238 |
-
response =
|
239 |
results.append(f"{model.name}: {response}")
|
240 |
else:
|
241 |
logger.info(f"Running local classification: {model.name}")
|
242 |
-
response =
|
243 |
results.append(f"{model.name}: {response}")
|
244 |
except Exception as e:
|
245 |
logger.error(f"Error in {model.name}: {str(e)}")
|
@@ -349,10 +349,10 @@ class ToxicityApp:
|
|
349 |
updates.append(gr.update(visible=visible))
|
350 |
return updates
|
351 |
|
352 |
-
|
353 |
"""Invokeボタンのハンドラ"""
|
354 |
-
gen_results =
|
355 |
-
class_results =
|
356 |
|
357 |
# 結果リストの長さを調整
|
358 |
gen_results.extend([""] * (len(TEXT_GENERATION_MODELS) - len(gen_results)))
|
|
|
1 |
import gradio as gr
|
2 |
+
from huggingface_hub import InferenceClient
|
3 |
from typing import List, Dict, Optional, Union
|
4 |
import logging
|
5 |
from dataclasses import dataclass
|
|
|
100 |
logger.error(f"Error preloading model {model_path}: {str(e)}")
|
101 |
# 続行するが、エラーをログに記録
|
102 |
|
103 |
+
def load_model(self, model_path: str, task: str = "text-generation"):
|
104 |
"""モデルが既にロードされているか確認し、なければロード"""
|
105 |
if model_path not in self.pipelines:
|
106 |
logger.info(f"Loading model on demand: {model_path}")
|
|
|
142 |
)
|
143 |
return outputs[0]["generated_text"]
|
144 |
|
145 |
+
def generate_text(self, model_path: str, text: str) -> str:
|
146 |
+
"""テキスト生成の実行"""
|
147 |
+
if model_path not in self.pipelines:
|
148 |
+
self.load_model(model_path, "text-generation")
|
149 |
|
150 |
try:
|
151 |
return self._generate_text_sync(self.pipelines[model_path], text)
|
|
|
159 |
result = pipeline(text)
|
160 |
return str(result)
|
161 |
|
162 |
+
def classify_text(self, model_path: str, text: str) -> str:
|
163 |
+
"""テキスト分類の実行"""
|
164 |
+
if model_path not in self.pipelines:
|
165 |
+
self.load_model(model_path, "text-classification")
|
166 |
|
167 |
try:
|
168 |
return self._classify_text_sync(self.pipelines[model_path], text)
|
|
|
181 |
"""Inference APIクライアントの初期化"""
|
182 |
for model in TEXT_GENERATION_MODELS + CLASSIFICATION_MODELS:
|
183 |
if model.type == ModelType.INFERENCE_API and model.model_id:
|
184 |
+
self.api_clients[model.model_id] = InferenceClient(
|
185 |
model.model_id,
|
186 |
token=True # これによりHFトークンを使用
|
187 |
)
|
|
|
206 |
# 事前ロード実行
|
207 |
self.local_manager.preload_models(models_to_preload, tasks)
|
208 |
|
209 |
+
def run_text_generation(self, text: str, selected_types: List[str]) -> List[str]:
|
210 |
"""テキスト生成モデルの実行"""
|
211 |
results = []
|
212 |
for model in TEXT_GENERATION_MODELS:
|
|
|
214 |
try:
|
215 |
if model.type == ModelType.INFERENCE_API:
|
216 |
logger.info(f"Running API text generation: {model.name}")
|
217 |
+
response = self.api_clients[model.model_id].text_generation(
|
218 |
text, max_new_tokens=100, temperature=0.7
|
219 |
)
|
220 |
results.append(f"{model.name}: {response}")
|
221 |
else:
|
222 |
logger.info(f"Running local text generation: {model.name}")
|
223 |
+
response = self.local_manager.generate_text(model.model_path, text)
|
224 |
results.append(f"{model.name}: {response}")
|
225 |
except Exception as e:
|
226 |
logger.error(f"Error in {model.name}: {str(e)}")
|
227 |
results.append(f"{model.name}: Error - {str(e)}")
|
228 |
return results
|
229 |
|
230 |
+
def run_classification(self, text: str, selected_types: List[str]) -> List[str]:
|
231 |
"""分類モデルの実行"""
|
232 |
results = []
|
233 |
for model in CLASSIFICATION_MODELS:
|
|
|
235 |
try:
|
236 |
if model.type == ModelType.INFERENCE_API:
|
237 |
logger.info(f"Running API classification: {model.name}")
|
238 |
+
response = self.api_clients[model.model_id].text_classification(text)
|
239 |
results.append(f"{model.name}: {response}")
|
240 |
else:
|
241 |
logger.info(f"Running local classification: {model.name}")
|
242 |
+
response = self.local_manager.classify_text(model.model_path, text)
|
243 |
results.append(f"{model.name}: {response}")
|
244 |
except Exception as e:
|
245 |
logger.error(f"Error in {model.name}: {str(e)}")
|
|
|
349 |
updates.append(gr.update(visible=visible))
|
350 |
return updates
|
351 |
|
352 |
+
def handle_invoke(self, text: str, selected_types: List[str]) -> List[str]:
|
353 |
"""Invokeボタンのハンドラ"""
|
354 |
+
gen_results = self.model_manager.run_text_generation(text, selected_types)
|
355 |
+
class_results = self.model_manager.run_classification(text, selected_types)
|
356 |
|
357 |
# 結果リストの長さを調整
|
358 |
gen_results.extend([""] * (len(TEXT_GENERATION_MODELS) - len(gen_results)))
|
requirements.txt
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
transformers>=4.30.0
|
2 |
torch==2.4.0
|
3 |
accelerate>=0.26.0
|
4 |
-
aiohttp>=3.9.0
|
|
|
1 |
transformers>=4.30.0
|
2 |
torch==2.4.0
|
3 |
accelerate>=0.26.0
|
|