File size: 5,352 Bytes
84f799d
 
 
 
abd1f77
84f799d
 
 
 
 
 
 
 
 
 
9dabbb4
 
 
 
84f799d
9dabbb4
84f799d
 
 
 
 
 
 
 
 
 
 
 
 
9dabbb4
 
84f799d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import streamlit as st
import requests
import re
import urllib.parse
from dotenv import load_dotenv
from bs4 import BeautifulSoup
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.docstore.document import Document
import chromadb
from sentence_transformers import SentenceTransformer
import google.generativeai as genai

# Page configuration
st.set_page_config(layout="wide")

# 載入 .env
load_dotenv()
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")

# Initialize Gemini API
genai.configure(api_key=GEMINI_API_KEY)

# Initialize ChromaDB
CHROMA_PATH = "chroma_db"
chroma_client = chromadb.PersistentClient(path=CHROMA_PATH)

# Initialize session state
if 'scraped' not in st.session_state:
    st.session_state.scraped = False
if 'collection_name' not in st.session_state:
    st.session_state.collection_name = "default_collection"
if 'chat_history' not in st.session_state:
    st.session_state.chat_history = []

# Initialize embedding model (中文優化)
embedding_model = SentenceTransformer("shibing624/text2vec-base-chinese")

def clean_text(text):
    return re.sub(r'\s+', ' ', re.sub(r'http\S+', '', text)).strip()

def split_content_into_chunks(content):
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200, length_function=len)
    return text_splitter.split_documents([Document(page_content=content)])

def add_chunks_to_db(chunks, collection_name):
    collection = chroma_client.get_or_create_collection(name=collection_name)
    documents = [chunk.page_content for chunk in chunks]
    embeddings = embedding_model.encode(documents, convert_to_list=True)
    collection.upsert(documents=documents, ids=[f"ID{i}" for i in range(len(chunks))], embeddings=embeddings)

def scrape_text(url, max_depth=1, same_domain=True):
    visited = set()
    base_domain = urllib.parse.urlparse(url).netloc

    def _scrape(u, depth):
        if depth > max_depth or u in visited:
            return
        visited.add(u)
        try:
            response = requests.get(u)
            response.raise_for_status()
            soup = BeautifulSoup(response.text, 'html.parser')
            
            text = clean_text(soup.get_text())
            chunks = split_content_into_chunks(text)
            add_chunks_to_db(chunks, st.session_state.collection_name)
            
            # 遞迴爬取新連結
            if depth < max_depth:
                for link in soup.find_all('a', href=True):
                    next_url = urllib.parse.urljoin(u, link['href'])
                    next_domain = urllib.parse.urlparse(next_url).netloc
                    if same_domain and next_domain != base_domain:
                        continue
                    if next_url.startswith('mailto:') or next_url.startswith('javascript:'):
                        continue
                    _scrape(next_url, depth + 1)
        except requests.exceptions.RequestException:
            pass  # 忽略單一頁面錯誤

    _scrape(url, 1)
    st.session_state.scraped = True
    return "Scraping and processing complete. You can now ask questions!"

def ask_question(query, collection_name):
    collection = chroma_client.get_or_create_collection(name=collection_name)
    query_embedding = embedding_model.encode(query, convert_to_list=True)
    results = collection.query(query_embeddings=[query_embedding], n_results=2)
    top_chunks = results.get("documents", [[]])[0]
    
    system_prompt = f"""
    You are a helpful assistant. Answer only from the provided context. 
    If you lack information, say: "I don't have enough information to answer that question."
    Context:
    {str(top_chunks)}
    """
    
    model = genai.GenerativeModel('gemini-2.0-flash')
    response = model.generate_content(system_prompt + "\nUser Query: " + query)
    return response.text

# Sidebar
with st.sidebar:
    st.header("Database Management")
    if st.button("Clear Chat History"):
        st.session_state.chat_history = []
        st.rerun()
    
    st.header("Step 1: Scrape a Website")
    url = st.text_input("Enter URL:")
    max_depth = st.selectbox("Recursion Depth (層數)", options=[1,2,3,4,5], index=0, help="選擇要遞迴爬幾層,預設1層")
    same_domain = st.checkbox("只允許同網域遞迴", value=True, help="預設只爬同一網域的連結")
    if url and st.button("Scrape & Process"):
        with st.spinner("Scraping..."):
            st.success(scrape_text(url, max_depth=max_depth, same_domain=same_domain))

# Main content
st.title("Web Scraper & Q&A Chatbot")
if st.session_state.scraped:
    st.subheader("Step 2: Ask Questions")
    for message in st.session_state.chat_history:
        with st.chat_message(message["role"]):
            st.write(message["content"])
    
    user_query = st.chat_input("Ask your question here")
    if user_query:
        st.session_state.chat_history.append({"role": "user", "content": user_query})
        with st.spinner("Searching..."):
            answer = ask_question(user_query, st.session_state.collection_name)
        st.session_state.chat_history.append({"role": "assistant", "content": answer})
        
        # Limit chat history to 6 messages
        st.session_state.chat_history = st.session_state.chat_history[-6:]
        st.rerun()
else:
    st.info("Please scrape a website first.")