Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
from datasets import load_dataset | |
# Загружаем модель и токенизатор | |
model_name = "ai-forever/rugpt3small_based_on_gpt2" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
# Загружаем и анализируем банковский датасет | |
try: | |
bank_dataset = load_dataset("ZhenDOS/alpha_bank_data") | |
# Выводим структуру первого элемента для анализа | |
first_example = bank_dataset['train'][0] | |
print("Структура датасета (первый элемент):", first_example) | |
# Определяем используемые поля на основе анализа датасета | |
question_field = 'question' if 'question' in first_example else 'input' | |
answer_field = 'answer' if 'answer' in first_example else 'output' | |
except Exception as e: | |
print(f"Ошибка при загрузке датасета: {e}") | |
bank_dataset = None | |
question_field = 'input' | |
answer_field = 'output' | |
# Функция для создания контекста из датасета | |
def create_context(dataset, num_examples=3): | |
if dataset is None: | |
return "" | |
try: | |
examples = [] | |
for example in dataset['train'].select(range(num_examples)): | |
# Используем определенные поля или альтернативные варианты | |
question = example.get(question_field) or example.get('text') or example.get('message') | |
answer = example.get(answer_field) or example.get('response') or example.get('content') | |
if question and answer: | |
examples.append(f"Вопрос: {question}\nОтвет: {answer}") | |
return "\n\n".join(examples) if examples else "" | |
except Exception as e: | |
print(f"Ошибка при создании контекста: {e}") | |
return "" | |
# Создаем контекст | |
context_examples = create_context(bank_dataset) | |
print("Созданный контекст:\n", context_examples) | |
# Функция генерации ответа | |
def generate_response(prompt): | |
# Добавляем контекст, если он есть | |
if context_examples: | |
full_prompt = f"""Контекст банковских вопросов: | |
{context_examples} | |
Вопрос клиента: {prompt} | |
Ответ:""" | |
else: | |
full_prompt = f"Вопрос клиента: {prompt}\nОтвет:" | |
inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=512) | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=150, | |
do_sample=True, | |
temperature=0.7, | |
top_k=50, | |
top_p=0.95, | |
eos_token_id=tokenizer.eos_token_id, | |
no_repeat_ngram_size=3, | |
early_stopping=True | |
) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Удаляем промпт из ответа | |
if response.startswith(full_prompt): | |
response = response[len(full_prompt):].strip() | |
# Постобработка ответа | |
response = response.split("\n")[0].strip() | |
if not response.endswith((".", "!", "?")): | |
response += "." | |
return response | |
# Примеры вопросов | |
examples = [ | |
"Как восстановить утерянную карту?", | |
"Какие документы нужны для открытия счета?", | |
"Как проверить баланс карты?", | |
"Как оформить кредитную карту?", | |
"Какие комиссии за перевод между счетами?" | |
] | |
# Интерфейс Gradio | |
demo = gr.Interface( | |
fn=generate_response, | |
inputs=gr.Textbox(lines=4, label="Введите вопрос по клиентским обращениям в банк"), | |
outputs=gr.Textbox(label="Ответ модели"), | |
title="Анализ клиентских обращений — RuGPT-3 с Alpha Bank Data", | |
description="Используется модель ai-forever/rugpt3small_based_on_gpt2 с учетом данных из датасета Alpha Bank.", | |
examples=examples | |
) | |
if __name__ == "__main__": | |
demo.launch() | |