gwIAS / rag.py
jaywadekar's picture
Fixed urls import and reduced HTML boilerplate
ba88389
# Utilities to build a RAG system to query information from the
# gwIAS search pipeline using Langchain
# Thanks to Pablo Villanueva Domingo for sharing his CAMELS template
# https://huggingface.co/spaces/PabloVD/CAMELSDocBot
from langchain import hub
from langchain_chroma import Chroma
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain.schema import Document
import requests
import json
import base64
from bs4 import BeautifulSoup
import re
def github_to_raw(url):
"""Convert GitHub URL to raw content URL"""
return url.replace("github.com", "raw.githubusercontent.com").replace("/blob/", "/")
def load_github_notebook(url):
"""Load Jupyter notebook from GitHub URL using GitHub API"""
try:
# Convert GitHub blob URL to API URL
if "github.com" in url and "/blob/" in url:
# Extract owner, repo, branch and path from URL
parts = url.replace("https://github.com/", "").split("/")
owner = parts[0]
repo = parts[1]
branch = parts[3] # usually 'main' or 'master'
path = "/".join(parts[4:])
api_url = f"https://api.github.com/repos/{owner}/{repo}/contents/{path}?ref={branch}"
else:
raise ValueError("URL must be a GitHub blob URL")
# Fetch notebook content
response = requests.get(api_url)
response.raise_for_status()
content_data = response.json()
if content_data.get('encoding') == 'base64':
notebook_content = base64.b64decode(content_data['content']).decode('utf-8')
else:
notebook_content = content_data['content']
# Parse notebook JSON
notebook = json.loads(notebook_content)
docs = []
cell_count = 0
# Process each cell
for cell in notebook.get('cells', []):
cell_count += 1
cell_type = cell.get('cell_type', 'unknown')
source = cell.get('source', [])
# Join source lines
if isinstance(source, list):
content = ''.join(source)
else:
content = str(source)
if content.strip(): # Only add non-empty cells
metadata = {
'source': url,
'cell_type': cell_type,
'cell_number': cell_count,
'name': f"{url} - Cell {cell_count} ({cell_type})"
}
# Add cell type prefix for better context
formatted_content = f"[{cell_type.upper()} CELL {cell_count}]\n{content}"
docs.append(Document(page_content=formatted_content, metadata=metadata))
return docs
except Exception as e:
print(f"Error loading notebook from {url}: {str(e)}")
return []
def clean_text(text):
"""Clean text content from a webpage"""
# Remove excessive newlines
text = re.sub(r'\n{3,}', '\n\n', text)
# Remove excessive whitespace
text = re.sub(r'\s{2,}', ' ', text)
return text.strip()
def clean_github_content(html_content):
"""Extract meaningful content from GitHub pages"""
# Ensure we're working with a BeautifulSoup object
if isinstance(html_content, str):
soup = BeautifulSoup(html_content, 'html.parser')
else:
soup = html_content
# Remove navigation, footer, and other boilerplate
for element in soup.find_all(['nav', 'footer', 'header']):
element.decompose()
# For README and code files
readme_content = soup.find('article', class_='markdown-body')
if readme_content:
return clean_text(readme_content.get_text())
# For code files
code_content = soup.find('table', class_='highlight')
if code_content:
return clean_text(code_content.get_text())
# For directory listings
file_list = soup.find('div', role='grid')
if file_list:
return clean_text(file_list.get_text())
# Fallback to main content
main_content = soup.find('main')
if main_content:
return clean_text(main_content.get_text())
# If no specific content found, get text from body
body = soup.find('body')
if body:
return clean_text(body.get_text())
# Final fallback
return clean_text(soup.get_text())
class GitHubLoader(WebBaseLoader):
"""Custom loader for GitHub pages with better content cleaning"""
def clean_text(self, text):
"""Clean text content"""
# Remove excessive newlines and spaces
text = re.sub(r'\n{2,}', '\n', text)
text = re.sub(r'\s{2,}', ' ', text)
# Remove common GitHub boilerplate
text = re.sub(r'Skip to content|Sign in|Search or jump to|Footer navigation|Terms|Privacy|Security|Status|Docs', '', text)
return text.strip()
def _scrape(self, url: str, *args, **kwargs) -> str:
"""Scrape data from URL and clean it.
Args:
url: The URL to scrape
*args: Additional positional arguments
**kwargs: Additional keyword arguments including bs_kwargs
Returns:
str: The cleaned content
"""
response = requests.get(url)
response.raise_for_status()
# For directory listings (tree URLs), use the API
if '/tree/' in url:
# Parse URL components
parts = url.replace("https://github.com/", "").split("/")
owner = parts[0]
repo = parts[1]
branch = parts[3] # usually 'main' or 'master'
path = "/".join(parts[4:]) if len(parts) > 4 else ""
# Construct API URL
api_url = f"https://api.github.com/repos/{owner}/{repo}/contents/{path}?ref={branch}"
api_response = requests.get(api_url)
api_response.raise_for_status()
# Parse directory listing
contents = api_response.json()
if isinstance(contents, list):
# Format directory contents
files = [f"{item['name']} ({item['type']})" for item in contents]
return "Directory contents:\n" + "\n".join(files)
else:
return f"Error: Unexpected API response for {url}"
# For regular files, parse HTML
soup = BeautifulSoup(response.text, 'html.parser')
# For README and markdown files
readme_content = soup.find('article', class_='markdown-body')
if readme_content:
return self.clean_text(readme_content.get_text())
# For code files
code_content = soup.find('table', class_='highlight')
if code_content:
return self.clean_text(code_content.get_text())
# For other content, get main content
main_content = soup.find('main')
if main_content:
return self.clean_text(main_content.get_text())
# Final fallback
return self.clean_text(soup.get_text())
# Load documentation from urls
def load_docs():
# Get urls
urlsfile = open("urls.txt")
urls = urlsfile.readlines()
urls = [url.replace("\n","") for url in urls]
urlsfile.close()
# Load documents from URLs
docs = []
for url in urls:
url = url.strip()
if not url:
continue
# Check if URL is a Jupyter notebook
if url.endswith('.ipynb') and 'github.com' in url and '/blob/' in url:
print(f"Loading notebook: {url}")
notebook_docs = load_github_notebook(url)
docs.extend(notebook_docs)
# Handle Python and Markdown files using raw content
elif url.endswith(('.py', '.md')) and 'github.com' in url and '/blob/' in url:
print(f"Loading raw content: {url}")
try:
raw_url = github_to_raw(url)
loader = WebBaseLoader([raw_url])
web_docs = loader.load()
# Preserve original URL in metadata
for doc in web_docs:
doc.metadata['source'] = url
docs.extend(web_docs)
except Exception as e:
print(f"Error loading {url}: {str(e)}")
# Handle directory listings
elif '/tree/' in url and 'github.com' in url:
print(f"Loading directory: {url}")
try:
# Parse URL components
parts = url.replace("https://github.com/", "").split("/")
owner = parts[0]
repo = parts[1]
branch = parts[3] # usually 'main' or 'master'
path = "/".join(parts[4:]) if len(parts) > 4 else ""
# Construct API URL
api_url = f"https://api.github.com/repos/{owner}/{repo}/contents/{path}?ref={branch}"
response = requests.get(api_url)
response.raise_for_status()
# Parse directory listing
contents = response.json()
if isinstance(contents, list):
# Format directory contents
content = "Directory contents:\n" + "\n".join([f"{item['name']} ({item['type']})" for item in contents])
docs.append(Document(page_content=content, metadata={'source': url}))
else:
print(f"Error: Unexpected API response for {url}")
except Exception as e:
print(f"Error loading directory {url}: {str(e)}")
else:
print(f"Loading web page: {url}")
try:
loader = GitHubLoader([url]) # Use custom loader
web_docs = loader.load()
docs.extend(web_docs)
except Exception as e:
print(f"Error loading {url}: {str(e)}")
# Add source URLs as document names for reference
for i, doc in enumerate(docs):
if 'source' in doc.metadata:
doc.metadata['name'] = doc.metadata['source']
else:
doc.metadata['name'] = f"Document {i+1}"
print(f"Loaded {len(docs)} documents:")
for doc in docs:
print(f" - {doc.metadata.get('name')}")
return docs
def extract_reference(url):
"""Extract a reference keyword from the GitHub URL"""
if "blob/main" in url:
return url.split("blob/main/")[-1]
elif "tree/main" in url:
return url.split("tree/main/")[-1] or "root"
return url
# Join content pages for processing
def format_docs(docs):
formatted_docs = []
for doc in docs:
source = doc.metadata.get('source', 'Unknown source')
reference = f"[{extract_reference(source)}]"
content = doc.page_content
formatted_docs.append(f"{content}\n\nReference: {reference}")
return "\n\n---\n\n".join(formatted_docs)
# Create a RAG chain
def RAG(llm, docs, embeddings):
# Split text
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
splits = text_splitter.split_documents(docs)
# Create vector store
vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings)
# Retrieve and generate using the relevant snippets of the documents
retriever = vectorstore.as_retriever()
# Prompt basis example for RAG systems
prompt = hub.pull("rlm/rag-prompt")
# Adding custom instructions to the prompt
template = prompt.messages[0].prompt.template
template_parts = template.split("\nQuestion: {question}")
combined_template = "You are an assistant for question-answering tasks. "\
+ "Use the following pieces of retrieved context to answer the question. "\
+ "If you don't know the answer, just say that you don't know. "\
+ "Try to keep the answer concise if possible. "\
+ "Write the names of the relevant functions from the retrived code and include code snippets to aid the user's understanding. "\
+ "Include the references used in square brackets at the end of your answer."\
+ template_parts[1]
prompt.messages[0].prompt.template = combined_template
# Create the chain
rag_chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
return rag_chain