Quality-Control-Inspector / tlt /datasets /text_classification /hf_text_classification_dataset.py
ParamDev's picture
Upload folder using huggingface_hub
a01ef8c verified
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2022 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
#
import os
from datasets import concatenate_datasets
from datasets.arrow_dataset import Dataset
from tlt import TLT_BASE_DIR
from tlt.utils.file_utils import read_json_file
from tlt.datasets.hf_dataset import HFDataset
from tlt.datasets.text_classification.text_classification_dataset import TextClassificationDataset
from downloader.datasets import DataDownloader
DATASET_CONFIG_DIR = os.path.join(TLT_BASE_DIR, "datasets/configs")
class HFTextClassificationDataset(TextClassificationDataset, HFDataset):
"""
A text classification dataset from the Hugging Face datasets catalog
"""
def __init__(self, dataset_dir, dataset_name, split=['train'], num_workers=0, shuffle_files=True,
distributed=False):
if not isinstance(split, list):
raise ValueError("Value of split argument must be a list.")
TextClassificationDataset.__init__(self, dataset_dir, dataset_name, "huggingface")
self._preprocessed = {}
self._split = split
self._data_loader = None
self._train_loader = None
self._test_loader = None
self._validation_loader = None
self._train_subset = None
self._test_subset = None
self._validation_subset = None
self._num_workers = num_workers
self._shuffle = shuffle_files
self._distributed = distributed
self._info = {
'name': dataset_name
}
if len(split) == 1:
self._validation_type = None # Train & evaluate on the whole dataset
# If only one split is given use it as the main dataset object
self._dataset = self.load_hf_dataset(dataset_name, split=split[0])
else:
self._validation_type = 'defined_split' # Defined by user or Hugging Face
if 'train' in split:
self._dataset = self.load_hf_dataset(dataset_name, split='train')
self._train_indices = range(len(self._dataset))
self._train_subset = self.train_subset
if 'test' in split:
test_dataset = self.load_hf_dataset(dataset_name, split='test')
test_length = len(test_dataset)
if self._dataset:
current_length = len(self._dataset)
self._dataset = concatenate_datasets([self._dataset, test_dataset])
self._test_indices = range(current_length, current_length + test_length)
else:
self._dataset = test_dataset
self._test_indices = range(test_length)
self._test_subset = self.test_subset
if 'validation' in split:
validation_dataset = self.load_hf_dataset(dataset_name, split='validation')
validation_length = len(validation_dataset)
if self._dataset:
current_length = len(self._dataset)
self._dataset = concatenate_datasets([self._dataset, validation_dataset])
self._validation_indices = range(current_length, current_length + validation_length)
else:
self._dataset = validation_dataset
self._validation_indices = range(validation_length)
self._validation_subset = self.validation_subset
if 'unsupervised' in split:
unsupervised_dataset = self.load_hf_dataset(dataset_name, split='unsupervised')
if self._dataset:
self._dataset = concatenate_datasets([self._dataset, unsupervised_dataset])
else:
self._dataset = unsupervised_dataset
def load_hf_dataset(self, dataset_name: str, split: str) -> Dataset:
"""
Helper function to load the dataset from hugging face catalog
"""
main_dataset = dataset_name
subset = None
config_file = os.path.join(DATASET_CONFIG_DIR, "hf_text_classification_datasets.json")
config_dict = read_json_file(config_file)
available_datasets = list(config_dict.keys())
if dataset_name not in available_datasets:
raise ValueError("Dataset is not supported. Choose from: {}".format(available_datasets))
# We separate the dataset_name by checking whether it has the format of "dataset/subset"
if '/' in dataset_name:
main_dataset = dataset_name.split('/')[0]
subset = dataset_name.split('/')[1]
if subset is not None:
downloader = DataDownloader(main_dataset, self._dataset_dir, catalog='hugging_face', subset=subset)
else:
downloader = DataDownloader(main_dataset, self._dataset_dir, catalog='hugging_face')
return downloader.download(split=split)
@property
def dataset(self) -> Dataset:
"""
Returns datasets.arrow_dataset.Dataset object
"""
return self._dataset
@property
def class_names(self) -> list:
"""
Returns a list of class labels
"""
try:
names = self.dataset.features['label'].names
except KeyError:
names = self.dataset.features['labels'].names
return names
@property
def info(self):
"""
Returns a dictionary of information about the dataset
"""
return {'dataset_info': self._info, 'preprocessing_info': self._preprocessed}
def __len__(self):
return len(self._dataset)