Skip to content

Commit d4cd23d

Browse files
authored
minor bug fixes (#401)
1 parent 19b541b commit d4cd23d

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

prompt2model/dataset_retriever/description_dataset_retriever.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def initialize_search_index(self) -> None:
116116
# Download the reranking index if one is not on disk already.
117117
logger.info("Downloading the Reranking Dataset Index File")
118118
urllib.request.urlretrieve(
119-
"http://phontron.com/data/prompt2model/dataset_reranking_index.json",
119+
"http://phontron.com/data/prompt2model/reranking_dataset_index.json",
120120
self.reranking_dataset_info_file,
121121
)
122122
with open(self.reranking_dataset_info_file, "r") as f:
@@ -659,12 +659,14 @@ def get_datasets_of_required_size(
659659
prompt_spec,
660660
self.total_num_points_to_transform - curr_datasets_size,
661661
)
662-
curr_datasets_size += len(canonicalized_dataset["train"]["input_col"])
663-
inputs += canonicalized_dataset["train"]["input_col"]
664-
outputs += canonicalized_dataset["train"]["output_col"]
665-
dataset_contributions[f"{dataset_name}_{config_name}"] = len(
666-
canonicalized_dataset["train"]["input_col"]
667-
)
662+
if canonicalized_dataset is not None and "train" in canonicalized_dataset:
663+
664+
curr_datasets_size += len(canonicalized_dataset["train"]["input_col"])
665+
inputs += canonicalized_dataset["train"]["input_col"]
666+
outputs += canonicalized_dataset["train"]["output_col"]
667+
dataset_contributions[f"{dataset_name}_{config_name}"] = len(
668+
canonicalized_dataset["train"]["input_col"]
669+
)
668670

669671
if len(datasets_info[dataset_name]["configs"]) == 1:
670672
del datasets_info[dataset_name]

0 commit comments

Comments
 (0)