|
from pydantic import BaseModel |
|
BaseModel.Config.arbitrary_types_allowed = True |
|
from haystack.document_stores import InMemoryDocumentStore |
|
from haystack.nodes import PromptNode, PromptTemplate, AnswerParser, BM25Retriever |
|
from haystack.pipelines import Pipeline |
|
from haystack.utils import print_answers |
|
import chainlit as cl |
|
from datasets import load_dataset |
|
import os |
|
from dotenv import load_dotenv |
|
|
|
dataset = load_dataset("PKU-Alignment/BeaverTails", split="330k_train") |
|
document_store = InMemoryDocumentStore(use_bm25=True) |
|
|
|
|
|
child_abuse = [entry['prompt'] for entry in dataset if entry['category']['child_abuse']] |
|
|
|
for prompt in child_abuse: |
|
print(prompt) |
|
|
|
|
|
for case in child_abuse: |
|
print(case) |
|
|
|
prompts_list = [] |
|
|
|
for prompt in child_abuse: |
|
prompts_list.append(prompt) |
|
|
|
from haystack import Document |
|
|
|
|
|
documents = [Document(content=prompt) for prompt in prompts_list] |
|
|
|
|
|
document_store.write_documents(documents) |
|
|
|
|
|
retriever = BM25Retriever(document_store=document_store, top_k=3) |
|
|
|
prompt_template = PromptTemplate( |
|
prompt = """ |
|
Answer the question truthfully based solely on the given documents. If the documents do not contain the answer, |
|
say that answering is not possible given the available information. |
|
Documents: {join(documents)} |
|
Question: {query} |
|
Answer: |
|
""", |
|
output_parser= AnswerParser() |
|
) |
|
|
|
prompt_node = PromptNode( |
|
model_name_or_path= "mistralai/Mistral-7B-Instruct-v0.2", |
|
default_prompt_template= prompt_template |
|
) |
|
|
|
generative_pipeline = Pipeline() |
|
|
|
generative_pipeline.add_node(component=retriever, name='retriever', inputs=["Query"]) |
|
|
|
generative_pipeline.add_node(component=prompt_node, name="prompt_node", inputs=["retriever"]) |
|
|
|
@cl.on_message |
|
async def main(message: str): |
|
response = await cl.make_async(generative_pipeline.run(message)) |
|
sentences = response["answers"][0].answer.split("\n") |
|
|
|
if sentences and not sentences[-1].strip().endswith(('.','?','!')): |
|
sentences.pop() |
|
|
|
|
|
result = "\n".join(sentences[1:]) |
|
await cl.Message(author="Bot", content= result).send() |