Skip to content

Commit 37aa39a

Browse files
Support 0<data_fraction<1 for CustomDatasetWithoutLabels (#328)
* Support 0<data_fraction<1 for CustomDatasetWithoutLabels * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 74c1d6d commit 37aa39a

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed

solo/data/pretrain_dataloader.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -352,16 +352,23 @@ def prepare_datasets(
352352

353353
if data_fraction > 0:
354354
assert data_fraction < 1, "Only use data_fraction for values smaller than 1."
355-
data = train_dataset.samples
356-
files = [f for f, _ in data]
357-
labels = [l for _, l in data]
358-
359355
from sklearn.model_selection import train_test_split
360356

361-
files, _, labels, _ = train_test_split(
362-
files, labels, train_size=data_fraction, stratify=labels, random_state=42
363-
)
364-
train_dataset.samples = [tuple(p) for p in zip(files, labels)]
357+
if isinstance(train_dataset, CustomDatasetWithoutLabels):
358+
files = train_dataset.images
359+
(
360+
files,
361+
_,
362+
) = train_test_split(files, train_size=data_fraction, random_state=42)
363+
train_dataset.images = files
364+
else:
365+
data = train_dataset.samples
366+
files = [f for f, _ in data]
367+
labels = [l for _, l in data]
368+
files, _, labels, _ = train_test_split(
369+
files, labels, train_size=data_fraction, stratify=labels, random_state=42
370+
)
371+
train_dataset.samples = [tuple(p) for p in zip(files, labels)]
365372

366373
return train_dataset
367374

0 commit comments

Comments
 (0)