Spaces:
Sleeping
Sleeping
File size: 2,386 Bytes
3bc88d4 b002418 3bc88d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
import datasets
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.retrievers import BM25Retriever
# Load the wikipedia dataset
knowledge_base = datasets.load_dataset("wikimedia/wikipedia", "20231101.en")
# Convert dataset entries to Document objects with metadata
source_docs = [
Document(page_content=doc["text"], metadata={"title": doc["title"]})
for doc in knowledge_base
]
# Split documents into smaller chunks for better retrieval
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500, # Characters per chunk
chunk_overlap=50, # Overlap between chunks to maintain context
add_start_index=True,
strip_whitespace=True,
separators=["\n\n", "\n", ".", " ", ""], # Priority order for splitting
)
docs_processed = text_splitter.split_documents(source_docs)
print(f"Knowledge base prepared with {len(docs_processed)} document chunks")
from smolagents import Tool
class RetrieverTool(Tool):
name = "retriever"
description = "Faster than WikipediaSearchTool, it uses semantic search to retrieve wikipedia article that could be most relevant to answer your query."
inputs = {
"query": {
"type": "string",
"description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.",
}
}
output_type = "string"
def __init__(self, docs, **kwargs):
super().__init__(**kwargs)
# Initialize the retriever with our processed documents
self.retriever = BM25Retriever.from_documents(
docs, k=10 # Return top 10 most relevant documents
)
def forward(self, query: str) -> str:
"""Execute the retrieval based on the provided query."""
assert isinstance(query, str), "Your search query must be a string"
# Retrieve relevant documents
docs = self.retriever.invoke(query)
# Format the retrieved documents for readability
return "\nRetrieved documents:\n" + "".join(
[
f"\n\n===== Document {str(i)} =====\n" + doc.page_content
for i, doc in enumerate(docs)
]
)
# Initialize our retriever tool with the processed documents
retriever_tool = RetrieverTool(docs_processed) |