chat / app.py
SimrusDenuvo's picture
Update app.py
7882653 verified
raw
history blame
4.52 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from datasets import load_dataset
# Загружаем модель и токенизатор
model_name = "ai-forever/rugpt3small_based_on_gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# Загружаем и анализируем банковский датасет
try:
bank_dataset = load_dataset("ZhenDOS/alpha_bank_data")
# Выводим структуру первого элемента для анализа
first_example = bank_dataset['train'][0]
print("Структура датасета (первый элемент):", first_example)
# Определяем используемые поля на основе анализа датасета
question_field = 'question' if 'question' in first_example else 'input'
answer_field = 'answer' if 'answer' in first_example else 'output'
except Exception as e:
print(f"Ошибка при загрузке датасета: {e}")
bank_dataset = None
question_field = 'input'
answer_field = 'output'
# Функция для создания контекста из датасета
def create_context(dataset, num_examples=3):
if dataset is None:
return ""
try:
examples = []
for example in dataset['train'].select(range(num_examples)):
# Используем определенные поля или альтернативные варианты
question = example.get(question_field) or example.get('text') or example.get('message')
answer = example.get(answer_field) or example.get('response') or example.get('content')
if question and answer:
examples.append(f"Вопрос: {question}\nОтвет: {answer}")
return "\n\n".join(examples) if examples else ""
except Exception as e:
print(f"Ошибка при создании контекста: {e}")
return ""
# Создаем контекст
context_examples = create_context(bank_dataset)
print("Созданный контекст:\n", context_examples)
# Функция генерации ответа
def generate_response(prompt):
# Добавляем контекст, если он есть
if context_examples:
full_prompt = f"""Контекст банковских вопросов:
{context_examples}
Вопрос клиента: {prompt}
Ответ:"""
else:
full_prompt = f"Вопрос клиента: {prompt}\nОтвет:"
inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=150,
do_sample=True,
temperature=0.7,
top_k=50,
top_p=0.95,
eos_token_id=tokenizer.eos_token_id,
no_repeat_ngram_size=3,
early_stopping=True
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Удаляем промпт из ответа
if response.startswith(full_prompt):
response = response[len(full_prompt):].strip()
# Постобработка ответа
response = response.split("\n")[0].strip()
if not response.endswith((".", "!", "?")):
response += "."
return response
# Примеры вопросов
examples = [
"Как восстановить утерянную карту?",
"Какие документы нужны для открытия счета?",
"Как проверить баланс карты?",
"Как оформить кредитную карту?",
"Какие комиссии за перевод между счетами?"
]
# Интерфейс Gradio
demo = gr.Interface(
fn=generate_response,
inputs=gr.Textbox(lines=4, label="Введите вопрос по клиентским обращениям в банк"),
outputs=gr.Textbox(label="Ответ модели"),
title="Анализ клиентских обращений — RuGPT-3 с Alpha Bank Data",
description="Используется модель ai-forever/rugpt3small_based_on_gpt2 с учетом данных из датасета Alpha Bank.",
examples=examples
)
if __name__ == "__main__":
demo.launch()