Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
import torch | |
from datasets import load_dataset | |
# Загружаем модель, токенизатор и датасет | |
model_name = "ai-forever/rugpt3small_based_on_gpt2" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
# Загружаем банковский датасет для контекста | |
bank_dataset = load_dataset("ZhenDOS/alpha_bank_data") | |
# Создаем контекст из датасета (первые несколько примеров) | |
context_examples = "\n".join([ | |
f"Вопрос: {example['question']}\nОтвет: {example['answer']}" | |
for example in bank_dataset['train'].select(range(5)) | |
]) | |
# Функция генерации ответа с учетом банковского контекста | |
def generate_response(prompt): | |
# Добавляем контекст из датасета к промпту | |
full_prompt = f"""Контекст по банковским вопросам: | |
{context_examples} | |
Вопрос клиента: {prompt} | |
Ответ:""" | |
inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=512) | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=150, | |
do_sample=True, | |
temperature=0.7, | |
top_k=50, | |
top_p=0.95, | |
eos_token_id=tokenizer.eos_token_id, | |
no_repeat_ngram_size=3, | |
early_stopping=True | |
) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Удаляем промпт из ответа | |
if response.startswith(full_prompt): | |
response = response[len(full_prompt):].strip() | |
# Постобработка ответа | |
response = response.split("\n")[0] # Берем только первую строку ответа | |
response = response.replace("Ответ:", "").strip() | |
return response | |
# Интерфейс Gradio с примерами вопросов | |
examples = [ | |
"Как восстановить утерянную карту?", | |
"Какие документы нужны для открытия счета?", | |
"Как проверить баланс карты?", | |
"Как оформить кредитную карту?", | |
"Какие комиссии за перевод между счетами?" | |
] | |
demo = gr.Interface( | |
fn=generate_response, | |
inputs=gr.Textbox(lines=4, label="Введите вопрос по клиентским обращениям в банк"), | |
outputs=gr.Textbox(label="Ответ модели"), | |
title="Анализ клиентских обращений — RuGPT-3 с Alpha Bank Data", | |
description="Используется модель ai-forever/rugpt3small_based_on_gpt2, дообученная на данных ZhenDOS/alpha_bank_data.", | |
examples=examples | |
) | |
# Запуск | |
if __name__ == "__main__": | |
demo.launch() | |