SimrusDenuvo commited on
Commit
90cd180
·
verified ·
1 Parent(s): 57878bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -68
app.py CHANGED
@@ -1,76 +1,40 @@
1
  import gradio as gr
2
- from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
 
5
- # 1. Инициализация модели (с квантованием для экономии памяти)
6
- try:
7
- model_name = "ai-forever/rugpt3small_based_on_gpt2"
8
- tokenizer = AutoTokenizer.from_pretrained(model_name)
9
- model = AutoModelForCausalLM.from_pretrained(
10
- model_name,
11
- torch_dtype=torch.float16,
12
- device_map="auto",
13
- load_in_8bit=True
14
- )
15
- generator = pipeline(
16
- "text-generation",
17
- model=model,
18
- tokenizer=tokenizer,
19
- device="cuda" if torch.cuda.is_available() else "cpu"
20
- )
21
- except Exception as e:
22
- raise RuntimeError(f"Ошибка загрузки модели: {str(e)}")
23
 
24
- # 2. Примеры обращений
25
- examples = [
26
- "Мой заказ #12345 не пришел",
27
- "Как оформить возврат товара?",
28
- "Не приходит SMS-подтверждение",
29
- "Ошибка при оплате картой"
30
- ]
31
-
32
- # 3. Функция генерации ответа с правильным форматом сообщений
33
- def generate_response(message, chat_history):
34
- # Формируем промпт с историей диалога
35
- prompt = "Ты оператор поддержки. Вежливо отвечай клиенту на русском.\n\n"
36
- for user_msg, bot_msg in chat_history:
37
- prompt += f"Клиент: {user_msg}\nОператор: {bot_msg}\n"
38
- prompt += f"Клиент: {message}\nОператор:"
39
-
40
- try:
41
- # Генерация ответа
42
- response = generator(
43
- prompt,
44
- max_new_tokens=200,
45
- temperature=0.7,
46
  do_sample=True,
47
- top_p=0.9
 
 
 
48
  )
49
- bot_message = response[0]["generated_text"].split("Оператор:")[-1].strip()
50
-
51
- # Возвращаем обновленную историю диалога в правильном формате
52
- return chat_history + [(message, bot_message)]
53
- except Exception as e:
54
- print(f"Ошибка генерации: {str(e)}")
55
- return chat_history + [(message, f"Извините, произошла ошибка. {str(e)}")]
56
-
57
- # 4. Создание интерфейса с правильным форматом Chatbot
58
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
59
- gr.Markdown("""<h1><center>📞 Поддержка клиентов</center></h1>""")
60
-
61
- with gr.Row():
62
- with gr.Column():
63
- chatbot = gr.Chatbot(height=350)
64
- msg = gr.Textbox(label="Ваш вопрос", placeholder="Опишите проблему...")
65
- btn = gr.Button("Отправить", variant="primary")
66
-
67
- with gr.Column():
68
- gr.Examples(examples, inputs=msg, label="Примеры обращений")
69
- gr.Markdown("**Подсказки:**\n1. Укажите номер заказа\n2. Опишите проблему подробно")
70
 
71
- # Обработчики с правильным форматом сообщений
72
- btn.click(generate_response, [msg, chatbot], [chatbot])
73
- msg.submit(generate_response, [msg, chatbot], [chatbot])
 
 
 
 
 
74
 
75
- if __name__ == "__main__":
76
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
 
5
+ # Загружаем модель и токенизатор
6
+ model_name = "ai-forever/rugpt3small_based_on_gpt2"
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ model = AutoModelForCausalLM.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ # Функция генерации ответа
11
+ def generate_response(prompt):
12
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
13
+ with torch.no_grad():
14
+ outputs = model.generate(
15
+ **inputs,
16
+ max_new_tokens=150,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  do_sample=True,
18
+ temperature=0.7,
19
+ top_k=50,
20
+ top_p=0.95,
21
+ eos_token_id=tokenizer.eos_token_id
22
  )
23
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
24
+ # Удаляем промпт из ответа, если повторяется
25
+ if response.startswith(prompt):
26
+ response = response[len(prompt):].strip()
27
+ return response.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
+ # Интерфейс Gradio
30
+ demo = gr.Interface(
31
+ fn=generate_response,
32
+ inputs=gr.Textbox(lines=4, label="Введите вопрос по клиентским обещаниям банка"),
33
+ outputs=gr.Textbox(label="Ответ модели"),
34
+ title="Анализ клиентских обещаний — RuGPT-3",
35
+ description="Используется модель ai-forever/rugpt3small_based_on_gpt2 на основе данных ZhenDOS/alpha_bank_data."
36
+ )
37
 
38
+ # Запуск
39
+ if name == "main":
40
+ demo.launch()