project_ABM / test_web_rag.py
twimbit-ai's picture
Update test_web_rag.py
9e0ae3c verified
import time
import urllib.request
from urllib.parse import quote
from seleniumbase import SB
import markdownify
from bs4 import BeautifulSoup
from requests_html import HTMLSession
import html2text
import re
from openai import OpenAI
import tiktoken
from zenrows import ZenRowsClient
import requests
import os
from dotenv import load_dotenv
from threading import Thread
load_dotenv()
ZENROWS_KEY = os.getenv('ZENROWS_KEY')
you_key = os.getenv("YOU_API_KEY")
client = OpenAI()
def get_fast_url_source(url):
session = HTMLSession()
r = session.get(url)
return r.text
def convert_html_to_text(html):
h = html2text.HTML2Text()
h.body_width = 0 # Disable line wrapping
text = h.handle(html)
text = re.sub(r'\n\s*', '', text)
text = re.sub(r'\* \\', '', text)
" ".join(text.split())
return text
def get_google_search_url(query):
url = 'https://www.google.com/search?q=' + quote(query)
# Perform the request
request = urllib.request.Request(url)
# Set a normal User Agent header, otherwise Google will block the request.
request.add_header('User-Agent',
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_6) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36')
raw_response = urllib.request.urlopen(request).read()
# Read the repsonse as a utf-8 string
html = raw_response.decode("utf-8")
# The code to get the html contents here.
soup = BeautifulSoup(html, 'html.parser')
# Find all the search result divs
divs = soup.select("#search div.g")
# print(divs)
url = []
for div in divs:
# Search for a h3 tag
results = div.select("h3")
urls = div.select('a')
# Check if we have found a result
# if (len(results) >= 1):
# # Print the title
# h3 = results[0]
# print(h3.get_text())
url.append(urls[0]['href'])
return url
def format_text(text):
soup = BeautifulSoup(text, 'html.parser')
results = soup.find_all(['p', 'h1', 'h2', 'span'])
text = ''
for key, result in enumerate(results):
if key % 2 == 0:
text = text + str(result) + '  '
else:
text = text + str(result) + '  '
return text
def get_page_source_selenium_base(url):
with SB(uc_cdp=True, guest_mode=True, headless=True) as sb:
sb.open(url)
sb.sleep(5)
page_source = sb.driver.get_page_source()
return page_source
def num_tokens_from_string(string: str, encoding_name: str) -> int:
encoding = tiktoken.get_encoding(encoding_name)
# encoding = tiktoken.encoding_for_model(encoding_name)
num_tokens = len(encoding.encode(string))
return num_tokens
def encoding_getter(encoding_type: str):
"""
Returns the appropriate encoding based on the given encoding type (either an encoding string or a model name).
"""
if "k_base" in encoding_type:
return tiktoken.get_encoding(encoding_type)
else:
return tiktoken.encoding_for_model(encoding_type)
def tokenizer(string: str, encoding_type: str) -> list:
"""
Returns the tokens in a text string using the specified encoding.
"""
encoding = encoding_getter(encoding_type)
tokens = encoding.encode(string)
return tokens
def token_counter(string: str, encoding_type: str) -> int:
"""
Returns the number of tokens in a text string using the specified encoding.
"""
num_tokens = len(tokenizer(string, encoding_type))
return num_tokens
def format_output(text):
page_source = format_text(text)
page_source = markdownify.markdownify(page_source)
# page_source = convert_html_to_text(page_source)
page_source = " ".join(page_source.split())
return page_source
def clean_text(text):
# Remove URLs
text = re.sub(r'http[s]?://\S+', '', text)
# Remove special characters and punctuation (keep only letters, numbers, and basic punctuation)
text = re.sub(r'[^a-zA-Z0-9\s,.!?-]', '', text)
# Normalize whitespace
text = re.sub(r'\s+', ' ', text).strip()
return text
def call_open_ai(system_prompt, max_tokens=800, stream=False):
messages = [
{
"role": "user",
"content": system_prompt
}
]
stream = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=messages,
temperature=0,
max_tokens=max_tokens,
top_p=0,
frequency_penalty=0,
presence_penalty=0,
stream=stream
)
return stream.choices[0].message.content
def url_summary(text, question):
system_prompt = """
Summarize the given text, please add all the important topics and numerical data.
While summarizing please keep this question in mind.
question:- {question}
text:
{text}
""".format(question=question, text=text)
return call_open_ai(system_prompt=system_prompt, max_tokens=800)
def get_google_search_query(question):
system_prompt = """
convert this question to the Google search query and return only query.
question:- {question}
""".format(question=question)
return call_open_ai(system_prompt=system_prompt, max_tokens=50)
def is_urlfile(url):
# Check if online file exists
try:
r = urllib.request.urlopen(url) # response
return r.getcode() == 200
except urllib.request.HTTPError:
return False
def check_url_pdf_file(url):
r = requests.get(url)
content_type = r.headers.get('content-type')
if 'application/pdf' in content_type:
return True
else:
return False
def get_ai_snippets_for_query(query, num):
headers = {"X-API-Key": you_key}
params = {"query": query}
return requests.get(
f"https://api.ydc-index.io/search?query={query}&num_web_results={num}",
params=params,
headers=headers,
).json().get('hits')
def get_web_search_you(query, num):
docs = get_ai_snippets_for_query(query, num)
markdown = ""
for doc in docs:
for key, value in doc.items():
if key == 'snippets':
markdown += f"{key}:\n"
for snippet in value:
markdown += f"- {snippet}\n"
else:
markdown += f"{key}: {value}\n"
markdown += "\n"
return markdown
def zenrows_scrapper(url):
zen_client = ZenRowsClient(ZENROWS_KEY)
params = {"js_render": "true"}
response = zen_client.get(url, params=params)
return response.text
def get_new_question_from_history(pre_question, new_question, answer):
system_prompt = """
Generate a new Google search query using the previous question and answer. And return only the query.
previous question:- {pre_question}
answer:- {answer}
new question:- {new_question}
""".format(pre_question=pre_question, answer=answer, new_question=new_question)
return call_open_ai(system_prompt=system_prompt, max_tokens=50)
def scraping_job(strategy, question, url, results, key):
if strategy == 'Deep':
# page_source = get_page_source_selenium_base(url)
page_source = zenrows_scrapper(url)
formatted_page_source = format_output(page_source)
formatted_page_source = clean_text(formatted_page_source)
else:
page_source = get_fast_url_source(url)
formatted_page_source = format_output(page_source)
formatted_page_source = clean_text(formatted_page_source)
tokens = token_counter(formatted_page_source, 'gpt-3.5-turbo')
if tokens >= 15585:
results[key] = ''
else:
summary = url_summary(formatted_page_source, question)
results[key] = summary
def get_docs_from_web(question, history, n_web_search, strategy):
if history:
question = get_new_question_from_history(history[0][0], question, history[0][1])
docs = ''
if strategy == 'Normal Fast':
docs = get_web_search_you(question, n_web_search)
else:
urls = get_google_search_url(get_google_search_query(question))[:n_web_search]
urls = list(set(urls))
yield f"Scraping started for {len(urls)} urls:-\n\n"
threads = [None] * len(urls)
results = [None] * len(urls)
for key, url in enumerate(urls):
if '.pdf' in url or '.PDF' in url:
yield f"Scraping skipped pdf detected. {key + 1}/{len(urls)} - {url} ❌\n"
results[key] = ''
continue
threads[key] = Thread(target=scraping_job, args=(strategy, question, url, results, key))
threads[key].start()
for i in range(len(threads)):
if threads[i] is not None:
threads[i].join()
for key, result in enumerate(results):
if result is not None and result != '':
docs += result
docs += '\n Source:-' + urls[key] + '\n\n'
yield f"Scraping Done {key + 1}/{len(urls)} - {urls[key]} βœ…\n"
yield {"data": docs}