Skip to content

Commit

Permalink
Add support for iterating EXtra-data SourceData objects
Browse files Browse the repository at this point in the history
  • Loading branch information
philsmt committed Mar 22, 2022
1 parent b9e7564 commit 769f086
Showing 1 changed file with 33 additions and 5 deletions.
38 changes: 33 additions & 5 deletions pasha/functor.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,9 +266,20 @@ def wrap(cls, value):
# Same assumption as in DataArrayFunctor.
return

import extra_data as xd
try:
from extra_data import DataCollection, SourceData, KeyData
except ImportError:
# Only support versions for which these types are top-level
# symbols.
return

if hasattr(SourceData, 'select_trains'):
# Added in EXtra-data 1.10.0.
supported_types = (DataCollection, SourceData, KeyData)
else:
supported_types = (DataCollection, KeyData)

if isinstance(value, (xd.DataCollection, xd.keydata.KeyData)):
if isinstance(value, supported_types):
return cls(value)

def split(self, num_workers):
Expand All @@ -286,7 +297,24 @@ def iterate(self, share):
for f in subobj.files:
f.close()

it = zip(range(*share.indices(self.n_trains)), subobj.trains())

for index, (train_id, data) in it:
index_it = range(*share.indices(self.n_trains))

from extra_data import DataCollection, SourceData

if isinstance(subobj, SourceData):
# SourceData has no trains() iterator yet, so simulate it
# ourselves by reconstructing a DataCollection object and
# use its trains() iterator.
dc = DataCollection(
subobj.files, {subobj.source: subobj}, subobj.train_ids,
inc_suspect_trains=subobj.inc_suspect_trains,
is_single_run=True)
data_it = ((train_id, data[subobj.source])
for train_id, data in dc.trains())
else:
# Use the regular trains() iterator for DataCollection and
# KeyData
data_it = subobj.trains()

for index, (train_id, data) in zip(index_it, data_it):
yield index, train_id, data

0 comments on commit 769f086

Please sign in to comment.