Spaces:
Runtime error
Runtime error
import uuid | |
from tempfile import NamedTemporaryFile | |
from typing import Tuple, List, Optional, Dict | |
from langchain.agents import AgentExecutor, AgentType, initialize_agent | |
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent | |
from langchain.callbacks.base import BaseCallbackHandler | |
from langchain.chains import LLMChain | |
from langchain.chains import RetrievalQA | |
from langchain.chat_models import ( | |
AzureChatOpenAI, | |
ChatOpenAI, | |
ChatAnthropic, | |
ChatAnyscale, | |
) | |
from langchain.document_loaders import PyPDFLoader | |
from langchain.embeddings import AzureOpenAIEmbeddings, OpenAIEmbeddings | |
from langchain.llms.base import BaseLLM | |
from langchain.memory import ConversationBufferMemory | |
from langchain.prompts import MessagesPlaceholder, ChatPromptTemplate | |
from langchain.retrievers import EnsembleRetriever | |
from langchain.retrievers.multi_query import MultiQueryRetriever | |
from langchain.retrievers.multi_vector import MultiVectorRetriever | |
from langchain.schema import Document, BaseRetriever | |
from langchain.schema.chat_history import BaseChatMessageHistory | |
from langchain.schema.runnable import RunnablePassthrough | |
from langchain.storage import InMemoryStore | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.tools.base import BaseTool | |
from langchain.vectorstores import FAISS | |
from langchain_core.messages import SystemMessage | |
from defaults import DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_OVERLAP, DEFAULT_RETRIEVER_K | |
from qagen import get_rag_qa_gen_chain | |
from summarize import get_rag_summarization_chain | |
def get_agent( | |
tools: list[BaseTool], | |
chat_history: BaseChatMessageHistory, | |
llm: BaseLLM, | |
callbacks, | |
): | |
memory_key = "agent_history" | |
system_message = SystemMessage( | |
content=( | |
"Do your best to answer the questions. " | |
"Feel free to use any tools available to look up " | |
"relevant information, only if necessary" | |
), | |
) | |
prompt = OpenAIFunctionsAgent.create_prompt( | |
system_message=system_message, | |
extra_prompt_messages=[MessagesPlaceholder(variable_name=memory_key)], | |
) | |
agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt) | |
# agent_memory = AgentTokenBufferMemory( | |
# chat_memory=chat_history, | |
# memory_key=memory_key, | |
# llm=llm, | |
# ) | |
agent_memory = ConversationBufferMemory( | |
chat_memory=chat_history, | |
return_messages=True, | |
memory_key=memory_key, | |
) | |
agent_executor = AgentExecutor( | |
agent=agent, | |
tools=tools, | |
memory=agent_memory, | |
verbose=True, | |
return_intermediate_steps=False, | |
callbacks=callbacks, | |
) | |
return ( | |
{"input": RunnablePassthrough()} | |
| agent_executor | |
| (lambda output: output["output"]) | |
) | |
def get_doc_agent( | |
tools: list[BaseTool], | |
llm: Optional[BaseLLM] = None, | |
agent_type: AgentType = AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION, | |
): | |
if llm is None: | |
llm = ChatOpenAI( | |
model_name="gpt-4-1106-preview", | |
temperature=0.0, | |
streaming=True, | |
) | |
prompt = ChatPromptTemplate.from_messages( | |
[ | |
( | |
"system", | |
""" | |
You assist a chatbot with answering questions about a document. | |
If necessary, break up incoming questions into multiple parts, | |
and use the tools provided to answer smaller questions before | |
answering the larger question. | |
""", | |
), | |
("user", "{input}"), | |
MessagesPlaceholder(variable_name="agent_scratchpad"), | |
], | |
) | |
agent_executor = initialize_agent( | |
tools, | |
llm, | |
agent=agent_type, | |
verbose=True, | |
memory=None, | |
handle_parsing_errors=True, | |
prompt=prompt, | |
) | |
return ( | |
{"input": RunnablePassthrough()} | |
| agent_executor | |
| (lambda output: output["output"]) | |
) | |
def get_runnable( | |
use_document_chat: bool, | |
document_chat_chain_type: str, | |
llm, | |
retriever, | |
memory, | |
chat_prompt, | |
summarization_prompt, | |
): | |
if not use_document_chat: | |
return LLMChain( | |
prompt=chat_prompt, | |
llm=llm, | |
memory=memory, | |
) | (lambda output: output["text"]) | |
if document_chat_chain_type == "Q&A Generation": | |
return get_rag_qa_gen_chain( | |
retriever, | |
llm, | |
) | |
elif document_chat_chain_type == "Summarization": | |
return get_rag_summarization_chain( | |
summarization_prompt, | |
retriever, | |
llm, | |
) | |
else: | |
return RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type=document_chat_chain_type, | |
retriever=retriever, | |
output_key="output_text", | |
) | (lambda output: output["output_text"]) | |
def get_llm( | |
provider: str, | |
model: str, | |
provider_api_key: str, | |
temperature: float, | |
max_tokens: int, | |
azure_available: bool, | |
azure_dict: dict[str, str], | |
): | |
if azure_available and provider == "Azure OpenAI": | |
return AzureChatOpenAI( | |
azure_endpoint=azure_dict["AZURE_OPENAI_BASE_URL"], | |
openai_api_version=azure_dict["AZURE_OPENAI_API_VERSION"], | |
deployment_name=azure_dict["AZURE_OPENAI_DEPLOYMENT_NAME"], | |
openai_api_key=azure_dict["AZURE_OPENAI_API_KEY"], | |
openai_api_type="azure", | |
model_version=azure_dict["AZURE_OPENAI_MODEL_VERSION"], | |
temperature=temperature, | |
streaming=True, | |
max_tokens=max_tokens, | |
) | |
elif provider_api_key: | |
if provider == "OpenAI": | |
return ChatOpenAI( | |
model_name=model, | |
openai_api_key=provider_api_key, | |
temperature=temperature, | |
streaming=True, | |
max_tokens=max_tokens, | |
) | |
elif provider == "Anthropic": | |
return ChatAnthropic( | |
model=model, | |
anthropic_api_key=provider_api_key, | |
temperature=temperature, | |
streaming=True, | |
max_tokens_to_sample=max_tokens, | |
) | |
elif provider == "Anyscale Endpoints": | |
return ChatAnyscale( | |
model_name=model, | |
anyscale_api_key=provider_api_key, | |
temperature=temperature, | |
streaming=True, | |
max_tokens=max_tokens, | |
) | |
return None | |
def get_texts_and_multiretriever( | |
uploaded_file_bytes: bytes, | |
openai_api_key: str, | |
chunk_size: int = DEFAULT_CHUNK_SIZE, | |
chunk_overlap: int = DEFAULT_CHUNK_OVERLAP, | |
k: int = DEFAULT_RETRIEVER_K, | |
azure_kwargs: Optional[Dict[str, str]] = None, | |
use_azure: bool = False, | |
) -> Tuple[List[Document], BaseRetriever]: | |
with NamedTemporaryFile() as temp_file: | |
temp_file.write(uploaded_file_bytes) | |
temp_file.seek(0) | |
loader = PyPDFLoader(temp_file.name) | |
documents = loader.load() | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=10000, | |
chunk_overlap=0, | |
) | |
child_text_splitter = RecursiveCharacterTextSplitter(chunk_size=400) | |
texts = text_splitter.split_documents(documents) | |
id_key = "doc_id" | |
text_ids = [str(uuid.uuid4()) for _ in texts] | |
sub_texts = [] | |
for i, text in enumerate(texts): | |
_id = text_ids[i] | |
_sub_texts = child_text_splitter.split_documents([text]) | |
for _text in _sub_texts: | |
_text.metadata[id_key] = _id | |
sub_texts.extend(_sub_texts) | |
embeddings_kwargs = {"openai_api_key": openai_api_key} | |
if use_azure and azure_kwargs: | |
azure_kwargs["azure_endpoint"] = azure_kwargs.pop("openai_api_base") | |
embeddings_kwargs.update(azure_kwargs) | |
embeddings = AzureOpenAIEmbeddings(**embeddings_kwargs) | |
else: | |
embeddings = OpenAIEmbeddings(**embeddings_kwargs) | |
store = InMemoryStore() | |
# MultiVectorRetriever | |
multivectorstore = FAISS.from_documents(sub_texts, embeddings) | |
multivector_retriever = MultiVectorRetriever( | |
vectorstore=multivectorstore, | |
docstore=store, | |
id_key=id_key, | |
) | |
multivector_retriever.docstore.mset(list(zip(text_ids, texts))) | |
# multivector_retriever.k = k | |
multiquery_text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=chunk_size, | |
chunk_overlap=chunk_overlap, | |
) | |
# MultiQueryRetriever | |
multiquery_texts = multiquery_text_splitter.split_documents(documents) | |
multiquerystore = FAISS.from_documents(multiquery_texts, embeddings) | |
multiquery_retriever = MultiQueryRetriever.from_llm( | |
retriever=multiquerystore.as_retriever(search_kwargs={"k": k}), | |
llm=ChatOpenAI(), | |
) | |
ensemble_retriever = EnsembleRetriever( | |
retrievers=[multiquery_retriever, multivector_retriever], | |
weights=[0.5, 0.5], | |
) | |
return multiquery_texts, ensemble_retriever | |
class StreamHandler(BaseCallbackHandler): | |
def __init__(self, container, initial_text=""): | |
self.container = container | |
self.text = initial_text | |
def on_llm_new_token(self, token: str, **kwargs) -> None: | |
self.text += token | |
self.container.markdown(self.text) | |