File size: 2,275 Bytes
fa73541
 
4e1ad49
 
 
 
220d04b
 
4e1ad49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from pydantic import BaseModel
BaseModel.Config.arbitrary_types_allowed = True
from haystack.document_stores import InMemoryDocumentStore
from haystack.nodes import PromptNode, PromptTemplate, AnswerParser, BM25Retriever
from haystack.pipelines import Pipeline
from haystack.utils import print_answers
import chainlit as cl
from datasets import load_dataset
import os
from dotenv import load_dotenv
#load the dataset
dataset = load_dataset("PKU-Alignment/BeaverTails", split="330k_train")
document_store = InMemoryDocumentStore(use_bm25=True)

#Extract the child abuse prompt in the dataset
child_abuse = [entry['prompt'] for entry in dataset if entry['category']['child_abuse']]

for prompt in child_abuse:
    print(prompt)

# Print the filtered cases
for case in child_abuse:
    print(case)
    
prompts_list = []

for prompt in child_abuse:
    prompts_list.append(prompt)

from haystack import Document

# Convert each prompt into a Document object
documents = [Document(content=prompt) for prompt in prompts_list]

# Write the documents to the document store
document_store.write_documents(documents)

#initialise the retriever
retriever = BM25Retriever(document_store=document_store, top_k=3)

prompt_template = PromptTemplate(
    prompt = """
    Answer the question truthfully based solely on the given documents. If the documents do not contain the answer,
    say that answering is not possible given the available information. 
    Documents: {join(documents)}
    Question: {query}
    Answer: 
    """,
    output_parser= AnswerParser()
)

prompt_node = PromptNode(
    model_name_or_path= "mistralai/Mistral-7B-Instruct-v0.2",
    default_prompt_template= prompt_template
)

generative_pipeline = Pipeline()

generative_pipeline.add_node(component=retriever, name='retriever', inputs=["Query"])

generative_pipeline.add_node(component=prompt_node, name="prompt_node", inputs=["retriever"])

@cl.on_message
async def main(message: str):
    response = await cl.make_async(generative_pipeline.run(message))
    sentences = response["answers"][0].answer.split("\n")

    if sentences and not sentences[-1].strip().endswith(('.','?','!')):
        sentences.pop()


    result = "\n".join(sentences[1:])
    await cl.Message(author="Bot", content= result).send()