File size: 3,097 Bytes
57645e8
af5aa20
57878bd
af5aa20
bd826f0
af5aa20
90cd180
 
 
884ac0c
af5aa20
 
 
 
 
 
 
 
 
 
90cd180
af5aa20
 
 
 
 
 
 
 
 
90cd180
 
 
 
b87b483
90cd180
 
 
af5aa20
 
 
fdefd2f
af5aa20
90cd180
af5aa20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5dee7eb
90cd180
 
af5aa20
90cd180
af5aa20
 
 
90cd180
5dee7eb
90cd180
1e10719
90cd180
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
from datasets import load_dataset

# Загружаем модель, токенизатор и датасет
model_name = "ai-forever/rugpt3small_based_on_gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# Загружаем банковский датасет для контекста
bank_dataset = load_dataset("ZhenDOS/alpha_bank_data")

# Создаем контекст из датасета (первые несколько примеров)
context_examples = "\n".join([
    f"Вопрос: {example['question']}\nОтвет: {example['answer']}" 
    for example in bank_dataset['train'].select(range(5))
])

# Функция генерации ответа с учетом банковского контекста
def generate_response(prompt):
    # Добавляем контекст из датасета к промпту
    full_prompt = f"""Контекст по банковским вопросам:
{context_examples}

Вопрос клиента: {prompt}
Ответ:"""
    
    inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=512)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=150,
            do_sample=True,
            temperature=0.7,
            top_k=50,
            top_p=0.95,
            eos_token_id=tokenizer.eos_token_id,
            no_repeat_ngram_size=3,
            early_stopping=True
        )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Удаляем промпт из ответа
    if response.startswith(full_prompt):
        response = response[len(full_prompt):].strip()
    
    # Постобработка ответа
    response = response.split("\n")[0]  # Берем только первую строку ответа
    response = response.replace("Ответ:", "").strip()
    
    return response

# Интерфейс Gradio с примерами вопросов
examples = [
    "Как восстановить утерянную карту?",
    "Какие документы нужны для открытия счета?",
    "Как проверить баланс карты?",
    "Как оформить кредитную карту?",
    "Какие комиссии за перевод между счетами?"
]

demo = gr.Interface(
    fn=generate_response,
    inputs=gr.Textbox(lines=4, label="Введите вопрос по клиентским обращениям в банк"),
    outputs=gr.Textbox(label="Ответ модели"),
    title="Анализ клиентских обращений — RuGPT-3 с Alpha Bank Data",
    description="Используется модель ai-forever/rugpt3small_based_on_gpt2, дообученная на данных ZhenDOS/alpha_bank_data.",
    examples=examples
)

# Запуск
if __name__ == "__main__":
    demo.launch()