from pydantic import BaseModel BaseModel.Config.arbitrary_types_allowed = True from haystack.document_stores import InMemoryDocumentStore from haystack.nodes import PromptNode, PromptTemplate, AnswerParser, BM25Retriever from haystack.pipelines import Pipeline from haystack.utils import print_answers import chainlit as cl from datasets import load_dataset import os from dotenv import load_dotenv #load the dataset dataset = load_dataset("PKU-Alignment/BeaverTails", split="330k_train") document_store = InMemoryDocumentStore(use_bm25=True) #Extract the child abuse prompt in the dataset child_abuse = [entry['prompt'] for entry in dataset if entry['category']['child_abuse']] for prompt in child_abuse: print(prompt) # Print the filtered cases for case in child_abuse: print(case) prompts_list = [] for prompt in child_abuse: prompts_list.append(prompt) from haystack import Document # Convert each prompt into a Document object documents = [Document(content=prompt) for prompt in prompts_list] # Write the documents to the document store document_store.write_documents(documents) #initialise the retriever retriever = BM25Retriever(document_store=document_store, top_k=3) prompt_template = PromptTemplate( prompt = """ Answer the question truthfully based solely on the given documents. If the documents do not contain the answer, say that answering is not possible given the available information. Documents: {join(documents)} Question: {query} Answer: """, output_parser= AnswerParser() ) prompt_node = PromptNode( model_name_or_path= "mistralai/Mistral-7B-Instruct-v0.2", default_prompt_template= prompt_template ) generative_pipeline = Pipeline() generative_pipeline.add_node(component=retriever, name='retriever', inputs=["Query"]) generative_pipeline.add_node(component=prompt_node, name="prompt_node", inputs=["retriever"]) @cl.on_message async def main(message: str): response = await cl.make_async(generative_pipeline.run(message)) sentences = response["answers"][0].answer.split("\n") if sentences and not sentences[-1].strip().endswith(('.','?','!')): sentences.pop() result = "\n".join(sentences[1:]) await cl.Message(author="Bot", content= result).send()