6
6
7
7
import datasets
8
8
9
+ from prompt2model .dataset_retriever .task_expansion_prompt import (
10
+ construct_prompt_for_task_explanation ,
11
+ )
9
12
from prompt2model .dataset_transformer .base import DatasetTransformer
10
13
from prompt2model .dataset_transformer .prompt_template import (
11
14
construct_prompt_for_plan ,
@@ -31,122 +34,200 @@ class PromptBasedDatasetTransformer(DatasetTransformer):
31
34
32
35
def __init__ (
33
36
self ,
37
+ num_points_to_transform : int = 10 ,
38
+ max_allowed_failed_transforms : int = 3 ,
34
39
plan_prompt_fn : Callable [
35
- [str , str , list [dict ], int ], str
40
+ [str , str , list [dict ]], str
36
41
] = construct_prompt_for_plan ,
37
42
transform_prompt_fn : Callable [
38
- [str , str , str , dict ], str
43
+ [str , str , str , str ], str
39
44
] = construct_prompt_for_transform_data ,
45
+ num_retries : int = 10 ,
40
46
):
41
- """Initialize the class.
47
+ """Initializes an instance of the PromptBasedDatasetTransformer class.
42
48
43
49
Args:
44
- plan_prompt_fn: A function that takes in a description of the target task,
45
- example of the target task,
46
- list of dictionaries where each dictionary is a row from a potentially
47
- relevant dataset,
48
- and the number of rows to use from this potentially relevant dataset,
49
- and returns a plan prompt.
50
-
51
- transform_prompt_fn: A function that takes in a description of the target
52
- task, an example of the target task,
53
- plan for dataset transformation,
54
- and the row from a potentially relevant dataset to be transformed.
50
+ num_points_to_transform: The number of points to transform.
51
+ max_allowed_failed_transforms: The maximum number of
52
+ failed transforms allowed.
53
+ plan_prompt_fn: The function to construct the prompt for plan
54
+ transform_prompt_fn: The function to construct the prompt
55
+ for transform data.
56
+ num_retries: The number of retries to attempt for each API call.
55
57
"""
56
58
self .plan_prompt_fn = plan_prompt_fn
57
59
self .transform_prompt_fn = transform_prompt_fn
58
60
self .plan : str = ""
59
-
60
- def make_dataset_from_samples (
61
- self ,
62
- inputs : list [str ],
63
- outputs : list [str ],
64
- ) -> datasets .DatasetDict :
65
- """Given a list of inputs and outputs, make a dataset.
66
-
67
- This function takes in inputs and outputs, both as list of strings,
68
- and returns a DatasetDict object with a single split, "train". It has
69
- two columns, "input_col" and "output_col".
70
-
71
-
72
- Args:
73
- inputs: A list of inputs, each input is a string.
74
- outputs: A list of outputs, each output is a string.
75
-
76
- Returns:
77
- A DatasetDict object with a single split, "train". It has two
78
- columns, "input_col" and "output_col".
79
- """
80
- if len (inputs ) <= 0 or len (inputs ) != len (outputs ):
81
- raise ValueError ("Length of inputs and outputs must be >0 and equal." )
82
-
83
- dataset_dict = {}
84
- dataset_dict ["train" ] = datasets .Dataset .from_dict (
85
- {"input_col" : inputs , "output_col" : outputs }
61
+ self .num_points_to_transform = num_points_to_transform
62
+ self .curr_failed_transforms = 0
63
+ self .max_allowed_failed_transforms = max_allowed_failed_transforms
64
+ self .num_retries = num_retries
65
+
66
+ def generate_task_explanation (self , prompt_spec : PromptSpec ) -> str :
67
+ """Generate task explanation."""
68
+ task_explanation_prompt = construct_prompt_for_task_explanation (
69
+ prompt_spec .instruction , prompt_spec .examples
70
+ )
71
+ return make_single_api_request (
72
+ task_explanation_prompt , max_api_calls = self .num_retries
86
73
)
87
- return datasets .DatasetDict (dataset_dict )
88
74
89
- def transform_data (
90
- self ,
91
- prompt_spec : PromptSpec ,
92
- dataset : datasets .Dataset ,
93
- num_points_to_transform : int ,
94
- ) -> datasets .DatasetDict :
95
- """Transform the dataset according to the prompt_spec and dataset."""
75
+ def generate_plan (
76
+ self , task_explanation : str , dataset : datasets .Dataset , prompt_spec : PromptSpec
77
+ ) -> str :
78
+ """Generate plan for the task."""
96
79
plan_prompt = self .plan_prompt_fn (
97
- prompt_spec .instruction ,
98
- prompt_spec .examples ,
99
- dataset ,
100
- min (5 , len (dataset )),
80
+ task_explanation , prompt_spec .examples , dataset
101
81
)
102
- self .plan = make_single_api_request (plan_prompt )
103
-
104
- logger .info (f"Plan created. Plan: { self .plan } " )
105
-
106
- inputs = []
107
- outputs = []
82
+ return make_single_api_request (plan_prompt , max_api_calls = self .num_retries )
108
83
109
- max_len = min (num_points_to_transform , len (dataset ))
110
- len_count = 0
84
+ def generate_transform_prompts (
85
+ self ,
86
+ task_explanation : str ,
87
+ dataset : datasets .Dataset ,
88
+ prompt_spec : PromptSpec ,
89
+ ) -> list [str ]:
90
+ """Get transform prompts for each row in the dataset."""
111
91
transform_prompts = []
112
- for row in dataset :
92
+ for i in range (min (self .num_points_to_transform , len (dataset ))):
93
+ row = dataset [i ]
113
94
transform_prompt = self .transform_prompt_fn (
114
- prompt_spec .instruction ,
115
- prompt_spec .examples ,
116
- self .plan ,
117
- row ,
95
+ task_explanation , row , self .plan , prompt_spec .examples
118
96
)
119
97
transform_prompts .append (transform_prompt )
98
+ return transform_prompts
120
99
121
- len_count += 1
122
- if len_count >= max_len :
123
- break
100
+ def generate_responses (
101
+ self , transform_prompts_batch : list [str ], model_name = "gpt-3.5-turbo"
102
+ ) -> list [str ]:
103
+ """Generate responses for the given transform prompts.
124
104
125
- async def generate_responses (transform_prompts ):
126
- responses = await api_tools .default_api_agent .generate_batch_completion (
127
- transform_prompts ,
128
- temperature = 0 ,
129
- responses_per_request = 1 ,
130
- requests_per_minute = 15 ,
131
- )
132
- return responses
105
+ Args:
106
+ transform_prompts_batch: A list of transform prompts.
107
+ model_name: The name of the model to use. Defaults to
108
+ "gpt-3.5-turbo" to save costs.
133
109
134
- try :
135
- loop = asyncio .get_event_loop ()
136
- responses = loop .run_until_complete (generate_responses (transform_prompts ))
137
- except API_ERRORS as e :
138
- handle_api_error (e )
110
+ Returns:
111
+ A list of generated responses.
112
+
113
+ Raises:
114
+ API_ERRORS: If there is an error with the API.
115
+
116
+ """
117
+ api_call_counter = 0
118
+ last_error = None
119
+ responses = []
120
+ while True :
121
+ api_call_counter += 1
122
+
123
+ async def generate_responses_async (transform_prompts ):
124
+ """Generate responses asynchronously using the specified model."""
125
+ responses = await api_tools .APIAgent (
126
+ model_name = model_name
127
+ ).generate_batch_completion (
128
+ transform_prompts ,
129
+ temperature = 0 ,
130
+ responses_per_request = 1 ,
131
+ requests_per_minute = 15 ,
132
+ )
133
+ return responses
134
+
135
+ try :
136
+ loop = asyncio .get_event_loop ()
137
+ responses = loop .run_until_complete (
138
+ generate_responses_async (transform_prompts_batch )
139
+ )
140
+ break
141
+ except API_ERRORS as e :
142
+ last_error = e
143
+ handle_api_error (e )
144
+ if api_call_counter > self .num_retries :
145
+ # In case we reach maximum number of API calls, we raise an error.
146
+ logger .error ("Maximum number of API calls reached." )
147
+ raise RuntimeError (
148
+ "Maximum number of API calls reached."
149
+ ) from last_error
150
+
151
+ return responses
152
+
153
+ def process_responses (
154
+ self , responses : list , prompt_spec : PromptSpec
155
+ ) -> tuple [list [str ], list [str ]]:
156
+ """Process the responses received from the API.
157
+
158
+ Args:
159
+ responses: A list of response strings from the API.
160
+ prompt_spec: The PromptSpec object containing the instruction and examples.
139
161
162
+ Returns:
163
+ A tuple containing two lists: inputs and outputs.
164
+ - inputs: A list of transformed input strings.
165
+ - outputs: A list of transformed output strings.
166
+ """
167
+ inputs , outputs = [], []
168
+ show_sample_flag = False
140
169
for response in responses :
141
170
try :
142
171
extraction = find_and_parse_json (response , ["input" , "output" ], [])
143
172
if extraction is not None :
144
- inputs .append (str (extraction ["input" ]))
145
- outputs .append (str (extraction ["output" ]))
173
+ if extraction ["input" ] is None or extraction ["output" ] is None :
174
+ raise ValueError ("Input or output is None" )
175
+ input = str (extraction ["input" ]).strip ()
176
+ output = str (extraction ["output" ]).strip ()
177
+ if input in prompt_spec .examples :
178
+ raise ValueError ("Repeated Task Examples from prompt" )
179
+
180
+ inputs .append (input )
181
+ outputs .append (output )
182
+ if show_sample_flag :
183
+ logger .info (f"inputs\n { input } \n \n ouputs\n { output } " )
184
+ show_sample_flag = False
185
+
146
186
except Exception as e :
147
- logger .error (f"Error extracting from response: { response } \n Error: { e } " )
148
- continue
187
+ logger .error (f"Error extracting from response: { e } " )
188
+ self .curr_failed_transforms += 1
189
+ if self .curr_failed_transforms > self .max_allowed_failed_transforms :
190
+ break
149
191
150
- logger . info ( f"Requested length: { max_len } \n Actual length: { len ( inputs ) } \n " )
192
+ return inputs , outputs
151
193
152
- return self .make_dataset_from_samples (inputs , outputs )
194
+ def transform_data (
195
+ self , prompt_spec : PromptSpec , dataset : datasets .Dataset
196
+ ) -> tuple [list [str ], list [str ]]:
197
+ """Transforms the given dataset based on the provided prompt specification.
198
+
199
+ Args:
200
+ prompt_spec: The prompt specification object that defines
201
+ the transformation rules.
202
+ dataset: The dataset to be transformed.
203
+
204
+ Returns:
205
+ A tuple containing two lists: inputs and outputs.
206
+ """
207
+ task_explanation = self .generate_task_explanation (prompt_spec )
208
+ self .plan = self .generate_plan (task_explanation , dataset , prompt_spec )
209
+ logger .info (f"Plan created. Plan: { self .plan } " )
210
+
211
+ transform_prompts = self .generate_transform_prompts (
212
+ task_explanation , dataset , prompt_spec
213
+ )
214
+ inputs , outputs = [], []
215
+ for batch_indices in range (0 , len (transform_prompts ), 100 ):
216
+ transform_prompt_batch = transform_prompts [
217
+ batch_indices : batch_indices + 100
218
+ ]
219
+ responses = self .generate_responses (transform_prompt_batch )
220
+ curr_inputs , curr_outputs = self .process_responses (responses , prompt_spec )
221
+ inputs .extend (curr_inputs )
222
+ outputs .extend (curr_outputs )
223
+ if self .curr_failed_transforms > self .max_allowed_failed_transforms :
224
+ logger .error (
225
+ f"Exceeded max allowed failed transforms: { self .curr_failed_transforms } " # noqa: E501
226
+ )
227
+ self .max_allowed_failed_transforms = 0
228
+ break
229
+
230
+ logger .info (
231
+ f"Requested length: { self .num_points_to_transform } \n Actual length: { len (inputs )} \n " # noqa: E501
232
+ )
233
+ return inputs , outputs
0 commit comments