toxic-eye / app.py
nyasukun's picture
.
d75daa2
raw
history blame
18.8 kB
import gradio as gr
from huggingface_hub import InferenceClient
from typing import List, Dict, Optional, Union
import logging
from enum import Enum, auto
import torch
from transformers import AutoTokenizer, pipeline
import spaces
import concurrent.futures
import time
# ロガーの設定
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# モデルタイプの定義
LOCAL = "local"
INFERENCE_API = "api"
# モデル定義
TEXT_GENERATION_MODELS = [
{
"name": "Zephyr-7B",
"description": "Specialized in understanding context and nuance",
"type": INFERENCE_API,
"model_id": "HuggingFaceH4/zephyr-7b-beta"
},
{
"name": "Llama-2",
"description": "Known for its robust performance in content analysis",
"type": LOCAL,
"model_path": "meta-llama/Llama-2-7b-hf"
},
{
"name": "Mistral-7B",
"description": "Offers precise and detailed text evaluation",
"type": LOCAL,
"model_path": "mistralai/Mistral-7B-v0.1"
}
]
CLASSIFICATION_MODELS = [
{
"name": "Toxic-BERT",
"description": "Fine-tuned for toxic content detection",
"type": LOCAL,
"model_path": "unitary/toxic-bert"
}
]
# グローバル変数でモデルやトークナイザーを管理
tokenizers = {}
pipelines = {}
api_clients = {}
# インファレンスAPIクライアントの初期化
def initialize_api_clients():
"""Inference APIクライアントの初期化"""
for model in TEXT_GENERATION_MODELS + CLASSIFICATION_MODELS:
if model["type"] == INFERENCE_API and "model_id" in model:
logger.info(f"Initializing API client for {model['name']}")
api_clients[model["model_id"]] = InferenceClient(
model["model_id"],
token=True # これによりHFトークンを使用
)
logger.info("API clients initialized")
# ローカルモデルを事前ロード
def preload_local_models():
"""ローカルモデルを事前ロード"""
logger.info("Preloading local models at application startup...")
# テキスト生成モデル
for model in TEXT_GENERATION_MODELS:
if model["type"] == LOCAL and "model_path" in model:
model_path = model["model_path"]
try:
logger.info(f"Preloading text generation model: {model_path}")
tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path)
pipelines[model_path] = pipeline(
"text-generation",
model=model_path,
tokenizer=tokenizers[model_path],
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="auto"
)
logger.info(f"Model preloaded successfully: {model_path}")
except Exception as e:
logger.error(f"Error preloading model {model_path}: {str(e)}")
# 分類モデル
for model in CLASSIFICATION_MODELS:
if model["type"] == LOCAL and "model_path" in model:
model_path = model["model_path"]
try:
logger.info(f"Preloading classification model: {model_path}")
tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path)
pipelines[model_path] = pipeline(
"text-classification",
model=model_path,
tokenizer=tokenizers[model_path],
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="auto"
)
logger.info(f"Model preloaded successfully: {model_path}")
except Exception as e:
logger.error(f"Error preloading model {model_path}: {str(e)}")
@spaces.GPU
def generate_text_local(model_path, text):
"""ローカルモデルでのテキスト生成"""
try:
logger.info(f"Running local text generation with {model_path}")
outputs = pipelines[model_path](
text,
max_new_tokens=40,
do_sample=False,
num_return_sequences=1
)
return outputs[0]["generated_text"]
except Exception as e:
logger.error(f"Error in local text generation with {model_path}: {str(e)}")
return f"Error: {str(e)}"
def generate_text_api(model_id, text):
"""API経由でのテキスト生成"""
try:
logger.info(f"Running API text generation with {model_id}")
response = api_clients[model_id].text_generation(
text,
max_new_tokens=40,
temperature=0.7
)
return response
except Exception as e:
logger.error(f"Error in API text generation with {model_id}: {str(e)}")
return f"Error: {str(e)}"
@spaces.GPU
def classify_text_local(model_path, text):
"""ローカルモデルでのテキスト分類"""
try:
logger.info(f"Running local classification with {model_path}")
result = pipelines[model_path](text)
return str(result)
except Exception as e:
logger.error(f"Error in local classification with {model_path}: {str(e)}")
return f"Error: {str(e)}"
def classify_text_api(model_id, text):
"""API経由でのテキスト分類"""
try:
logger.info(f"Running API classification with {model_id}")
response = api_clients[model_id].text_classification(text)
return str(response)
except Exception as e:
logger.error(f"Error in API classification with {model_id}: {str(e)}")
return f"Error: {str(e)}"
@spaces.GPU
def parallel_text_generation(model_paths, texts):
"""複数のローカルモデルを一度のGPU割り当てで実行するための最適化関数"""
try:
logger.info(f"Running parallel text generation for {len(model_paths)} models")
results = {}
# 各モデルのパイプラインが既にロードされている前提で実行
for i, (model_path, text) in enumerate(zip(model_paths, texts)):
try:
logger.info(f"Processing model {i+1}/{len(model_paths)}: {model_path}")
outputs = pipelines[model_path](
text,
max_new_tokens=40,
do_sample=False,
num_return_sequences=1
)
results[model_path] = outputs[0]["generated_text"]
except Exception as e:
logger.error(f"Error in text generation with {model_path}: {str(e)}")
results[model_path] = f"Error: {str(e)}"
return results
except Exception as e:
logger.error(f"Error in parallel text generation: {str(e)}")
return {model_path: f"Error: {str(e)}" for model_path in model_paths}
@spaces.GPU
def parallel_text_classification(model_paths, texts):
"""複数のローカル分類モデルを一度のGPU割り当てで実行するための最適化関数"""
try:
logger.info(f"Running parallel text classification for {len(model_paths)} models")
results = {}
# 各モデルのパイプラインが既にロードされている前提で実行
for i, (model_path, text) in enumerate(zip(model_paths, texts)):
try:
logger.info(f"Processing classification model {i+1}/{len(model_paths)}: {model_path}")
result = pipelines[model_path](text)
results[model_path] = str(result)
except Exception as e:
logger.error(f"Error in classification with {model_path}: {str(e)}")
results[model_path] = f"Error: {str(e)}"
return results
except Exception as e:
logger.error(f"Error in parallel text classification: {str(e)}")
return {model_path: f"Error: {str(e)}" for model_path in model_paths}
# Invokeボタンのハンドラ
def handle_invoke(text, selected_types):
"""Invokeボタンのハンドラ - 並列処理版"""
start_time = time.time()
logger.info("Starting parallel model execution")
# 結果を格納する配列(順番を保持するため、最初に空の配列を作成)
results = [""] * (len(TEXT_GENERATION_MODELS) + len(CLASSIFICATION_MODELS))
# ローカルの生成モデルを一括処理するための準備
local_gen_models = []
local_gen_texts = []
local_gen_indices = []
# ローカルの分類モデルを一括処理するための準備
local_cls_models = []
local_cls_texts = []
local_cls_indices = []
# APIモデルとその他のタスク
api_tasks = []
# テキスト生成モデルの分類
for i, model in enumerate(TEXT_GENERATION_MODELS):
if model["type"] in selected_types:
if model["type"] == LOCAL:
local_gen_models.append(model["model_path"])
local_gen_texts.append(text)
local_gen_indices.append(i)
else: # api
api_tasks.append((i, model, "gen_api"))
# 分類モデルの分類
for i, model in enumerate(CLASSIFICATION_MODELS):
idx = i + len(TEXT_GENERATION_MODELS)
if model["type"] in selected_types:
if model["type"] == LOCAL:
local_cls_models.append(model["model_path"])
local_cls_texts.append(text)
local_cls_indices.append(idx)
else: # api
api_tasks.append((idx, model, "cls_api"))
# APIタスクを処理する関数
def process_api_task(task_data):
idx, model, task_type = task_data
try:
if task_type == "gen_api":
result = generate_text_api(model["model_id"], text)
return idx, f"{model['name']}: {result}"
elif task_type == "cls_api":
result = classify_text_api(model["model_id"], text)
return idx, f"{model['name']}: {result}"
except Exception as e:
logger.error(f"Error in {model['name']}: {str(e)}")
return idx, f"{model['name']}: Error - {str(e)}"
# API処理を並列実行
futures = []
if api_tasks:
with concurrent.futures.ThreadPoolExecutor(max_workers=len(api_tasks)) as executor:
futures = [executor.submit(process_api_task, task) for task in api_tasks]
# ローカル生成モデルを並列処理
if local_gen_models:
try:
local_gen_results = parallel_text_generation(local_gen_models, local_gen_texts)
for model_path, idx in zip(local_gen_models, local_gen_indices):
model_name = next(m["name"] for m in TEXT_GENERATION_MODELS if m["model_path"] == model_path)
results[idx] = f"{model_name}: {local_gen_results[model_path]}"
except Exception as e:
logger.error(f"Error in parallel text generation: {str(e)}")
for model_path, idx in zip(local_gen_models, local_gen_indices):
model_name = next(m["name"] for m in TEXT_GENERATION_MODELS if m["model_path"] == model_path)
results[idx] = f"{model_name}: Error - {str(e)}"
# ローカル分類モデルを並列処理
if local_cls_models:
try:
local_cls_results = parallel_text_classification(local_cls_models, local_cls_texts)
for model_path, idx in zip(local_cls_models, local_cls_indices):
model_name = next(m["name"] for m in CLASSIFICATION_MODELS if m["model_path"] == model_path)
results[idx] = f"{model_name}: {local_cls_results[model_path]}"
except Exception as e:
logger.error(f"Error in parallel text classification: {str(e)}")
for model_path, idx in zip(local_cls_models, local_cls_indices):
model_name = next(m["name"] for m in CLASSIFICATION_MODELS if m["model_path"] == model_path)
results[idx] = f"{model_name}: Error - {str(e)}"
# APIタスクの結果を収集
for future in concurrent.futures.as_completed(futures):
idx, result = future.result()
results[idx] = result
# 実行時間を記録
elapsed_time = time.time() - start_time
logger.info(f"Parallel model execution completed in {elapsed_time:.2f} seconds")
return results
# モデルの表示状態を更新
def update_model_visibility(selected_types):
"""モデルの表示状態を更新"""
logger.info(f"Updating visibility for types: {selected_types}")
updates = []
for model_outputs in [gen_model_outputs, class_model_outputs]:
for output in model_outputs:
visible = output["type"] in selected_types
logger.info(f"Model {output['name']} (type: {output['type']}): visible = {visible}")
updates.append(gr.update(visible=visible))
return updates
# モデルをロードしUIを更新する
def load_models_and_update_ui():
"""モデルをロードしUIを更新する"""
try:
# APIクライアント初期化
initialize_api_clients()
# モデルのロード
preload_local_models()
logger.info("Models loaded successfully")
# ロード完了メッセージを返して、UIのロード中表示を非表示にする
return gr.update(visible=False), gr.update(visible=True)
except Exception as e:
logger.error(f"Error loading models: {e}")
return gr.update(value=f"Error loading models: {e}"), gr.update(visible=False)
# モデルグリッドの作成
def create_model_grid(models):
"""モデルグリッドの作成"""
outputs = []
with gr.Column() as container:
for i in range(0, len(models), 2):
with gr.Row() as row:
for j in range(min(2, len(models) - i)):
model = models[i + j]
with gr.Column():
with gr.Group() as group:
gr.Markdown(f"### {model['name']}")
gr.Markdown(f"Type: {model['type']}")
output = gr.Textbox(
label="Model Output",
lines=5,
interactive=False,
info=model['description']
)
outputs.append({
"type": model["type"],
"name": model["name"],
"output": output,
"group": group
})
return outputs
# グローバル変数としてUI部品を保持
input_text = None
filter_checkboxes = None
invoke_button = None
gen_model_outputs = []
class_model_outputs = []
community_output = None
# UIの作成
def create_ui():
"""UIの作成"""
global input_text, filter_checkboxes, invoke_button, gen_model_outputs, class_model_outputs, community_output
with gr.Blocks() as demo:
# ロード中コンポーネント
with gr.Group(visible=True) as loading_group:
gr.Markdown("""
# Toxic Eye
### Loading models... This may take a few minutes.
The application is initializing and preloading all models.
Please wait while the models are being loaded...
""")
# メインUIコンポーネント(初期状態では非表示)
with gr.Group(visible=False) as main_ui_group:
# ヘッダー
gr.Markdown("""
# Toxic Eye
This system evaluates the toxicity level of input text using multiple approaches.
""")
# 入力セクション
with gr.Row():
input_text = gr.Textbox(
label="Input Text",
placeholder="Enter text to analyze...",
lines=3
)
# フィルターセクション
with gr.Row():
filter_checkboxes = gr.CheckboxGroup(
choices=[LOCAL, INFERENCE_API],
value=[LOCAL, INFERENCE_API],
label="Filter Models",
info="Choose which types of models to display",
interactive=True
)
# Invokeボタン
with gr.Row():
invoke_button = gr.Button(
"Invoke Selected Models",
variant="primary",
size="lg"
)
# モデルタブ
with gr.Tabs():
with gr.Tab("Text Generation LLM"):
gen_model_outputs = create_model_grid(TEXT_GENERATION_MODELS)
with gr.Tab("Classification LLM"):
class_model_outputs = create_model_grid(CLASSIFICATION_MODELS)
with gr.Tab("Community (Not implemented)"):
with gr.Column():
community_output = gr.Textbox(
label="Related Community Topics",
lines=5,
interactive=False
)
# イベントハンドラの設定
filter_checkboxes.change(
fn=update_model_visibility,
inputs=[filter_checkboxes],
outputs=[
output["group"]
for outputs in [gen_model_outputs, class_model_outputs]
for output in outputs
]
)
invoke_button.click(
fn=handle_invoke,
inputs=[input_text, filter_checkboxes],
outputs=[
output["output"]
for outputs in [gen_model_outputs, class_model_outputs]
for output in outputs
]
)
# 起動時にモデルロード処理を実行
demo.load(
fn=load_models_and_update_ui,
inputs=None,
outputs=[loading_group, main_ui_group]
)
return demo
# メイン関数
def main():
logger.info("Starting Toxic Eye application")
demo = create_ui()
demo.launch()
if __name__ == "__main__":
main()