import streamlit as st import requests import re import os import urllib.parse from dotenv import load_dotenv from bs4 import BeautifulSoup from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain.docstore.document import Document import chromadb from sentence_transformers import SentenceTransformer import google.generativeai as genai # Page configuration st.set_page_config(layout="wide") # 載入 .env load_dotenv() GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") # Initialize Gemini API genai.configure(api_key=GEMINI_API_KEY) # Initialize ChromaDB CHROMA_PATH = "chroma_db" chroma_client = chromadb.PersistentClient(path=CHROMA_PATH) # Initialize session state if 'scraped' not in st.session_state: st.session_state.scraped = False if 'collection_name' not in st.session_state: st.session_state.collection_name = "default_collection" if 'chat_history' not in st.session_state: st.session_state.chat_history = [] # Initialize embedding model (中文優化) embedding_model = SentenceTransformer("shibing624/text2vec-base-chinese") def clean_text(text): return re.sub(r'\s+', ' ', re.sub(r'http\S+', '', text)).strip() def split_content_into_chunks(content): text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200, length_function=len) return text_splitter.split_documents([Document(page_content=content)]) def add_chunks_to_db(chunks, collection_name): collection = chroma_client.get_or_create_collection(name=collection_name) documents = [chunk.page_content for chunk in chunks] embeddings = embedding_model.encode(documents, convert_to_list=True) collection.upsert(documents=documents, ids=[f"ID{i}" for i in range(len(chunks))], embeddings=embeddings) def scrape_text(url, max_depth=1, same_domain=True): visited = set() base_domain = urllib.parse.urlparse(url).netloc def _scrape(u, depth): if depth > max_depth or u in visited: return visited.add(u) try: response = requests.get(u) response.raise_for_status() soup = BeautifulSoup(response.text, 'html.parser') text = clean_text(soup.get_text()) chunks = split_content_into_chunks(text) add_chunks_to_db(chunks, st.session_state.collection_name) # 遞迴爬取新連結 if depth < max_depth: for link in soup.find_all('a', href=True): next_url = urllib.parse.urljoin(u, link['href']) next_domain = urllib.parse.urlparse(next_url).netloc if same_domain and next_domain != base_domain: continue if next_url.startswith('mailto:') or next_url.startswith('javascript:'): continue _scrape(next_url, depth + 1) except requests.exceptions.RequestException: pass # 忽略單一頁面錯誤 _scrape(url, 1) st.session_state.scraped = True return "Scraping and processing complete. You can now ask questions!" def ask_question(query, collection_name): collection = chroma_client.get_or_create_collection(name=collection_name) query_embedding = embedding_model.encode(query, convert_to_list=True) results = collection.query(query_embeddings=[query_embedding], n_results=2) top_chunks = results.get("documents", [[]])[0] system_prompt = f""" You are a helpful assistant. Answer only from the provided context. If you lack information, say: "I don't have enough information to answer that question." Context: {str(top_chunks)} """ model = genai.GenerativeModel('gemini-2.0-flash') response = model.generate_content(system_prompt + "\nUser Query: " + query) return response.text # Sidebar with st.sidebar: st.header("Database Management") if st.button("Clear Chat History"): st.session_state.chat_history = [] st.rerun() st.header("Step 1: Scrape a Website") url = st.text_input("Enter URL:") max_depth = st.selectbox("Recursion Depth (層數)", options=[1,2,3,4,5], index=0, help="選擇要遞迴爬幾層,預設1層") same_domain = st.checkbox("只允許同網域遞迴", value=True, help="預設只爬同一網域的連結") if url and st.button("Scrape & Process"): with st.spinner("Scraping..."): st.success(scrape_text(url, max_depth=max_depth, same_domain=same_domain)) # Main content st.title("Web Scraper & Q&A Chatbot") if st.session_state.scraped: st.subheader("Step 2: Ask Questions") for message in st.session_state.chat_history: with st.chat_message(message["role"]): st.write(message["content"]) user_query = st.chat_input("Ask your question here") if user_query: st.session_state.chat_history.append({"role": "user", "content": user_query}) with st.spinner("Searching..."): answer = ask_question(user_query, st.session_state.collection_name) st.session_state.chat_history.append({"role": "assistant", "content": answer}) # Limit chat history to 6 messages st.session_state.chat_history = st.session_state.chat_history[-6:] st.rerun() else: st.info("Please scrape a website first.")