From 19824a8d9850371880ec01fd2698dc67299e6d96 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Mon, 14 Aug 2017 19:28:38 +0800 Subject: [PATCH] Move local data from global into class DataGenerator. --- data_utils/data.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/data_utils/data.py b/data_utils/data.py index f404b4fa..98180b4b 100644 --- a/data_utils/data.py +++ b/data_utils/data.py @@ -17,11 +17,6 @@ from data_utils.featurizer.speech_featurizer import SpeechFeaturizer from data_utils.speech import SpeechSegment from data_utils.normalizer import FeatureNormalizer -# for caching tar files info -local_data = local() -local_data.tar2info = {} -local_data.tar2object = {} - class DataGenerator(object): """ @@ -89,6 +84,10 @@ class DataGenerator(object): self._num_threads = num_threads self._rng = random.Random(random_seed) self._epoch = 0 + # for caching tar files info + self.local_data = local() + self.local_data.tar2info = {} + self.local_data.tar2object = {} def process_utterance(self, filename, transcript): """Load, augment, featurize and normalize for speech data. @@ -241,16 +240,16 @@ class DataGenerator(object): """ if file.startswith('tar:'): tarpath, filename = file.split(':', 1)[1].split('#', 1) - if 'tar2info' not in local_data.__dict__: - local_data.tar2info = {} - if 'tar2object' not in local_data.__dict__: - local_data.tar2object = {} - if tarpath not in local_data.tar2info: + if 'tar2info' not in self.local_data.__dict__: + self.local_data.tar2info = {} + if 'tar2object' not in self.local_data.__dict__: + self.local_data.tar2object = {} + if tarpath not in self.local_data.tar2info: object, infoes = self._parse_tar(tarpath) - local_data.tar2info[tarpath] = infoes - local_data.tar2object[tarpath] = object - return local_data.tar2object[tarpath].extractfile( - local_data.tar2info[tarpath][filename]) + self.local_data.tar2info[tarpath] = infoes + self.local_data.tar2object[tarpath] = object + return self.local_data.tar2object[tarpath].extractfile( + self.local_data.tar2info[tarpath][filename]) else: return open(file, 'r')