oiisa commited on
Commit
80d600f
·
verified ·
1 Parent(s): 8b53a7b

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +167 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,169 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from langchain_community.vectorstores import FAISS
3
+ from langchain_community.embeddings import HuggingFaceEmbeddings
4
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
5
+ from langchain_community.document_loaders import PyPDFLoader, Docx2txtLoader, TextLoader
6
+ from transformers import pipeline
7
+ import os
8
 
9
+ # Конфигурация
10
+ DATA_DIR = "data"
11
+ INDEX_DIR = "faiss_index"
12
+ MODEL_NAME = "IlyaGusev/saiga_llama3_8b"
13
+
14
+ # Инициализация модели
15
+ @st.cache_resource
16
+ def load_llm():
17
+ return pipeline(
18
+ "text-generation",
19
+ model=MODEL_NAME,
20
+ device_map="auto",
21
+ model_kwargs={"torch_dtype": "auto"}
22
+ )
23
+
24
+ # Инициализация эмбеддингов
25
+ @st.cache_resource
26
+ def load_embeddings():
27
+ return HuggingFaceEmbeddings(model_name="cointegrated/LaBSE-en-ru")
28
+
29
+ # Загрузка и обработка документов
30
+ def process_documents():
31
+ documents = []
32
+
33
+ for filename in os.listdir(DATA_DIR):
34
+ filepath = os.path.join(DATA_DIR, filename)
35
+ try:
36
+ if filename.endswith(".pdf"):
37
+ loader = PyPDFLoader(filepath)
38
+ elif filename.endswith(".docx"):
39
+ loader = Docx2txtLoader(filepath)
40
+ elif filename.endswith(".txt"):
41
+ loader = TextLoader(filepath)
42
+ else:
43
+ continue
44
+
45
+ documents.extend(loader.load())
46
+ except Exception as e:
47
+ st.error(f"Ошибка загрузки {filename}: {str(e)}")
48
+
49
+ if not documents:
50
+ return None
51
+
52
+ # Разделение текста на чанки
53
+ text_splitter = RecursiveCharacterTextSplitter(
54
+ chunk_size=500,
55
+ chunk_overlap=100
56
+ )
57
+ chunks = text_splitter.split_documents(documents)
58
+
59
+ # Создание векторного хранилища
60
+ embeddings = load_embeddings()
61
+ vectorstore = FAISS.from_documents(chunks, embeddings)
62
+ vectorstore.save_local(INDEX_DIR)
63
+
64
+ return vectorstore
65
+
66
+ # Поиск релевантных документов
67
+ def retrieve_docs(query):
68
+ if os.path.exists(INDEX_DIR):
69
+ embeddings = load_embeddings()
70
+ vectorstore = FAISS.load_local(INDEX_DIR, embeddings)
71
+ else:
72
+ vectorstore = process_documents()
73
+ if vectorstore is None:
74
+ return []
75
+
76
+ results = vectorstore.similarity_search(query, k=3)
77
+ return [doc.page_content for doc in results]
78
+
79
+ # Генерация ответа с RAG
80
+ def generate_with_rag(query, history):
81
+ # Получаем релевантные документы
82
+ context_docs = retrieve_docs(query)
83
+
84
+ if not context_docs:
85
+ context = "Информация не найдена в документах."
86
+ else:
87
+ context = "\n\n".join([f"[Документ {i+1}]: {doc}" for i, doc in enumerate(context_docs)])
88
+
89
+ # Формируем промпт
90
+ system_prompt = """
91
+ Ты ассистент по вопросам магистратуры. Отвечай ТОЛЬКО на основе предоставленной информации.
92
+ Если в контексте нет ответа - скажи "Я не нашел информации по этому вопросу в документах".
93
+ """
94
+
95
+ history_str = "\n".join([
96
+ f"{'Студент' if msg['role']=='user' else 'Ассистент'}: {msg['content']}"
97
+ for msg in history
98
+ ])
99
+
100
+ full_prompt = f"""
101
+ <|system|>{system_prompt}</s>
102
+ <|context|>
103
+ {context}
104
+ </s>
105
+ <|history|>
106
+ {history_str}
107
+ </s>
108
+ <|user|>{query}</s>
109
+ <|assistant|>
110
+ """
111
+
112
+ # Генерируем ответ
113
+ generator = load_llm()
114
+ response = generator(
115
+ full_prompt,
116
+ max_new_tokens=1024,
117
+ temperature=0.3,
118
+ do_sample=True,
119
+ eos_token_id=128001
120
+ )
121
+
122
+ return response[0]['generated_text'].split("<|assistant|>")[-1].strip()
123
+
124
+ # Интерфейс Streamlit
125
+ st.title("🎓 Ассистент по магистратуре с RAG")
126
+ st.write("Загрузите документы в папку 'data' и задавайте вопросы")
127
+
128
+ # Загрузка документов
129
+ if st.sidebar.button("Обновить базу знаний"):
130
+ with st.spinner("Обработка документов..."):
131
+ process_documents()
132
+ st.sidebar.success("База знаний обновлена!")
133
+
134
+ # История диалога
135
+ if "messages" not in st.session_state:
136
+ st.session_state.messages = [
137
+ {"role": "assistant", "content": "Привет! Задайте вопрос о магистратуре, и я отвечу на основе документов."}
138
+ ]
139
+
140
+ # Отображение истории
141
+ for msg in st.session_state.messages:
142
+ st.chat_message(msg["role"]).write(msg["content"])
143
+
144
+ # Обработка ввода
145
+ if prompt := st.chat_input("Ваш вопрос о магистратуре..."):
146
+ # Добавляем вопрос в историю
147
+ st.session_state.messages.append({"role": "user", "content": prompt})
148
+ st.chat_message("user").write(prompt)
149
+
150
+ # Генерация ответа с RAG
151
+ with st.spinner("Ищу информацию..."):
152
+ try:
153
+ response = generate_with_rag(
154
+ prompt,
155
+ st.session_state.messages[-5:] # Последние 5 сообщений как контекст
156
+ )
157
+ except Exception as e:
158
+ response = f"Ошибка: {str(e)}"
159
+
160
+ # Добавляем ответ в историю
161
+ st.session_state.messages.append({"role": "assistant", "content": response})
162
+ st.chat_message("assistant").write(response)
163
+
164
+ # Кнопка очистки истории
165
+ if st.sidebar.button("Очистить историю диалога"):
166
+ st.session_state.messages = [
167
+ {"role": "assistant", "content": "История очищена. Чем могу помочь?"}
168
+ ]
169
+ st.rerun()