Spaces:
Sleeping
Sleeping
File size: 3,573 Bytes
c1a63a0 e7628d2 61e196f c1a63a0 5c0fa85 c1a63a0 e7628d2 c1a63a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import os, tempfile, streamlit as st
from langchain.prompts import PromptTemplate
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_retrieval_chain
# from langchain_chroma import Chroma
# from langchain.vectorstores import FAISS
from langchain_community.vectorstores import FAISS
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
from langchain_community.document_loaders import PyPDFLoader
# Streamlit app config
st.subheader("Upload PDFs and have interactive conversations to get instant answers.")
with st.sidebar:
google_api_key = st.text_input("Google API key", type="password")
source_doc = st.file_uploader("Source document", type="pdf")
col1, col2 = st.columns([4,1])
query = col1.text_input("Query", label_visibility="collapsed")
os.environ['GOOGLE_API_KEY'] = google_api_key
# Session state initialization for documents and retrievers
if 'retriever' not in st.session_state or 'loaded_doc' not in st.session_state:
st.session_state.retriever = None
st.session_state.loaded_doc = None
submit = col2.button("Submit")
if submit:
# Validate inputs
if not google_api_key or not query:
st.warning("Please provide the missing fields.")
elif not source_doc:
st.warning("Please upload the source document.")
else:
with st.spinner("Please wait..."):
# Check if it's the same document; if not or if retriever isn't set, reload and recompute
if st.session_state.loaded_doc != source_doc:
try:
# Save uploaded file temporarily to disk, load and split the file into pages, delete temp file
with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
tmp_file.write(source_doc.read())
loader = PyPDFLoader(tmp_file.name)
pages = loader.load_and_split()
os.remove(tmp_file.name)
# Generate embeddings for the pages, and store in Chroma vector database
embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
vectorstore = FAISS.from_documents(pages, embeddings)
#Configure Chroma as a retriever with top_k=5
st.session_state.retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
# Store the uploaded file in session state to prevent reloading
st.session_state.loaded_doc = source_doc
except Exception as e:
st.error(f"An error occurred: {e}")
try:
# Initialize the ChatGoogleGenerativeAI module, create and invoke the retrieval chain
llm = ChatGoogleGenerativeAI(model="gemini-pro")
template = """
You are a helpful AI assistant. Answer based on the context provided.
context: {context}
input: {input}
answer:
"""
prompt = PromptTemplate.from_template(template)
combine_docs_chain = create_stuff_documents_chain(llm, prompt)
retrieval_chain = create_retrieval_chain(st.session_state.retriever, combine_docs_chain)
response = retrieval_chain.invoke({"input": query})
st.success(response['answer'])
except Exception as e:
st.error(f"An error occurred: {e}")
|