SimrusDenuvo commited on
Commit
af5aa20
·
verified ·
1 Parent(s): 1e10719

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -13
app.py CHANGED
@@ -1,15 +1,33 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
 
4
 
5
- # Загружаем модель и токенизатор
6
  model_name = "ai-forever/rugpt3small_based_on_gpt2"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForCausalLM.from_pretrained(model_name)
9
 
10
- # Функция генерации ответа
 
 
 
 
 
 
 
 
 
11
  def generate_response(prompt):
12
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
 
 
 
 
 
 
 
 
13
  with torch.no_grad():
14
  outputs = model.generate(
15
  **inputs,
@@ -18,21 +36,39 @@ def generate_response(prompt):
18
  temperature=0.7,
19
  top_k=50,
20
  top_p=0.95,
21
- eos_token_id=tokenizer.eos_token_id
 
 
22
  )
 
23
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
24
- # Удаляем промпт из ответа, если повторяется
25
- if response.startswith(prompt):
26
- response = response[len(prompt):].strip()
27
- return response.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- # Интерфейс Gradio
30
  demo = gr.Interface(
31
  fn=generate_response,
32
- inputs=gr.Textbox(lines=4, label="Введите вопрос по клиентским обещаниям банка"),
33
  outputs=gr.Textbox(label="Ответ модели"),
34
- title="Анализ клиентских обещаний — RuGPT-3",
35
- description="Используется модель ai-forever/rugpt3small_based_on_gpt2 на основе данных ZhenDOS/alpha_bank_data."
 
36
  )
37
 
38
  # Запуск
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
3
  import torch
4
+ from datasets import load_dataset
5
 
6
+ # Загружаем модель, токенизатор и датасет
7
  model_name = "ai-forever/rugpt3small_based_on_gpt2"
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
  model = AutoModelForCausalLM.from_pretrained(model_name)
10
 
11
+ # Загружаем банковский датасет для контекста
12
+ bank_dataset = load_dataset("ZhenDOS/alpha_bank_data")
13
+
14
+ # Создаем контекст из датасета (первые несколько примеров)
15
+ context_examples = "\n".join([
16
+ f"Вопрос: {example['question']}\nОтвет: {example['answer']}"
17
+ for example in bank_dataset['train'].select(range(5))
18
+ ])
19
+
20
+ # Функция генерации ответа с учетом банковского контекста
21
  def generate_response(prompt):
22
+ # Добавляем контекст из датасета к промпту
23
+ full_prompt = f"""Контекст по банковским вопросам:
24
+ {context_examples}
25
+
26
+ Вопрос клиента: {prompt}
27
+ Ответ:"""
28
+
29
+ inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=512)
30
+
31
  with torch.no_grad():
32
  outputs = model.generate(
33
  **inputs,
 
36
  temperature=0.7,
37
  top_k=50,
38
  top_p=0.95,
39
+ eos_token_id=tokenizer.eos_token_id,
40
+ no_repeat_ngram_size=3,
41
+ early_stopping=True
42
  )
43
+
44
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
45
+
46
+ # Удаляем промпт из ответа
47
+ if response.startswith(full_prompt):
48
+ response = response[len(full_prompt):].strip()
49
+
50
+ # Постобработка ответа
51
+ response = response.split("\n")[0] # Берем только первую строку ответа
52
+ response = response.replace("Ответ:", "").strip()
53
+
54
+ return response
55
+
56
+ # Интерфейс Gradio с примерами вопросов
57
+ examples = [
58
+ "Как восстановить утерянную карту?",
59
+ "Какие документы нужны для открытия счета?",
60
+ "Как проверить баланс карты?",
61
+ "Как оформить кредитную карту?",
62
+ "Какие комиссии за перевод между счетами?"
63
+ ]
64
 
 
65
  demo = gr.Interface(
66
  fn=generate_response,
67
+ inputs=gr.Textbox(lines=4, label="Введите вопрос по клиентским обращениям в банк"),
68
  outputs=gr.Textbox(label="Ответ модели"),
69
+ title="Анализ клиентских обращений — RuGPT-3 с Alpha Bank Data",
70
+ description="Используется модель ai-forever/rugpt3small_based_on_gpt2, дообученная на данных ZhenDOS/alpha_bank_data.",
71
+ examples=examples
72
  )
73
 
74
  # Запуск