marondeau commited on
Commit
60b4723
Β·
unverified Β·
2 Parent(s): 17cfdef 5a55b5b

Merge pull request #51 from marondeau/formatter

Browse files
buster/chatbot.py CHANGED
@@ -1,6 +1,7 @@
1
  import logging
2
  import os
3
  from dataclasses import dataclass, field
 
4
 
5
  import numpy as np
6
  import openai
@@ -9,6 +10,11 @@ import promptlayer
9
  from openai.embeddings_utils import cosine_similarity, get_embedding
10
 
11
  from buster.docparser import read_documents
 
 
 
 
 
12
 
13
  logger = logging.getLogger(__name__)
14
  logging.basicConfig(level=logging.INFO)
@@ -149,53 +155,49 @@ class Chatbot:
149
  documents_str: str = self.prepare_documents(matched_documents, max_words=self.cfg.max_words)
150
  return text_before_documents + documents_str + text_before_prompt + question
151
 
152
- def get_gpt_response(self, **completion_kwargs):
153
  # Call the API to generate a response
154
  logger.info(f"querying GPT...")
155
  try:
156
- return openai.Completion.create(**completion_kwargs)
157
-
158
  except Exception as e:
159
  # log the error and return a generic response instead.
160
  logger.exception("Error connecting to OpenAI API. See traceback:")
161
- response = {"choices": [{"text": "We're having trouble connecting to OpenAI right now... Try again soon!"}]}
162
- return response
163
 
164
- def generate_response(self, prompt: str, matched_documents: pd.DataFrame, unknown_prompt: str) -> str:
 
 
 
 
 
165
  """
166
  Generate a response based on the retrieved documents.
167
  """
168
  if len(matched_documents) == 0:
169
  # No matching documents were retrieved, return
170
- return unknown_prompt
 
171
 
172
  logger.info(f"Prompt: {prompt}")
173
  response = self.get_gpt_response(prompt=prompt, **self.cfg.completion_kwargs)
174
- response_str = response["choices"][0]["text"]
175
- logger.info(f"GPT Response:\n{response_str}")
176
- return response_str
177
-
178
- def add_sources(self, response: str, matched_documents: pd.DataFrame, sep: str, format: str):
179
- """
180
- Add sources fromt the matched documents to the response.
181
- """
182
-
183
- urls = matched_documents.url.to_list()
184
- titles = matched_documents.title.to_list()
185
- similarities = matched_documents.similarity.to_list()
186
-
187
- response += f"{sep}{sep}πŸ“ Here are the sources I used to answer your question:{sep}{sep}"
188
- for url, title, similarity in zip(urls, titles, similarities):
189
- if format == "markdown":
190
- response += f"[πŸ”— {title}]({url}), relevance: {similarity:2.3f}{sep}"
191
- elif format == "html":
192
- response += f"<a href='{url}'>πŸ”— {title}</a>{sep}"
193
- elif format == "slack":
194
- response += f"<{url}|πŸ”— {title}>, relevance: {similarity:2.3f}{sep}"
195
  else:
196
- raise ValueError(f"{format} is not a valid URL format.")
197
 
198
- return response
199
 
200
  def check_response_relevance(
201
  self, response: str, engine: str, unk_embedding: np.array, unk_threshold: float
@@ -217,36 +219,16 @@ class Chatbot:
217
  # Likely that the answer is meaningful, add the top sources
218
  return score < unk_threshold
219
 
220
- def format_response(self, response: str, matched_documents: pd.DataFrame, text_after_response: str) -> str:
221
- """
222
- Format the response by adding the sources if necessary, and a disclaimer prompt.
223
- """
224
- sep = self.cfg.separator
225
-
226
- is_relevant = self.check_response_relevance(
227
- response=response,
228
- engine=self.cfg.embedding_model,
229
- unk_embedding=self.unk_embedding,
230
- unk_threshold=self.cfg.unknown_threshold,
231
- )
232
- if is_relevant:
233
- # Passes our relevance detection mechanism that the answer is meaningful, add the top sources
234
- response = self.add_sources(
235
- response=response,
236
- matched_documents=matched_documents,
237
- sep=self.cfg.separator,
238
- format=self.cfg.link_format,
239
- )
240
-
241
- response += f"{sep}{sep}{sep}{text_after_response}{sep}"
242
-
243
- return response
244
-
245
- def process_input(self, question: str) -> str:
246
  """
247
  Main function to process the input question and generate a formatted output.
248
  """
249
 
 
 
 
 
 
250
  logger.info(f"User Question:\n{question}")
251
 
252
  matched_documents = self.rank_documents(
@@ -262,9 +244,6 @@ class Chatbot:
262
  text_before_prompt=self.cfg.text_before_prompt,
263
  text_before_documents=self.cfg.text_before_documents,
264
  )
265
- response = self.generate_response(prompt, matched_documents, self.cfg.unknown_prompt)
266
- formatted_output = self.format_response(
267
- response, matched_documents, text_after_response=self.cfg.text_after_response
268
- )
269
 
270
- return formatted_output
 
1
  import logging
2
  import os
3
  from dataclasses import dataclass, field
4
+ from typing import Iterable
5
 
6
  import numpy as np
7
  import openai
 
10
  from openai.embeddings_utils import cosine_similarity, get_embedding
11
 
12
  from buster.docparser import read_documents
13
+ from buster.formatter import Formatter, HTMLFormatter, MarkdownFormatter, SlackFormatter
14
+ from buster.formatter.base import Response, Source
15
+
16
+ FORMATTERS = {"text": Formatter, "slack": SlackFormatter, "html": HTMLFormatter, "markdown": MarkdownFormatter}
17
+
18
 
19
  logger = logging.getLogger(__name__)
20
  logging.basicConfig(level=logging.INFO)
 
155
  documents_str: str = self.prepare_documents(matched_documents, max_words=self.cfg.max_words)
156
  return text_before_documents + documents_str + text_before_prompt + question
157
 
158
+ def get_gpt_response(self, **completion_kwargs) -> Response:
159
  # Call the API to generate a response
160
  logger.info(f"querying GPT...")
161
  try:
162
+ response = openai.Completion.create(**completion_kwargs)
 
163
  except Exception as e:
164
  # log the error and return a generic response instead.
165
  logger.exception("Error connecting to OpenAI API. See traceback:")
166
+ return Response("", True, "We're having trouble connecting to OpenAI right now... Try again soon!")
 
167
 
168
+ text = response["choices"][0]["text"]
169
+ return Response(text)
170
+
171
+ def generate_response(
172
+ self, prompt: str, matched_documents: pd.DataFrame, unknown_prompt: str
173
+ ) -> tuple[Response, Iterable[Source]]:
174
  """
175
  Generate a response based on the retrieved documents.
176
  """
177
  if len(matched_documents) == 0:
178
  # No matching documents were retrieved, return
179
+ sources = tuple()
180
+ return Response(unknown_prompt), sources
181
 
182
  logger.info(f"Prompt: {prompt}")
183
  response = self.get_gpt_response(prompt=prompt, **self.cfg.completion_kwargs)
184
+ if response:
185
+ logger.info(f"GPT Response:\n{response.text}")
186
+ relevant = self.check_response_relevance(
187
+ response=response.text,
188
+ engine=self.cfg.embedding_model,
189
+ unk_embedding=self.unk_embedding,
190
+ unk_threshold=self.cfg.unknown_threshold,
191
+ )
192
+ if relevant:
193
+ sources = (
194
+ Source(dct["name"], dct["url"], dct["similarity"])
195
+ for dct in matched_documents.to_dict(orient="records")
196
+ )
 
 
 
 
 
 
 
 
197
  else:
198
+ sources = tuple()
199
 
200
+ return response, sources
201
 
202
  def check_response_relevance(
203
  self, response: str, engine: str, unk_embedding: np.array, unk_threshold: float
 
219
  # Likely that the answer is meaningful, add the top sources
220
  return score < unk_threshold
221
 
222
+ def process_input(self, question: str, formatter: Formatter = None) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  """
224
  Main function to process the input question and generate a formatted output.
225
  """
226
 
227
+ if formatter is None and self.cfg.link_format not in FORMATTERS:
228
+ raise ValueError(f"Unknown link format {self.cfg.link_format}")
229
+ elif formatter is None:
230
+ formatter = FORMATTERS[self.cfg.link_format]()
231
+
232
  logger.info(f"User Question:\n{question}")
233
 
234
  matched_documents = self.rank_documents(
 
244
  text_before_prompt=self.cfg.text_before_prompt,
245
  text_before_documents=self.cfg.text_before_documents,
246
  )
247
+ response, sources = self.generate_response(prompt, matched_documents, self.cfg.unknown_prompt)
 
 
 
248
 
249
+ return formatter(response, sources)
buster/formatter/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .base import Formatter
2
+ from .html import HTMLFormatter
3
+ from .markdown import MarkdownFormatter
4
+ from .slack import SlackFormatter
5
+
6
+ __all__ = [Formatter, HTMLFormatter, MarkdownFormatter, SlackFormatter]
buster/formatter/base.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Iterable, NamedTuple
3
+
4
+
5
+ # Should be from the `documents` module.
6
+ class Source(NamedTuple):
7
+ name: str
8
+ url: str
9
+ question_similarity: float
10
+ # TODO Add answer similarity.
11
+ # answer_similarity: float
12
+
13
+
14
+ # Should be from the `nlp` module.
15
+ @dataclass(slots=True)
16
+ class Response:
17
+ text: str
18
+ error: bool = False
19
+ error_msg: str | None = None
20
+
21
+
22
+ @dataclass
23
+ class Formatter:
24
+ source_template: str = "{source.name} (relevance: {source.question_similarity:2.3f})"
25
+ error_msg_template: str = "Something went wrong: {response.error_msg}"
26
+ error_fallback_template: str = "Something went very wrong."
27
+ sourced_answer_template: str = "{response.text}\n\nSources:\n{sources}\n\nBut what do I know, I'm a chatbot."
28
+ unsourced_answer_template: str = "{response.text}\n\nBut what do I know, I'm a chatbot."
29
+
30
+ def source_item(self, source: Source) -> str:
31
+ """Format a single source item."""
32
+ return self.source_template.format(source=source)
33
+
34
+ def sources_list(self, sources: Iterable[Source]) -> str | None:
35
+ """Format sources into a list."""
36
+ items = [self.source_item(source) for source in sources]
37
+ if not items:
38
+ return None # No list needed.
39
+
40
+ return "\n".join(f"{ind}. {item}" for ind, item in enumerate(items, 1))
41
+
42
+ def error(self, response: Response) -> str:
43
+ """Format an error message."""
44
+ if response.error_msg:
45
+ return self.error_msg_template.format(response=response)
46
+ return self.error_fallback_template.format(response=response)
47
+
48
+ def answer(self, response: Response, sources: Iterable[Source]) -> str:
49
+ """Format an answer and its sources."""
50
+ sources_list = self.sources_list(sources)
51
+ if not sources_list:
52
+ return self.sourced_answer_template.format(response=response, sources=sources_list)
53
+
54
+ return self.unsourced_answer_template.format(response=response)
55
+
56
+ def __call__(self, response: Response, sources: Iterable[Source]) -> str:
57
+ """Format an answer and its sources, or an error message."""
58
+ if response.error:
59
+ return self.error(response)
60
+ return self.answer(response, sources)
buster/formatter/html.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import html
2
+ from dataclasses import dataclass
3
+ from typing import Iterable
4
+
5
+ from buster.formatter.base import Formatter, Response, Source
6
+
7
+
8
+ @dataclass
9
+ class HTMLFormatter(Formatter):
10
+ """Format the answer in HTML."""
11
+
12
+ source_template: str = """<li><a href='{source.url}'>πŸ”— {source.name}</a></li>"""
13
+ error_msg_template: str = """<div class="error">Something went wrong:\n<p>{response.error_msg}</p></div>"""
14
+ error_fallback_template: str = """<div class="error">Something went very wrong.</div>"""
15
+ sourced_answer_template: str = (
16
+ """<div class="answer"><p>{response.text}</p></div>\n"""
17
+ """<div class="sources>πŸ“ Here are the sources I used to answer your question:\n"""
18
+ """<ol>\n{sources}</ol></div>\n"""
19
+ """<div class="footer">I'm a chatbot, bleep bloop.</div>"""
20
+ )
21
+ unsourced_answer_template: str = (
22
+ """<div class="answer">{response.text}</div>\n<div class="footer">I'm a chatbot, bleep bloop.</div>"""
23
+ )
24
+
25
+ def sources_list(self, sources: Iterable[Source]) -> str | None:
26
+ """Format sources into a list."""
27
+ items = [self.source_item(source) for source in sources]
28
+ if not items:
29
+ return None # No list needed.
30
+
31
+ return "\n".join(items)
32
+
33
+ def __call__(self, response: Response, sources: Iterable[Source]) -> str:
34
+ # Escape any html in the text.
35
+ response = Response(
36
+ html.escape(response.text) if response.text else response.text,
37
+ response.error,
38
+ html.escape(response.error_msg) if response.error_msg else response.error_msg,
39
+ )
40
+ sources = (Source(html.escape(source.name), source.url, source.question_similarity) for source in sources)
41
+ return super().__call__(response, sources)
buster/formatter/markdown.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Iterable
3
+
4
+ from buster.formatter.base import Formatter, Source
5
+
6
+
7
+ @dataclass
8
+ class MarkdownFormatter(Formatter):
9
+ """Format the answer in markdown."""
10
+
11
+ source_template: str = """[πŸ”— {source.name}]({source.url}), relevance: {source.question_similarity:2.3f}"""
12
+ error_msg_template: str = """Something went wrong:\n{response.error_msg}"""
13
+ error_fallback_template: str = """Something went very wrong."""
14
+ sourced_answer_template: str = (
15
+ """{response.text}\n\n"""
16
+ """πŸ“ Here are the sources I used to answer your question:\n"""
17
+ """{sources}\n\n"""
18
+ """I'm a chatbot, bleep bloop."""
19
+ )
20
+ unsourced_answer_template: str = """{response.text}\n\nI'm a chatbot, bleep bloop."""
21
+
22
+ def sources_list(self, sources: Iterable[Source]) -> str | None:
23
+ """Format sources into a list."""
24
+ items = [self.source_item(source) for source in sources]
25
+ if not items:
26
+ return None # No list needed.
27
+
28
+ return "\n".join(f"{ind}. {item}" for ind, item in enumerate(items, 1))
buster/formatter/slack.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Iterable
3
+
4
+ from buster.formatter.base import Formatter, Source
5
+
6
+
7
+ @dataclass
8
+ class SlackFormatter(Formatter):
9
+ """Format the answer for Slack."""
10
+
11
+ source_template: str = """<{source.url}|πŸ”— {source.name}>, relevance: {source.question_similarity:2.3f}"""
12
+ error_msg_template: str = """Something went wrong:\n{response.error_msg}"""
13
+ error_fallback_template: str = """Something went very wrong."""
14
+ sourced_answer_template: str = (
15
+ """{response.text}\n\n"""
16
+ """πŸ“ Here are the sources I used to answer your question:\n"""
17
+ """{sources}\n\n"""
18
+ """I'm a chatbot, bleep bloop."""
19
+ )
20
+ unsourced_answer_template: str = """{response.text}\n\nI'm a chatbot, bleep bloop."""
21
+
22
+ def sources_list(self, sources: Iterable[Source]) -> str | None:
23
+ """Format sources into a list."""
24
+ items = [self.source_item(source) for source in sources]
25
+ if not items:
26
+ return None # No list needed.
27
+
28
+ return "\n".join(items)