Ekkyboss commited on
Commit
4e1ad49
·
verified ·
1 Parent(s): 7aa9e42

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -0
app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from haystack.document_stores import InMemoryDocumentStore
2
+ from haystack.nodes import PromptNode, PromptTemplate, AnswerParser, BM25Retriever
3
+ from haystack.pipelines import Pipeline
4
+ from haystack.utils import print_answers
5
+ import os
6
+ from dotenv import load_dotenv
7
+ #load the dataset
8
+ dataset = load_dataset("PKU-Alignment/BeaverTails", split="330k_train")
9
+ document_store = InMemoryDocumentStore(use_bm25=True)
10
+
11
+ #Extract the child abuse prompt in the dataset
12
+ child_abuse = [entry['prompt'] for entry in dataset if entry['category']['child_abuse']]
13
+
14
+ for prompt in child_abuse:
15
+ print(prompt)
16
+
17
+ # Print the filtered cases
18
+ for case in child_abuse:
19
+ print(case)
20
+
21
+ prompts_list = []
22
+
23
+ for prompt in child_abuse:
24
+ prompts_list.append(prompt)
25
+
26
+ from haystack import Document
27
+
28
+ # Convert each prompt into a Document object
29
+ documents = [Document(content=prompt) for prompt in prompts_list]
30
+
31
+ # Write the documents to the document store
32
+ document_store.write_documents(documents)
33
+
34
+ #initialise the retriever
35
+ retriever = BM25Retriever(document_store=document_store, top_k=3)
36
+
37
+ prompt_template = PromptTemplate(
38
+ prompt = """
39
+ Answer the question truthfully based solely on the given documents. If the documents do not contain the answer,
40
+ say that answering is not possible given the available information.
41
+ Documents: {join(documents)}
42
+ Question: {query}
43
+ Answer:
44
+ """,
45
+ output_parser= AnswerParser()
46
+ )
47
+
48
+ prompt_node = PromptNode(
49
+ model_name_or_path= "mistralai/Mistral-7B-Instruct-v0.2",
50
+ default_prompt_template= prompt_template
51
+ )
52
+
53
+ generative_pipeline = Pipeline()
54
+
55
+ generative_pipeline.add_node(component=retriever, name='retriever', inputs=["Query"])
56
+
57
+ generative_pipeline.add_node(component=prompt_node, name="prompt_node", inputs=["retriever"])
58
+
59
+ @cl.on_message
60
+ async def main(message: str):
61
+ response = await cl.make_async(generative_pipeline.run(message))
62
+ sentences = response["answers"][0].answer.split("\n")
63
+
64
+ if sentences and not sentences[-1].strip().endswith(('.','?','!')):
65
+ sentences.pop()
66
+
67
+
68
+ result = "\n".join(sentences[1:])
69
+ await cl.Message(author="Bot", content= result).send()