Spaces:
Sleeping
Sleeping
import datasets | |
from langchain.docstore.document import Document | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.retrievers import BM25Retriever | |
# Load the wikipedia dataset | |
knowledge_base = datasets.load_dataset("wikimedia/wikipedia", "20231101.en") | |
# Convert dataset entries to Document objects with metadata | |
source_docs = [ | |
Document(page_content=doc["text"], metadata={"title": doc["title"]}) | |
for doc in knowledge_base | |
] | |
# Split documents into smaller chunks for better retrieval | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=500, # Characters per chunk | |
chunk_overlap=50, # Overlap between chunks to maintain context | |
add_start_index=True, | |
strip_whitespace=True, | |
separators=["\n\n", "\n", ".", " ", ""], # Priority order for splitting | |
) | |
docs_processed = text_splitter.split_documents(source_docs) | |
print(f"Knowledge base prepared with {len(docs_processed)} document chunks") | |
from smolagents import Tool | |
class RetrieverTool(Tool): | |
name = "retriever" | |
description = "Faster than WikipediaSearchTool, it uses semantic search to retrieve wikipedia article that could be most relevant to answer your query." | |
inputs = { | |
"query": { | |
"type": "string", | |
"description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.", | |
} | |
} | |
output_type = "string" | |
def __init__(self, docs, **kwargs): | |
super().__init__(**kwargs) | |
# Initialize the retriever with our processed documents | |
self.retriever = BM25Retriever.from_documents( | |
docs, k=10 # Return top 10 most relevant documents | |
) | |
def forward(self, query: str) -> str: | |
"""Execute the retrieval based on the provided query.""" | |
assert isinstance(query, str), "Your search query must be a string" | |
# Retrieve relevant documents | |
docs = self.retriever.invoke(query) | |
# Format the retrieved documents for readability | |
return "\nRetrieved documents:\n" + "".join( | |
[ | |
f"\n\n===== Document {str(i)} =====\n" + doc.page_content | |
for i, doc in enumerate(docs) | |
] | |
) | |
# Initialize our retriever tool with the processed documents | |
retriever_tool = RetrieverTool(docs_processed) |