|
|
|
from langchain.chains import LLMChain |
|
from langchain.llms import HuggingFacePipeline |
|
from transformers import ( |
|
AutoModelForCausalLM, |
|
AutoTokenizer, |
|
pipeline, |
|
T5Tokenizer, |
|
T5ForConditionalGeneration, |
|
GPT2TokenizerFast, |
|
) |
|
from transformers import LlamaForCausalLM, AutoModelForCausalLM, LlamaTokenizer |
|
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler |
|
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate, PromptTemplate |
|
|
|
|
|
|
|
|
|
from langchain.chat_models import ChatOpenAI |
|
|
|
from langchain.embeddings.openai import OpenAIEmbeddings |
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
from langchain.vectorstores import Chroma |
|
from langchain.text_splitter import ( |
|
CharacterTextSplitter, |
|
RecursiveCharacterTextSplitter, |
|
) |
|
from langchain.document_loaders import TextLoader, UnstructuredHTMLLoader, PyPDFLoader |
|
from langchain.chains.retrieval_qa.base import RetrievalQA |
|
from langchain.llms import HuggingFaceHub |
|
from dotenv import load_dotenv |
|
from langchain.llms import HuggingFaceTextGenInference |
|
from langchain.chains.question_answering import load_qa_chain |
|
from langchain.chains import ConversationalRetrievalChain |
|
from langchain.chains.conversation.memory import ( |
|
ConversationBufferMemory, |
|
ConversationBufferWindowMemory, |
|
) |
|
|
|
|
|
|
|
def get_llm_hf_online(inference_api_url=""): |
|
if not inference_api_url: |
|
inference_api_url = ( |
|
"https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta" |
|
) |
|
|
|
llm = HuggingFaceTextGenInference( |
|
|
|
verbose=True, |
|
|
|
max_new_tokens=1024, |
|
|
|
top_p=0.95, |
|
typical_p=0.95, |
|
temperature=0.1, |
|
|
|
|
|
|
|
inference_server_url=inference_api_url, |
|
timeout=10, |
|
|
|
) |
|
|
|
return llm |
|
|
|
|
|
def get_llm_hf_local(model_path): |
|
|
|
|
|
model = LlamaForCausalLM.from_pretrained( |
|
model_path, device_map="auto" |
|
) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
|
|
|
|
|
pipe = pipeline( |
|
"text-generation", |
|
model=model, |
|
tokenizer=tokenizer, |
|
max_new_tokens=1024, |
|
model_kwargs={"temperature": 0.1}, |
|
) |
|
llm = HuggingFacePipeline(pipeline=pipe) |
|
|
|
return llm |
|
|
|
|
|
def get_llm_hf_local_zephyr(model_path): |
|
|
|
|
|
model = LlamaForCausalLM.from_pretrained( |
|
model_path, device_map="auto" |
|
) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
|
|
|
|
|
pipe = pipeline( |
|
"text-generation", |
|
model=model, |
|
tokenizer=tokenizer, |
|
max_new_tokens=1024, |
|
temperature=0.1, |
|
|
|
|
|
|
|
return_full_text=True |
|
|
|
) |
|
llm = HuggingFacePipeline(pipeline=pipe) |
|
|
|
return llm |
|
|
|
|
|
def get_chat_vllm(model_name, inference_server_url, langfuse_callback=None): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
chat = ChatOpenAI( |
|
model=model_name, |
|
openai_api_key="EMPTY", |
|
openai_api_base=inference_server_url, |
|
max_tokens=512, |
|
temperature=0.1, |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
return chat |
|
|
|
def get_chat_vllm_stream(model_name, inference_server_url, langfuse_callback=None): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
chat = ChatOpenAI( |
|
model=model_name, |
|
openai_api_key="EMPTY", |
|
openai_api_base=inference_server_url, |
|
max_tokens=512, |
|
temperature=0.1, |
|
streaming=True, |
|
callbacks=[StreamingStdOutCallbackHandler(), langfuse_callback], |
|
) |
|
|
|
|
|
|
|
|
|
|
|
return chat |
|
|
|
|
|
def get_chat_vllm_stream_TODO(model_name, inference_server_url, streaming=True): |
|
|
|
|
|
|
|
|
|
if streaming: |
|
streaming_callback = StreamingStdOutCallbackHandler() |
|
else: |
|
streaming_callback = None |
|
|
|
from langchain.callbacks.manager import CallbackManager |
|
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()]) |
|
|
|
|
|
|
|
chat = ChatOpenAI( |
|
model=model_name, |
|
openai_api_key="EMPTY", |
|
openai_api_base=inference_server_url, |
|
max_tokens=512, |
|
temperature=0.1, |
|
streaming=streaming, |
|
callbacks=[streaming_callback], |
|
callback_manager=callback_manager, |
|
stream=True, |
|
) |
|
|
|
from langchain_community.llms import VLLMOpenAI |
|
from langchain.callbacks.manager import CallbackManager |
|
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()]) |
|
|
|
llm = VLLMOpenAI( |
|
openai_api_key="EMPTY", |
|
openai_api_base=inference_server_url, |
|
model=model_name, |
|
max_tokens=512, |
|
temperature=0.1, |
|
streaming=True, |
|
stream=True, |
|
callbacks=[streaming_callback], |
|
callback_manager=callback_manager, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
return chat |
|
|
|
|
|
|
|
def _get_llm_hf_local(model_path): |
|
model_path = "/mnt/localstorage/yinghan/llm/orca_mini_v3_13b" |
|
model_path = "/mnt/localstorage/yinghan/llm/zephyr-7b-beta" |
|
model = LlamaForCausalLM.from_pretrained( |
|
model_path, device_map="auto" |
|
) |
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
|
|
|
|
|
pipe = pipeline( |
|
"text-generation", |
|
model=model, |
|
tokenizer=tokenizer, |
|
max_new_tokens=1024, |
|
model_kwargs={"temperature": 0}, |
|
) |
|
llm = HuggingFacePipeline(pipeline=pipe) |
|
|
|
return llm |
|
|
|
|
|
|
|
from langchain.chains import RetrievalQAWithSourcesChain, StuffDocumentsChain |
|
|
|
def get_cite_combine_docs_chain(llm): |
|
|
|
|
|
|
|
def format_document(doc, index, prompt): |
|
"""Format a document into a string based on a prompt template.""" |
|
|
|
base_info = {"page_content": doc.page_content, "index": index, "source": doc.metadata["source"]} |
|
|
|
|
|
missing_metadata = set(prompt.input_variables).difference(base_info) |
|
if len(missing_metadata) > 0: |
|
raise ValueError(f"Missing metadata: {list(missing_metadata)}.") |
|
|
|
|
|
document_info = {k: base_info[k] for k in prompt.input_variables} |
|
return prompt.format(**document_info) |
|
|
|
|
|
class StuffDocumentsWithIndexChain(StuffDocumentsChain): |
|
def _get_inputs(self, docs, **kwargs): |
|
|
|
doc_strings = [ |
|
format_document(doc, i, self.document_prompt) |
|
for i, doc in enumerate(docs, 1) |
|
] |
|
|
|
|
|
inputs = {k: v for k, v in kwargs.items() if k in self.llm_chain.prompt.input_variables} |
|
inputs[self.document_variable_name] = self.document_separator.join(doc_strings) |
|
return inputs |
|
|
|
|
|
|
|
combine_doc_prompt = PromptTemplate( |
|
input_variables=["context", "question"], |
|
template="""You are given a question and passages. Provide a clear and structured Helpful Answer based on the passages provided, |
|
the context and the guidelines. |
|
|
|
Guidelines: |
|
- If the passages have useful facts or numbers, use them in your answer. |
|
- When you use information from a passage, mention where it came from by using format [[i]] at the end of the sentence. i stands for the paper index of the document. |
|
- Do not cite the passage in a style like 'passage i', always use format [[i]] where i stands for the passage index of the document. |
|
- Do not use the sentence such as 'Doc i says ...' or '... in Doc i' or 'Passage i ...' to say where information came from. |
|
- If the same thing is said in more than one document, you can mention all of them like this: [[i]], [[j]], [[k]]. |
|
- Do not just summarize each passage one by one. Group your summaries to highlight the key parts in the explanation. |
|
- If it makes sense, use bullet points and lists to make your answers easier to understand. |
|
- You do not need to use every passage. Only use the ones that help answer the question. |
|
- If the documents do not have the information needed to answer the question, just say you do not have enough information. |
|
- If the passage is the caption of a picture, you can still use it as part of your answer as any other document. |
|
|
|
----------------------- |
|
Passages: |
|
{context} |
|
----------------------- |
|
Question: {question} |
|
|
|
Helpful Answer with format citations:""" |
|
) |
|
|
|
|
|
combine_docs_chain = StuffDocumentsWithIndexChain( |
|
llm_chain=LLMChain( |
|
llm=llm, |
|
prompt=combine_doc_prompt, |
|
), |
|
document_prompt=PromptTemplate( |
|
input_variables=["index", "source", "page_content"], |
|
template="[[{index}]]\nsource: {source}:\n{page_content}", |
|
), |
|
document_variable_name="context", |
|
) |
|
|
|
return combine_docs_chain |
|
|
|
|
|
class ConversationChainFactory_bp: |
|
def __init__( |
|
self, memory_key="chat_history", output_key="answer", return_messages=True |
|
): |
|
self.memory_key = memory_key |
|
self.output_key = output_key |
|
self.return_messages = return_messages |
|
|
|
def create(self, vectorstore, llm): |
|
memory = ConversationBufferWindowMemory( |
|
memory_key=self.memory_key, |
|
return_messages=self.return_messages, |
|
output_key=self.output_key, |
|
) |
|
|
|
|
|
conversation_chain = ConversationalRetrievalChain.from_llm( |
|
llm=llm, |
|
retriever=vectorstore.as_retriever(), |
|
memory=memory, |
|
return_source_documents=True, |
|
) |
|
|
|
return conversation_chain |
|
|
|
|
|
class ConversationChainFactory: |
|
def __init__( |
|
self, memory_key="chat_history", output_key="answer", return_messages=True |
|
): |
|
self.memory_key = memory_key |
|
self.output_key = output_key |
|
self.return_messages = return_messages |
|
|
|
def create(self, retriver, llm, langfuse_callback=None): |
|
memory = ConversationBufferWindowMemory( |
|
memory_key=self.memory_key, |
|
return_messages=self.return_messages, |
|
output_key=self.output_key, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
prompt_template = """You are a helpful research assistant. Use the following pieces of context to answer the question at the end. |
|
Please ignore the contexts if they are not related to the question. If you don't know the answer, just say that you don't know, |
|
don't try to make up an answer. |
|
|
|
{context} |
|
|
|
Question: {question} |
|
|
|
Helpful Answer:""" |
|
PROMPT = PromptTemplate( |
|
template=prompt_template, input_variables=["context", "question"] |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_template = """Return text in the original language of the follow up question. |
|
If the follow up question does not need context, return the exact same text back. |
|
Never rephrase the follow up question given the chat history unless the follow up question needs context. |
|
|
|
Chat History: {chat_history} |
|
|
|
Follow Up Question: {question} |
|
|
|
Standalone Question:""" |
|
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conversation_chain = ConversationalRetrievalChain.from_llm( |
|
llm=llm, |
|
retriever=retriver, |
|
memory=memory, |
|
return_source_documents=True, |
|
|
|
rephrase_question=False, |
|
get_chat_history=lambda x: x, |
|
|
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
return conversation_chain |
|
|
|
|
|
class ConversationChainFactoryDev: |
|
def __init__( |
|
self, memory_key="chat_history", output_key="answer", return_messages=True |
|
): |
|
self.memory_key = memory_key |
|
self.output_key = output_key |
|
self.return_messages = return_messages |
|
|
|
def create(self, retriver, llm, langfuse_callback=None): |
|
memory = ConversationBufferWindowMemory( |
|
memory_key=self.memory_key, |
|
return_messages=self.return_messages, |
|
output_key=self.output_key, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
prompt_template = """You are a helpful research assistant. Use the following pieces of context to answer the question at the end. |
|
Please ignore the contexts if they are not related to the question. If you don't know the answer, just say that you don't know, |
|
don't try to make up an answer. |
|
|
|
{context} |
|
|
|
Question: {question} |
|
|
|
Helpful Answer:""" |
|
PROMPT = PromptTemplate( |
|
template=prompt_template, input_variables=["context", "question"] |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_template = """Return text in the original language of the follow up question. |
|
If the follow up question does not need context, return the exact same text back. |
|
Never rephrase the follow up question given the chat history unless the follow up question needs context. |
|
|
|
Chat History: {chat_history} |
|
|
|
Follow Up Question: {question} |
|
|
|
Standalone Question:""" |
|
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conversation_chain = ConversationalRetrievalChain.from_llm( |
|
llm=llm, |
|
retriever=retriver, |
|
memory=memory, |
|
return_source_documents=True, |
|
|
|
rephrase_question=False, |
|
get_chat_history=lambda x: x, |
|
|
|
|
|
|
|
|
|
) |
|
|
|
|
|
return conversation_chain |
|
|
|
|
|
class RAGChain: |
|
def __init__( |
|
self, memory_key="chat_history", output_key="answer", return_messages=True |
|
): |
|
self.memory_key = memory_key |
|
self.output_key = output_key |
|
self.return_messages = return_messages |
|
|
|
def create(self, retriever, llm, add_citation=False): |
|
memory = ConversationBufferWindowMemory( |
|
k=2, |
|
memory_key=self.memory_key, |
|
return_messages=self.return_messages, |
|
output_key=self.output_key, |
|
) |
|
|
|
|
|
conversation_chain = ConversationalRetrievalChain.from_llm( |
|
llm=llm, |
|
retriever=retriever, |
|
memory=memory, |
|
return_source_documents=True, |
|
rephrase_question=False, |
|
get_chat_history=lambda x: x, |
|
|
|
|
|
|
|
|
|
) |
|
|
|
|
|
if add_citation: |
|
|
|
cite_combine_docs_chain = get_cite_combine_docs_chain(llm) |
|
conversation_chain.combine_docs_chain = cite_combine_docs_chain |
|
|
|
return conversation_chain |
|
|
|
|