Skip to content

Commit 431d682

Browse files
authored
Merge pull request #419 from AbhishekRP2002/main
feat: add milvus vector db integration
2 parents b509cb8 + 1258f3a commit 431d682

File tree

6 files changed

+313
-7
lines changed

6 files changed

+313
-7
lines changed

backend/modules/vector_db/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def get_embedding_dimensions(self, embeddings: Embeddings) -> int:
9090
Fetch embedding dimensions
9191
"""
9292
# Calculate embedding size
93-
logger.debug(f"[VectorDB] Embedding a dummy doc to get vector dimensions")
93+
logger.debug("Embedding a dummy doc to get vector dimensions")
9494
partial_embeddings = embeddings.embed_documents(["Initial document"])
9595
vector_size = len(partial_embeddings[0])
9696
logger.debug(f"Vector size: {vector_size}")

backend/modules/vector_db/milvus.py

+300
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
from typing import List
2+
3+
from langchain.docstore.document import Document
4+
from langchain.embeddings.base import Embeddings
5+
from langchain_milvus import Milvus
6+
from pymilvus import CollectionSchema, DataType, FieldSchema, MilvusClient
7+
8+
from backend.constants import (
9+
DATA_POINT_FQN_METADATA_KEY,
10+
DATA_POINT_HASH_METADATA_KEY,
11+
DEFAULT_BATCH_SIZE_FOR_VECTOR_STORE,
12+
)
13+
from backend.logger import logger
14+
from backend.modules.vector_db.base import BaseVectorDB
15+
from backend.types import DataPointVector, VectorDBConfig
16+
17+
MAX_SCROLL_LIMIT = int(1e6)
18+
BATCH_SIZE = 1000
19+
20+
21+
class MilvusVectorDB(BaseVectorDB):
22+
def __init__(self, config: VectorDBConfig):
23+
"""
24+
Initialize Milvus vector database client
25+
Args:
26+
:param config: VectorDBConfig
27+
- provider: str
28+
- local: bool
29+
- url: str
30+
URI of the Milvus server.
31+
- If you only need a local vector database for small scale data or prototyping,
32+
setting the uri as a local file, e.g.`./milvus.db`, is the most convenient method,
33+
as it automatically utilizes [Milvus Lite](https://milvus.io/docs/milvus_lite.md)
34+
to store all data in this file.
35+
- If you have large scale of data, say more than a million vectors, you can set up
36+
a more performant Milvus server on [Docker or Kubernetes](https://milvus.io/docs/quickstart.md).
37+
In this setup, please use the server address and port as your uri, e.g.`http://localhost:19530`.
38+
If you enable the authentication feature on Milvus,
39+
use "<your_username>:<your_password>" as the token, otherwise don't set the token.
40+
- If you use [Zilliz Cloud](https://zilliz.com/cloud), the fully managed cloud
41+
service for Milvus, adjust the `uri` and `token`, which correspond to the
42+
[Public Endpoint and API key](https://docs.zilliz.com/docs/on-zilliz-cloud-console#cluster-details)
43+
- api_key: str
44+
Token for authentication with the Milvus server.
45+
"""
46+
# TODO: create an extended config for Milvus like done in Qdrant
47+
logger.debug(f"Connecting to Milvus using config: {config.model_dump()}")
48+
self.config = config
49+
self.metric_type = config.config.get("metric_type", "COSINE")
50+
# Milvus-lite is used for local == True
51+
if config.local is True:
52+
# TODO: make this path customizable
53+
self.url = "./cognita_milvus.db"
54+
self.api_key = ""
55+
self.milvus_client = MilvusClient(
56+
uri=self.url,
57+
db_name=config.config.get("db_name", "milvus_default_db"),
58+
)
59+
else:
60+
self.url = config.url
61+
self.api_key = config.api_key
62+
if not self.api_key:
63+
api_key = None
64+
65+
self.milvus_client = MilvusClient(
66+
uri=self.url,
67+
token=api_key,
68+
db_name=config.config.get("db_name", "milvus_default_db"),
69+
)
70+
71+
def create_collection(self, collection_name: str, embeddings: Embeddings):
72+
"""
73+
Create a collection in the vector database
74+
Args:
75+
:param collection_name: str - Name of the collection
76+
:param embeddings: Embeddings - Embeddings object to be used for creating embeddings of the documents
77+
Current implementation includes Quick setup in which the collection is created, indexed and loaded into the memory.
78+
79+
"""
80+
# TODO: Add customized setup with indexed params
81+
logger.debug(f"[Milvus] Creating new collection {collection_name}")
82+
83+
vector_size = self.get_embedding_dimensions(embeddings)
84+
85+
fields = [
86+
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
87+
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=vector_size),
88+
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
89+
FieldSchema(name="metadata", dtype=DataType.JSON),
90+
]
91+
92+
schema = CollectionSchema(
93+
fields=fields, description=f"Collection for {collection_name}"
94+
)
95+
96+
self.milvus_client.create_collection(
97+
collection_name=collection_name,
98+
dimension=vector_size,
99+
metric_type=self.metric_type, # https://milvus.io/docs/metric.md#Metric-Types : check for other supported metrics
100+
schema=schema,
101+
auto_id=True,
102+
)
103+
104+
# Can use this to create custom multiple indices
105+
index_params = self.milvus_client.prepare_index_params()
106+
index_params.add_index(
107+
field_name="vector", index_type="FLAT", metric_type=self.metric_type
108+
)
109+
self.milvus_client.create_index(
110+
collection_name=collection_name, index_params=index_params
111+
)
112+
113+
logger.debug(f"[Milvus] Created new collection {collection_name}")
114+
115+
def _delete_existing_documents(
116+
self, collection_name: str, documents: List[Document]
117+
):
118+
"""
119+
Delete existing documents from the collection
120+
"""
121+
# Instead of using document IDs, we'll delete based on metadata matching
122+
for doc in documents:
123+
if (
124+
DATA_POINT_FQN_METADATA_KEY in doc.metadata
125+
and DATA_POINT_HASH_METADATA_KEY in doc.metadata
126+
):
127+
delete_expr = (
128+
f'metadata["{DATA_POINT_FQN_METADATA_KEY}"] == "{doc.metadata[DATA_POINT_FQN_METADATA_KEY]}" && '
129+
f'metadata["{DATA_POINT_HASH_METADATA_KEY}"] == "{doc.metadata[DATA_POINT_HASH_METADATA_KEY]}"'
130+
)
131+
132+
logger.debug(
133+
f"[Milvus] Deleting records matching expression: {delete_expr}"
134+
)
135+
136+
self.milvus_client.delete(
137+
collection_name=collection_name,
138+
filter=delete_expr,
139+
)
140+
141+
def upsert_documents(
142+
self,
143+
collection_name: str,
144+
documents: List[Document],
145+
embeddings: Embeddings,
146+
incremental: bool = True,
147+
):
148+
"""
149+
Upsert documents in the database.
150+
Upsert = Insert / update
151+
- Check if collection exists or not
152+
- Check if collection is empty or not
153+
- If collection is empty, insert all documents
154+
- If collection is not empty, delete existing documents and insert new documents
155+
"""
156+
if len(documents) == 0:
157+
logger.warning("No documents to index")
158+
return
159+
160+
logger.debug(
161+
f"[Milvus] Adding {len(documents)} documents to collection {collection_name}"
162+
)
163+
164+
if not self.milvus_client.has_collection(collection_name):
165+
raise Exception(
166+
f"Collection {collection_name} does not exist. Please create it first using `create_collection`."
167+
)
168+
169+
stats = self.milvus_client.get_collection_stats(collection_name=collection_name)
170+
if stats["row_count"] == 0:
171+
logger.warning(
172+
f"[Milvus] Collection {collection_name} is empty. Inserting all documents."
173+
)
174+
self.get_vector_store(collection_name, embeddings).add_documents(
175+
documents=documents
176+
)
177+
178+
if incremental and len(documents) > 0:
179+
self._delete_existing_documents(collection_name, documents)
180+
181+
self.get_vector_store(collection_name, embeddings).add_documents(
182+
documents=documents
183+
)
184+
185+
logger.debug(
186+
f"[Milvus] Upserted {len(documents)} documents to collection {collection_name}"
187+
)
188+
189+
def get_collections(self) -> List[str]:
190+
logger.debug("[Milvus] Fetching collections from the vector database")
191+
collections = self.milvus_client.list_collections()
192+
logger.debug(f"[Milvus] Fetched {len(collections)} collections")
193+
return collections
194+
195+
def delete_collection(self, collection_name: str):
196+
logger.debug(f"[Milvus] Deleting {collection_name} collection")
197+
self.milvus_client.drop_collection(collection_name)
198+
logger.debug(f"[Milvus] Deleted {collection_name} collection")
199+
200+
def get_vector_store(self, collection_name: str, embeddings: Embeddings):
201+
logger.debug(f"[Milvus] Getting vector store for collection {collection_name}")
202+
return Milvus(
203+
collection_name=collection_name,
204+
connection_args={
205+
"uri": self.url,
206+
"token": self.api_key,
207+
},
208+
embedding_function=embeddings,
209+
auto_id=True,
210+
primary_field="id",
211+
text_field="text",
212+
metadata_field="metadata",
213+
)
214+
215+
def get_vector_client(self):
216+
logger.debug("[Milvus] Getting Milvus client")
217+
return self.milvus_client
218+
219+
def list_data_point_vectors(
220+
self,
221+
collection_name: str,
222+
data_source_fqn: str,
223+
batch_size: int = DEFAULT_BATCH_SIZE_FOR_VECTOR_STORE,
224+
) -> List[DataPointVector]:
225+
"""
226+
Get vectors from the collection
227+
"""
228+
logger.debug(
229+
f"[Milvus] Listing data point vectors for collection {collection_name}"
230+
)
231+
filter_expr = (
232+
f'metadata["{DATA_POINT_FQN_METADATA_KEY}"] == "{data_source_fqn}"'
233+
)
234+
235+
data_point_vectors: List[DataPointVector] = []
236+
237+
offset = 0
238+
239+
while True:
240+
search_result = self.milvus_client.query(
241+
collection_name=collection_name,
242+
filter=filter_expr,
243+
output_fields=[
244+
"*"
245+
], # returning all the fields of the entity / data point
246+
limit=batch_size,
247+
offset=offset,
248+
)
249+
250+
for result in search_result:
251+
if result.get("metadata", {}).get(
252+
DATA_POINT_FQN_METADATA_KEY
253+
) and result.get("metadata", {}).get(DATA_POINT_HASH_METADATA_KEY):
254+
data_point_vectors.append(
255+
DataPointVector(
256+
data_point_vector_id=str(result["id"]),
257+
data_point_fqn=result["metadata"][
258+
DATA_POINT_FQN_METADATA_KEY
259+
],
260+
data_point_hash=result["metadata"][
261+
DATA_POINT_HASH_METADATA_KEY
262+
],
263+
)
264+
)
265+
266+
if (
267+
len(search_result) < batch_size
268+
or len(data_point_vectors) >= MAX_SCROLL_LIMIT
269+
):
270+
break
271+
272+
offset += batch_size
273+
274+
logger.debug(f"[Milvus] Listed {len(data_point_vectors)} data point vectors")
275+
276+
return data_point_vectors
277+
278+
def delete_data_point_vectors(
279+
self,
280+
collection_name: str,
281+
data_point_vectors: List[DataPointVector],
282+
batch_size: int = DEFAULT_BATCH_SIZE_FOR_VECTOR_STORE,
283+
):
284+
"""
285+
Delete vectors from the collection
286+
"""
287+
logger.debug(f"[Milvus] Deleting {len(data_point_vectors)} data point vectors")
288+
289+
for i in range(0, len(data_point_vectors), batch_size):
290+
batch_vectors = data_point_vectors[i : i + batch_size]
291+
292+
delete_expr = " or ".join(
293+
[f"id == {vector.data_point_vector_id}" for vector in batch_vectors]
294+
)
295+
296+
self.milvus_client.delete(
297+
collection_name=collection_name, filter=delete_expr
298+
)
299+
300+
logger.debug(f"[Milvus] Deleted {len(data_point_vectors)} data point vectors")

backend/modules/vector_db/mongo.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,14 @@ def _create_search_index(self, collection_name: str, embeddings: Embeddings):
6464
result = self.db[collection_name].create_search_index(model=search_index_model)
6565
logger.debug(f"New search index named {result} is building.")
6666

67-
# Immediate avaialbility of the index is not guaranteed upon creation.
67+
# Immediate availability of the index is not guaranteed upon creation.
6868
# MongoDB documentation recommends polling for the index to be ready.
6969
# Ensure this check to provide a seamless experience.
7070
# TODO (mnvsk97): We might want to introduce a new status in the ingestion runs to reflex this.
7171
logger.debug(
7272
"Polling to check if the index is ready. This may take up to a minute."
7373
)
74-
predicate = lambda index: index.get("queryable") is True
74+
predicate = lambda index: index.get("queryable") is True # noqa: E731
7575
while True:
7676
indices = list(
7777
self.db[collection_name].list_search_indexes("vector_search_index")
@@ -96,7 +96,7 @@ def upsert_documents(
9696
f"[Mongo] Adding {len(documents)} documents to collection {collection_name}"
9797
)
9898

99-
"""Upsert documenlots with their embeddings"""
99+
"""Upsert documents with their embeddings"""
100100
collection = self.db[collection_name]
101101

102102
data_point_fqns = []

backend/modules/vector_db/qdrant.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ def create_collection(self, collection_name: str, embeddings: Embeddings):
4444
logger.debug(f"[Qdrant] Creating new collection {collection_name}")
4545

4646
# Calculate embedding size
47-
logger.debug(f"[Qdrant] Embedding a dummy doc to get vector dimensions")
4847
partial_embeddings = embeddings.embed_documents(["Initial document"])
4948
vector_size = len(partial_embeddings[0])
5049
logger.debug(f"Vector size: {vector_size}")
@@ -166,7 +165,7 @@ def upsert_documents(
166165
)
167166

168167
def get_collections(self) -> List[str]:
169-
logger.debug(f"[Qdrant] Fetching collections")
168+
logger.debug("[Qdrant] Fetching collections")
170169
collections = self.qdrant_client.get_collections().collections
171170
logger.debug(f"[Qdrant] Fetched {len(collections)} collections")
172171
return [collection.name for collection in collections]
@@ -185,7 +184,7 @@ def get_vector_store(self, collection_name: str, embeddings: Embeddings):
185184
)
186185

187186
def get_vector_client(self):
188-
logger.debug(f"[Qdrant] Getting Qdrant client")
187+
logger.debug("[Qdrant] Getting Qdrant client")
189188
return self.qdrant_client
190189

191190
def list_data_point_vectors(

backend/vectordb.requirements.txt

+5
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,8 @@ weaviate-client==3.25.3
77
### MongoDB
88
pymongo==4.10.1
99
langchain-mongodb==0.2.0
10+
11+
12+
### Milvus
13+
pymilvus==2.4.10
14+
langchain-milvus==0.1.7

compose.env

+2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ ML_REPO_NAME=''
1515
VECTOR_DB_CONFIG='{"provider":"qdrant","url":"http://qdrant-server:6333", "config": {"grpc_port": 6334, "prefer_grpc": false}}'
1616
# MONGO Example
1717
# VECTOR_DB_CONFIG='{"provider":"mongo","url":"connection_uri", "config": {"database_name": "cognita"}}'
18+
# Milvus Example
19+
# VECTOR_DB_CONFIG='{"provider":"Milvus", "url":"connection_uri", "api_key":"milvus_auth_token", "config":{"db_name":"cognita", "metric_type":"COSINE"}}'
1820
COGNITA_BACKEND_PORT=8000
1921

2022
UNSTRUCTURED_IO_URL=http://unstructured-io-parsers:9500/

0 commit comments

Comments
 (0)