File size: 2,083 Bytes
3ff674d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import pickle
from typing import List
from xml.etree.ElementTree import fromstring

import requests
from tqdm import tqdm
from groq import Groq
from cohere import Client

from gossip_semantic_search.constant import HEADERS, CHANNEL_KEY,  ITEM_KEY
from gossip_semantic_search.utils import (xml_to_dict, article_raw_to_article,
                                          generates_questions, embed_content)
from gossip_semantic_search.models import Article

class DatasetCreator:
    def __init__(self,
                 urls: List[str],
                 save_path: str = None,
                 number_questions:int = 0,
                 embed_articles: bool = False):

        self.urls = urls
        self.save_path = save_path
        self.number_questions = number_questions
        self.embed_articles = embed_articles

        self.articles: List[Article] = []

    def extract_articles(self):
        for url in self.urls:
            response = requests.get(url, headers=HEADERS)
            xml_string = response.text
            root = fromstring(xml_string)
            articles_raw = xml_to_dict(root)[CHANNEL_KEY][ITEM_KEY]
            self.articles.extend([article_raw_to_article(article_raw)
                                  for article_raw in articles_raw])

    def save_articles(self):
        with open(self.save_path, 'wb') as f:
            pickle.dump(self.articles, f)

    def generate_questions(self):
        client = Groq()
        for article in tqdm(self.articles, desc="Generating questions"):
            article.questions = generates_questions(article.content, self.number_questions, client)

    def embed_article(self):
        client = Client()
        for article in tqdm(self.articles, desc="Embedding content"):
            article.embeded_content = embed_content([article.content], client)[0, :]

    def run(self):
        self.extract_articles()

        if self.number_questions:
            self.generate_questions()

        if self.embed_articles:
            self.embed_article()

        if self.articles:
            self.save_articles()