Spaces:
Runtime error
Runtime error
from typing import List | |
from langchain.output_parsers import PydanticOutputParser, OutputFixingParser | |
from langchain.prompts.chat import ( | |
ChatPromptTemplate, | |
) | |
from langchain.schema.language_model import BaseLanguageModel | |
from langchain.schema.retriever import BaseRetriever | |
from langchain.schema.runnable import RunnablePassthrough, RunnableSequence | |
from pydantic import BaseModel, Field | |
class QuestionAnswerPair(BaseModel): | |
question: str = Field(..., description="The question that will be answered.") | |
answer: str = Field(..., description="The answer to the question that was asked.") | |
def to_str(self, idx: int) -> str: | |
question_piece = f"{idx}. **Q:** {self.question}" | |
whitespace = " " * (len(str(idx)) + 2) | |
answer_piece = f"{whitespace}**A:** {self.answer}" | |
return f"{question_piece}\n\n{answer_piece}" | |
class QuestionAnswerPairList(BaseModel): | |
QuestionAnswerPairs: List[QuestionAnswerPair] | |
def to_str(self) -> str: | |
return "\n\n".join( | |
[ | |
qap.to_str(idx) | |
for idx, qap in enumerate(self.QuestionAnswerPairs, start=1) | |
], | |
) | |
PYDANTIC_PARSER: PydanticOutputParser = PydanticOutputParser( | |
pydantic_object=QuestionAnswerPairList, | |
) | |
templ1 = """You are a smart assistant designed to help college professors come up with reading comprehension questions. | |
Given a piece of text, you must come up with question and answer pairs that can be used to test a student's reading comprehension abilities. | |
Generate as many question/answer pairs as you can. | |
When coming up with the question/answer pairs, you must respond in the following format: | |
{format_instructions} | |
Do not provide additional commentary and do not wrap your response in Markdown formatting. Return RAW, VALID JSON. | |
""" | |
templ2 = """{prompt} | |
Please create question/answer pairs, in the specified JSON format, for the following text: | |
---------------- | |
{context}""" | |
CHAT_PROMPT = ChatPromptTemplate.from_messages( | |
[ | |
("system", templ1), | |
("human", templ2), | |
], | |
).partial(format_instructions=PYDANTIC_PARSER.get_format_instructions) | |
def get_rag_qa_gen_chain( | |
retriever: BaseRetriever, | |
llm: BaseLanguageModel, | |
input_key: str = "prompt", | |
) -> RunnableSequence: | |
return ( | |
{"context": retriever, input_key: RunnablePassthrough()} | |
| CHAT_PROMPT | |
| llm | |
| OutputFixingParser.from_llm(llm=llm, parser=PYDANTIC_PARSER) | |
| (lambda parsed_output: parsed_output.to_str()) | |
) | |