|
import gradio as gr |
|
from huggingface_hub import InferenceClient |
|
from typing import List, Dict, Optional, Union |
|
import logging |
|
from enum import Enum, auto |
|
import torch |
|
from transformers import AutoTokenizer, pipeline |
|
import spaces |
|
import concurrent.futures |
|
import time |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
LOCAL = "local" |
|
INFERENCE_API = "api" |
|
|
|
|
|
TEXT_GENERATION_MODELS = [ |
|
{ |
|
"name": "Zephyr-7B", |
|
"description": "Specialized in understanding context and nuance", |
|
"type": INFERENCE_API, |
|
"model_id": "HuggingFaceH4/zephyr-7b-beta" |
|
}, |
|
{ |
|
"name": "Llama-2", |
|
"description": "Known for its robust performance in content analysis", |
|
"type": LOCAL, |
|
"model_path": "meta-llama/Llama-2-7b-hf" |
|
}, |
|
{ |
|
"name": "Mistral-7B", |
|
"description": "Offers precise and detailed text evaluation", |
|
"type": LOCAL, |
|
"model_path": "mistralai/Mistral-7B-v0.1" |
|
} |
|
] |
|
|
|
CLASSIFICATION_MODELS = [ |
|
{ |
|
"name": "Toxic-BERT", |
|
"description": "Fine-tuned for toxic content detection", |
|
"type": LOCAL, |
|
"model_path": "unitary/toxic-bert" |
|
} |
|
] |
|
|
|
|
|
tokenizers = {} |
|
pipelines = {} |
|
api_clients = {} |
|
|
|
|
|
def initialize_api_clients(): |
|
"""Inference APIクライアントの初期化""" |
|
for model in TEXT_GENERATION_MODELS + CLASSIFICATION_MODELS: |
|
if model["type"] == INFERENCE_API and "model_id" in model: |
|
logger.info(f"Initializing API client for {model['name']}") |
|
api_clients[model["model_id"]] = InferenceClient( |
|
model["model_id"], |
|
token=True |
|
) |
|
logger.info("API clients initialized") |
|
|
|
|
|
def preload_local_models(): |
|
"""ローカルモデルを事前ロード""" |
|
logger.info("Preloading local models at application startup...") |
|
|
|
|
|
for model in TEXT_GENERATION_MODELS: |
|
if model["type"] == LOCAL and "model_path" in model: |
|
model_path = model["model_path"] |
|
try: |
|
logger.info(f"Preloading text generation model: {model_path}") |
|
tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path) |
|
pipelines[model_path] = pipeline( |
|
"text-generation", |
|
model=model_path, |
|
tokenizer=tokenizers[model_path], |
|
torch_dtype=torch.bfloat16, |
|
trust_remote_code=True, |
|
device_map="auto" |
|
) |
|
logger.info(f"Model preloaded successfully: {model_path}") |
|
except Exception as e: |
|
logger.error(f"Error preloading model {model_path}: {str(e)}") |
|
|
|
|
|
for model in CLASSIFICATION_MODELS: |
|
if model["type"] == LOCAL and "model_path" in model: |
|
model_path = model["model_path"] |
|
try: |
|
logger.info(f"Preloading classification model: {model_path}") |
|
tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path) |
|
pipelines[model_path] = pipeline( |
|
"text-classification", |
|
model=model_path, |
|
tokenizer=tokenizers[model_path], |
|
torch_dtype=torch.bfloat16, |
|
trust_remote_code=True, |
|
device_map="auto" |
|
) |
|
logger.info(f"Model preloaded successfully: {model_path}") |
|
except Exception as e: |
|
logger.error(f"Error preloading model {model_path}: {str(e)}") |
|
|
|
@spaces.GPU |
|
def generate_text_local(model_path, text): |
|
"""ローカルモデルでのテキスト生成""" |
|
try: |
|
logger.info(f"Running local text generation with {model_path}") |
|
outputs = pipelines[model_path]( |
|
text, |
|
max_new_tokens=40, |
|
do_sample=False, |
|
num_return_sequences=1 |
|
) |
|
return outputs[0]["generated_text"] |
|
except Exception as e: |
|
logger.error(f"Error in local text generation with {model_path}: {str(e)}") |
|
return f"Error: {str(e)}" |
|
|
|
def generate_text_api(model_id, text): |
|
"""API経由でのテキスト生成""" |
|
try: |
|
logger.info(f"Running API text generation with {model_id}") |
|
response = api_clients[model_id].text_generation( |
|
text, |
|
max_new_tokens=40, |
|
temperature=0.7 |
|
) |
|
return response |
|
except Exception as e: |
|
logger.error(f"Error in API text generation with {model_id}: {str(e)}") |
|
return f"Error: {str(e)}" |
|
|
|
@spaces.GPU |
|
def classify_text_local(model_path, text): |
|
"""ローカルモデルでのテキスト分類""" |
|
try: |
|
logger.info(f"Running local classification with {model_path}") |
|
result = pipelines[model_path](text) |
|
return str(result) |
|
except Exception as e: |
|
logger.error(f"Error in local classification with {model_path}: {str(e)}") |
|
return f"Error: {str(e)}" |
|
|
|
def classify_text_api(model_id, text): |
|
"""API経由でのテキスト分類""" |
|
try: |
|
logger.info(f"Running API classification with {model_id}") |
|
response = api_clients[model_id].text_classification(text) |
|
return str(response) |
|
except Exception as e: |
|
logger.error(f"Error in API classification with {model_id}: {str(e)}") |
|
return f"Error: {str(e)}" |
|
|
|
@spaces.GPU |
|
def parallel_text_generation(model_paths, texts): |
|
"""複数のローカルモデルを一度のGPU割り当てで実行するための最適化関数""" |
|
try: |
|
logger.info(f"Running parallel text generation for {len(model_paths)} models") |
|
results = {} |
|
|
|
|
|
for i, (model_path, text) in enumerate(zip(model_paths, texts)): |
|
try: |
|
logger.info(f"Processing model {i+1}/{len(model_paths)}: {model_path}") |
|
outputs = pipelines[model_path]( |
|
text, |
|
max_new_tokens=40, |
|
do_sample=False, |
|
num_return_sequences=1 |
|
) |
|
results[model_path] = outputs[0]["generated_text"] |
|
except Exception as e: |
|
logger.error(f"Error in text generation with {model_path}: {str(e)}") |
|
results[model_path] = f"Error: {str(e)}" |
|
|
|
return results |
|
except Exception as e: |
|
logger.error(f"Error in parallel text generation: {str(e)}") |
|
return {model_path: f"Error: {str(e)}" for model_path in model_paths} |
|
|
|
@spaces.GPU |
|
def parallel_text_classification(model_paths, texts): |
|
"""複数のローカル分類モデルを一度のGPU割り当てで実行するための最適化関数""" |
|
try: |
|
logger.info(f"Running parallel text classification for {len(model_paths)} models") |
|
results = {} |
|
|
|
|
|
for i, (model_path, text) in enumerate(zip(model_paths, texts)): |
|
try: |
|
logger.info(f"Processing classification model {i+1}/{len(model_paths)}: {model_path}") |
|
result = pipelines[model_path](text) |
|
results[model_path] = str(result) |
|
except Exception as e: |
|
logger.error(f"Error in classification with {model_path}: {str(e)}") |
|
results[model_path] = f"Error: {str(e)}" |
|
|
|
return results |
|
except Exception as e: |
|
logger.error(f"Error in parallel text classification: {str(e)}") |
|
return {model_path: f"Error: {str(e)}" for model_path in model_paths} |
|
|
|
|
|
def handle_invoke(text, selected_types): |
|
"""Invokeボタンのハンドラ - 並列処理版""" |
|
start_time = time.time() |
|
logger.info("Starting parallel model execution") |
|
|
|
|
|
results = [""] * (len(TEXT_GENERATION_MODELS) + len(CLASSIFICATION_MODELS)) |
|
|
|
|
|
local_gen_models = [] |
|
local_gen_texts = [] |
|
local_gen_indices = [] |
|
|
|
|
|
local_cls_models = [] |
|
local_cls_texts = [] |
|
local_cls_indices = [] |
|
|
|
|
|
api_tasks = [] |
|
|
|
|
|
for i, model in enumerate(TEXT_GENERATION_MODELS): |
|
if model["type"] in selected_types: |
|
if model["type"] == LOCAL: |
|
local_gen_models.append(model["model_path"]) |
|
local_gen_texts.append(text) |
|
local_gen_indices.append(i) |
|
else: |
|
api_tasks.append((i, model, "gen_api")) |
|
|
|
|
|
for i, model in enumerate(CLASSIFICATION_MODELS): |
|
idx = i + len(TEXT_GENERATION_MODELS) |
|
if model["type"] in selected_types: |
|
if model["type"] == LOCAL: |
|
local_cls_models.append(model["model_path"]) |
|
local_cls_texts.append(text) |
|
local_cls_indices.append(idx) |
|
else: |
|
api_tasks.append((idx, model, "cls_api")) |
|
|
|
|
|
def process_api_task(task_data): |
|
idx, model, task_type = task_data |
|
try: |
|
if task_type == "gen_api": |
|
result = generate_text_api(model["model_id"], text) |
|
return idx, f"{model['name']}: {result}" |
|
elif task_type == "cls_api": |
|
result = classify_text_api(model["model_id"], text) |
|
return idx, f"{model['name']}: {result}" |
|
except Exception as e: |
|
logger.error(f"Error in {model['name']}: {str(e)}") |
|
return idx, f"{model['name']}: Error - {str(e)}" |
|
|
|
|
|
futures = [] |
|
if api_tasks: |
|
with concurrent.futures.ThreadPoolExecutor(max_workers=len(api_tasks)) as executor: |
|
futures = [executor.submit(process_api_task, task) for task in api_tasks] |
|
|
|
|
|
if local_gen_models: |
|
try: |
|
local_gen_results = parallel_text_generation(local_gen_models, local_gen_texts) |
|
for model_path, idx in zip(local_gen_models, local_gen_indices): |
|
model_name = next(m["name"] for m in TEXT_GENERATION_MODELS if m["model_path"] == model_path) |
|
results[idx] = f"{model_name}: {local_gen_results[model_path]}" |
|
except Exception as e: |
|
logger.error(f"Error in parallel text generation: {str(e)}") |
|
for model_path, idx in zip(local_gen_models, local_gen_indices): |
|
model_name = next(m["name"] for m in TEXT_GENERATION_MODELS if m["model_path"] == model_path) |
|
results[idx] = f"{model_name}: Error - {str(e)}" |
|
|
|
|
|
if local_cls_models: |
|
try: |
|
local_cls_results = parallel_text_classification(local_cls_models, local_cls_texts) |
|
for model_path, idx in zip(local_cls_models, local_cls_indices): |
|
model_name = next(m["name"] for m in CLASSIFICATION_MODELS if m["model_path"] == model_path) |
|
results[idx] = f"{model_name}: {local_cls_results[model_path]}" |
|
except Exception as e: |
|
logger.error(f"Error in parallel text classification: {str(e)}") |
|
for model_path, idx in zip(local_cls_models, local_cls_indices): |
|
model_name = next(m["name"] for m in CLASSIFICATION_MODELS if m["model_path"] == model_path) |
|
results[idx] = f"{model_name}: Error - {str(e)}" |
|
|
|
|
|
for future in concurrent.futures.as_completed(futures): |
|
idx, result = future.result() |
|
results[idx] = result |
|
|
|
|
|
elapsed_time = time.time() - start_time |
|
logger.info(f"Parallel model execution completed in {elapsed_time:.2f} seconds") |
|
|
|
return results |
|
|
|
|
|
def update_model_visibility(selected_types): |
|
"""モデルの表示状態を更新""" |
|
logger.info(f"Updating visibility for types: {selected_types}") |
|
|
|
updates = [] |
|
for model_outputs in [gen_model_outputs, class_model_outputs]: |
|
for output in model_outputs: |
|
visible = output["type"] in selected_types |
|
logger.info(f"Model {output['name']} (type: {output['type']}): visible = {visible}") |
|
updates.append(gr.update(visible=visible)) |
|
return updates |
|
|
|
|
|
def load_models_and_update_ui(): |
|
"""モデルをロードしUIを更新する""" |
|
try: |
|
|
|
initialize_api_clients() |
|
|
|
preload_local_models() |
|
logger.info("Models loaded successfully") |
|
|
|
return gr.update(visible=False), gr.update(visible=True) |
|
except Exception as e: |
|
logger.error(f"Error loading models: {e}") |
|
return gr.update(value=f"Error loading models: {e}"), gr.update(visible=False) |
|
|
|
|
|
def create_model_grid(models): |
|
"""モデルグリッドの作成""" |
|
outputs = [] |
|
with gr.Column() as container: |
|
for i in range(0, len(models), 2): |
|
with gr.Row() as row: |
|
for j in range(min(2, len(models) - i)): |
|
model = models[i + j] |
|
with gr.Column(): |
|
with gr.Group() as group: |
|
gr.Markdown(f"### {model['name']}") |
|
gr.Markdown(f"Type: {model['type']}") |
|
output = gr.Textbox( |
|
label="Model Output", |
|
lines=5, |
|
interactive=False, |
|
info=model['description'] |
|
) |
|
outputs.append({ |
|
"type": model["type"], |
|
"name": model["name"], |
|
"output": output, |
|
"group": group |
|
}) |
|
return outputs |
|
|
|
|
|
input_text = None |
|
filter_checkboxes = None |
|
invoke_button = None |
|
gen_model_outputs = [] |
|
class_model_outputs = [] |
|
community_output = None |
|
|
|
|
|
def create_ui(): |
|
"""UIの作成""" |
|
global input_text, filter_checkboxes, invoke_button, gen_model_outputs, class_model_outputs, community_output |
|
|
|
with gr.Blocks() as demo: |
|
|
|
with gr.Group(visible=True) as loading_group: |
|
gr.Markdown(""" |
|
# Toxic Eye |
|
|
|
### Loading models... This may take a few minutes. |
|
|
|
The application is initializing and preloading all models. |
|
Please wait while the models are being loaded... |
|
""") |
|
|
|
|
|
with gr.Group(visible=False) as main_ui_group: |
|
|
|
gr.Markdown(""" |
|
# Toxic Eye |
|
This system evaluates the toxicity level of input text using multiple approaches. |
|
""") |
|
|
|
|
|
with gr.Row(): |
|
input_text = gr.Textbox( |
|
label="Input Text", |
|
placeholder="Enter text to analyze...", |
|
lines=3 |
|
) |
|
|
|
|
|
with gr.Row(): |
|
filter_checkboxes = gr.CheckboxGroup( |
|
choices=[LOCAL, INFERENCE_API], |
|
value=[LOCAL, INFERENCE_API], |
|
label="Filter Models", |
|
info="Choose which types of models to display", |
|
interactive=True |
|
) |
|
|
|
|
|
with gr.Row(): |
|
invoke_button = gr.Button( |
|
"Invoke Selected Models", |
|
variant="primary", |
|
size="lg" |
|
) |
|
|
|
|
|
with gr.Tabs(): |
|
with gr.Tab("Text Generation LLM"): |
|
gen_model_outputs = create_model_grid(TEXT_GENERATION_MODELS) |
|
with gr.Tab("Classification LLM"): |
|
class_model_outputs = create_model_grid(CLASSIFICATION_MODELS) |
|
with gr.Tab("Community (Not implemented)"): |
|
with gr.Column(): |
|
community_output = gr.Textbox( |
|
label="Related Community Topics", |
|
lines=5, |
|
interactive=False |
|
) |
|
|
|
|
|
filter_checkboxes.change( |
|
fn=update_model_visibility, |
|
inputs=[filter_checkboxes], |
|
outputs=[ |
|
output["group"] |
|
for outputs in [gen_model_outputs, class_model_outputs] |
|
for output in outputs |
|
] |
|
) |
|
|
|
invoke_button.click( |
|
fn=handle_invoke, |
|
inputs=[input_text, filter_checkboxes], |
|
outputs=[ |
|
output["output"] |
|
for outputs in [gen_model_outputs, class_model_outputs] |
|
for output in outputs |
|
] |
|
) |
|
|
|
|
|
demo.load( |
|
fn=load_models_and_update_ui, |
|
inputs=None, |
|
outputs=[loading_group, main_ui_group] |
|
) |
|
|
|
return demo |
|
|
|
|
|
def main(): |
|
logger.info("Starting Toxic Eye application") |
|
demo = create_ui() |
|
demo.launch() |
|
|
|
if __name__ == "__main__": |
|
main() |