Spaces:
Sleeping
Sleeping
import pickle | |
from typing import List | |
from xml.etree.ElementTree import fromstring | |
import requests | |
from tqdm import tqdm | |
from groq import Groq | |
from cohere import Client | |
from gossip_semantic_search.constant import HEADERS, CHANNEL_KEY, ITEM_KEY | |
from gossip_semantic_search.utils import (xml_to_dict, article_raw_to_article, | |
generates_questions, embed_content) | |
from gossip_semantic_search.models import Article | |
class DatasetCreator: | |
def __init__(self, | |
urls: List[str], | |
save_path: str = None, | |
number_questions:int = 0, | |
embed_articles: bool = False): | |
self.urls = urls | |
self.save_path = save_path | |
self.number_questions = number_questions | |
self.embed_articles = embed_articles | |
self.articles: List[Article] = [] | |
def extract_articles(self): | |
for url in self.urls: | |
response = requests.get(url, headers=HEADERS) | |
xml_string = response.text | |
root = fromstring(xml_string) | |
articles_raw = xml_to_dict(root)[CHANNEL_KEY][ITEM_KEY] | |
self.articles.extend([article_raw_to_article(article_raw) | |
for article_raw in articles_raw]) | |
def save_articles(self): | |
with open(self.save_path, 'wb') as f: | |
pickle.dump(self.articles, f) | |
def generate_questions(self): | |
client = Groq() | |
for article in tqdm(self.articles, desc="Generating questions"): | |
article.questions = generates_questions(article.content, self.number_questions, client) | |
def embed_article(self): | |
client = Client() | |
for article in tqdm(self.articles, desc="Embedding content"): | |
article.embeded_content = embed_content([article.content], client)[0, :] | |
def run(self): | |
self.extract_articles() | |
if self.number_questions: | |
self.generate_questions() | |
if self.embed_articles: | |
self.embed_article() | |
if self.articles: | |
self.save_articles() | |