File size: 3,717 Bytes
8918ac7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import json
import torch
import datasets
from torch.utils.data import DataLoader
from .collator import Collator
from .batch_sampler import BatchSampler
from .norm import normalize_dataset
from torch.utils.data import Dataset
from typing import Dict, Any, List, Union
import pandas as pd

def prepare_dataloaders(args, tokenizer, logger):
    """Prepare train, validation and test dataloaders."""
    # Process datasets
    train_dataset = datasets.load_dataset(args.dataset)['train']
    train_dataset_token_lengths = [len(item['aa_seq']) for item in train_dataset]
    val_dataset = datasets.load_dataset(args.dataset)['validation']
    val_dataset_token_lengths = [len(item['aa_seq']) for item in val_dataset]
    test_dataset = datasets.load_dataset(args.dataset)['test']
    test_dataset_token_lengths = [len(item['aa_seq']) for item in test_dataset]
    
    if args.normalize is not None:
        train_dataset, val_dataset, test_dataset = normalize_dataset(train_dataset, val_dataset, test_dataset, args.normalize)
    
    # log dataset info
    logger.info("Dataset Statistics:")
    logger.info("------------------------")
    logger.info(f"Dataset: {args.dataset}")
    logger.info(f"  Number of train samples: {len(train_dataset)}")
    logger.info(f"  Number of val samples: {len(val_dataset)}")
    logger.info(f"  Number of test samples: {len(test_dataset)}")
    
    # log 3 data points from train_dataset
    logger.info("Sample 3 data points from train dataset:")
    logger.info(f"  Train data point 1: {train_dataset[0]}")
    logger.info(f"  Train data point 2: {train_dataset[1]}")
    logger.info(f"  Train data point 3: {train_dataset[2]}")
    logger.info("------------------------")
    
    collator = Collator(
        tokenizer=tokenizer,
        max_length=args.max_seq_len if args.max_seq_len > 0 else None,
        structure_seq=args.structure_seq,
        problem_type=args.problem_type,
        plm_model=args.plm_model,
        num_labels=args.num_labels
    )
    
    # Common dataloader parameters
    dataloader_params = {
        'num_workers': args.num_workers,
        'collate_fn': collator,
        'pin_memory': True,
        'persistent_workers': True if args.num_workers > 0 else False,
        'prefetch_factor': 2,
    }
    
    # Create dataloaders based on batching strategy
    if args.batch_token is not None:
        train_loader = create_token_based_loader(train_dataset, train_dataset_token_lengths, args.batch_token, True, **dataloader_params)
        val_loader = create_token_based_loader(val_dataset, val_dataset_token_lengths, args.batch_token, False, **dataloader_params)
        test_loader = create_token_based_loader(test_dataset, test_dataset_token_lengths, args.batch_token, False, **dataloader_params)
    else:
        train_loader = create_size_based_loader(train_dataset, args.batch_size, True, **dataloader_params)
        val_loader = create_size_based_loader(val_dataset, args.batch_size, False, **dataloader_params)
        test_loader = create_size_based_loader(test_dataset, args.batch_size, False, **dataloader_params)
    
    return train_loader, val_loader, test_loader

def create_token_based_loader(dataset, token_lengths, batch_token, shuffle, **kwargs):
    """Create dataloader with token-based batching."""
    sampler = BatchSampler(token_lengths, batch_token, shuffle=shuffle)
    return DataLoader(dataset, batch_sampler=sampler, **kwargs)

def create_size_based_loader(dataset, batch_size, shuffle, **kwargs):
    """Create dataloader with size-based batching."""
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, **kwargs)