File size: 4,516 Bytes
57645e8
7882653
57878bd
af5aa20
bd826f0
7882653
90cd180
 
 
884ac0c
7882653
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af5aa20
7882653
 
 
af5aa20
7882653
90cd180
7882653
 
 
af5aa20
 
 
 
7882653
 
af5aa20
 
 
90cd180
 
 
 
b87b483
90cd180
 
 
af5aa20
 
 
fdefd2f
af5aa20
90cd180
af5aa20
 
 
 
 
 
7882653
 
 
af5aa20
 
 
7882653
af5aa20
 
 
 
 
 
 
5dee7eb
7882653
90cd180
 
af5aa20
90cd180
af5aa20
7882653
af5aa20
90cd180
5dee7eb
1e10719
90cd180
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from datasets import load_dataset

# Загружаем модель и токенизатор
model_name = "ai-forever/rugpt3small_based_on_gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# Загружаем и анализируем банковский датасет
try:
    bank_dataset = load_dataset("ZhenDOS/alpha_bank_data")
    # Выводим структуру первого элемента для анализа
    first_example = bank_dataset['train'][0]
    print("Структура датасета (первый элемент):", first_example)
    
    # Определяем используемые поля на основе анализа датасета
    question_field = 'question' if 'question' in first_example else 'input'
    answer_field = 'answer' if 'answer' in first_example else 'output'
    
except Exception as e:
    print(f"Ошибка при загрузке датасета: {e}")
    bank_dataset = None
    question_field = 'input'
    answer_field = 'output'

# Функция для создания контекста из датасета
def create_context(dataset, num_examples=3):
    if dataset is None:
        return ""
    
    try:
        examples = []
        for example in dataset['train'].select(range(num_examples)):
            # Используем определенные поля или альтернативные варианты
            question = example.get(question_field) or example.get('text') or example.get('message')
            answer = example.get(answer_field) or example.get('response') or example.get('content')
            
            if question and answer:
                examples.append(f"Вопрос: {question}\nОтвет: {answer}")
        
        return "\n\n".join(examples) if examples else ""
    except Exception as e:
        print(f"Ошибка при создании контекста: {e}")
        return ""

# Создаем контекст
context_examples = create_context(bank_dataset)
print("Созданный контекст:\n", context_examples)

# Функция генерации ответа
def generate_response(prompt):
    # Добавляем контекст, если он есть
    if context_examples:
        full_prompt = f"""Контекст банковских вопросов:
{context_examples}

Вопрос клиента: {prompt}
Ответ:"""
    else:
        full_prompt = f"Вопрос клиента: {prompt}\nОтвет:"
    
    inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=512)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=150,
            do_sample=True,
            temperature=0.7,
            top_k=50,
            top_p=0.95,
            eos_token_id=tokenizer.eos_token_id,
            no_repeat_ngram_size=3,
            early_stopping=True
        )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Удаляем промпт из ответа
    if response.startswith(full_prompt):
        response = response[len(full_prompt):].strip()
    
    # Постобработка ответа
    response = response.split("\n")[0].strip()
    if not response.endswith((".", "!", "?")):
        response += "."
    
    return response

# Примеры вопросов
examples = [
    "Как восстановить утерянную карту?",
    "Какие документы нужны для открытия счета?",
    "Как проверить баланс карты?",
    "Как оформить кредитную карту?",
    "Какие комиссии за перевод между счетами?"
]

# Интерфейс Gradio
demo = gr.Interface(
    fn=generate_response,
    inputs=gr.Textbox(lines=4, label="Введите вопрос по клиентским обращениям в банк"),
    outputs=gr.Textbox(label="Ответ модели"),
    title="Анализ клиентских обращений — RuGPT-3 с Alpha Bank Data",
    description="Используется модель ai-forever/rugpt3small_based_on_gpt2 с учетом данных из датасета Alpha Bank.",
    examples=examples
)

if __name__ == "__main__":
    demo.launch()