VenusFactory / src /data /dataloader.py
2dogey's picture
Upload folder using huggingface_hub
8918ac7 verified
import json
import torch
import datasets
from torch.utils.data import DataLoader
from .collator import Collator
from .batch_sampler import BatchSampler
from .norm import normalize_dataset
from torch.utils.data import Dataset
from typing import Dict, Any, List, Union
import pandas as pd
def prepare_dataloaders(args, tokenizer, logger):
"""Prepare train, validation and test dataloaders."""
# Process datasets
train_dataset = datasets.load_dataset(args.dataset)['train']
train_dataset_token_lengths = [len(item['aa_seq']) for item in train_dataset]
val_dataset = datasets.load_dataset(args.dataset)['validation']
val_dataset_token_lengths = [len(item['aa_seq']) for item in val_dataset]
test_dataset = datasets.load_dataset(args.dataset)['test']
test_dataset_token_lengths = [len(item['aa_seq']) for item in test_dataset]
if args.normalize is not None:
train_dataset, val_dataset, test_dataset = normalize_dataset(train_dataset, val_dataset, test_dataset, args.normalize)
# log dataset info
logger.info("Dataset Statistics:")
logger.info("------------------------")
logger.info(f"Dataset: {args.dataset}")
logger.info(f" Number of train samples: {len(train_dataset)}")
logger.info(f" Number of val samples: {len(val_dataset)}")
logger.info(f" Number of test samples: {len(test_dataset)}")
# log 3 data points from train_dataset
logger.info("Sample 3 data points from train dataset:")
logger.info(f" Train data point 1: {train_dataset[0]}")
logger.info(f" Train data point 2: {train_dataset[1]}")
logger.info(f" Train data point 3: {train_dataset[2]}")
logger.info("------------------------")
collator = Collator(
tokenizer=tokenizer,
max_length=args.max_seq_len if args.max_seq_len > 0 else None,
structure_seq=args.structure_seq,
problem_type=args.problem_type,
plm_model=args.plm_model,
num_labels=args.num_labels
)
# Common dataloader parameters
dataloader_params = {
'num_workers': args.num_workers,
'collate_fn': collator,
'pin_memory': True,
'persistent_workers': True if args.num_workers > 0 else False,
'prefetch_factor': 2,
}
# Create dataloaders based on batching strategy
if args.batch_token is not None:
train_loader = create_token_based_loader(train_dataset, train_dataset_token_lengths, args.batch_token, True, **dataloader_params)
val_loader = create_token_based_loader(val_dataset, val_dataset_token_lengths, args.batch_token, False, **dataloader_params)
test_loader = create_token_based_loader(test_dataset, test_dataset_token_lengths, args.batch_token, False, **dataloader_params)
else:
train_loader = create_size_based_loader(train_dataset, args.batch_size, True, **dataloader_params)
val_loader = create_size_based_loader(val_dataset, args.batch_size, False, **dataloader_params)
test_loader = create_size_based_loader(test_dataset, args.batch_size, False, **dataloader_params)
return train_loader, val_loader, test_loader
def create_token_based_loader(dataset, token_lengths, batch_token, shuffle, **kwargs):
"""Create dataloader with token-based batching."""
sampler = BatchSampler(token_lengths, batch_token, shuffle=shuffle)
return DataLoader(dataset, batch_sampler=sampler, **kwargs)
def create_size_based_loader(dataset, batch_size, shuffle, **kwargs):
"""Create dataloader with size-based batching."""
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, **kwargs)