chat / app.py
SimrusDenuvo's picture
Update app.py
e66857b verified
raw
history blame
5.23 kB
import os
import json
import gradio as gr
import time
from transformers import pipeline
from kaggle.api.kaggle_api_extended import KaggleApi
# === Подготовка банковского набора данных через Kaggle ===
# Будет скачан dataset PromptCloudHQ/banking-chatbot-dataset, содержащий примеры вопросов и ответов для банковского чат-бота.
DATA_DIR = './data'
json_file = None
# Скачиваем при первом запуске
if not os.path.exists(DATA_DIR):
os.makedirs(DATA_DIR)
api = KaggleApi()
api.authenticate()
api.dataset_download_files('PromptCloudHQ/banking-chatbot-dataset', path=DATA_DIR, unzip=True)
# Находим JSON-файл с данными
for fname in os.listdir(DATA_DIR):
if fname.endswith('.json'):
json_file = os.path.join(DATA_DIR, fname)
break
else:
# Если папка есть — ищем файл
for fname in os.listdir(DATA_DIR):
if fname.endswith('.json'):
json_file = os.path.join(DATA_DIR, fname)
break
if json_file is None:
raise FileNotFoundError('Не удалось найти JSON-файл с банковскими данными в ./data')
# Загружаем JSON с примерами
with open(json_file, 'r', encoding='utf-8') as f:
kb = json.load(f)
# Структура: {"intents": [ ... ]}
intents = kb.get('intents')
if intents is None:
raise ValueError('Ожидался ключ "intents" в JSON-файле датасета')
# Собираем два few-shot примера
examples = []
for intent in intents[:2]:
patterns = intent.get('patterns', [])
responses = intent.get('responses', [])
ex = f"Паттерны: {', '.join(patterns)}\nОтветы: {', '.join(responses)}"
examples.append(ex)
# === Инициализация трёх бесплатных русскоязычных моделей (GPT-2 based) ===
models = {
'ruDialoGPT-small': pipeline('text-generation', model='t-bank-ai/ruDialoGPT-small', tokenizer='t-bank-ai/ruDialoGPT-small', device=-1),
'ruDialoGPT-medium': pipeline('text-generation', model='t-bank-ai/ruDialoGPT-medium', tokenizer='t-bank-ai/ruDialoGPT-medium', device=-1),
'ruGPT3-small': pipeline('text-generation', model='ai-forever/rugpt3small_based_on_gpt2', tokenizer='ai-forever/rugpt3small_based_on_gpt2', device=-1),
}
# Системная инструкция для CoT
system_instruction = (
"Вы — банковский ассистент. Ваша задача — корректно и вежливо отвечать на запросы клиентов банка, "
"давать рекомендации по банковским операциям и услугам."
)
# Строим полный промпт с CoT и примерами
def build_prompt(question: str) -> str:
few_shot_text = "\n\n".join(f"Пример:\n{ex}" for ex in examples)
prompt = (
f"{system_instruction}\n\n"
f"{few_shot_text}\n\n"
f"Вопрос клиента: {question}\n"
"Сначала подробно опишите рассуждения шаг за шагом, а затем кратко сформулируйте ответ."
)
return prompt
# Генерация ответов и измерение времени
def generate(question: str):
prompt = build_prompt(question)
results = {}
for name, pipe in models.items():
start = time.time()
out = pipe(prompt, max_length=200, do_sample=True, top_p=0.9, temperature=0.7)[0]['generated_text']
elapsed = round(time.time() - start, 2)
# Извлекаем связный ответ — последнюю строку
answer = out.strip().split('\n')[-1]
results[name] = {'answer': answer, 'time': elapsed}
return results
# Форматируем вывод для Gradio
def format_outputs(question: str):
res = generate(question)
return (
res['ruDialoGPT-small']['answer'], f"{res['ruDialoGPT-small']['time']}s",
res['ruDialoGPT-medium']['answer'], f"{res['ruDialoGPT-medium']['time']}s",
res['ruGPT3-small']['answer'], f"{res['ruGPT3-small']['time']}s"
)
# === Интерфейс Gradio ===
with gr.Blocks() as demo:
gr.Markdown("## Ответы на клиентские обращения\nCoT + тайминг по трём бесплатным моделям")
txt = gr.Textbox(label='Описание проблемы клиента', placeholder='Например: "Почему я не могу снять деньги с карты?"', lines=2)
btn = gr.Button('Сгенерировать ответы')
out1 = gr.Textbox(label='ruDialoGPT-small Ответ')
t1 = gr.Textbox(label='ruDialoGPT-small Время')
out2 = gr.Textbox(label='ruDialoGPT-medium Ответ')
t2 = gr.Textbox(label='ruDialoGPT-medium Время')
out3 = gr.Textbox(label='ruGPT3-small Ответ')
t3 = gr.Textbox(label='ruGPT3-small Время')
btn.click(format_outputs, inputs=[txt], outputs=[out1, t1, out2, t2, out3, t3])
demo.launch()