@@ -142,80 +142,81 @@ async def generate_batch_completion(
142
142
Returns:
143
143
List of generated responses.
144
144
"""
145
- openai .aiosession .set (ClientSession ())
146
- limiter = aiolimiter .AsyncLimiter (requests_per_minute )
147
-
148
- async def _throttled_completion_acreate (
149
- model : str ,
150
- messages : list [dict [str , str ]],
151
- temperature : float ,
152
- max_tokens : int ,
153
- n : int ,
154
- top_p : float ,
155
- limiter : aiolimiter .AsyncLimiter ,
156
- ):
157
- async with limiter :
158
- for _ in range (3 ):
159
- try :
160
- return await acompletion (
161
- model = model ,
162
- messages = messages ,
163
- api_base = self .api_base ,
164
- temperature = temperature ,
165
- max_tokens = max_tokens ,
166
- n = n ,
167
- top_p = top_p ,
168
- )
169
- except tuple (ERROR_ERRORS_TO_MESSAGES .keys ()) as e :
170
- if isinstance (
171
- e ,
172
- (
173
- openai .APIStatusError ,
174
- openai .APIError ,
175
- ),
176
- ):
177
- logging .warning (
178
- ERROR_ERRORS_TO_MESSAGES [type (e )].format (e = e )
145
+ async with ClientSession () as _ :
146
+ limiter = aiolimiter .AsyncLimiter (requests_per_minute )
147
+
148
+ async def _throttled_completion_acreate (
149
+ model : str ,
150
+ messages : list [dict [str , str ]],
151
+ temperature : float ,
152
+ max_tokens : int ,
153
+ n : int ,
154
+ top_p : float ,
155
+ limiter : aiolimiter .AsyncLimiter ,
156
+ ):
157
+ async with limiter :
158
+ for _ in range (3 ):
159
+ try :
160
+ return await acompletion (
161
+ model = model ,
162
+ messages = messages ,
163
+ api_base = self .api_base ,
164
+ temperature = temperature ,
165
+ max_tokens = max_tokens ,
166
+ n = n ,
167
+ top_p = top_p ,
179
168
)
180
- elif isinstance (e , openai .BadRequestError ):
181
- logging .warning (ERROR_ERRORS_TO_MESSAGES [type (e )])
182
- return {
183
- "choices" : [
184
- {
185
- "message" : {
186
- "content" : "Invalid Request: Prompt was filtered" # noqa E501
169
+ except tuple (ERROR_ERRORS_TO_MESSAGES .keys ()) as e :
170
+ if isinstance (
171
+ e ,
172
+ (
173
+ openai .APIStatusError ,
174
+ openai .APIError ,
175
+ ),
176
+ ):
177
+ logging .warning (
178
+ ERROR_ERRORS_TO_MESSAGES [type (e )].format (e = e )
179
+ )
180
+ elif isinstance (e , openai .BadRequestError ):
181
+ logging .warning (ERROR_ERRORS_TO_MESSAGES [type (e )])
182
+ return {
183
+ "choices" : [
184
+ {
185
+ "message" : {
186
+ "content" : "Invalid Request: Prompt was filtered" # noqa E501
187
+ }
187
188
}
188
- }
189
- ]
190
- }
191
- else :
192
- logging .warning (ERROR_ERRORS_TO_MESSAGES [type (e )])
193
- await asyncio .sleep (10 )
194
- return {"choices" : [{"message" : {"content" : "" }}]}
195
-
196
- num_prompt_tokens = max (count_tokens_from_string (prompt ) for prompt in prompts )
197
- if self .max_tokens :
198
- max_tokens = self .max_tokens - num_prompt_tokens - token_buffer
199
- else :
200
- max_tokens = 3 * num_prompt_tokens
201
-
202
- async_responses = [
203
- _throttled_completion_acreate (
204
- model = self .model_name ,
205
- messages = [
206
- {"role" : "user" , "content" : f"{ prompt } " },
207
- ],
208
- temperature = temperature ,
209
- max_tokens = max_tokens ,
210
- n = responses_per_request ,
211
- top_p = 1 ,
212
- limiter = limiter ,
189
+ ]
190
+ }
191
+ else :
192
+ logging .warning (ERROR_ERRORS_TO_MESSAGES [type (e )])
193
+ await asyncio .sleep (10 )
194
+ return {"choices" : [{"message" : {"content" : "" }}]}
195
+
196
+ num_prompt_tokens = max (
197
+ count_tokens_from_string (prompt ) for prompt in prompts
213
198
)
214
- for prompt in prompts
215
- ]
216
- responses = await tqdm_asyncio .gather (* async_responses )
199
+ if self .max_tokens :
200
+ max_tokens = self .max_tokens - num_prompt_tokens - token_buffer
201
+ else :
202
+ max_tokens = 3 * num_prompt_tokens
203
+
204
+ async_responses = [
205
+ _throttled_completion_acreate (
206
+ model = self .model_name ,
207
+ messages = [
208
+ {"role" : "user" , "content" : f"{ prompt } " },
209
+ ],
210
+ temperature = temperature ,
211
+ max_tokens = max_tokens ,
212
+ n = responses_per_request ,
213
+ top_p = 1 ,
214
+ limiter = limiter ,
215
+ )
216
+ for prompt in prompts
217
+ ]
218
+ responses = await tqdm_asyncio .gather (* async_responses )
217
219
# Note: will never be none because it's set, but mypy doesn't know that.
218
- await openai .aiosession .get ().close ()
219
220
return responses
220
221
221
222
0 commit comments