Spaces:
Sleeping
Sleeping
import streamlit as st | |
import os | |
import pickle | |
from langchain.prompts import ChatPromptTemplate | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain_community.llms import HuggingFacePipeline | |
from langchain.retrievers import ParentDocumentRetriever | |
from langchain.storage import InMemoryStore | |
from langchain_chroma import Chroma | |
from langchain.llms import LlamaCpp | |
from langchain_core.prompts import ChatPromptTemplate, FewShotChatMessagePromptTemplate | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnableLambda | |
from datetime import date | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
# Environment variables | |
os.environ['LANGCHAIN_TRACING_V2'] = 'true' | |
os.environ['LANGCHAIN_ENDPOINT'] = 'https://api.smith.langchain.com' | |
os.environ['LANGCHAIN_API_KEY'] = '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' | |
def load_model(): | |
model_name = "bigscience/bloom-560m" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
return model, tokenizer | |
def load_from_pickle(filename): | |
with open(filename, "rb") as file: | |
return pickle.load(file) | |
def load_retriever(docstore_path,chroma_path,embeddings,child_splitter,parent_splitter): | |
"""Loads the vector store and document store, initializing the retriever.""" | |
db3 = Chroma(collection_name="full_documents", #collection_name shoud be the same as in the first time | |
embedding_function=embeddings, | |
persist_directory=chroma_path | |
) | |
store_dict = load_from_pickle(docstore_path) | |
store = InMemoryStore() | |
store.mset(list(store_dict.items())) | |
retriever = ParentDocumentRetriever( | |
vectorstore=db3, | |
docstore=store, | |
child_splitter=child_splitter, | |
parent_splitter=parent_splitter, | |
search_kwargs={"k": 5} | |
) | |
return retriever | |
def inspect(state): | |
if "context_sources" not in st.session_state: | |
st.session_state.context_sources = [] | |
context = state['normal_context'] | |
st.session_state.context_sources =[doc.metadata['source'] for doc in context] | |
st.session_state.context_content = [doc.page_content for doc in context] | |
return state | |
def retrieve_normal_context(retriever, question): | |
docs = retriever.invoke(question) | |
return docs | |
# Your OLMOLLM class implementation here (adapted for the Hugging Face model) | |
def get_chain(temperature): | |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L12-v2") | |
docstore_path = 'ohw_proj_chorma_db.pcl' | |
chroma_path = 'ohw_proj_chorma_db' | |
parent_splitter = RecursiveCharacterTextSplitter(chunk_size=2000, | |
chunk_overlap=500) | |
# create the child documents - The small chunks | |
child_splitter = RecursiveCharacterTextSplitter(chunk_size=300, | |
chunk_overlap=50) | |
retriever = load_retriever(docstore_path,chroma_path,embeddings,child_splitter,parent_splitter) | |
model, tokenizer = load_model() | |
pipe = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
max_length=4000, | |
max_new_tokens = 500, | |
temperature=temperature, | |
top_p=0.95, | |
repetition_penalty=1.15 | |
) | |
llm = HuggingFacePipeline(pipeline=pipe) | |
today = date.today() | |
# Response prompt | |
response_prompt_template = """You are an assistant who helps Ocean Hack Week community to answer their questions. I am going to ask you a question. Your response should be comprehensive and not contradicted with the following context if they are relevant. Otherwise, ignore them if they are not relevant. | |
Keep track of chat history: {chat_history} | |
Today's date: {date} | |
## Normal Context: | |
{normal_context} | |
# Original Question: {question} | |
# Answer (embed links where relevant): | |
""" | |
response_prompt = ChatPromptTemplate.from_template(response_prompt_template) | |
context_chain = RunnableLambda(lambda x: { | |
"question": x["question"], | |
"normal_context": retrieve_normal_context(retriever,x["question"]), | |
# "step_back_context": retrieve_step_back_context(retriever,generate_queries_step_back.invoke({"question": x["question"]})), | |
"chat_history": x["chat_history"], | |
"date": today}) | |
chain = ( | |
context_chain | |
| RunnableLambda(inspect) | |
| response_prompt | |
| llm | |
| StrOutputParser() | |
) | |
return chain | |
def clear_chat_history(): | |
st.session_state.messages = [] | |
st.session_state.context_sources = [] | |
st.session_state.key = 0 | |
# st.set_page_config(page_title='OHW AI') | |
# Sidebar | |
with st.sidebar: | |
st.title("OHW Assistant") | |
temperature = st.slider("Temperature: ", 0.0, 1.0, 0.5, 0.1) | |
chain = get_chain(temperature) | |
st.button('Clear Chat History', on_click=clear_chat_history) | |
# Main app | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
for q, message in enumerate(st.session_state.messages): | |
if (message["role"] == 'assistant'): | |
with st.chat_message(message["role"]): | |
tab1, tab2 = st.tabs(["Answer", "Sources"]) | |
with tab1: | |
for i, source in enumerate(message["sources"]): | |
name = f'{source}' | |
with st.expander(name): | |
st.markdown(f'{message["context"][i]}') | |
# st.markdown(message["content"]) | |
# with tab2: | |
# for i, source in enumerate(message["sources"]): | |
# name = f'{source}' | |
# with st.expander(name): | |
# st.markdown(f'{message["context"][i]}') | |
else: | |
question = message["content"] | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
if prompt := st.chat_input("How may I assist you today?"): | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
with st.chat_message("assistant"): | |
query=st.session_state.messages[-1]['content'] | |
tab1, tab2 = st.tabs(["Answer", "Sources"]) | |
with tab1: | |
for i, source in enumerate(st.session_state.context_sources): | |
name = f'{source}' | |
with st.expander(name): | |
st.markdown(f'{st.session_state.context_content[i]}') | |
# with st.spinner("Generating answer..."): | |
# Generate the full answer at once | |
# full_answer = chain.invoke({"question": query, "chat_history": st.session_state.messages}) | |
# Display the full answer | |
st.markdown(full_answer, unsafe_allow_html=True) | |
# with tab2: | |
# for i, source in enumerate(st.session_state.context_sources): | |
# name = f'{source}' | |
# with st.expander(name): | |
# st.markdown(f'{st.session_state.context_content[i]}') | |
st.session_state.messages.append({"role": "assistant", "content": full_answer}) | |
st.session_state.messages[-1]['sources'] = st.session_state.context_sources | |
st.session_state.messages[-1]['context'] = st.session_state.context_content | |