Skip to content

Commit f9be9a6

Browse files
Merge pull request #345 from edenai/SD2-1328-important-streaming-should-work-with-the-new-llm-chat-endoints
[Fix] stream for llm completion
2 parents f5ab9fb + febdea3 commit f9be9a6

File tree

2 files changed

+14
-5
lines changed

2 files changed

+14
-5
lines changed

edenai_apis/features/llm/chat/chat_dataclass.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from typing import List, Optional, Union, Dict, Any, Literal
1+
from typing import List, Optional, Union, Dict, Any, Literal, Generator
22
from enum import Enum
33
from pydantic import BaseModel, Field, model_validator
4+
from litellm import ModelResponseStream
45

56

67
class ChatRole(str, Enum):
@@ -211,3 +212,7 @@ class ChatDataClass(BaseModel):
211212
system_fingerprint: Optional[str] = Field(
212213
None, description="Identifier for the system version that processed the request"
213214
)
215+
216+
217+
class StreamChat(BaseModel):
218+
stream: Generator[ModelResponseStream, None, None]

edenai_apis/llmengine/llm_engine.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import base64
22
import json
33
import mimetypes
4-
import os
5-
import re
64
import uuid
75
from io import BytesIO
86
from typing import Dict, List, Literal, Optional, Type, Union
@@ -53,6 +51,9 @@
5351
AutomaticTranslationDataClass,
5452
LanguageDetectionDataClass,
5553
)
54+
from edenai_apis.features.llm.chat.chat_dataclass import (
55+
StreamChat as StreamChatCompletion,
56+
)
5657
from edenai_apis.llmengine.clients import LLM_COMPLETION_CLIENTS
5758
from edenai_apis.llmengine.clients.completion import CompletionClient
5859
from edenai_apis.llmengine.mapping import Mappings
@@ -845,7 +846,10 @@ def completion(
845846
completion_params = completion_params
846847
call_params = self._prepare_args(**completion_params)
847848
response = self.completion_client.completion(**call_params, **kwargs)
848-
response = ResponseModel.model_validate(response)
849-
return response
849+
if stream:
850+
return StreamChatCompletion(stream=response)
851+
else:
852+
response = ResponseModel.model_validate(response)
853+
return response
850854
except Exception as ex:
851855
raise ex

0 commit comments

Comments
 (0)