chat / app.py
SimrusDenuvo's picture
Update app.py
d5e56ca verified
raw
history blame
6.03 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
from datasets import load_dataset
# Загрузка датасета
dataset = load_dataset("ZhenDOS/alpha_bank_data")
# Инициализация разных моделей
model_name = "ai-forever/rugpt3large_based_on_gpt2"
"GigaChat-like:" "ai-forever/rugpt3large_based_on_gpt2", # Русская модель большого размера
"ChatGPT-like": "tinkoff-ai/ruDialoGPT-medium", # Диалоговая модель для русского языка
"DeepSeek-like": "ai-forever/sbert_large_nlu_ru" # Русская модель для понимания текста
# Инициализация моделей и токенизаторов
models = model_name
tokenizers = AutoTokenizer.from_pretrained(model_name)
for model_name, model_path in MODELS.items():
try:
if model_name == "DeepSeek-like":
# Для SBERT используем pipeline
models[model_name] = pipeline("text-generation", model=model_path)
else:
tokenizers[model_name] = AutoTokenizer.from_pretrained(model_path)
models[model_name] = AutoModelForCausalLM.from_pretrained(model_path)
except Exception as e:
print(f"Ошибка при загрузке модели {model_name}: {e}")
# Промпты для обработки обращений
PROMPTS = {
"Анализ проблемы":
"Проанализируй клиентское обращение и выдели основную проблему. "
"Обращение: {text}\n\nПроблема:",
"Формирование ответа":
"Клиент обратился с проблемой: {problem}\n\n"
"Сформируй вежливый и профессиональный ответ, предлагая решение. "
"Используй информацию о банковских услугах. Ответ:"
}
def generate_with_model(prompt, model_name, max_length=150):
"""Генерация ответа с помощью выбранной модели"""
if model_name not in models:
return f"Модель {model_name} не загружена"
try:
if model_name == "DeepSeek-like":
# Обработка через pipeline
result = models[model_name](
prompt,
max_length=max_length,
do_sample=True,
temperature=0.7,
top_p=0.9
)
return result[0]['generated_text']
else:
# Обработка через transformers
inputs = tokenizers[model_name](prompt, return_tensors="pt", truncation=True)
with torch.no_grad():
outputs = models[model_name].generate(
**inputs,
max_new_tokens=max_length,
do_sample=True,
temperature=0.7,
top_p=0.9,
eos_token_id=tokenizers[model_name].eos_token_id
)
response = tokenizers[model_name].decode(outputs[0], skip_special_tokens=True)
return response[len(prompt):] if response.startswith(prompt) else response
except Exception as e:
return f"Ошибка генерации: {str(e)}"
def process_complaint(text, prompt_type):
"""Обработка клиентского обращения с выбранным промптом"""
if prompt_type not in PROMPTS:
return "Неверный тип промпта"
# Получаем случайный пример из датасета, если текст не введен
if not text.strip():
example = dataset['train'].shuffle().select(range(1))[0]
text = example['text']
prompt = PROMPTS[prompt_type].format(text=text, problem="")
results = {}
for model_name in MODELS.keys():
results[model_name] = generate_with_model(prompt, model_name)
return results
# Интерфейс Gradio
with gr.Blocks(title="Анализ клиентских обращений Alpha Bank") as demo:
gr.Markdown("## Анализ клиентских обращений Alpha Bank")
gr.Markdown("Тестирование разных моделей на обработку обращений")
with gr.Row():
with gr.Column():
text_input = gr.Textbox(
label="Текст обращения",
placeholder="Введите текст обращения или оставьте пустым для примера из датасета",
lines=5
)
prompt_type = gr.Radio(
list(PROMPTS.keys()),
label="Тип промпта",
value=list(PROMPTS.keys())[0]
)
submit_btn = gr.Button("Обработать")
with gr.Column():
outputs = []
for model_name in MODELS.keys():
outputs.append(
gr.Textbox(
label=f"{model_name}",
interactive=False,
lines=5
)
)
# Примеры из датасета
examples = gr.Examples(
examples=[x['text'] for x in dataset['train'].select(range(3))],
inputs=text_input,
label="Примеры из датасета"
)
def process_and_display(text, prompt_type):
results = process_complaint(text, prompt_type)
return [results.get(model_name, "") for model_name in MODELS.keys()]
submit_btn.click(
fn=process_and_display,
inputs=[text_input, prompt_type],
outputs=outputs
)
if __name__ == "__main__":
demo.launch()