Spaces:
Build error
Build error
import streamlit as st | |
import ollama | |
import os | |
import logging | |
from langchain_ollama import ChatOllama | |
from langchain_community.document_loaders import PyMuPDFLoader | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
from langchain_ollama import OllamaEmbeddings | |
import faiss | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.docstore.in_memory import InMemoryDocstore | |
from langchain import hub | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain_core.prompts import ChatPromptTemplate | |
from typing import List, Tuple, Dict, Any, Optional | |
# pip install -qU langchain-ollama | |
# pip install langchain | |
##### Logging | |
def format_docs(docs): | |
return "\n\n".join([doc.page_content for doc in docs]) | |
def extract_model_names( | |
models_info: Dict[str, List[Dict[str, Any]]], | |
) -> Tuple[str, ...]: | |
""" | |
Extract model names from the provided models information. | |
Args: | |
models_info (Dict[str, List[Dict[str, Any]]]): Dictionary containing information about available models. | |
Returns: | |
Tuple[str, ...]: A tuple of model names. | |
""" | |
# Logging configuration | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s - %(levelname)s - %(message)s", | |
datefmt="%Y-%m-%d %H:%M:%S", | |
) | |
logger = logging.getLogger(__name__) | |
logger.info("Extracting model names from models_info") | |
model_names = tuple(model["name"] for model in models_info["models"]) | |
logger.info(f"Extracted model names: {model_names}") | |
return model_names | |
def generate_response(rag_chain, input_text): | |
response = rag_chain.invoke(input_text) | |
return response | |
### Ken 12/11/2024 ADD START | |
def get_pdf(uploaded_file): | |
if uploaded_file : | |
temp_file = "./temp.pdf" | |
# Delete the existing temp.pdf file if it exists | |
if os.path.exists(temp_file): | |
os.remove(temp_file) | |
with open(temp_file, "wb") as file: | |
file.write(uploaded_file.getvalue()) | |
file_name = uploaded_file.name | |
loader = PyPDFLoader(temp_file) | |
docs = loader.load() | |
return docs | |
### Ken 12/11/2024 ADD END | |
def main() -> None: | |
st.title("🧠 This is a RAG Chatbot with Ollama and Langchain !!!") | |
st.write("The LLM model unsloth/Llama-3.2-3B-Instruct is used") | |
st.write("You can upload a PDF to chat with !!!") | |
with st.sidebar: | |
st.title("PDF FILE UPLOAD:") | |
docs = st.file_uploader("Upload your PDF File and Click on the Submit & Process Button", accept_multiple_files=False, key="pdf_uploader") | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
### Ken 12/11/2024 ADD START | |
raw_text = get_pdf(docs) | |
### Ken 12/11/2024 ADD END | |
#chunks = text_splitter.split_documents(docs) | |
chunks = text_splitter.split_documents(raw_text) | |
embeddings = OllamaEmbeddings(model='nomic-embed-text', base_url="http://localhost:11434") | |
single_vector = embeddings.embed_query("this is some text data") | |
index = faiss.IndexFlatL2(len(single_vector)) | |
vector_store = FAISS( | |
embedding_function=embeddings, | |
index=index, | |
docstore=InMemoryDocstore(), | |
index_to_docstore_id={} | |
) | |
ids = vector_store.add_documents(documents=chunks) | |
## Retreival | |
retriever = vector_store.as_retriever(search_type="mmr", search_kwargs = {'k': 3, | |
'fetch_k': 100, | |
'lambda_mult': 1}) | |
prompt = """ | |
You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. | |
If you don't know the answer, just say that you don't know. | |
Answer in bullet points. Make sure your answer is relevant to the question and it is answered from the context only. | |
Question: {question} | |
Context: {context} | |
Answer: | |
""" | |
prompt = ChatPromptTemplate.from_template(prompt) | |
model = ChatOllama(model="unsloth/Llama-3.2-3B-Instruct") | |
rag_chain = ( | |
{"context": retriever|format_docs, "question": RunnablePassthrough()} | |
| prompt | |
| model | |
| StrOutputParser() | |
) | |
with st.form("llm-form"): | |
text = st.text_area("Enter your question or statement:") | |
submit = st.form_submit_button("Submit") | |
if "chat_history" not in st.session_state: | |
st.session_state['chat_history'] = [] | |
if submit and text: | |
with st.spinner("Generating response..."): | |
response = generate_response(rag_chain, text) | |
st.session_state['chat_history'].append({"user": text, "ollama": response}) | |
st.write(response) | |
st.write("## Chat History") | |
for chat in reversed(st.session_state['chat_history']): | |
st.write(f"**🧑 User**: {chat['user']}") | |
st.write(f"**🧠 Assistant**: {chat['ollama']}") | |
st.write("---") | |
if __name__ == "__main__": | |
main() | |