SimrusDenuvo commited on
Commit
7a9745c
·
verified ·
1 Parent(s): 0cc54f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -4
app.py CHANGED
@@ -3,9 +3,10 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
 
5
  # Загружаем модель для русского языка
6
- model_name = "DeepPavlov/rubert-base-cased" # или другая модель для русского языка
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- model = AutoModelForCausalLM.from_pretrained(model_name)
 
9
 
10
  # Проверка доступности GPU (если оно есть)
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -14,11 +15,16 @@ model = model.to(device)
14
  # Генерация ответа с более точным форматом
15
  # Генерация ответа
16
  def generate_response(question):
17
- prompt = f"Вы — сотрудник банка. Клиент задает вопрос, и вы должны дать ясный и точный ответ.\nВопрос клиента: {question}\nОтвет банка:"
 
 
 
18
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
19
 
20
- outputs = model.generate(**inputs, max_new_tokens=100, do_sample=True, top_p=0.95, top_k=50, temperature=0.7)
 
21
 
 
22
  generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
23
  response = generated.replace(prompt, "").strip()
24
 
 
3
  import torch
4
 
5
  # Загружаем модель для русского языка
6
+ model_name = "sberbank-ai/ruT5-base"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
9
+
10
 
11
  # Проверка доступности GPU (если оно есть)
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
15
  # Генерация ответа с более точным форматом
16
  # Генерация ответа
17
  def generate_response(question):
18
+ # Новый промпт
19
+ prompt = f"Представьте, что вы сотрудник банка, и клиент спрашивает вас: '{question}'. Пожалуйста, дайте подробный ответ."
20
+
21
+ # Подготовка входных данных
22
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
23
 
24
+ # Генерация ответа с измененными параметрами
25
+ outputs = model.generate(**inputs, max_new_tokens=50, do_sample=True, top_p=0.95, top_k=50, temperature=1.0)
26
 
27
+ # Декодирование и удаление лишнего текста
28
  generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
29
  response = generated.replace(prompt, "").strip()
30