pierre Brault
imit
3ff674d
import pickle
from typing import List
from xml.etree.ElementTree import fromstring
import requests
from tqdm import tqdm
from groq import Groq
from cohere import Client
from gossip_semantic_search.constant import HEADERS, CHANNEL_KEY, ITEM_KEY
from gossip_semantic_search.utils import (xml_to_dict, article_raw_to_article,
generates_questions, embed_content)
from gossip_semantic_search.models import Article
class DatasetCreator:
def __init__(self,
urls: List[str],
save_path: str = None,
number_questions:int = 0,
embed_articles: bool = False):
self.urls = urls
self.save_path = save_path
self.number_questions = number_questions
self.embed_articles = embed_articles
self.articles: List[Article] = []
def extract_articles(self):
for url in self.urls:
response = requests.get(url, headers=HEADERS)
xml_string = response.text
root = fromstring(xml_string)
articles_raw = xml_to_dict(root)[CHANNEL_KEY][ITEM_KEY]
self.articles.extend([article_raw_to_article(article_raw)
for article_raw in articles_raw])
def save_articles(self):
with open(self.save_path, 'wb') as f:
pickle.dump(self.articles, f)
def generate_questions(self):
client = Groq()
for article in tqdm(self.articles, desc="Generating questions"):
article.questions = generates_questions(article.content, self.number_questions, client)
def embed_article(self):
client = Client()
for article in tqdm(self.articles, desc="Embedding content"):
article.embeded_content = embed_content([article.content], client)[0, :]
def run(self):
self.extract_articles()
if self.number_questions:
self.generate_questions()
if self.embed_articles:
self.embed_article()
if self.articles:
self.save_articles()