Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,30 +1,65 @@
|
|
1 |
import gradio as gr
|
2 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM
|
3 |
import torch
|
4 |
from datasets import load_dataset
|
5 |
|
6 |
-
# Загружаем
|
7 |
model_name = "ai-forever/rugpt3small_based_on_gpt2"
|
8 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
9 |
model = AutoModelForCausalLM.from_pretrained(model_name)
|
10 |
|
11 |
-
# Загружаем банковский датасет
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
-
# Создаем контекст
|
15 |
-
context_examples =
|
16 |
-
|
17 |
-
for example in bank_dataset['train'].select(range(5))
|
18 |
-
])
|
19 |
|
20 |
-
# Функция генерации ответа
|
21 |
def generate_response(prompt):
|
22 |
-
# Добавляем
|
23 |
-
|
|
|
24 |
{context_examples}
|
25 |
|
26 |
Вопрос клиента: {prompt}
|
27 |
Ответ:"""
|
|
|
|
|
28 |
|
29 |
inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=512)
|
30 |
|
@@ -48,12 +83,13 @@ def generate_response(prompt):
|
|
48 |
response = response[len(full_prompt):].strip()
|
49 |
|
50 |
# Постобработка ответа
|
51 |
-
response = response.split("\n")[0]
|
52 |
-
|
|
|
53 |
|
54 |
return response
|
55 |
|
56 |
-
#
|
57 |
examples = [
|
58 |
"Как восстановить утерянную карту?",
|
59 |
"Какие документы нужны для открытия счета?",
|
@@ -62,15 +98,15 @@ examples = [
|
|
62 |
"Какие комиссии за перевод между счетами?"
|
63 |
]
|
64 |
|
|
|
65 |
demo = gr.Interface(
|
66 |
fn=generate_response,
|
67 |
inputs=gr.Textbox(lines=4, label="Введите вопрос по клиентским обращениям в банк"),
|
68 |
outputs=gr.Textbox(label="Ответ модели"),
|
69 |
title="Анализ клиентских обращений — RuGPT-3 с Alpha Bank Data",
|
70 |
-
description="Используется модель ai-forever/rugpt3small_based_on_gpt2
|
71 |
examples=examples
|
72 |
)
|
73 |
|
74 |
-
# Запуск
|
75 |
if __name__ == "__main__":
|
76 |
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
3 |
import torch
|
4 |
from datasets import load_dataset
|
5 |
|
6 |
+
# Загружаем модель и токенизатор
|
7 |
model_name = "ai-forever/rugpt3small_based_on_gpt2"
|
8 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
9 |
model = AutoModelForCausalLM.from_pretrained(model_name)
|
10 |
|
11 |
+
# Загружаем и анализируем банковский датасет
|
12 |
+
try:
|
13 |
+
bank_dataset = load_dataset("ZhenDOS/alpha_bank_data")
|
14 |
+
# Выводим структуру первого элемента для анализа
|
15 |
+
first_example = bank_dataset['train'][0]
|
16 |
+
print("Структура датасета (первый элемент):", first_example)
|
17 |
+
|
18 |
+
# Определяем используемые поля на основе анализа датасета
|
19 |
+
question_field = 'question' if 'question' in first_example else 'input'
|
20 |
+
answer_field = 'answer' if 'answer' in first_example else 'output'
|
21 |
+
|
22 |
+
except Exception as e:
|
23 |
+
print(f"Ошибка при загрузке датасета: {e}")
|
24 |
+
bank_dataset = None
|
25 |
+
question_field = 'input'
|
26 |
+
answer_field = 'output'
|
27 |
+
|
28 |
+
# Функция для создания контекста из датасета
|
29 |
+
def create_context(dataset, num_examples=3):
|
30 |
+
if dataset is None:
|
31 |
+
return ""
|
32 |
+
|
33 |
+
try:
|
34 |
+
examples = []
|
35 |
+
for example in dataset['train'].select(range(num_examples)):
|
36 |
+
# Используем определенные поля или альтернативные варианты
|
37 |
+
question = example.get(question_field) or example.get('text') or example.get('message')
|
38 |
+
answer = example.get(answer_field) or example.get('response') or example.get('content')
|
39 |
+
|
40 |
+
if question and answer:
|
41 |
+
examples.append(f"Вопрос: {question}\nОтвет: {answer}")
|
42 |
+
|
43 |
+
return "\n\n".join(examples) if examples else ""
|
44 |
+
except Exception as e:
|
45 |
+
print(f"Ошибка при создании контекста: {e}")
|
46 |
+
return ""
|
47 |
|
48 |
+
# Создаем контекст
|
49 |
+
context_examples = create_context(bank_dataset)
|
50 |
+
print("Созданный контекст:\n", context_examples)
|
|
|
|
|
51 |
|
52 |
+
# Функция генерации ответа
|
53 |
def generate_response(prompt):
|
54 |
+
# Добавляем контекст, если он есть
|
55 |
+
if context_examples:
|
56 |
+
full_prompt = f"""Контекст банковских вопросов:
|
57 |
{context_examples}
|
58 |
|
59 |
Вопрос клиента: {prompt}
|
60 |
Ответ:"""
|
61 |
+
else:
|
62 |
+
full_prompt = f"Вопрос клиента: {prompt}\nОтвет:"
|
63 |
|
64 |
inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=512)
|
65 |
|
|
|
83 |
response = response[len(full_prompt):].strip()
|
84 |
|
85 |
# Постобработка ответа
|
86 |
+
response = response.split("\n")[0].strip()
|
87 |
+
if not response.endswith((".", "!", "?")):
|
88 |
+
response += "."
|
89 |
|
90 |
return response
|
91 |
|
92 |
+
# Примеры вопросов
|
93 |
examples = [
|
94 |
"Как восстановить утерянную карту?",
|
95 |
"Какие документы нужны для открытия счета?",
|
|
|
98 |
"Какие комиссии за перевод между счетами?"
|
99 |
]
|
100 |
|
101 |
+
# Интерфейс Gradio
|
102 |
demo = gr.Interface(
|
103 |
fn=generate_response,
|
104 |
inputs=gr.Textbox(lines=4, label="Введите вопрос по клиентским обращениям в банк"),
|
105 |
outputs=gr.Textbox(label="Ответ модели"),
|
106 |
title="Анализ клиентских обращений — RuGPT-3 с Alpha Bank Data",
|
107 |
+
description="Используется модель ai-forever/rugpt3small_based_on_gpt2 с учетом данных из датасета Alpha Bank.",
|
108 |
examples=examples
|
109 |
)
|
110 |
|
|
|
111 |
if __name__ == "__main__":
|
112 |
demo.launch()
|