Skip to content

Commit 6be1590

Browse files
authored
add client session fix (#398)
1 parent 6709095 commit 6be1590

File tree

1 file changed

+71
-70
lines changed

1 file changed

+71
-70
lines changed

prompt2model/utils/api_tools.py

Lines changed: 71 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -142,80 +142,81 @@ async def generate_batch_completion(
142142
Returns:
143143
List of generated responses.
144144
"""
145-
openai.aiosession.set(ClientSession())
146-
limiter = aiolimiter.AsyncLimiter(requests_per_minute)
147-
148-
async def _throttled_completion_acreate(
149-
model: str,
150-
messages: list[dict[str, str]],
151-
temperature: float,
152-
max_tokens: int,
153-
n: int,
154-
top_p: float,
155-
limiter: aiolimiter.AsyncLimiter,
156-
):
157-
async with limiter:
158-
for _ in range(3):
159-
try:
160-
return await acompletion(
161-
model=model,
162-
messages=messages,
163-
api_base=self.api_base,
164-
temperature=temperature,
165-
max_tokens=max_tokens,
166-
n=n,
167-
top_p=top_p,
168-
)
169-
except tuple(ERROR_ERRORS_TO_MESSAGES.keys()) as e:
170-
if isinstance(
171-
e,
172-
(
173-
openai.APIStatusError,
174-
openai.APIError,
175-
),
176-
):
177-
logging.warning(
178-
ERROR_ERRORS_TO_MESSAGES[type(e)].format(e=e)
145+
async with ClientSession() as _:
146+
limiter = aiolimiter.AsyncLimiter(requests_per_minute)
147+
148+
async def _throttled_completion_acreate(
149+
model: str,
150+
messages: list[dict[str, str]],
151+
temperature: float,
152+
max_tokens: int,
153+
n: int,
154+
top_p: float,
155+
limiter: aiolimiter.AsyncLimiter,
156+
):
157+
async with limiter:
158+
for _ in range(3):
159+
try:
160+
return await acompletion(
161+
model=model,
162+
messages=messages,
163+
api_base=self.api_base,
164+
temperature=temperature,
165+
max_tokens=max_tokens,
166+
n=n,
167+
top_p=top_p,
179168
)
180-
elif isinstance(e, openai.BadRequestError):
181-
logging.warning(ERROR_ERRORS_TO_MESSAGES[type(e)])
182-
return {
183-
"choices": [
184-
{
185-
"message": {
186-
"content": "Invalid Request: Prompt was filtered" # noqa E501
169+
except tuple(ERROR_ERRORS_TO_MESSAGES.keys()) as e:
170+
if isinstance(
171+
e,
172+
(
173+
openai.APIStatusError,
174+
openai.APIError,
175+
),
176+
):
177+
logging.warning(
178+
ERROR_ERRORS_TO_MESSAGES[type(e)].format(e=e)
179+
)
180+
elif isinstance(e, openai.BadRequestError):
181+
logging.warning(ERROR_ERRORS_TO_MESSAGES[type(e)])
182+
return {
183+
"choices": [
184+
{
185+
"message": {
186+
"content": "Invalid Request: Prompt was filtered" # noqa E501
187+
}
187188
}
188-
}
189-
]
190-
}
191-
else:
192-
logging.warning(ERROR_ERRORS_TO_MESSAGES[type(e)])
193-
await asyncio.sleep(10)
194-
return {"choices": [{"message": {"content": ""}}]}
195-
196-
num_prompt_tokens = max(count_tokens_from_string(prompt) for prompt in prompts)
197-
if self.max_tokens:
198-
max_tokens = self.max_tokens - num_prompt_tokens - token_buffer
199-
else:
200-
max_tokens = 3 * num_prompt_tokens
201-
202-
async_responses = [
203-
_throttled_completion_acreate(
204-
model=self.model_name,
205-
messages=[
206-
{"role": "user", "content": f"{prompt}"},
207-
],
208-
temperature=temperature,
209-
max_tokens=max_tokens,
210-
n=responses_per_request,
211-
top_p=1,
212-
limiter=limiter,
189+
]
190+
}
191+
else:
192+
logging.warning(ERROR_ERRORS_TO_MESSAGES[type(e)])
193+
await asyncio.sleep(10)
194+
return {"choices": [{"message": {"content": ""}}]}
195+
196+
num_prompt_tokens = max(
197+
count_tokens_from_string(prompt) for prompt in prompts
213198
)
214-
for prompt in prompts
215-
]
216-
responses = await tqdm_asyncio.gather(*async_responses)
199+
if self.max_tokens:
200+
max_tokens = self.max_tokens - num_prompt_tokens - token_buffer
201+
else:
202+
max_tokens = 3 * num_prompt_tokens
203+
204+
async_responses = [
205+
_throttled_completion_acreate(
206+
model=self.model_name,
207+
messages=[
208+
{"role": "user", "content": f"{prompt}"},
209+
],
210+
temperature=temperature,
211+
max_tokens=max_tokens,
212+
n=responses_per_request,
213+
top_p=1,
214+
limiter=limiter,
215+
)
216+
for prompt in prompts
217+
]
218+
responses = await tqdm_asyncio.gather(*async_responses)
217219
# Note: will never be none because it's set, but mypy doesn't know that.
218-
await openai.aiosession.get().close()
219220
return responses
220221

221222

0 commit comments

Comments
 (0)