File size: 2,949 Bytes
e7a44ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import re
import torch
from collections import Counter
from datasets import load_dataset
from sklearn.preprocessing import LabelEncoder
from transformers import AutoTokenizer
import random

# ====== Dataset Loading ======

def load_emotion_dataset(split="train"):
    return load_dataset("dair-ai/emotion", split=split)

def encode_labels(dataset):
    le = LabelEncoder()
    all_labels = [example["label"] for example in dataset]
    le.fit(all_labels)
    dataset = dataset.map(lambda x: {"label": le.transform([x["label"]])[0]})
    return dataset, le

# ====== Tokenizer for RNN/LSTM ======

def simple_tokenizer(text):
    text = text.lower()
    text = re.sub(r"[^a-z0-9\s]", "", text)  # Remove special characters
    return text.split()

# ====== Vocab Builder for RNN/LSTM ======

def build_vocab(dataset, min_freq=2):
    counter = Counter()
    for example in dataset:
        tokens = simple_tokenizer(example["text"])
        counter.update(tokens)

    vocab = {"<PAD>": 0, "<UNK>": 1}
    idx = 2
    for word, freq in counter.items():
        if freq >= min_freq:
            vocab[word] = idx
            idx += 1
    return vocab

# ====== Collate Function for RNN/LSTM ======

def collate_fn_rnn(batch, vocab, max_length=32, partial_prob=0.0):
    texts = [item["text"] for item in batch]
    labels = [item["label"] for item in batch]

    all_input_ids = []
    for text in texts:
        tokens = simple_tokenizer(text)

        # πŸ”₯ Randomly truncate tokens with some probability
        if random.random() < partial_prob and len(tokens) > 5:
            # Keep between 30% to 70% of the tokens
            cutoff = random.randint(int(len(tokens)*0.3), int(len(tokens)*0.7))
            tokens = tokens[:cutoff]

        ids = [vocab.get(token, vocab["<UNK>"]) for token in tokens]
        if len(ids) < max_length:
            ids += [vocab["<PAD>"]] * (max_length - len(ids))
        else:
            ids = ids[:max_length]
        all_input_ids.append(ids)

    input_ids = torch.tensor(all_input_ids)
    labels = torch.tensor(labels)
    return input_ids, labels

# ====== Collate Function for Transformer ======

def collate_fn_transformer(batch, tokenizer, max_length=128, partial_prob=0.5):
    import random
    texts = []
    labels = []

    for item in batch:
        text = item["text"]
        tokens = text.split()

        # πŸ”₯ Random truncation
        if random.random() < partial_prob and len(tokens) > 5:
            cutoff = random.randint(int(len(tokens)*0.3), int(len(tokens)*0.7))
            tokens = tokens[:cutoff]
            text = " ".join(tokens)

        texts.append(text)
        labels.append(item["label"])

    encoding = tokenizer(texts, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt")
    encoding["labels"] = torch.tensor(labels)
    return encoding