import datasets from langchain.docstore.document import Document from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.retrievers import BM25Retriever # Load the wikipedia dataset knowledge_base = datasets.load_dataset("wikimedia/wikipedia", "20231101.en") # Convert dataset entries to Document objects with metadata source_docs = [ Document(page_content=doc["text"], metadata={"title": doc["title"]}) for doc in knowledge_base ] # Split documents into smaller chunks for better retrieval text_splitter = RecursiveCharacterTextSplitter( chunk_size=500, # Characters per chunk chunk_overlap=50, # Overlap between chunks to maintain context add_start_index=True, strip_whitespace=True, separators=["\n\n", "\n", ".", " ", ""], # Priority order for splitting ) docs_processed = text_splitter.split_documents(source_docs) print(f"Knowledge base prepared with {len(docs_processed)} document chunks") from smolagents import Tool class RetrieverTool(Tool): name = "retriever" description = "Faster than WikipediaSearchTool, it uses semantic search to retrieve wikipedia article that could be most relevant to answer your query." inputs = { "query": { "type": "string", "description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.", } } output_type = "string" def __init__(self, docs, **kwargs): super().__init__(**kwargs) # Initialize the retriever with our processed documents self.retriever = BM25Retriever.from_documents( docs, k=10 # Return top 10 most relevant documents ) def forward(self, query: str) -> str: """Execute the retrieval based on the provided query.""" assert isinstance(query, str), "Your search query must be a string" # Retrieve relevant documents docs = self.retriever.invoke(query) # Format the retrieved documents for readability return "\nRetrieved documents:\n" + "".join( [ f"\n\n===== Document {str(i)} =====\n" + doc.page_content for i, doc in enumerate(docs) ] ) # Initialize our retriever tool with the processed documents retriever_tool = RetrieverTool(docs_processed)