File size: 2,363 Bytes
cd5b6a8
d4ae976
0ccee0d
d4ae976
1aba791
15e4bac
7f8cf27
e67deaf
719919b
1aba791
 
 
d4ae976
1aba791
15e4bac
5d56f39
1aba791
 
 
 
 
15e4bac
 
 
5d56f39
d4ae976
15e4bac
0ccee0d
cd5b6a8
1aba791
15e4bac
1aba791
5d56f39
1aba791
 
5d56f39
1aba791
 
 
 
15e4bac
1aba791
 
 
15e4bac
1aba791
0ccee0d
1aba791
15e4bac
1aba791
 
15e4bac
1aba791
 
 
 
 
 
5d56f39
d4ae976
 
1aba791
 
15e4bac
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.chains import RetrievalQA
from langchain.llms import HuggingFacePipeline
from transformers import pipeline
from langchain.text_splitter import RecursiveCharacterTextSplitter

class KnowledgeManager:
    def __init__(self, root_dir="."):
        self.root_dir = root_dir
        self.docsearch = None
        self.qa_chain = None
        self.llm = None
        self.embeddings = None

        self._initialize_llm()
        self._initialize_embeddings()
        self._load_knowledge_base()

    def _initialize_llm(self):
        # Load local text2text model using HuggingFace pipeline (FLAN-T5 small)
        local_pipe = pipeline("text2text-generation", model="google/flan-t5-small", max_length=1024)
        self.llm = HuggingFacePipeline(pipeline=local_pipe)

    def _initialize_embeddings(self):
        # Use general-purpose sentence transformer
        self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")

    def _load_knowledge_base(self):
        # Automatically find all .txt files in the root directory
        txt_files = [f for f in os.listdir(self.root_dir) if f.endswith(".txt")]

        if not txt_files:
            raise FileNotFoundError("No .txt files found in root directory.")

        all_texts = []
        for filename in txt_files:
            path = os.path.join(self.root_dir, filename)
            with open(path, "r", encoding="utf-8") as f:
                all_texts.append(f.read())

        full_text = "\n\n".join(all_texts)

        # Split text into chunks for embedding
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
        docs = text_splitter.create_documents([full_text])

        # Create FAISS vector store
        self.docsearch = FAISS.from_documents(docs, self.embeddings)

        # Build the QA chain
        self.qa_chain = RetrievalQA.from_chain_type(
            llm=self.llm,
            chain_type="stuff",
            retriever=self.docsearch.as_retriever(),
            return_source_documents=True,
        )

    def ask(self, query):
        if not self.qa_chain:
            raise ValueError("Knowledge base not initialized.")
        result = self.qa_chain(query)
        return result['result']