chat / app.py
SimrusDenuvo's picture
Update app.py
af5aa20 verified
raw
history blame
3.1 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
from datasets import load_dataset
# Загружаем модель, токенизатор и датасет
model_name = "ai-forever/rugpt3small_based_on_gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# Загружаем банковский датасет для контекста
bank_dataset = load_dataset("ZhenDOS/alpha_bank_data")
# Создаем контекст из датасета (первые несколько примеров)
context_examples = "\n".join([
f"Вопрос: {example['question']}\nОтвет: {example['answer']}"
for example in bank_dataset['train'].select(range(5))
])
# Функция генерации ответа с учетом банковского контекста
def generate_response(prompt):
# Добавляем контекст из датасета к промпту
full_prompt = f"""Контекст по банковским вопросам:
{context_examples}
Вопрос клиента: {prompt}
Ответ:"""
inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=150,
do_sample=True,
temperature=0.7,
top_k=50,
top_p=0.95,
eos_token_id=tokenizer.eos_token_id,
no_repeat_ngram_size=3,
early_stopping=True
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Удаляем промпт из ответа
if response.startswith(full_prompt):
response = response[len(full_prompt):].strip()
# Постобработка ответа
response = response.split("\n")[0] # Берем только первую строку ответа
response = response.replace("Ответ:", "").strip()
return response
# Интерфейс Gradio с примерами вопросов
examples = [
"Как восстановить утерянную карту?",
"Какие документы нужны для открытия счета?",
"Как проверить баланс карты?",
"Как оформить кредитную карту?",
"Какие комиссии за перевод между счетами?"
]
demo = gr.Interface(
fn=generate_response,
inputs=gr.Textbox(lines=4, label="Введите вопрос по клиентским обращениям в банк"),
outputs=gr.Textbox(label="Ответ модели"),
title="Анализ клиентских обращений — RuGPT-3 с Alpha Bank Data",
description="Используется модель ai-forever/rugpt3small_based_on_gpt2, дообученная на данных ZhenDOS/alpha_bank_data.",
examples=examples
)
# Запуск
if __name__ == "__main__":
demo.launch()