1
1
import os
2
- from typing import Optional
2
+ from typing import Optional , Union
3
3
4
- from litellm import acompletion , aembedding , completion , embedding
4
+ from litellm import (CustomStreamWrapper , EmbeddingResponse , ModelResponse ,
5
+ acompletion , aembedding , completion , embedding )
5
6
6
7
try :
7
8
import torch
@@ -39,8 +40,8 @@ def __init__(self, model: str):
39
40
)
40
41
41
42
self .device = "cuda" if torch .cuda .is_available () else "cpu"
42
- self .model = AutoModel .from_pretrained (self . model_name ).to (self .device )
43
- self .tokenizer = AutoTokenizer .from_pretrained (self . model_name )
43
+ self .model = AutoModel .from_pretrained (model ).to (self .device )
44
+ self .tokenizer = AutoTokenizer .from_pretrained (model )
44
45
45
46
def get_embedding (self , input : str ):
46
47
inputs = self .tokenizer (
@@ -97,7 +98,9 @@ def _build_completion_params(self, completion_params, messages):
97
98
def _build_embedding_params (self , embedding_params , input ):
98
99
return {** self ._embedding_params , ** embedding_params , "input" : input }
99
100
100
- def complete (self , messages : list , ** completion_params ):
101
+ def complete (
102
+ self , messages : list , ** completion_params
103
+ ) -> Union [ModelResponse , CustomStreamWrapper ]:
101
104
"""
102
105
Complete a message.
103
106
@@ -111,7 +114,9 @@ def complete(self, messages: list, **completion_params):
111
114
112
115
return completion (** completion_params )
113
116
114
- async def acomplete (self , messages : list , ** completion_params ):
117
+ async def acomplete (
118
+ self , messages : list , ** completion_params
119
+ ) -> Union [ModelResponse , CustomStreamWrapper ]:
115
120
"""
116
121
Complete a message asynchronously.
117
122
@@ -124,9 +129,9 @@ async def acomplete(self, messages: list, **completion_params):
124
129
125
130
return await acompletion (** completion_params )
126
131
127
- def _local_embed (self , input : list [str ]):
128
- return {
129
- " data" : [
132
+ def _local_embed (self , input : list [str ]) -> EmbeddingResponse :
133
+ return EmbeddingResponse (
134
+ data = [
130
135
{
131
136
"embedding" : torch .stack (
132
137
[self ._local_embedding_model .get_embedding (d )]
@@ -136,9 +141,9 @@ def _local_embed(self, input: list[str]):
136
141
}
137
142
for d in input
138
143
]
139
- }
144
+ )
140
145
141
- def embed (self , input : list [str ], ** embedding_params ):
146
+ def embed (self , input : list [str ], ** embedding_params ) -> EmbeddingResponse :
142
147
"""
143
148
Embed a message.
144
149
@@ -153,7 +158,7 @@ def embed(self, input: list[str], **embedding_params):
153
158
154
159
return embedding (** embedding_params )
155
160
156
- async def aembed (self , input : list [str ], ** embedding_params ):
161
+ async def aembed (self , input : list [str ], ** embedding_params ) -> EmbeddingResponse :
157
162
"""
158
163
Embed a message asynchronously.
159
164
0 commit comments