chat / app.py
SimrusDenuvo's picture
Update app.py
d0ca3ae verified
raw
history blame
4.78 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from datasets import load_dataset
# Загружаем модель и токенизатор
model_name = "ai-forever/rugpt3small_based_on_gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# Дополнительные знания о банковских услугах
BANK_KNOWLEDGE = {
"Как проверить баланс карты?": [
"1. Через мобильное приложение банка (раздел 'Карты' → 'Баланс')",
"2. В интернет-банке (в личном кабинете выберите карту)",
"3. По SMS (отправьте BALANCE на номер 900)",
"4. В банкомате (вставьте карту и выберите 'Запрос баланса')",
"5. По телефону горячей линии (8-800-100-00-00)"
],
"Как восстановить утерянную карту?": [
"1. Немедленно позвоните в банк по телефону 8-800-100-00-00 для блокировки карты",
"2. Обратитесь в отделение банка с паспортом",
"3. Заполните заявление на перевыпуск карты",
"4. Новая карта будет готова через 3-5 рабочих дней"
]
}
def enhance_response(question, generated_response):
# Если вопрос есть в наших знаниях, возвращаем структурированный ответ
if question in BANK_KNOWLEDGE:
return "\n".join(BANK_KNOWLEDGE[question])
# Улучшаем стандартные ответы модели
improvements = {
"баланс": "Вы можете проверить баланс карты:\n"
"1. В мобильном приложении\n"
"2. Через интернет-банк\n"
"3. В банкомате\n"
"4. По телефону горячей линии 8-800-100-00-00",
"кредит": "По вопросам кредитования вы можете:\n"
"1. Оставить заявку на сайте\n"
"2. Обратиться в отделение банка\n"
"3. Позвонить по телефону 8-800-100-00-00",
"карт": "По вопросам банковских карт:\n"
"1. Обратитесь в отделение банка\n"
"2. Позвоните на горячую линию\n"
"3. Используйте чат в мобильном приложении"
}
for keyword, improved_answer in improvements.items():
if keyword in question.lower():
return improved_answer
return generated_response
def generate_response(prompt):
# Генерируем ответ с помощью модели
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=150,
do_sample=True,
temperature=0.7,
top_k=50,
top_p=0.95,
eos_token_id=tokenizer.eos_token_id,
no_repeat_ngram_size=3,
early_stopping=True
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Удаляем промпт из ответа
if response.startswith(prompt):
response = response[len(prompt):].strip()
# Улучшаем ответ
enhanced_response = enhance_response(prompt, response)
return enhanced_response
# Интерфейс Gradio
demo = gr.Interface(
fn=generate_response,
inputs=gr.Textbox(lines=4, label="Введите вопрос по клиентским обращениям в банк"),
outputs=gr.Textbox(label="Ответ модели"),
title="Анализ клиентских обращений — Alpha Bank Assistant",
description="Получите точные ответы на вопросы о банковских услугах",
examples=[
"Как проверить баланс карты?",
"Как восстановить утерянную карту?",
"Как оформить кредитную карту?",
"Какие документы нужны для открытия счета?"
]
)
if __name__ == "__main__":
demo.launch()