boryasbora's picture
Update app.py
a485938 verified
raw
history blame
7.48 kB
import streamlit as st
import os
import pickle
from langchain.prompts import ChatPromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.llms import HuggingFacePipeline
from langchain.retrievers import ParentDocumentRetriever
from langchain.storage import InMemoryStore
from langchain_chroma import Chroma
from langchain.llms import LlamaCpp
from langchain_core.prompts import ChatPromptTemplate, FewShotChatMessagePromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda
from datetime import date
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
# Environment variables
os.environ['LANGCHAIN_TRACING_V2'] = 'true'
os.environ['LANGCHAIN_ENDPOINT'] = 'https://api.smith.langchain.com'
os.environ['LANGCHAIN_API_KEY'] = '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'
@st.cache_resource
def load_model():
model_name = "bigscience/bloom-560m"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
return model, tokenizer
def load_from_pickle(filename):
with open(filename, "rb") as file:
return pickle.load(file)
def load_retriever(docstore_path,chroma_path,embeddings,child_splitter,parent_splitter):
"""Loads the vector store and document store, initializing the retriever."""
db3 = Chroma(collection_name="full_documents", #collection_name shoud be the same as in the first time
embedding_function=embeddings,
persist_directory=chroma_path
)
store_dict = load_from_pickle(docstore_path)
store = InMemoryStore()
store.mset(list(store_dict.items()))
retriever = ParentDocumentRetriever(
vectorstore=db3,
docstore=store,
child_splitter=child_splitter,
parent_splitter=parent_splitter,
search_kwargs={"k": 5}
)
return retriever
def inspect(state):
if "context_sources" not in st.session_state:
st.session_state.context_sources = []
context = state['normal_context']
st.session_state.context_sources =[doc.metadata['source'] for doc in context]
st.session_state.context_content = [doc.page_content for doc in context]
return state
def retrieve_normal_context(retriever, question):
docs = retriever.invoke(question)
return docs
# Your OLMOLLM class implementation here (adapted for the Hugging Face model)
@st.cache_resource
def get_chain(temperature):
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L12-v2")
docstore_path = 'ohw_proj_chorma_db.pcl'
chroma_path = 'ohw_proj_chorma_db'
parent_splitter = RecursiveCharacterTextSplitter(chunk_size=2000,
chunk_overlap=500)
# create the child documents - The small chunks
child_splitter = RecursiveCharacterTextSplitter(chunk_size=300,
chunk_overlap=50)
retriever = load_retriever(docstore_path,chroma_path,embeddings,child_splitter,parent_splitter)
model, tokenizer = load_model()
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_length=4000,
max_new_tokens = 500,
temperature=temperature,
top_p=0.95,
repetition_penalty=1.15
)
llm = HuggingFacePipeline(pipeline=pipe)
today = date.today()
# Response prompt
response_prompt_template = """You are an assistant who helps Ocean Hack Week community to answer their questions. I am going to ask you a question. Your response should be comprehensive and not contradicted with the following context if they are relevant. Otherwise, ignore them if they are not relevant.
Keep track of chat history: {chat_history}
Today's date: {date}
## Normal Context:
{normal_context}
# Original Question: {question}
# Answer (embed links where relevant):
"""
response_prompt = ChatPromptTemplate.from_template(response_prompt_template)
context_chain = RunnableLambda(lambda x: {
"question": x["question"],
"normal_context": retrieve_normal_context(retriever,x["question"]),
# "step_back_context": retrieve_step_back_context(retriever,generate_queries_step_back.invoke({"question": x["question"]})),
"chat_history": x["chat_history"],
"date": today})
chain = (
context_chain
| RunnableLambda(inspect)
| response_prompt
| llm
| StrOutputParser()
)
return chain
def clear_chat_history():
st.session_state.messages = []
st.session_state.context_sources = []
st.session_state.key = 0
# st.set_page_config(page_title='OHW AI')
# Sidebar
with st.sidebar:
st.title("OHW Assistant")
temperature = st.slider("Temperature: ", 0.0, 1.0, 0.5, 0.1)
chain = get_chain(temperature)
st.button('Clear Chat History', on_click=clear_chat_history)
# Main app
if "messages" not in st.session_state:
st.session_state.messages = []
for q, message in enumerate(st.session_state.messages):
if (message["role"] == 'assistant'):
with st.chat_message(message["role"]):
tab1, tab2 = st.tabs(["Answer", "Sources"])
with tab1:
for i, source in enumerate(message["sources"]):
name = f'{source}'
with st.expander(name):
st.markdown(f'{message["context"][i]}')
# st.markdown(message["content"])
# with tab2:
# for i, source in enumerate(message["sources"]):
# name = f'{source}'
# with st.expander(name):
# st.markdown(f'{message["context"][i]}')
else:
question = message["content"]
with st.chat_message(message["role"]):
st.markdown(message["content"])
if prompt := st.chat_input("How may I assist you today?"):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant"):
query=st.session_state.messages[-1]['content']
tab1, tab2 = st.tabs(["Answer", "Sources"])
with tab1:
for i, source in enumerate(st.session_state.context_sources):
name = f'{source}'
with st.expander(name):
st.markdown(f'{st.session_state.context_content[i]}')
# with st.spinner("Generating answer..."):
# Generate the full answer at once
# full_answer = chain.invoke({"question": query, "chat_history": st.session_state.messages})
# Display the full answer
st.markdown(full_answer, unsafe_allow_html=True)
# with tab2:
# for i, source in enumerate(st.session_state.context_sources):
# name = f'{source}'
# with st.expander(name):
# st.markdown(f'{st.session_state.context_content[i]}')
st.session_state.messages.append({"role": "assistant", "content": full_answer})
st.session_state.messages[-1]['sources'] = st.session_state.context_sources
st.session_state.messages[-1]['context'] = st.session_state.context_content