SimrusDenuvo commited on
Commit
7882653
·
verified ·
1 Parent(s): af5aa20

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -17
app.py CHANGED
@@ -1,30 +1,65 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
3
  import torch
4
  from datasets import load_dataset
5
 
6
- # Загружаем модель, токенизатор и датасет
7
  model_name = "ai-forever/rugpt3small_based_on_gpt2"
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
  model = AutoModelForCausalLM.from_pretrained(model_name)
10
 
11
- # Загружаем банковский датасет для контекста
12
- bank_dataset = load_dataset("ZhenDOS/alpha_bank_data")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- # Создаем контекст из датасета (первые несколько примеров)
15
- context_examples = "\n".join([
16
- f"Вопрос: {example['question']}\nОтвет: {example['answer']}"
17
- for example in bank_dataset['train'].select(range(5))
18
- ])
19
 
20
- # Функция генерации ответа с учетом банковского контекста
21
  def generate_response(prompt):
22
- # Добавляем контекст из датасета к промпту
23
- full_prompt = f"""Контекст по банковским вопросам:
 
24
  {context_examples}
25
 
26
  Вопрос клиента: {prompt}
27
  Ответ:"""
 
 
28
 
29
  inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=512)
30
 
@@ -48,12 +83,13 @@ def generate_response(prompt):
48
  response = response[len(full_prompt):].strip()
49
 
50
  # Постобработка ответа
51
- response = response.split("\n")[0] # Берем только первую строку ответа
52
- response = response.replace("Ответ:", "").strip()
 
53
 
54
  return response
55
 
56
- # Интерфейс Gradio с примерами вопросов
57
  examples = [
58
  "Как восстановить утерянную карту?",
59
  "Какие документы нужны для открытия счета?",
@@ -62,15 +98,15 @@ examples = [
62
  "Какие комиссии за перевод между счетами?"
63
  ]
64
 
 
65
  demo = gr.Interface(
66
  fn=generate_response,
67
  inputs=gr.Textbox(lines=4, label="Введите вопрос по клиентским обращениям в банк"),
68
  outputs=gr.Textbox(label="Ответ модели"),
69
  title="Анализ клиентских обращений — RuGPT-3 с Alpha Bank Data",
70
- description="Используется модель ai-forever/rugpt3small_based_on_gpt2, дообученная на данных ZhenDOS/alpha_bank_data.",
71
  examples=examples
72
  )
73
 
74
- # Запуск
75
  if __name__ == "__main__":
76
  demo.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
  from datasets import load_dataset
5
 
6
+ # Загружаем модель и токенизатор
7
  model_name = "ai-forever/rugpt3small_based_on_gpt2"
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
  model = AutoModelForCausalLM.from_pretrained(model_name)
10
 
11
+ # Загружаем и анализируем банковский датасет
12
+ try:
13
+ bank_dataset = load_dataset("ZhenDOS/alpha_bank_data")
14
+ # Выводим структуру первого элемента для анализа
15
+ first_example = bank_dataset['train'][0]
16
+ print("Структура датасета (первый элемент):", first_example)
17
+
18
+ # Определяем используемые поля на основе анализа датасета
19
+ question_field = 'question' if 'question' in first_example else 'input'
20
+ answer_field = 'answer' if 'answer' in first_example else 'output'
21
+
22
+ except Exception as e:
23
+ print(f"Ошибка при загрузке датасета: {e}")
24
+ bank_dataset = None
25
+ question_field = 'input'
26
+ answer_field = 'output'
27
+
28
+ # Функция для создания контекста из датасета
29
+ def create_context(dataset, num_examples=3):
30
+ if dataset is None:
31
+ return ""
32
+
33
+ try:
34
+ examples = []
35
+ for example in dataset['train'].select(range(num_examples)):
36
+ # Используем определенные поля или альтернативные варианты
37
+ question = example.get(question_field) or example.get('text') or example.get('message')
38
+ answer = example.get(answer_field) or example.get('response') or example.get('content')
39
+
40
+ if question and answer:
41
+ examples.append(f"Вопрос: {question}\nОтвет: {answer}")
42
+
43
+ return "\n\n".join(examples) if examples else ""
44
+ except Exception as e:
45
+ print(f"Ошибка при создании контекста: {e}")
46
+ return ""
47
 
48
+ # Создаем контекст
49
+ context_examples = create_context(bank_dataset)
50
+ print("Созданный контекст:\n", context_examples)
 
 
51
 
52
+ # Функция генерации ответа
53
  def generate_response(prompt):
54
+ # Добавляем контекст, если он есть
55
+ if context_examples:
56
+ full_prompt = f"""Контекст банковских вопросов:
57
  {context_examples}
58
 
59
  Вопрос клиента: {prompt}
60
  Ответ:"""
61
+ else:
62
+ full_prompt = f"Вопрос клиента: {prompt}\nОтвет:"
63
 
64
  inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=512)
65
 
 
83
  response = response[len(full_prompt):].strip()
84
 
85
  # Постобработка ответа
86
+ response = response.split("\n")[0].strip()
87
+ if not response.endswith((".", "!", "?")):
88
+ response += "."
89
 
90
  return response
91
 
92
+ # Примеры вопросов
93
  examples = [
94
  "Как восстановить утерянную карту?",
95
  "Какие документы нужны для открытия счета?",
 
98
  "Какие комиссии за перевод между счетами?"
99
  ]
100
 
101
+ # Интерфейс Gradio
102
  demo = gr.Interface(
103
  fn=generate_response,
104
  inputs=gr.Textbox(lines=4, label="Введите вопрос по клиентским обращениям в банк"),
105
  outputs=gr.Textbox(label="Ответ модели"),
106
  title="Анализ клиентских обращений — RuGPT-3 с Alpha Bank Data",
107
+ description="Используется модель ai-forever/rugpt3small_based_on_gpt2 с учетом данных из датасета Alpha Bank.",
108
  examples=examples
109
  )
110
 
 
111
  if __name__ == "__main__":
112
  demo.launch()