File size: 5,531 Bytes
d30ec9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import os
from tqdm import tqdm
import pathlib

from langchain_community.document_loaders import TextLoader
from langchain.docstore.document import Document
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS

os.environ["RAY_memory_monitor_refresh_ms"] = "0"
os.environ["RAY_DEDUP_LOGS"] = "0"
import ray

from common import DATASET_DIR, EMBEDDING_MODEL_NAME, MODEL_KWARGS, VECTORSTORE_FILENAME

# Each document is parsed on the same CPU, to decrease paging and data copies, and up to the the number of vCPUs.
CONCURRENCY = 32


# @ray.remote(num_cpus=1) # Outside a container, num_cpus=1 might speed things dramatically.
@ray.remote
def parse_doc(document_path: str) -> Document:
    print("Loading", document_path)
    loader = TextLoader(document_path)
    langchain_dataset_documents = loader.load()

    # Update the metadata with the proper metadata JSON file, parsed from Arxiv.com
    return langchain_dataset_documents


def add_documents_to_vector_store(
    vector_store, new_documents, text_splitter, embeddings
):
    split_docs = text_splitter.split_documents(new_documents)
    # print("Embedding vectors...")
    store = FAISS.from_documents(split_docs, embeddings)
    if vector_store is None:
        vector_store = store
    else:
        print("Updating vector store", store)
        vector_store.merge_from(store)
    return vector_store


def ingest_dataset_to_vectore_store(
    vectorstore_filename: str, dataset_directory: os.PathLike
):
    ray.init()
    vector_store = None
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=160,  # TODO: Finetune
        chunk_overlap=40,  # TODO: Finetune
        length_function=len,
    )

    dataset_documents = []
    dataset_dir_path = pathlib.Path(dataset_directory)
    dataset_dir_path.mkdir(exist_ok=True)

    for _dirname in os.listdir(str(dataset_dir_path)):
        if _dirname.startswith("."):
            continue
        catagory_path = dataset_dir_path / pathlib.Path(_dirname)
        for filename in os.listdir(str(dataset_dir_path / catagory_path)):
            dataset_path = dataset_dir_path / catagory_path / pathlib.Path(filename)
            dataset_documents.append(str(dataset_path))
    print(dataset_documents)
    print(f"Found {len(dataset_documents)} items in dataset: ")
    langchain_documents = []

    model_name = EMBEDDING_MODEL_NAME
    model_kwargs = MODEL_KWARGS
    print("Creating huggingface embeddings for ", model_name)
    embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)

    if vector_store is None and os.path.exists(vectorstore_filename):
        print("Loading existing vector store from", vectorstore_filename)
        vector_store = FAISS.load_local(
            vectorstore_filename, embeddings, allow_dangerous_deserialization=True
        )

    jobs = []
    docs_count = len(dataset_documents)
    failed = 0
    print(f"Embedding {docs_count} documents with Ray...")
    for i, document in enumerate(tqdm(dataset_documents)):
        try:
            # print(f"Submitting job ", i)
            job = parse_doc.remote(document)
            jobs.append(job)

            if i > 1 and i <= docs_count and i % CONCURRENCY == 0:
                if langchain_documents:
                    vector_store = add_documents_to_vector_store(
                        vector_store, langchain_documents, text_splitter, embeddings
                    )
                    print(f"\nSaving vector store to disk at {vectorstore_filename}...")
                    try:
                        os.unlink(vectorstore_filename)
                    except:
                        ...

                    vector_store.save_local(vectorstore_filename)
                    langchain_documents = []
                    jobs = []

            # Block jobs every CONCURRENCY iterations
            if i > 1 and i % CONCURRENCY == 0:
                # print(f"Collecting {len(jobs)} jobs...")
                for _ in jobs:
                    try:
                        # print("waiting for ray job ", _)
                        data = ray.get(_)
                        langchain_documents.extend(data)
                    except Exception as e:
                        print("error in job: ", e)
                        continue
        except Exception as e:
            print(f"\n\nERROR reading dataset {i}:", e)
            failed = failed + 1
            continue

    # print(f"Collecting {len(jobs)} jobs...")
    for _ in jobs:
        try:
            print("waiting for ray job ", _)
            data = ray.get(_)
            langchain_documents.extend(data)
        except Exception as e:
            print("error in job: ", e)
            continue

    if langchain_documents:
        vector_store = add_documents_to_vector_store(
            vector_store, langchain_documents, text_splitter, embeddings
        )
        print(f"\nSaving vector store to disk at {vectorstore_filename}...")
        try:
            os.unlink(vectorstore_filename)
        except:
            ...

        vector_store.save_local(vectorstore_filename)

    return vector_store


def main():
    vectorstore_filename = VECTORSTORE_FILENAME
    dataset_directory = DATASET_DIR
    ingest_dataset_to_vectore_store(
        vectorstore_filename=vectorstore_filename, dataset_directory=dataset_directory
    )


if __name__ == "__main__":
    main()