Upload 6 files
Browse files- .gitattributes +2 -0
- all_books.csv +3 -0
- app.py +228 -0
- model/book_data.csv +3 -0
- model/embeddings.npy +3 -0
- model/faiss_index.bin +3 -0
- preproc.py +185 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
all_books.csv filter=lfs diff=lfs merge=lfs -text
|
37 |
+
model/book_data.csv filter=lfs diff=lfs merge=lfs -text
|
all_books.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:822102db8bdd23be463436a64a29e453a67c75eb37e43b57a91955187588dd04
|
3 |
+
size 10940812
|
app.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# app_modified.py - Оптимизированная версия Streamlit-приложения
|
2 |
+
import streamlit as st
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from transformers import AutoTokenizer, AutoModel
|
7 |
+
import faiss
|
8 |
+
import re
|
9 |
+
import nltk
|
10 |
+
from nltk.corpus import stopwords
|
11 |
+
from nltk.tokenize import word_tokenize
|
12 |
+
import os
|
13 |
+
|
14 |
+
# Загрузим стоп-слова для русского языка
|
15 |
+
try:
|
16 |
+
nltk.data.find('corpora/stopwords')
|
17 |
+
except LookupError:
|
18 |
+
nltk.download('stopwords')
|
19 |
+
|
20 |
+
try:
|
21 |
+
nltk.data.find('tokenizers/punkt')
|
22 |
+
except LookupError:
|
23 |
+
nltk.download('punkt')
|
24 |
+
|
25 |
+
stop_words = set(stopwords.words('russian'))
|
26 |
+
|
27 |
+
# Класс для получения эмбеддингов с помощью RuBERT
|
28 |
+
class RuBERTEmbedder:
|
29 |
+
def __init__(self, model_name="DeepPavlov/rubert-base-cased"):
|
30 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
31 |
+
self.model = AutoModel.from_pretrained(model_name)
|
32 |
+
self.model.eval()
|
33 |
+
# Используем CPU для запуска в Spaces
|
34 |
+
self.device = "cpu"
|
35 |
+
self.model.to(self.device)
|
36 |
+
|
37 |
+
def mean_pooling(self, model_output, attention_mask):
|
38 |
+
"""Среднее значение по токенам для получения эмбеддинга предложения"""
|
39 |
+
token_embeddings = model_output[0]
|
40 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
41 |
+
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
42 |
+
|
43 |
+
def get_embedding(self, text):
|
44 |
+
"""Получение векторного представления текста"""
|
45 |
+
encoded_input = self.tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors='pt')
|
46 |
+
encoded_input = {k: v.to(self.device) for k, v in encoded_input.items()}
|
47 |
+
|
48 |
+
with torch.no_grad():
|
49 |
+
model_output = self.model(**encoded_input)
|
50 |
+
|
51 |
+
embeddings = self.mean_pooling(model_output, encoded_input['attention_mask'])
|
52 |
+
return embeddings.cpu().numpy()[0]
|
53 |
+
|
54 |
+
def preprocess_text(text):
|
55 |
+
"""Предобработка текста: удаление специальных символов, приведение к нижнему регистру, удаление стоп-слов"""
|
56 |
+
if isinstance(text, str):
|
57 |
+
# Приведение к нижнему регистру
|
58 |
+
text = text.lower()
|
59 |
+
# Удаление специальных символов
|
60 |
+
text = re.sub(r'[^\w\s]', '', text)
|
61 |
+
# Токенизация
|
62 |
+
tokens = word_tokenize(text, language='russian')
|
63 |
+
# Удаление стоп-слов
|
64 |
+
filtered_tokens = [word for word in tokens if word not in stop_words]
|
65 |
+
# Объединение обратно в строку
|
66 |
+
return ' '.join(filtered_tokens)
|
67 |
+
return ''
|
68 |
+
|
69 |
+
# Класс поисковой системы
|
70 |
+
class BookSearchEngine:
|
71 |
+
def __init__(self, embedder=None):
|
72 |
+
self.embedder = embedder
|
73 |
+
self.faiss_index = None
|
74 |
+
self.book_data = None
|
75 |
+
self.embeddings = None
|
76 |
+
|
77 |
+
def load_model(self, model_dir='model'):
|
78 |
+
"""Загрузка модели из сохраненных файлов"""
|
79 |
+
try:
|
80 |
+
# Загружаем данные книг
|
81 |
+
self.book_data = pd.read_csv(f"{model_dir}/book_data.csv")
|
82 |
+
|
83 |
+
# Загружаем эмбеддинги
|
84 |
+
self.embeddings = np.load(f"{model_dir}/embeddings.npy")
|
85 |
+
|
86 |
+
# Загружаем индекс FAISS
|
87 |
+
self.faiss_index = faiss.read_index(f"{model_dir}/faiss_index.bin")
|
88 |
+
|
89 |
+
return True
|
90 |
+
except Exception as e:
|
91 |
+
st.error(f"Ошибка при загрузке модели: {e}")
|
92 |
+
return False
|
93 |
+
|
94 |
+
def search(self, query, k=5):
|
95 |
+
"""Поиск книг по пользовательскому запросу"""
|
96 |
+
if self.embedder is None or self.faiss_index is None:
|
97 |
+
st.error("Поисковая система не инициализирована")
|
98 |
+
return []
|
99 |
+
|
100 |
+
# Предобработка запроса
|
101 |
+
processed_query = preprocess_text(query)
|
102 |
+
|
103 |
+
# Получение эмбеддинга запроса
|
104 |
+
query_embedding = self.embedder.get_embedding(processed_query)
|
105 |
+
query_embedding = query_embedding.reshape(1, -1)
|
106 |
+
|
107 |
+
# Нормализуем вектор запроса
|
108 |
+
faiss.normalize_L2(query_embedding)
|
109 |
+
|
110 |
+
# Поиск ближайших соседей
|
111 |
+
scores, indices = self.faiss_index.search(query_embedding, k)
|
112 |
+
|
113 |
+
# Формирование результатов
|
114 |
+
results = []
|
115 |
+
for i, (score, idx) in enumerate(zip(scores[0], indices[0])):
|
116 |
+
if idx < len(self.book_data):
|
117 |
+
book = self.book_data.iloc[idx]
|
118 |
+
results.append({
|
119 |
+
'rank': i + 1,
|
120 |
+
'score': float(score),
|
121 |
+
'title': book.get('title', 'Нет названия'),
|
122 |
+
'author': book.get('author', 'Нет автора'),
|
123 |
+
'annotation': book.get('annotation', 'Нет аннотации'),
|
124 |
+
'page_url': book.get('page_url', '#'),
|
125 |
+
'book_image_url': book.get('book_image_url', book.get('image_url', ''))
|
126 |
+
})
|
127 |
+
|
128 |
+
return results
|
129 |
+
|
130 |
+
# Инициализация поисковой системы
|
131 |
+
@st.cache_resource
|
132 |
+
def initialize_search_engine():
|
133 |
+
# Инициализация модели RuBERT
|
134 |
+
embedder = RuBERTEmbedder()
|
135 |
+
|
136 |
+
# Создание поисковой системы
|
137 |
+
search_engine = BookSearchEngine(embedder)
|
138 |
+
|
139 |
+
# Загрузка подготовленной модели
|
140 |
+
if search_engine.load_model():
|
141 |
+
st.success(f"Поисковая система загружена. Всего книг: {len(search_engine.book_data)}")
|
142 |
+
else:
|
143 |
+
st.error("Не удалось загрузить модель. Пожалуйста, убедитесь, что директория 'model' содержит необходимые файлы.")
|
144 |
+
st.info("Перед запуском приложения нужно выполнить предварительную обработку данных с помощью скрипта preprocess.py")
|
145 |
+
|
146 |
+
return search_engine
|
147 |
+
|
148 |
+
# Основной код приложения
|
149 |
+
def main():
|
150 |
+
st.set_page_config(
|
151 |
+
page_title="Умный поиск книг",
|
152 |
+
page_icon="📚",
|
153 |
+
layout="wide"
|
154 |
+
)
|
155 |
+
|
156 |
+
st.title("📚 Умный поиск книг")
|
157 |
+
st.subheader("Найдите книги, соответствующие вашему запросу")
|
158 |
+
|
159 |
+
# Инициализация поисковой системы
|
160 |
+
search_engine = initialize_search_engine()
|
161 |
+
|
162 |
+
# Основной интерфейс поиска
|
163 |
+
st.write("### Введите описание книги, которую вы ищете")
|
164 |
+
|
165 |
+
col1, col2 = st.columns([3, 1])
|
166 |
+
|
167 |
+
with col1:
|
168 |
+
query = st.text_area("Описание книги:", height=150)
|
169 |
+
|
170 |
+
with col2:
|
171 |
+
num_results = st.slider("Количество результатов:", min_value=1, max_value=20, value=5)
|
172 |
+
search_button = st.button("🔍 Искать", type="primary")
|
173 |
+
|
174 |
+
# Если нажата кнопка поиска
|
175 |
+
if search_button:
|
176 |
+
if query:
|
177 |
+
with st.spinner("Ищем подходящие книги..."):
|
178 |
+
results = search_engine.search(query, k=num_results)
|
179 |
+
|
180 |
+
if results:
|
181 |
+
st.write(f"### Найдено {len(results)} подходящих книг:")
|
182 |
+
|
183 |
+
for i, result in enumerate(results):
|
184 |
+
col_image, col_content, col_score = st.columns([1, 2, 1])
|
185 |
+
|
186 |
+
with col_image:
|
187 |
+
if 'book_image_url' in result and result['book_image_url']:
|
188 |
+
try:
|
189 |
+
st.image(result['book_image_url'], width=150)
|
190 |
+
except Exception:
|
191 |
+
st.write("Изображение недоступно")
|
192 |
+
|
193 |
+
with col_content:
|
194 |
+
if 'page_url' in result and result['page_url']:
|
195 |
+
st.markdown(f"#### [{i+1}. {result['title']}]({result['page_url']})")
|
196 |
+
else:
|
197 |
+
st.markdown(f"#### {i+1}. {result['title']}")
|
198 |
+
st.write(f"**Автор:** {result['author']}")
|
199 |
+
with st.expander("Показать аннотацию"):
|
200 |
+
st.write(result['annotation'])
|
201 |
+
|
202 |
+
with col_score:
|
203 |
+
st.metric(
|
204 |
+
"Релевантность",
|
205 |
+
f"{result['score']:.2f}",
|
206 |
+
delta=None
|
207 |
+
)
|
208 |
+
|
209 |
+
st.divider()
|
210 |
+
else:
|
211 |
+
st.info("К сожалению, подходящих книг не найдено.")
|
212 |
+
else:
|
213 |
+
st.warning("Пожалуйста, введите описание книги для поиска.")
|
214 |
+
|
215 |
+
st.markdown("---")
|
216 |
+
st.markdown("### О проекте")
|
217 |
+
st.write("""
|
218 |
+
Этот сервис позволяет искать книги по их описанию с использованием семантиче��кой близости.
|
219 |
+
Система анализирует смысл вашего запроса и находит книги с наиболее подходящими аннотациями.
|
220 |
+
|
221 |
+
**Технологии:**
|
222 |
+
- RuBERT для создания векторных представлений текста
|
223 |
+
- FAISS для быстрого поиска ближайших соседей
|
224 |
+
- Streamlit для веб-интерфейса
|
225 |
+
""")
|
226 |
+
|
227 |
+
if __name__ == "__main__":
|
228 |
+
main()
|
model/book_data.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9b0b05dfb2ae1360e81a8d3505de2f0494b758aff5b2771732499c8c841a3989
|
3 |
+
size 18740967
|
model/embeddings.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:67120e2811a0d7624613267bb99e6f663bfcd18b6121caee020966bf2da35657
|
3 |
+
size 17126528
|
model/faiss_index.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6527eab896450d47e13be2af9bc34d78eb52470175221c984644c218bffe6300
|
3 |
+
size 17126445
|
preproc.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
preprocess.py - Скрипт для предварительной обработки датасета книг
|
3 |
+
и создания векторных представлений для поисковой системы
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
import pandas as pd
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from transformers import AutoTokenizer, AutoModel
|
11 |
+
import faiss
|
12 |
+
import re
|
13 |
+
import nltk
|
14 |
+
from nltk.corpus import stopwords
|
15 |
+
from nltk.tokenize import word_tokenize
|
16 |
+
import argparse
|
17 |
+
from tqdm import tqdm
|
18 |
+
import nltk
|
19 |
+
nltk.download('punkt')
|
20 |
+
nltk.download('stopwords')
|
21 |
+
nltk.download('punkt_tab')
|
22 |
+
|
23 |
+
# Загрузим стоп-слова для русского языка
|
24 |
+
try:
|
25 |
+
nltk.data.find('corpora/stopwords')
|
26 |
+
except LookupError:
|
27 |
+
nltk.download('stopwords')
|
28 |
+
|
29 |
+
try:
|
30 |
+
nltk.data.find('tokenizers/punkt')
|
31 |
+
except LookupError:
|
32 |
+
nltk.download('punkt')
|
33 |
+
|
34 |
+
stop_words = set(stopwords.words('russian'))
|
35 |
+
|
36 |
+
# Класс для получения эмбеддингов с помощью RuBERT
|
37 |
+
class RuBERTEmbedder:
|
38 |
+
def __init__(self, model_name="DeepPavlov/rubert-base-cased"):
|
39 |
+
print(f"Загрузка модели {model_name}...")
|
40 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
41 |
+
self.model = AutoModel.from_pretrained(model_name)
|
42 |
+
self.model.eval()
|
43 |
+
# Используем GPU если доступен, иначе CPU
|
44 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
45 |
+
print(f"Используется устройство: {self.device}")
|
46 |
+
self.model.to(self.device)
|
47 |
+
|
48 |
+
def mean_pooling(self, model_output, attention_mask):
|
49 |
+
"""Среднее значение по токенам для получения эмбеддинга предложения"""
|
50 |
+
token_embeddings = model_output[0]
|
51 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
52 |
+
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
53 |
+
|
54 |
+
def get_embedding(self, text):
|
55 |
+
"""Получение векторного представления текста"""
|
56 |
+
encoded_input = self.tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors='pt')
|
57 |
+
encoded_input = {k: v.to(self.device) for k, v in encoded_input.items()}
|
58 |
+
|
59 |
+
with torch.no_grad():
|
60 |
+
model_output = self.model(**encoded_input)
|
61 |
+
|
62 |
+
embeddings = self.mean_pooling(model_output, encoded_input['attention_mask'])
|
63 |
+
return embeddings.cpu().numpy()[0]
|
64 |
+
|
65 |
+
def get_embeddings_batch(self, texts, batch_size=8):
|
66 |
+
"""Получение векторных представлений для списка текстов с использованием батчей"""
|
67 |
+
all_embeddings = []
|
68 |
+
|
69 |
+
for i in tqdm(range(0, len(texts), batch_size), desc="Создание эмбеддингов"):
|
70 |
+
batch_texts = texts[i:i+batch_size]
|
71 |
+
# Обработка пустых строк
|
72 |
+
batch_texts = [text if text and isinstance(text, str) else " " for text in batch_texts]
|
73 |
+
|
74 |
+
encoded_input = self.tokenizer(batch_texts, padding=True, truncation=True, max_length=512, return_tensors='pt')
|
75 |
+
encoded_input = {k: v.to(self.device) for k, v in encoded_input.items()}
|
76 |
+
|
77 |
+
with torch.no_grad():
|
78 |
+
model_output = self.model(**encoded_input)
|
79 |
+
|
80 |
+
embeddings = self.mean_pooling(model_output, encoded_input['attention_mask'])
|
81 |
+
all_embeddings.append(embeddings.cpu().numpy())
|
82 |
+
|
83 |
+
return np.vstack(all_embeddings)
|
84 |
+
|
85 |
+
def preprocess_text(text):
|
86 |
+
"""Предобработка текста: удаление специальных символов, приведение к нижнему регистру, удаление стоп-слов"""
|
87 |
+
if isinstance(text, str):
|
88 |
+
# Приведение к нижнему регистру
|
89 |
+
text = text.lower()
|
90 |
+
# Удаление специальных символов
|
91 |
+
text = re.sub(r'[^\w\s]', '', text)
|
92 |
+
# Токенизация
|
93 |
+
tokens = word_tokenize(text, language='russian')
|
94 |
+
# Удаление стоп-слов
|
95 |
+
filtered_tokens = [word for word in tokens if word not in stop_words]
|
96 |
+
# Объединение обратно в строку
|
97 |
+
return ' '.join(filtered_tokens)
|
98 |
+
return ''
|
99 |
+
|
100 |
+
def prepare_data(input_file, output_dir="model", annotation_column="annotation", title_column="title",
|
101 |
+
author_column="author", image_url_column="image_url", page_url_column="page_url", sample_size=None):
|
102 |
+
"""Подготовка данных для поисковой системы"""
|
103 |
+
# Создание выходной директории, если она не существует
|
104 |
+
os.makedirs(output_dir, exist_ok=True)
|
105 |
+
|
106 |
+
print(f"Загрузка данных из {input_file}...")
|
107 |
+
df = pd.read_csv(input_file)
|
108 |
+
|
109 |
+
# Проверка наличия обязательной колонки с аннотацией
|
110 |
+
if annotation_column not in df.columns:
|
111 |
+
raise ValueError(f"В файле отсутствует колонка с аннотациями: {annotation_column}")
|
112 |
+
|
113 |
+
# Очистка от записей без аннотации
|
114 |
+
initial_size = len(df)
|
115 |
+
df = df.dropna(subset=[annotation_column])
|
116 |
+
print(f"Удалено записей без аннотаций: {initial_size - len(df)}")
|
117 |
+
|
118 |
+
# Если указан размер выборки, отбираем случайные записи
|
119 |
+
if sample_size and sample_size < len(df):
|
120 |
+
df = df.sample(sample_size, random_state=42)
|
121 |
+
print(f"Используется случайная выборка из {sample_size} записей")
|
122 |
+
|
123 |
+
# Предобработка аннотаций
|
124 |
+
print("Предобработка аннотаций...")
|
125 |
+
df['processed_annotation'] = df[annotation_column].apply(preprocess_text)
|
126 |
+
|
127 |
+
# Загрузка модели для векторизации
|
128 |
+
print("Инициализация модели для векторизации...")
|
129 |
+
embedder = RuBERTEmbedder()
|
130 |
+
|
131 |
+
# Векторизация аннотаций
|
132 |
+
print("Векторизация аннотаций...")
|
133 |
+
annotations = df['processed_annotation'].tolist()
|
134 |
+
embeddings = embedder.get_embeddings_batch(annotations)
|
135 |
+
|
136 |
+
# Создание индекса FAISS
|
137 |
+
print("Создание индекса FAISS...")
|
138 |
+
dimension = embeddings.shape[1]
|
139 |
+
index = faiss.IndexFlatIP(dimension)
|
140 |
+
faiss.normalize_L2(embeddings)
|
141 |
+
index.add(embeddings)
|
142 |
+
|
143 |
+
# Сохранение обработанных данных и индекса
|
144 |
+
print(f"Сохранение данных в {output_dir}...")
|
145 |
+
|
146 |
+
# Сохраняем только нужные колонки
|
147 |
+
columns_to_save = [col for col in [annotation_column, title_column, author_column, image_url_column, page_url_column, 'processed_annotation'] if col in df.columns]
|
148 |
+
df[columns_to_save].to_csv(f"{output_dir}/book_data.csv", index=False)
|
149 |
+
|
150 |
+
# Сохраняем эмбеддинги
|
151 |
+
np.save(f"{output_dir}/embeddings.npy", embeddings)
|
152 |
+
|
153 |
+
# Сохраняем индекс FAISS
|
154 |
+
faiss.write_index(index, f"{output_dir}/faiss_index.bin")
|
155 |
+
|
156 |
+
print(f"Данные успешно обработаны и сохранены в {output_dir}")
|
157 |
+
print(f"Всего книг: {len(df)}")
|
158 |
+
return df
|
159 |
+
|
160 |
+
def main():
|
161 |
+
parser = argparse.ArgumentParser(description='Предобработка датасета книг для поисковой системы')
|
162 |
+
parser.add_argument('--input', type=str, required=True, help='Путь к CSV файлу с данными книг')
|
163 |
+
parser.add_argument('--output', type=str, default='model', help='Директория для сохранения модели и данных')
|
164 |
+
parser.add_argument('--annotation', type=str, default='annotation', help='Имя колонки с аннотациями')
|
165 |
+
parser.add_argument('--title', type=str, default='title', help='Имя колонки с названиями книг')
|
166 |
+
parser.add_argument('--author', type=str, default='author', help='Имя колонки с авторами')
|
167 |
+
parser.add_argument('--image_url', type=str, default='image_url', help='Имя колонки с URL изображений')
|
168 |
+
parser.add_argument('--page_url', type=str, default='page_url', help='Имя колонки с URL страниц')
|
169 |
+
parser.add_argument('--sample', type=int, default=None, help='Размер выборки (если нужно ограничить)')
|
170 |
+
|
171 |
+
args = parser.parse_args()
|
172 |
+
|
173 |
+
prepare_data(
|
174 |
+
input_file=args.input,
|
175 |
+
output_dir=args.output,
|
176 |
+
annotation_column=args.annotation,
|
177 |
+
title_column=args.title,
|
178 |
+
author_column=args.author,
|
179 |
+
image_url_column=args.image_url,
|
180 |
+
page_url_column=args.page_url,
|
181 |
+
sample_size=args.sample
|
182 |
+
)
|
183 |
+
|
184 |
+
if __name__ == "__main__":
|
185 |
+
main()
|