|
import os |
|
import os.path as op |
|
import gc |
|
import json |
|
from typing import List |
|
import logging |
|
|
|
try: |
|
from .blob_storage import BlobStorage, disk_usage |
|
except: |
|
class BlobStorage: |
|
pass |
|
|
|
|
|
def generate_lineidx(filein: str, idxout: str) -> None: |
|
idxout_tmp = idxout + '.tmp' |
|
with open(filein, 'r') as tsvin, open(idxout_tmp, 'w') as tsvout: |
|
fsize = os.fstat(tsvin.fileno()).st_size |
|
fpos = 0 |
|
while fpos != fsize: |
|
tsvout.write(str(fpos) + "\n") |
|
tsvin.readline() |
|
fpos = tsvin.tell() |
|
os.rename(idxout_tmp, idxout) |
|
|
|
|
|
def read_to_character(fp, c): |
|
result = [] |
|
while True: |
|
s = fp.read(32) |
|
assert s != '' |
|
if c in s: |
|
result.append(s[: s.index(c)]) |
|
break |
|
else: |
|
result.append(s) |
|
return ''.join(result) |
|
|
|
|
|
class TSVFile(object): |
|
def __init__(self, |
|
tsv_file: str, |
|
if_generate_lineidx: bool = False, |
|
lineidx: str = None, |
|
class_selector: List[str] = None, |
|
blob_storage: BlobStorage = None): |
|
self.tsv_file = tsv_file |
|
self.lineidx = op.splitext(tsv_file)[0] + '.lineidx' \ |
|
if not lineidx else lineidx |
|
self.linelist = op.splitext(tsv_file)[0] + '.linelist' |
|
self.chunks = op.splitext(tsv_file)[0] + '.chunks' |
|
self._fp = None |
|
self._lineidx = None |
|
self._sample_indices = None |
|
self._class_boundaries = None |
|
self._class_selector = class_selector |
|
self._blob_storage = blob_storage |
|
self._len = None |
|
|
|
|
|
self.pid = None |
|
|
|
if not op.isfile(self.lineidx) and if_generate_lineidx: |
|
generate_lineidx(self.tsv_file, self.lineidx) |
|
|
|
def __del__(self): |
|
self.gcidx() |
|
if self._fp: |
|
self._fp.close() |
|
|
|
if self._blob_storage and 'azcopy' in self.tsv_file and os.path.exists(self.tsv_file): |
|
try: |
|
original_usage = disk_usage('/') |
|
os.remove(self.tsv_file) |
|
logging.info("Purged %s (disk usage: %.2f%% => %.2f%%)" % |
|
(self.tsv_file, original_usage, disk_usage('/') * 100)) |
|
except: |
|
|
|
|
|
pass |
|
|
|
def __str__(self): |
|
return "TSVFile(tsv_file='{}')".format(self.tsv_file) |
|
|
|
def __repr__(self): |
|
return str(self) |
|
|
|
def gcidx(self): |
|
logging.debug('Run gc collect') |
|
self._lineidx = None |
|
self._sample_indices = None |
|
|
|
return gc.collect() |
|
|
|
def get_class_boundaries(self): |
|
return self._class_boundaries |
|
|
|
def num_rows(self, gcf=False): |
|
if (self._len is None): |
|
self._ensure_lineidx_loaded() |
|
retval = len(self._sample_indices) |
|
|
|
if (gcf): |
|
self.gcidx() |
|
|
|
self._len = retval |
|
|
|
return self._len |
|
|
|
def seek(self, idx: int): |
|
self._ensure_tsv_opened() |
|
self._ensure_lineidx_loaded() |
|
try: |
|
pos = self._lineidx[self._sample_indices[idx]] |
|
except: |
|
logging.info('=> {}-{}'.format(self.tsv_file, idx)) |
|
raise |
|
self._fp.seek(pos) |
|
return [s.strip() for s in self._fp.readline().split('\t')] |
|
|
|
def seek_first_column(self, idx: int): |
|
self._ensure_tsv_opened() |
|
self._ensure_lineidx_loaded() |
|
pos = self._lineidx[idx] |
|
self._fp.seek(pos) |
|
return read_to_character(self._fp, '\t') |
|
|
|
def get_key(self, idx: int): |
|
return self.seek_first_column(idx) |
|
|
|
def __getitem__(self, index: int): |
|
return self.seek(index) |
|
|
|
def __len__(self): |
|
return self.num_rows() |
|
|
|
def _ensure_lineidx_loaded(self): |
|
if self._lineidx is None: |
|
logging.debug('=> loading lineidx: {}'.format(self.lineidx)) |
|
with open(self.lineidx, 'r') as fp: |
|
lines = fp.readlines() |
|
lines = [line.strip() for line in lines] |
|
self._lineidx = [int(line) for line in lines] |
|
|
|
|
|
linelist = None |
|
if op.isfile(self.linelist): |
|
with open(self.linelist, 'r') as fp: |
|
linelist = sorted( |
|
[ |
|
int(line.strip()) |
|
for line in fp.readlines() |
|
] |
|
) |
|
|
|
if op.isfile(self.chunks): |
|
self._sample_indices = [] |
|
self._class_boundaries = [] |
|
class_boundaries = json.load(open(self.chunks, 'r')) |
|
for class_name, boundary in class_boundaries.items(): |
|
start = len(self._sample_indices) |
|
if class_name in self._class_selector: |
|
for idx in range(boundary[0], boundary[1] + 1): |
|
|
|
if linelist and idx not in linelist: |
|
continue |
|
self._sample_indices.append(idx) |
|
end = len(self._sample_indices) |
|
self._class_boundaries.append((start, end)) |
|
else: |
|
if linelist: |
|
self._sample_indices = linelist |
|
else: |
|
self._sample_indices = list(range(len(self._lineidx))) |
|
|
|
def _ensure_tsv_opened(self): |
|
if self._fp is None: |
|
if self._blob_storage: |
|
self._fp = self._blob_storage.open(self.tsv_file) |
|
else: |
|
self._fp = open(self.tsv_file, 'r') |
|
self.pid = os.getpid() |
|
|
|
if self.pid != os.getpid(): |
|
logging.debug('=> re-open {} because the process id changed'.format(self.tsv_file)) |
|
self._fp = open(self.tsv_file, 'r') |
|
self.pid = os.getpid() |
|
|
|
|
|
class TSVWriter(object): |
|
def __init__(self, tsv_file): |
|
self.tsv_file = tsv_file |
|
self.lineidx_file = op.splitext(tsv_file)[0] + '.lineidx' |
|
self.tsv_file_tmp = self.tsv_file + '.tmp' |
|
self.lineidx_file_tmp = self.lineidx_file + '.tmp' |
|
|
|
self.tsv_fp = open(self.tsv_file_tmp, 'w') |
|
self.lineidx_fp = open(self.lineidx_file_tmp, 'w') |
|
|
|
self.idx = 0 |
|
|
|
def write(self, values, sep='\t'): |
|
v = '{0}\n'.format(sep.join(map(str, values))) |
|
self.tsv_fp.write(v) |
|
self.lineidx_fp.write(str(self.idx) + '\n') |
|
self.idx = self.idx + len(v) |
|
|
|
def close(self): |
|
self.tsv_fp.close() |
|
self.lineidx_fp.close() |
|
os.rename(self.tsv_file_tmp, self.tsv_file) |
|
os.rename(self.lineidx_file_tmp, self.lineidx_file) |
|
|