Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
import torch | |
from datasets import load_dataset | |
# Загрузка датасета | |
dataset = load_dataset("ZhenDOS/alpha_bank_data") | |
# Инициализация разных моделей | |
model_name = "ai-forever/rugpt3large_based_on_gpt2" | |
"GigaChat-like:" "ai-forever/rugpt3large_based_on_gpt2", # Русская модель большого размера | |
"ChatGPT-like": "tinkoff-ai/ruDialoGPT-medium", # Диалоговая модель для русского языка | |
"DeepSeek-like": "ai-forever/sbert_large_nlu_ru" # Русская модель для понимания текста | |
# Инициализация моделей и токенизаторов | |
models = model_name | |
tokenizers = AutoTokenizer.from_pretrained(model_name) | |
for model_name, model_path in MODELS.items(): | |
try: | |
if model_name == "DeepSeek-like": | |
# Для SBERT используем pipeline | |
models[model_name] = pipeline("text-generation", model=model_path) | |
else: | |
tokenizers[model_name] = AutoTokenizer.from_pretrained(model_path) | |
models[model_name] = AutoModelForCausalLM.from_pretrained(model_path) | |
except Exception as e: | |
print(f"Ошибка при загрузке модели {model_name}: {e}") | |
# Промпты для обработки обращений | |
PROMPTS = { | |
"Анализ проблемы": | |
"Проанализируй клиентское обращение и выдели основную проблему. " | |
"Обращение: {text}\n\nПроблема:", | |
"Формирование ответа": | |
"Клиент обратился с проблемой: {problem}\n\n" | |
"Сформируй вежливый и профессиональный ответ, предлагая решение. " | |
"Используй информацию о банковских услугах. Ответ:" | |
} | |
def generate_with_model(prompt, model_name, max_length=150): | |
"""Генерация ответа с помощью выбранной модели""" | |
if model_name not in models: | |
return f"Модель {model_name} не загружена" | |
try: | |
if model_name == "DeepSeek-like": | |
# Обработка через pipeline | |
result = models[model_name]( | |
prompt, | |
max_length=max_length, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.9 | |
) | |
return result[0]['generated_text'] | |
else: | |
# Обработка через transformers | |
inputs = tokenizers[model_name](prompt, return_tensors="pt", truncation=True) | |
with torch.no_grad(): | |
outputs = models[model_name].generate( | |
**inputs, | |
max_new_tokens=max_length, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.9, | |
eos_token_id=tokenizers[model_name].eos_token_id | |
) | |
response = tokenizers[model_name].decode(outputs[0], skip_special_tokens=True) | |
return response[len(prompt):] if response.startswith(prompt) else response | |
except Exception as e: | |
return f"Ошибка генерации: {str(e)}" | |
def process_complaint(text, prompt_type): | |
"""Обработка клиентского обращения с выбранным промптом""" | |
if prompt_type not in PROMPTS: | |
return "Неверный тип промпта" | |
# Получаем случайный пример из датасета, если текст не введен | |
if not text.strip(): | |
example = dataset['train'].shuffle().select(range(1))[0] | |
text = example['text'] | |
prompt = PROMPTS[prompt_type].format(text=text, problem="") | |
results = {} | |
for model_name in MODELS.keys(): | |
results[model_name] = generate_with_model(prompt, model_name) | |
return results | |
# Интерфейс Gradio | |
with gr.Blocks(title="Анализ клиентских обращений Alpha Bank") as demo: | |
gr.Markdown("## Анализ клиентских обращений Alpha Bank") | |
gr.Markdown("Тестирование разных моделей на обработку обращений") | |
with gr.Row(): | |
with gr.Column(): | |
text_input = gr.Textbox( | |
label="Текст обращения", | |
placeholder="Введите текст обращения или оставьте пустым для примера из датасета", | |
lines=5 | |
) | |
prompt_type = gr.Radio( | |
list(PROMPTS.keys()), | |
label="Тип промпта", | |
value=list(PROMPTS.keys())[0] | |
) | |
submit_btn = gr.Button("Обработать") | |
with gr.Column(): | |
outputs = [] | |
for model_name in MODELS.keys(): | |
outputs.append( | |
gr.Textbox( | |
label=f"{model_name}", | |
interactive=False, | |
lines=5 | |
) | |
) | |
# Примеры из датасета | |
examples = gr.Examples( | |
examples=[x['text'] for x in dataset['train'].select(range(3))], | |
inputs=text_input, | |
label="Примеры из датасета" | |
) | |
def process_and_display(text, prompt_type): | |
results = process_complaint(text, prompt_type) | |
return [results.get(model_name, "") for model_name in MODELS.keys()] | |
submit_btn.click( | |
fn=process_and_display, | |
inputs=[text_input, prompt_type], | |
outputs=outputs | |
) | |
if __name__ == "__main__": | |
demo.launch() | |