Skip to content

Commit 12a8aa2

Browse files
committed
Fixed tests
1 parent 110ed24 commit 12a8aa2

File tree

4 files changed

+494
-28
lines changed

4 files changed

+494
-28
lines changed

src/giskard_lmutils/model/litellm.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import os
2-
from typing import Optional
2+
from typing import Optional, Union
33

4-
from litellm import acompletion, aembedding, completion, embedding
4+
from litellm import (CustomStreamWrapper, EmbeddingResponse, ModelResponse,
5+
acompletion, aembedding, completion, embedding)
56

67
try:
78
import torch
@@ -39,8 +40,8 @@ def __init__(self, model: str):
3940
)
4041

4142
self.device = "cuda" if torch.cuda.is_available() else "cpu"
42-
self.model = AutoModel.from_pretrained(self.model_name).to(self.device)
43-
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
43+
self.model = AutoModel.from_pretrained(model).to(self.device)
44+
self.tokenizer = AutoTokenizer.from_pretrained(model)
4445

4546
def get_embedding(self, input: str):
4647
inputs = self.tokenizer(
@@ -97,7 +98,9 @@ def _build_completion_params(self, completion_params, messages):
9798
def _build_embedding_params(self, embedding_params, input):
9899
return {**self._embedding_params, **embedding_params, "input": input}
99100

100-
def complete(self, messages: list, **completion_params):
101+
def complete(
102+
self, messages: list, **completion_params
103+
) -> Union[ModelResponse, CustomStreamWrapper]:
101104
"""
102105
Complete a message.
103106
@@ -111,7 +114,9 @@ def complete(self, messages: list, **completion_params):
111114

112115
return completion(**completion_params)
113116

114-
async def acomplete(self, messages: list, **completion_params):
117+
async def acomplete(
118+
self, messages: list, **completion_params
119+
) -> Union[ModelResponse, CustomStreamWrapper]:
115120
"""
116121
Complete a message asynchronously.
117122
@@ -124,9 +129,9 @@ async def acomplete(self, messages: list, **completion_params):
124129

125130
return await acompletion(**completion_params)
126131

127-
def _local_embed(self, input: list[str]):
128-
return {
129-
"data": [
132+
def _local_embed(self, input: list[str]) -> EmbeddingResponse:
133+
return EmbeddingResponse(
134+
data=[
130135
{
131136
"embedding": torch.stack(
132137
[self._local_embedding_model.get_embedding(d)]
@@ -136,9 +141,9 @@ def _local_embed(self, input: list[str]):
136141
}
137142
for d in input
138143
]
139-
}
144+
)
140145

141-
def embed(self, input: list[str], **embedding_params):
146+
def embed(self, input: list[str], **embedding_params) -> EmbeddingResponse:
142147
"""
143148
Embed a message.
144149
@@ -153,7 +158,7 @@ def embed(self, input: list[str], **embedding_params):
153158

154159
return embedding(**embedding_params)
155160

156-
async def aembed(self, input: list[str], **embedding_params):
161+
async def aembed(self, input: list[str], **embedding_params) -> EmbeddingResponse:
157162
"""
158163
Embed a message asynchronously.
159164

0 commit comments

Comments
 (0)