nyasukun commited on
Commit
9975206
·
1 Parent(s): 0a54f6b
Files changed (2) hide show
  1. app.py +20 -20
  2. requirements.txt +0 -1
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- from huggingface_hub import AsyncInferenceClient
3
  from typing import List, Dict, Optional, Union
4
  import logging
5
  from dataclasses import dataclass
@@ -100,7 +100,7 @@ class LocalModelManager:
100
  logger.error(f"Error preloading model {model_path}: {str(e)}")
101
  # 続行するが、エラーをログに記録
102
 
103
- async def load_model(self, model_path: str, task: str = "text-generation"):
104
  """モデルが既にロードされているか確認し、なければロード"""
105
  if model_path not in self.pipelines:
106
  logger.info(f"Loading model on demand: {model_path}")
@@ -142,10 +142,10 @@ class LocalModelManager:
142
  )
143
  return outputs[0]["generated_text"]
144
 
145
- async def generate_text(self, model_path: str, text: str) -> str:
146
- """テキスト生成の実行(非同期ラッパー)"""
147
- if model_path not in self.models:
148
- await self.load_model(model_path, "text-generation")
149
 
150
  try:
151
  return self._generate_text_sync(self.pipelines[model_path], text)
@@ -159,10 +159,10 @@ class LocalModelManager:
159
  result = pipeline(text)
160
  return str(result)
161
 
162
- async def classify_text(self, model_path: str, text: str) -> str:
163
- """テキスト分類の実行(非同期ラッパー)"""
164
- if model_path not in self.models:
165
- await self.load_model(model_path, "text-classification")
166
 
167
  try:
168
  return self._classify_text_sync(self.pipelines[model_path], text)
@@ -181,7 +181,7 @@ class ModelManager:
181
  """Inference APIクライアントの初期化"""
182
  for model in TEXT_GENERATION_MODELS + CLASSIFICATION_MODELS:
183
  if model.type == ModelType.INFERENCE_API and model.model_id:
184
- self.api_clients[model.model_id] = AsyncInferenceClient(
185
  model.model_id,
186
  token=True # これによりHFトークンを使用
187
  )
@@ -206,7 +206,7 @@ class ModelManager:
206
  # 事前ロード実行
207
  self.local_manager.preload_models(models_to_preload, tasks)
208
 
209
- async def run_text_generation(self, text: str, selected_types: List[str]) -> List[str]:
210
  """テキスト生成モデルの実行"""
211
  results = []
212
  for model in TEXT_GENERATION_MODELS:
@@ -214,20 +214,20 @@ class ModelManager:
214
  try:
215
  if model.type == ModelType.INFERENCE_API:
216
  logger.info(f"Running API text generation: {model.name}")
217
- response = await self.api_clients[model.model_id].text_generation(
218
  text, max_new_tokens=100, temperature=0.7
219
  )
220
  results.append(f"{model.name}: {response}")
221
  else:
222
  logger.info(f"Running local text generation: {model.name}")
223
- response = await self.local_manager.generate_text(model.model_path, text)
224
  results.append(f"{model.name}: {response}")
225
  except Exception as e:
226
  logger.error(f"Error in {model.name}: {str(e)}")
227
  results.append(f"{model.name}: Error - {str(e)}")
228
  return results
229
 
230
- async def run_classification(self, text: str, selected_types: List[str]) -> List[str]:
231
  """分類モデルの実行"""
232
  results = []
233
  for model in CLASSIFICATION_MODELS:
@@ -235,11 +235,11 @@ class ModelManager:
235
  try:
236
  if model.type == ModelType.INFERENCE_API:
237
  logger.info(f"Running API classification: {model.name}")
238
- response = await self.api_clients[model.model_id].text_classification(text)
239
  results.append(f"{model.name}: {response}")
240
  else:
241
  logger.info(f"Running local classification: {model.name}")
242
- response = await self.local_manager.classify_text(model.model_path, text)
243
  results.append(f"{model.name}: {response}")
244
  except Exception as e:
245
  logger.error(f"Error in {model.name}: {str(e)}")
@@ -349,10 +349,10 @@ class ToxicityApp:
349
  updates.append(gr.update(visible=visible))
350
  return updates
351
 
352
- async def handle_invoke(self, text: str, selected_types: List[str]) -> List[str]:
353
  """Invokeボタンのハンドラ"""
354
- gen_results = await self.model_manager.run_text_generation(text, selected_types)
355
- class_results = await self.model_manager.run_classification(text, selected_types)
356
 
357
  # 結果リストの長さを調整
358
  gen_results.extend([""] * (len(TEXT_GENERATION_MODELS) - len(gen_results)))
 
1
  import gradio as gr
2
+ from huggingface_hub import InferenceClient
3
  from typing import List, Dict, Optional, Union
4
  import logging
5
  from dataclasses import dataclass
 
100
  logger.error(f"Error preloading model {model_path}: {str(e)}")
101
  # 続行するが、エラーをログに記録
102
 
103
+ def load_model(self, model_path: str, task: str = "text-generation"):
104
  """モデルが既にロードされているか確認し、なければロード"""
105
  if model_path not in self.pipelines:
106
  logger.info(f"Loading model on demand: {model_path}")
 
142
  )
143
  return outputs[0]["generated_text"]
144
 
145
+ def generate_text(self, model_path: str, text: str) -> str:
146
+ """テキスト生成の実行"""
147
+ if model_path not in self.pipelines:
148
+ self.load_model(model_path, "text-generation")
149
 
150
  try:
151
  return self._generate_text_sync(self.pipelines[model_path], text)
 
159
  result = pipeline(text)
160
  return str(result)
161
 
162
+ def classify_text(self, model_path: str, text: str) -> str:
163
+ """テキスト分類の実行"""
164
+ if model_path not in self.pipelines:
165
+ self.load_model(model_path, "text-classification")
166
 
167
  try:
168
  return self._classify_text_sync(self.pipelines[model_path], text)
 
181
  """Inference APIクライアントの初期化"""
182
  for model in TEXT_GENERATION_MODELS + CLASSIFICATION_MODELS:
183
  if model.type == ModelType.INFERENCE_API and model.model_id:
184
+ self.api_clients[model.model_id] = InferenceClient(
185
  model.model_id,
186
  token=True # これによりHFトークンを使用
187
  )
 
206
  # 事前ロード実行
207
  self.local_manager.preload_models(models_to_preload, tasks)
208
 
209
+ def run_text_generation(self, text: str, selected_types: List[str]) -> List[str]:
210
  """テキスト生成モデルの実行"""
211
  results = []
212
  for model in TEXT_GENERATION_MODELS:
 
214
  try:
215
  if model.type == ModelType.INFERENCE_API:
216
  logger.info(f"Running API text generation: {model.name}")
217
+ response = self.api_clients[model.model_id].text_generation(
218
  text, max_new_tokens=100, temperature=0.7
219
  )
220
  results.append(f"{model.name}: {response}")
221
  else:
222
  logger.info(f"Running local text generation: {model.name}")
223
+ response = self.local_manager.generate_text(model.model_path, text)
224
  results.append(f"{model.name}: {response}")
225
  except Exception as e:
226
  logger.error(f"Error in {model.name}: {str(e)}")
227
  results.append(f"{model.name}: Error - {str(e)}")
228
  return results
229
 
230
+ def run_classification(self, text: str, selected_types: List[str]) -> List[str]:
231
  """分類モデルの実行"""
232
  results = []
233
  for model in CLASSIFICATION_MODELS:
 
235
  try:
236
  if model.type == ModelType.INFERENCE_API:
237
  logger.info(f"Running API classification: {model.name}")
238
+ response = self.api_clients[model.model_id].text_classification(text)
239
  results.append(f"{model.name}: {response}")
240
  else:
241
  logger.info(f"Running local classification: {model.name}")
242
+ response = self.local_manager.classify_text(model.model_path, text)
243
  results.append(f"{model.name}: {response}")
244
  except Exception as e:
245
  logger.error(f"Error in {model.name}: {str(e)}")
 
349
  updates.append(gr.update(visible=visible))
350
  return updates
351
 
352
+ def handle_invoke(self, text: str, selected_types: List[str]) -> List[str]:
353
  """Invokeボタンのハンドラ"""
354
+ gen_results = self.model_manager.run_text_generation(text, selected_types)
355
+ class_results = self.model_manager.run_classification(text, selected_types)
356
 
357
  # 結果リストの長さを調整
358
  gen_results.extend([""] * (len(TEXT_GENERATION_MODELS) - len(gen_results)))
requirements.txt CHANGED
@@ -1,4 +1,3 @@
1
  transformers>=4.30.0
2
  torch==2.4.0
3
  accelerate>=0.26.0
4
- aiohttp>=3.9.0
 
1
  transformers>=4.30.0
2
  torch==2.4.0
3
  accelerate>=0.26.0