Skip to content

Commit 6709095

Browse files
authored
Improving Transform and Rerank Module (#396)
1 parent 04debc9 commit 6709095

20 files changed

+920
-790
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ huggingface_data/huggingface_datasets/huggingface_datasets_datafinder_index
2222
huggingface_data/huggingface_datasets/reranking_dataset_index.json
2323
huggingface_data/huggingface_models/
2424
retrieved_dataset_dict/
25+
result/
26+
checkpoint/
2527
status.yaml
26-
2728
# Outputs generated by the colab demo
2829
trained_model/
2930
trained_tokenizer/

examples/create_transform_data_example.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,13 @@
3333

3434
# run this pipeline to retrieve relevant datasets, rerank them,
3535
# and transform them based on the prompt
36-
retriever = DescriptionDatasetRetriever()
37-
num_points_to_transform = 20
36+
total_num_points_to_transform = 20
37+
retriever = DescriptionDatasetRetriever(
38+
auto_transform_data=True,
39+
total_num_points_to_transform=total_num_points_to_transform,
40+
)
3841
retrieved_dataset_dict = retriever.retrieve_dataset_dict(
3942
prompt_spec,
40-
auto_transform_data=True,
41-
num_points_to_transform=num_points_to_transform,
4243
)
4344

4445
# save the final dataset to disk

examples/huggingface_data/huggingface_datasets/dataset_index.json

Lines changed: 0 additions & 1 deletion
This file was deleted.

examples/huggingface_data/huggingface_datasets/reranking_dataset_index.json

Lines changed: 0 additions & 1 deletion
This file was deleted.
Binary file not shown.

prompt2model/dataset_retriever/description_dataset_retriever.py

Lines changed: 233 additions & 64 deletions
Large diffs are not rendered by default.

prompt2model/dataset_retriever/reranking_prompt.py

Lines changed: 78 additions & 127 deletions
Large diffs are not rendered by default.
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
"""This module contains the functions to construct the prompt for task expansion."""
2+
METAPROMPT_BASE = "Carefully analyse the task description and examples of the task, and explain the task to give a clearer description. Do not explain each example, but rather capture the general trends. Also place special focus on the format of the input/output examples." # noqa: E501
3+
4+
TASK = """
5+
Task Description: {task_description}
6+
7+
Task Examples: {examples}
8+
"""
9+
10+
11+
def construct_prompt_for_task_explanation(instruction: str, demonstrations: str):
12+
"""Constructs prompt for task explanation.
13+
14+
This is useful for clarifying the requirements of a task,
15+
and providing a clearer description of the task.
16+
17+
Args:
18+
instruction (str): The task instruction.
19+
demonstrations (str): The task demonstrations.
20+
21+
Returns:
22+
str: The constructed prompt.
23+
"""
24+
task = TASK.format(task_description=instruction, examples=demonstrations)
25+
prompt = "\n--------\n".join([METAPROMPT_BASE, task])
26+
return prompt

prompt2model/dataset_transformer/base.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ def transform_data(
1717
self,
1818
prompt_spec: PromptSpec,
1919
dataset: datasets.Dataset,
20-
num_points_to_transform: int,
2120
) -> datasets.Dataset:
2221
"""Transform a split of data.
2322

prompt2model/dataset_transformer/prompt_based.py

Lines changed: 169 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66

77
import datasets
88

9+
from prompt2model.dataset_retriever.task_expansion_prompt import (
10+
construct_prompt_for_task_explanation,
11+
)
912
from prompt2model.dataset_transformer.base import DatasetTransformer
1013
from prompt2model.dataset_transformer.prompt_template import (
1114
construct_prompt_for_plan,
@@ -31,122 +34,200 @@ class PromptBasedDatasetTransformer(DatasetTransformer):
3134

3235
def __init__(
3336
self,
37+
num_points_to_transform: int = 10,
38+
max_allowed_failed_transforms: int = 3,
3439
plan_prompt_fn: Callable[
35-
[str, str, list[dict], int], str
40+
[str, str, list[dict]], str
3641
] = construct_prompt_for_plan,
3742
transform_prompt_fn: Callable[
38-
[str, str, str, dict], str
43+
[str, str, str, str], str
3944
] = construct_prompt_for_transform_data,
45+
num_retries: int = 10,
4046
):
41-
"""Initialize the class.
47+
"""Initializes an instance of the PromptBasedDatasetTransformer class.
4248
4349
Args:
44-
plan_prompt_fn: A function that takes in a description of the target task,
45-
example of the target task,
46-
list of dictionaries where each dictionary is a row from a potentially
47-
relevant dataset,
48-
and the number of rows to use from this potentially relevant dataset,
49-
and returns a plan prompt.
50-
51-
transform_prompt_fn: A function that takes in a description of the target
52-
task, an example of the target task,
53-
plan for dataset transformation,
54-
and the row from a potentially relevant dataset to be transformed.
50+
num_points_to_transform: The number of points to transform.
51+
max_allowed_failed_transforms: The maximum number of
52+
failed transforms allowed.
53+
plan_prompt_fn: The function to construct the prompt for plan
54+
transform_prompt_fn: The function to construct the prompt
55+
for transform data.
56+
num_retries: The number of retries to attempt for each API call.
5557
"""
5658
self.plan_prompt_fn = plan_prompt_fn
5759
self.transform_prompt_fn = transform_prompt_fn
5860
self.plan: str = ""
59-
60-
def make_dataset_from_samples(
61-
self,
62-
inputs: list[str],
63-
outputs: list[str],
64-
) -> datasets.DatasetDict:
65-
"""Given a list of inputs and outputs, make a dataset.
66-
67-
This function takes in inputs and outputs, both as list of strings,
68-
and returns a DatasetDict object with a single split, "train". It has
69-
two columns, "input_col" and "output_col".
70-
71-
72-
Args:
73-
inputs: A list of inputs, each input is a string.
74-
outputs: A list of outputs, each output is a string.
75-
76-
Returns:
77-
A DatasetDict object with a single split, "train". It has two
78-
columns, "input_col" and "output_col".
79-
"""
80-
if len(inputs) <= 0 or len(inputs) != len(outputs):
81-
raise ValueError("Length of inputs and outputs must be >0 and equal.")
82-
83-
dataset_dict = {}
84-
dataset_dict["train"] = datasets.Dataset.from_dict(
85-
{"input_col": inputs, "output_col": outputs}
61+
self.num_points_to_transform = num_points_to_transform
62+
self.curr_failed_transforms = 0
63+
self.max_allowed_failed_transforms = max_allowed_failed_transforms
64+
self.num_retries = num_retries
65+
66+
def generate_task_explanation(self, prompt_spec: PromptSpec) -> str:
67+
"""Generate task explanation."""
68+
task_explanation_prompt = construct_prompt_for_task_explanation(
69+
prompt_spec.instruction, prompt_spec.examples
70+
)
71+
return make_single_api_request(
72+
task_explanation_prompt, max_api_calls=self.num_retries
8673
)
87-
return datasets.DatasetDict(dataset_dict)
8874

89-
def transform_data(
90-
self,
91-
prompt_spec: PromptSpec,
92-
dataset: datasets.Dataset,
93-
num_points_to_transform: int,
94-
) -> datasets.DatasetDict:
95-
"""Transform the dataset according to the prompt_spec and dataset."""
75+
def generate_plan(
76+
self, task_explanation: str, dataset: datasets.Dataset, prompt_spec: PromptSpec
77+
) -> str:
78+
"""Generate plan for the task."""
9679
plan_prompt = self.plan_prompt_fn(
97-
prompt_spec.instruction,
98-
prompt_spec.examples,
99-
dataset,
100-
min(5, len(dataset)),
80+
task_explanation, prompt_spec.examples, dataset
10181
)
102-
self.plan = make_single_api_request(plan_prompt)
103-
104-
logger.info(f"Plan created. Plan: {self.plan}")
105-
106-
inputs = []
107-
outputs = []
82+
return make_single_api_request(plan_prompt, max_api_calls=self.num_retries)
10883

109-
max_len = min(num_points_to_transform, len(dataset))
110-
len_count = 0
84+
def generate_transform_prompts(
85+
self,
86+
task_explanation: str,
87+
dataset: datasets.Dataset,
88+
prompt_spec: PromptSpec,
89+
) -> list[str]:
90+
"""Get transform prompts for each row in the dataset."""
11191
transform_prompts = []
112-
for row in dataset:
92+
for i in range(min(self.num_points_to_transform, len(dataset))):
93+
row = dataset[i]
11394
transform_prompt = self.transform_prompt_fn(
114-
prompt_spec.instruction,
115-
prompt_spec.examples,
116-
self.plan,
117-
row,
95+
task_explanation, row, self.plan, prompt_spec.examples
11896
)
11997
transform_prompts.append(transform_prompt)
98+
return transform_prompts
12099

121-
len_count += 1
122-
if len_count >= max_len:
123-
break
100+
def generate_responses(
101+
self, transform_prompts_batch: list[str], model_name="gpt-3.5-turbo"
102+
) -> list[str]:
103+
"""Generate responses for the given transform prompts.
124104
125-
async def generate_responses(transform_prompts):
126-
responses = await api_tools.default_api_agent.generate_batch_completion(
127-
transform_prompts,
128-
temperature=0,
129-
responses_per_request=1,
130-
requests_per_minute=15,
131-
)
132-
return responses
105+
Args:
106+
transform_prompts_batch: A list of transform prompts.
107+
model_name: The name of the model to use. Defaults to
108+
"gpt-3.5-turbo" to save costs.
133109
134-
try:
135-
loop = asyncio.get_event_loop()
136-
responses = loop.run_until_complete(generate_responses(transform_prompts))
137-
except API_ERRORS as e:
138-
handle_api_error(e)
110+
Returns:
111+
A list of generated responses.
112+
113+
Raises:
114+
API_ERRORS: If there is an error with the API.
115+
116+
"""
117+
api_call_counter = 0
118+
last_error = None
119+
responses = []
120+
while True:
121+
api_call_counter += 1
122+
123+
async def generate_responses_async(transform_prompts):
124+
"""Generate responses asynchronously using the specified model."""
125+
responses = await api_tools.APIAgent(
126+
model_name=model_name
127+
).generate_batch_completion(
128+
transform_prompts,
129+
temperature=0,
130+
responses_per_request=1,
131+
requests_per_minute=15,
132+
)
133+
return responses
134+
135+
try:
136+
loop = asyncio.get_event_loop()
137+
responses = loop.run_until_complete(
138+
generate_responses_async(transform_prompts_batch)
139+
)
140+
break
141+
except API_ERRORS as e:
142+
last_error = e
143+
handle_api_error(e)
144+
if api_call_counter > self.num_retries:
145+
# In case we reach maximum number of API calls, we raise an error.
146+
logger.error("Maximum number of API calls reached.")
147+
raise RuntimeError(
148+
"Maximum number of API calls reached."
149+
) from last_error
150+
151+
return responses
152+
153+
def process_responses(
154+
self, responses: list, prompt_spec: PromptSpec
155+
) -> tuple[list[str], list[str]]:
156+
"""Process the responses received from the API.
157+
158+
Args:
159+
responses: A list of response strings from the API.
160+
prompt_spec: The PromptSpec object containing the instruction and examples.
139161
162+
Returns:
163+
A tuple containing two lists: inputs and outputs.
164+
- inputs: A list of transformed input strings.
165+
- outputs: A list of transformed output strings.
166+
"""
167+
inputs, outputs = [], []
168+
show_sample_flag = False
140169
for response in responses:
141170
try:
142171
extraction = find_and_parse_json(response, ["input", "output"], [])
143172
if extraction is not None:
144-
inputs.append(str(extraction["input"]))
145-
outputs.append(str(extraction["output"]))
173+
if extraction["input"] is None or extraction["output"] is None:
174+
raise ValueError("Input or output is None")
175+
input = str(extraction["input"]).strip()
176+
output = str(extraction["output"]).strip()
177+
if input in prompt_spec.examples:
178+
raise ValueError("Repeated Task Examples from prompt")
179+
180+
inputs.append(input)
181+
outputs.append(output)
182+
if show_sample_flag:
183+
logger.info(f"inputs\n{input}\n\nouputs\n{output}")
184+
show_sample_flag = False
185+
146186
except Exception as e:
147-
logger.error(f"Error extracting from response: {response}\nError: {e}")
148-
continue
187+
logger.error(f"Error extracting from response: {e}")
188+
self.curr_failed_transforms += 1
189+
if self.curr_failed_transforms > self.max_allowed_failed_transforms:
190+
break
149191

150-
logger.info(f"Requested length: {max_len}\nActual length: {len(inputs)}\n")
192+
return inputs, outputs
151193

152-
return self.make_dataset_from_samples(inputs, outputs)
194+
def transform_data(
195+
self, prompt_spec: PromptSpec, dataset: datasets.Dataset
196+
) -> tuple[list[str], list[str]]:
197+
"""Transforms the given dataset based on the provided prompt specification.
198+
199+
Args:
200+
prompt_spec: The prompt specification object that defines
201+
the transformation rules.
202+
dataset: The dataset to be transformed.
203+
204+
Returns:
205+
A tuple containing two lists: inputs and outputs.
206+
"""
207+
task_explanation = self.generate_task_explanation(prompt_spec)
208+
self.plan = self.generate_plan(task_explanation, dataset, prompt_spec)
209+
logger.info(f"Plan created. Plan: {self.plan}")
210+
211+
transform_prompts = self.generate_transform_prompts(
212+
task_explanation, dataset, prompt_spec
213+
)
214+
inputs, outputs = [], []
215+
for batch_indices in range(0, len(transform_prompts), 100):
216+
transform_prompt_batch = transform_prompts[
217+
batch_indices : batch_indices + 100
218+
]
219+
responses = self.generate_responses(transform_prompt_batch)
220+
curr_inputs, curr_outputs = self.process_responses(responses, prompt_spec)
221+
inputs.extend(curr_inputs)
222+
outputs.extend(curr_outputs)
223+
if self.curr_failed_transforms > self.max_allowed_failed_transforms:
224+
logger.error(
225+
f"Exceeded max allowed failed transforms: {self.curr_failed_transforms}" # noqa: E501
226+
)
227+
self.max_allowed_failed_transforms = 0
228+
break
229+
230+
logger.info(
231+
f"Requested length: {self.num_points_to_transform}\nActual length: {len(inputs)}\n" # noqa: E501
232+
)
233+
return inputs, outputs

0 commit comments

Comments
 (0)