Spaces:
Sleeping
Sleeping
import os | |
import json | |
import gradio as gr | |
import time | |
from transformers import pipeline | |
from kaggle.api.kaggle_api_extended import KaggleApi | |
# === Подготовка банковского набора данных через Kaggle === | |
# Будет скачан dataset PromptCloudHQ/banking-chatbot-dataset, содержащий примеры вопросов и ответов для банковского чат-бота. | |
DATA_DIR = './data' | |
json_file = None | |
# Скачиваем при первом запуске | |
if not os.path.exists(DATA_DIR): | |
os.makedirs(DATA_DIR) | |
api = KaggleApi() | |
api.authenticate() | |
api.dataset_download_files('PromptCloudHQ/banking-chatbot-dataset', path=DATA_DIR, unzip=True) | |
# Находим JSON-файл с данными | |
for fname in os.listdir(DATA_DIR): | |
if fname.endswith('.json'): | |
json_file = os.path.join(DATA_DIR, fname) | |
break | |
else: | |
# Если папка есть — ищем файл | |
for fname in os.listdir(DATA_DIR): | |
if fname.endswith('.json'): | |
json_file = os.path.join(DATA_DIR, fname) | |
break | |
if json_file is None: | |
raise FileNotFoundError('Не удалось найти JSON-файл с банковскими данными в ./data') | |
# Загружаем JSON с примерами | |
with open(json_file, 'r', encoding='utf-8') as f: | |
kb = json.load(f) | |
# Структура: {"intents": [ ... ]} | |
intents = kb.get('intents') | |
if intents is None: | |
raise ValueError('Ожидался ключ "intents" в JSON-файле датасета') | |
# Собираем два few-shot примера | |
examples = [] | |
for intent in intents[:2]: | |
patterns = intent.get('patterns', []) | |
responses = intent.get('responses', []) | |
ex = f"Паттерны: {', '.join(patterns)}\nОтветы: {', '.join(responses)}" | |
examples.append(ex) | |
# === Инициализация трёх бесплатных русскоязычных моделей (GPT-2 based) === | |
models = { | |
'ruDialoGPT-small': pipeline('text-generation', model='t-bank-ai/ruDialoGPT-small', tokenizer='t-bank-ai/ruDialoGPT-small', device=-1), | |
'ruDialoGPT-medium': pipeline('text-generation', model='t-bank-ai/ruDialoGPT-medium', tokenizer='t-bank-ai/ruDialoGPT-medium', device=-1), | |
'ruGPT3-small': pipeline('text-generation', model='ai-forever/rugpt3small_based_on_gpt2', tokenizer='ai-forever/rugpt3small_based_on_gpt2', device=-1), | |
} | |
# Системная инструкция для CoT | |
system_instruction = ( | |
"Вы — банковский ассистент. Ваша задача — корректно и вежливо отвечать на запросы клиентов банка, " | |
"давать рекомендации по банковским операциям и услугам." | |
) | |
# Строим полный промпт с CoT и примерами | |
def build_prompt(question: str) -> str: | |
few_shot_text = "\n\n".join(f"Пример:\n{ex}" for ex in examples) | |
prompt = ( | |
f"{system_instruction}\n\n" | |
f"{few_shot_text}\n\n" | |
f"Вопрос клиента: {question}\n" | |
"Сначала подробно опишите рассуждения шаг за шагом, а затем кратко сформулируйте ответ." | |
) | |
return prompt | |
# Генерация ответов и измерение времени | |
def generate(question: str): | |
prompt = build_prompt(question) | |
results = {} | |
for name, pipe in models.items(): | |
start = time.time() | |
out = pipe(prompt, max_length=200, do_sample=True, top_p=0.9, temperature=0.7)[0]['generated_text'] | |
elapsed = round(time.time() - start, 2) | |
# Извлекаем связный ответ — последнюю строку | |
answer = out.strip().split('\n')[-1] | |
results[name] = {'answer': answer, 'time': elapsed} | |
return results | |
# Форматируем вывод для Gradio | |
def format_outputs(question: str): | |
res = generate(question) | |
return ( | |
res['ruDialoGPT-small']['answer'], f"{res['ruDialoGPT-small']['time']}s", | |
res['ruDialoGPT-medium']['answer'], f"{res['ruDialoGPT-medium']['time']}s", | |
res['ruGPT3-small']['answer'], f"{res['ruGPT3-small']['time']}s" | |
) | |
# === Интерфейс Gradio === | |
with gr.Blocks() as demo: | |
gr.Markdown("## Ответы на клиентские обращения\nCoT + тайминг по трём бесплатным моделям") | |
txt = gr.Textbox(label='Описание проблемы клиента', placeholder='Например: "Почему я не могу снять деньги с карты?"', lines=2) | |
btn = gr.Button('Сгенерировать ответы') | |
out1 = gr.Textbox(label='ruDialoGPT-small Ответ') | |
t1 = gr.Textbox(label='ruDialoGPT-small Время') | |
out2 = gr.Textbox(label='ruDialoGPT-medium Ответ') | |
t2 = gr.Textbox(label='ruDialoGPT-medium Время') | |
out3 = gr.Textbox(label='ruGPT3-small Ответ') | |
t3 = gr.Textbox(label='ruGPT3-small Время') | |
btn.click(format_outputs, inputs=[txt], outputs=[out1, t1, out2, t2, out3, t3]) | |
demo.launch() | |