toxic-eye / app.py
nyasukun's picture
.
aacf53d
raw
history blame
13 kB
import gradio as gr
from huggingface_hub import InferenceClient
from typing import List, Dict, Optional, Union
import logging
from enum import Enum, auto
import torch
from transformers import AutoTokenizer, pipeline
import spaces
# ロガーの設定
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# モデルタイプの定義
LOCAL = "local"
INFERENCE_API = "api"
# モデル定義
TEXT_GENERATION_MODELS = [
{
"name": "Zephyr-7B",
"description": "Specialized in understanding context and nuance",
"type": INFERENCE_API,
"model_id": "HuggingFaceH4/zephyr-7b-beta"
},
{
"name": "Llama-2",
"description": "Known for its robust performance in content analysis",
"type": LOCAL,
"model_path": "meta-llama/Llama-2-7b-hf"
},
{
"name": "Mistral-7B",
"description": "Offers precise and detailed text evaluation",
"type": LOCAL,
"model_path": "mistralai/Mistral-7B-v0.1"
}
]
CLASSIFICATION_MODELS = [
{
"name": "Toxic-BERT",
"description": "Fine-tuned for toxic content detection",
"type": LOCAL,
"model_path": "unitary/toxic-bert"
}
]
# グローバル変数でモデルやトークナイザーを管理
tokenizers = {}
pipelines = {}
api_clients = {}
# インファレンスAPIクライアントの初期化
def initialize_api_clients():
"""Inference APIクライアントの初期化"""
for model in TEXT_GENERATION_MODELS + CLASSIFICATION_MODELS:
if model["type"] == INFERENCE_API and "model_id" in model:
logger.info(f"Initializing API client for {model['name']}")
api_clients[model["model_id"]] = InferenceClient(
model["model_id"],
token=True # これによりHFトークンを使用
)
logger.info("API clients initialized")
# ローカルモデルを事前ロード
def preload_local_models():
"""ローカルモデルを事前ロード"""
logger.info("Preloading local models at application startup...")
# テキスト生成モデル
for model in TEXT_GENERATION_MODELS:
if model["type"] == LOCAL and "model_path" in model:
model_path = model["model_path"]
try:
logger.info(f"Preloading text generation model: {model_path}")
tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path)
pipelines[model_path] = pipeline(
"text-generation",
model=model_path,
tokenizer=tokenizers[model_path],
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="auto"
)
logger.info(f"Model preloaded successfully: {model_path}")
except Exception as e:
logger.error(f"Error preloading model {model_path}: {str(e)}")
# 分類モデル
for model in CLASSIFICATION_MODELS:
if model["type"] == LOCAL and "model_path" in model:
model_path = model["model_path"]
try:
logger.info(f"Preloading classification model: {model_path}")
tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path)
pipelines[model_path] = pipeline(
"text-classification",
model=model_path,
tokenizer=tokenizers[model_path],
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="auto"
)
logger.info(f"Model preloaded successfully: {model_path}")
except Exception as e:
logger.error(f"Error preloading model {model_path}: {str(e)}")
@spaces.GPU
def generate_text_local(model_path, text):
"""ローカルモデルでのテキスト生成"""
try:
logger.info(f"Running local text generation with {model_path}")
outputs = pipelines[model_path](
text,
max_new_tokens=40,
do_sample=False,
num_return_sequences=1
)
return outputs[0]["generated_text"]
except Exception as e:
logger.error(f"Error in local text generation with {model_path}: {str(e)}")
return f"Error: {str(e)}"
def generate_text_api(model_id, text):
"""API経由でのテキスト生成"""
try:
logger.info(f"Running API text generation with {model_id}")
response = api_clients[model_id].text_generation(
text,
max_new_tokens=40,
temperature=0.7
)
return response
except Exception as e:
logger.error(f"Error in API text generation with {model_id}: {str(e)}")
return f"Error: {str(e)}"
@spaces.GPU
def classify_text_local(model_path, text):
"""ローカルモデルでのテキスト分類"""
try:
logger.info(f"Running local classification with {model_path}")
result = pipelines[model_path](text)
return str(result)
except Exception as e:
logger.error(f"Error in local classification with {model_path}: {str(e)}")
return f"Error: {str(e)}"
def classify_text_api(model_id, text):
"""API経由でのテキスト分類"""
try:
logger.info(f"Running API classification with {model_id}")
response = api_clients[model_id].text_classification(text)
return str(response)
except Exception as e:
logger.error(f"Error in API classification with {model_id}: {str(e)}")
return f"Error: {str(e)}"
# Invokeボタンのハンドラ
def handle_invoke(text, selected_types):
"""Invokeボタンのハンドラ"""
results = []
# テキスト生成モデルの実行
for model in TEXT_GENERATION_MODELS:
if model["type"] in selected_types:
if model["type"] == LOCAL:
result = generate_text_local(model["model_path"], text)
else: # api
result = generate_text_api(model["model_id"], text)
results.append(f"{model['name']}: {result}")
# 分類モデルの実行
for model in CLASSIFICATION_MODELS:
if model["type"] in selected_types:
if model["type"] == LOCAL:
result = classify_text_local(model["model_path"], text)
else: # api
result = classify_text_api(model["model_id"], text)
results.append(f"{model['name']}: {result}")
# 結果リストの長さを調整
while len(results) < len(TEXT_GENERATION_MODELS) + len(CLASSIFICATION_MODELS):
results.append("")
return results
# モデルの表示状態を更新
def update_model_visibility(selected_types):
"""モデルの表示状態を更新"""
logger.info(f"Updating visibility for types: {selected_types}")
updates = []
for model_outputs in [gen_model_outputs, class_model_outputs]:
for output in model_outputs:
visible = output["type"] in selected_types
logger.info(f"Model {output['name']} (type: {output['type']}): visible = {visible}")
updates.append(gr.update(visible=visible))
return updates
# モデルをロードしUIを更新する
def load_models_and_update_ui():
"""モデルをロードしUIを更新する"""
try:
# APIクライアント初期化
initialize_api_clients()
# モデルのロード
preload_local_models()
logger.info("Models loaded successfully")
# ロード完了メッセージを返して、UIのロード中表示を非表示にする
return gr.update(visible=False), gr.update(visible=True)
except Exception as e:
logger.error(f"Error loading models: {e}")
return gr.update(value=f"Error loading models: {e}"), gr.update(visible=False)
# モデルグリッドの作成
def create_model_grid(models):
"""モデルグリッドの作成"""
outputs = []
with gr.Column() as container:
for i in range(0, len(models), 2):
with gr.Row() as row:
for j in range(min(2, len(models) - i)):
model = models[i + j]
with gr.Column():
with gr.Group() as group:
gr.Markdown(f"### {model['name']}")
gr.Markdown(f"Type: {model['type']}")
output = gr.Textbox(
label="Model Output",
lines=5,
interactive=False,
info=model['description']
)
outputs.append({
"type": model["type"],
"name": model["name"],
"output": output,
"group": group
})
return outputs
# グローバル変数としてUI部品を保持
input_text = None
filter_checkboxes = None
invoke_button = None
gen_model_outputs = []
class_model_outputs = []
community_output = None
# UIの作成
def create_ui():
"""UIの作成"""
global input_text, filter_checkboxes, invoke_button, gen_model_outputs, class_model_outputs, community_output
with gr.Blocks() as demo:
# ロード中コンポーネント
with gr.Group(visible=True) as loading_group:
gr.Markdown("""
# Toxic Eye
### Loading models... This may take a few minutes.
The application is initializing and preloading all models.
Please wait while the models are being loaded...
""")
# メインUIコンポーネント(初期状態では非表示)
with gr.Group(visible=False) as main_ui_group:
# ヘッダー
gr.Markdown("""
# Toxic Eye
This system evaluates the toxicity level of input text using multiple approaches.
""")
# 入力セクション
with gr.Row():
input_text = gr.Textbox(
label="Input Text",
placeholder="Enter text to analyze...",
lines=3
)
# フィルターセクション
with gr.Row():
filter_checkboxes = gr.CheckboxGroup(
choices=[LOCAL, INFERENCE_API],
value=[LOCAL, INFERENCE_API],
label="Filter Models",
info="Choose which types of models to display",
interactive=True
)
# Invokeボタン
with gr.Row():
invoke_button = gr.Button(
"Invoke Selected Models",
variant="primary",
size="lg"
)
# モデルタブ
with gr.Tabs():
with gr.Tab("Text Generation LLM"):
gen_model_outputs = create_model_grid(TEXT_GENERATION_MODELS)
with gr.Tab("Classification LLM"):
class_model_outputs = create_model_grid(CLASSIFICATION_MODELS)
with gr.Tab("Community (Not implemented)"):
with gr.Column():
community_output = gr.Textbox(
label="Related Community Topics",
lines=5,
interactive=False
)
# イベントハンドラの設定
filter_checkboxes.change(
fn=update_model_visibility,
inputs=[filter_checkboxes],
outputs=[
output["group"]
for outputs in [gen_model_outputs, class_model_outputs]
for output in outputs
]
)
invoke_button.click(
fn=handle_invoke,
inputs=[input_text, filter_checkboxes],
outputs=[
output["output"]
for outputs in [gen_model_outputs, class_model_outputs]
for output in outputs
]
)
# 起動時にモデルロード処理を実行
demo.load(
fn=load_models_and_update_ui,
inputs=None,
outputs=[loading_group, main_ui_group]
)
return demo
# メイン関数
def main():
logger.info("Starting Toxic Eye application")
demo = create_ui()
demo.launch()
if __name__ == "__main__":
main()