Spaces:
Runtime error
Runtime error
import json | |
import torch | |
import datasets | |
from torch.utils.data import DataLoader | |
from .collator import Collator | |
from .batch_sampler import BatchSampler | |
from .norm import normalize_dataset | |
from torch.utils.data import Dataset | |
from typing import Dict, Any, List, Union | |
import pandas as pd | |
def prepare_dataloaders(args, tokenizer, logger): | |
"""Prepare train, validation and test dataloaders.""" | |
# Process datasets | |
train_dataset = datasets.load_dataset(args.dataset)['train'] | |
train_dataset_token_lengths = [len(item['aa_seq']) for item in train_dataset] | |
val_dataset = datasets.load_dataset(args.dataset)['validation'] | |
val_dataset_token_lengths = [len(item['aa_seq']) for item in val_dataset] | |
test_dataset = datasets.load_dataset(args.dataset)['test'] | |
test_dataset_token_lengths = [len(item['aa_seq']) for item in test_dataset] | |
if args.normalize is not None: | |
train_dataset, val_dataset, test_dataset = normalize_dataset(train_dataset, val_dataset, test_dataset, args.normalize) | |
# log dataset info | |
logger.info("Dataset Statistics:") | |
logger.info("------------------------") | |
logger.info(f"Dataset: {args.dataset}") | |
logger.info(f" Number of train samples: {len(train_dataset)}") | |
logger.info(f" Number of val samples: {len(val_dataset)}") | |
logger.info(f" Number of test samples: {len(test_dataset)}") | |
# log 3 data points from train_dataset | |
logger.info("Sample 3 data points from train dataset:") | |
logger.info(f" Train data point 1: {train_dataset[0]}") | |
logger.info(f" Train data point 2: {train_dataset[1]}") | |
logger.info(f" Train data point 3: {train_dataset[2]}") | |
logger.info("------------------------") | |
collator = Collator( | |
tokenizer=tokenizer, | |
max_length=args.max_seq_len if args.max_seq_len > 0 else None, | |
structure_seq=args.structure_seq, | |
problem_type=args.problem_type, | |
plm_model=args.plm_model, | |
num_labels=args.num_labels | |
) | |
# Common dataloader parameters | |
dataloader_params = { | |
'num_workers': args.num_workers, | |
'collate_fn': collator, | |
'pin_memory': True, | |
'persistent_workers': True if args.num_workers > 0 else False, | |
'prefetch_factor': 2, | |
} | |
# Create dataloaders based on batching strategy | |
if args.batch_token is not None: | |
train_loader = create_token_based_loader(train_dataset, train_dataset_token_lengths, args.batch_token, True, **dataloader_params) | |
val_loader = create_token_based_loader(val_dataset, val_dataset_token_lengths, args.batch_token, False, **dataloader_params) | |
test_loader = create_token_based_loader(test_dataset, test_dataset_token_lengths, args.batch_token, False, **dataloader_params) | |
else: | |
train_loader = create_size_based_loader(train_dataset, args.batch_size, True, **dataloader_params) | |
val_loader = create_size_based_loader(val_dataset, args.batch_size, False, **dataloader_params) | |
test_loader = create_size_based_loader(test_dataset, args.batch_size, False, **dataloader_params) | |
return train_loader, val_loader, test_loader | |
def create_token_based_loader(dataset, token_lengths, batch_token, shuffle, **kwargs): | |
"""Create dataloader with token-based batching.""" | |
sampler = BatchSampler(token_lengths, batch_token, shuffle=shuffle) | |
return DataLoader(dataset, batch_sampler=sampler, **kwargs) | |
def create_size_based_loader(dataset, batch_size, shuffle, **kwargs): | |
"""Create dataloader with size-based batching.""" | |
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, **kwargs) | |