File size: 5,352 Bytes
84f799d abd1f77 84f799d 9dabbb4 84f799d 9dabbb4 84f799d 9dabbb4 84f799d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import streamlit as st
import requests
import re
import urllib.parse
from dotenv import load_dotenv
from bs4 import BeautifulSoup
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.docstore.document import Document
import chromadb
from sentence_transformers import SentenceTransformer
import google.generativeai as genai
# Page configuration
st.set_page_config(layout="wide")
# 載入 .env
load_dotenv()
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
# Initialize Gemini API
genai.configure(api_key=GEMINI_API_KEY)
# Initialize ChromaDB
CHROMA_PATH = "chroma_db"
chroma_client = chromadb.PersistentClient(path=CHROMA_PATH)
# Initialize session state
if 'scraped' not in st.session_state:
st.session_state.scraped = False
if 'collection_name' not in st.session_state:
st.session_state.collection_name = "default_collection"
if 'chat_history' not in st.session_state:
st.session_state.chat_history = []
# Initialize embedding model (中文優化)
embedding_model = SentenceTransformer("shibing624/text2vec-base-chinese")
def clean_text(text):
return re.sub(r'\s+', ' ', re.sub(r'http\S+', '', text)).strip()
def split_content_into_chunks(content):
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200, length_function=len)
return text_splitter.split_documents([Document(page_content=content)])
def add_chunks_to_db(chunks, collection_name):
collection = chroma_client.get_or_create_collection(name=collection_name)
documents = [chunk.page_content for chunk in chunks]
embeddings = embedding_model.encode(documents, convert_to_list=True)
collection.upsert(documents=documents, ids=[f"ID{i}" for i in range(len(chunks))], embeddings=embeddings)
def scrape_text(url, max_depth=1, same_domain=True):
visited = set()
base_domain = urllib.parse.urlparse(url).netloc
def _scrape(u, depth):
if depth > max_depth or u in visited:
return
visited.add(u)
try:
response = requests.get(u)
response.raise_for_status()
soup = BeautifulSoup(response.text, 'html.parser')
text = clean_text(soup.get_text())
chunks = split_content_into_chunks(text)
add_chunks_to_db(chunks, st.session_state.collection_name)
# 遞迴爬取新連結
if depth < max_depth:
for link in soup.find_all('a', href=True):
next_url = urllib.parse.urljoin(u, link['href'])
next_domain = urllib.parse.urlparse(next_url).netloc
if same_domain and next_domain != base_domain:
continue
if next_url.startswith('mailto:') or next_url.startswith('javascript:'):
continue
_scrape(next_url, depth + 1)
except requests.exceptions.RequestException:
pass # 忽略單一頁面錯誤
_scrape(url, 1)
st.session_state.scraped = True
return "Scraping and processing complete. You can now ask questions!"
def ask_question(query, collection_name):
collection = chroma_client.get_or_create_collection(name=collection_name)
query_embedding = embedding_model.encode(query, convert_to_list=True)
results = collection.query(query_embeddings=[query_embedding], n_results=2)
top_chunks = results.get("documents", [[]])[0]
system_prompt = f"""
You are a helpful assistant. Answer only from the provided context.
If you lack information, say: "I don't have enough information to answer that question."
Context:
{str(top_chunks)}
"""
model = genai.GenerativeModel('gemini-2.0-flash')
response = model.generate_content(system_prompt + "\nUser Query: " + query)
return response.text
# Sidebar
with st.sidebar:
st.header("Database Management")
if st.button("Clear Chat History"):
st.session_state.chat_history = []
st.rerun()
st.header("Step 1: Scrape a Website")
url = st.text_input("Enter URL:")
max_depth = st.selectbox("Recursion Depth (層數)", options=[1,2,3,4,5], index=0, help="選擇要遞迴爬幾層,預設1層")
same_domain = st.checkbox("只允許同網域遞迴", value=True, help="預設只爬同一網域的連結")
if url and st.button("Scrape & Process"):
with st.spinner("Scraping..."):
st.success(scrape_text(url, max_depth=max_depth, same_domain=same_domain))
# Main content
st.title("Web Scraper & Q&A Chatbot")
if st.session_state.scraped:
st.subheader("Step 2: Ask Questions")
for message in st.session_state.chat_history:
with st.chat_message(message["role"]):
st.write(message["content"])
user_query = st.chat_input("Ask your question here")
if user_query:
st.session_state.chat_history.append({"role": "user", "content": user_query})
with st.spinner("Searching..."):
answer = ask_question(user_query, st.session_state.collection_name)
st.session_state.chat_history.append({"role": "assistant", "content": answer})
# Limit chat history to 6 messages
st.session_state.chat_history = st.session_state.chat_history[-6:]
st.rerun()
else:
st.info("Please scrape a website first.") |