Skip to content

Commit 7aad8eb

Browse files
Support multiple collections (#26)
* Allow passing the collection name in each request to override the default * Allow getting the collection names in QdrantConnector * get vector size from model description * ruff format * add isort * apply pre-commit hooks --------- Co-authored-by: generall <[email protected]>
1 parent 13cf930 commit 7aad8eb

File tree

8 files changed

+178
-37
lines changed

8 files changed

+178
-37
lines changed

README.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
[![smithery badge](https://smithery.ai/badge/mcp-server-qdrant)](https://smithery.ai/protocol/mcp-server-qdrant)
44

55
> The [Model Context Protocol (MCP)](https://modelcontextprotocol.io/introduction) is an open protocol that enables
6-
> seamless integration between LLM applications and external data sources and tools. Whether youre building an
6+
> seamless integration between LLM applications and external data sources and tools. Whether you're building an
77
> AI-powered IDE, enhancing a chat interface, or creating custom AI workflows, MCP provides a standardized way to
88
> connect LLMs with the context they need.
99
@@ -25,11 +25,15 @@ It acts as a semantic memory layer on top of the Qdrant database.
2525
- Input:
2626
- `information` (string): Information to store
2727
- `metadata` (JSON): Optional metadata to store
28+
- `collection_name` (string): Name of the collection to store the information in, optional. If not provided,
29+
the default collection name will be used.
2830
- Returns: Confirmation message
2931
2. `qdrant-find`
3032
- Retrieve relevant information from the Qdrant database
3133
- Input:
3234
- `query` (string): Query to use for searching
35+
- `collection_name` (string): Name of the collection to store the information in, optional. If not provided,
36+
the default collection name will be used.
3337
- Returns: Information stored in the Qdrant database as separate messages
3438

3539
## Environment Variables
@@ -40,7 +44,7 @@ The configuration of the server is done using environment variables:
4044
|--------------------------|---------------------------------------------------------------------|-------------------------------------------------------------------|
4145
| `QDRANT_URL` | URL of the Qdrant server | None |
4246
| `QDRANT_API_KEY` | API key for the Qdrant server | None |
43-
| `COLLECTION_NAME` | Name of the collection to use | *Required* |
47+
| `COLLECTION_NAME` | Name of the default collection to use. | *Required* |
4448
| `QDRANT_LOCAL_PATH` | Path to the local Qdrant database (alternative to `QDRANT_URL`) | None |
4549
| `EMBEDDING_PROVIDER` | Embedding provider to use (currently only "fastembed" is supported) | `fastembed` |
4650
| `EMBEDDING_MODEL` | Name of the embedding model to use | `sentence-transformers/all-MiniLM-L6-v2` |

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,12 @@ build-backend = "hatchling.build"
1818

1919
[tool.uv]
2020
dev-dependencies = [
21+
"isort>=6.0.1",
2122
"pre-commit>=4.1.0",
2223
"pyright>=1.1.389",
2324
"pytest>=8.3.3",
2425
"pytest-asyncio>=0.23.0",
25-
"ruff>=0.8.0"
26+
"ruff>=0.8.0",
2627
]
2728

2829
[project.scripts]

src/mcp_server_qdrant/embeddings/base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,8 @@ async def embed_query(self, query: str) -> List[float]:
1919
def get_vector_name(self) -> str:
2020
"""Get the name of the vector for the Qdrant collection."""
2121
pass
22+
23+
@abstractmethod
24+
def get_vector_size(self) -> int:
25+
"""Get the size of the vector for the Qdrant collection."""
26+
pass

src/mcp_server_qdrant/embeddings/fastembed.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import List
33

44
from fastembed import TextEmbedding
5+
from fastembed.common.model_description import DenseModelDescription
56

67
from mcp_server_qdrant.embeddings.base import EmbeddingProvider
78

@@ -41,3 +42,10 @@ def get_vector_name(self) -> str:
4142
"""
4243
model_name = self.embedding_model.model_name.split("/")[-1].lower()
4344
return f"fast-{model_name}"
45+
46+
def get_vector_size(self) -> int:
47+
"""Get the size of the vector for the Qdrant collection."""
48+
model_description: DenseModelDescription = (
49+
self.embedding_model._get_model_description(self.model_name)
50+
)
51+
return model_description.dim

src/mcp_server_qdrant/qdrant.py

Lines changed: 46 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -41,39 +41,29 @@ def __init__(
4141
):
4242
self._qdrant_url = qdrant_url.rstrip("/") if qdrant_url else None
4343
self._qdrant_api_key = qdrant_api_key
44-
self._collection_name = collection_name
44+
self._default_collection_name = collection_name
4545
self._embedding_provider = embedding_provider
4646
self._client = AsyncQdrantClient(
4747
location=qdrant_url, api_key=qdrant_api_key, path=qdrant_local_path
4848
)
4949

50-
async def _ensure_collection_exists(self):
51-
"""Ensure that the collection exists, creating it if necessary."""
52-
collection_exists = await self._client.collection_exists(self._collection_name)
53-
if not collection_exists:
54-
# Create the collection with the appropriate vector size
55-
# We'll get the vector size by embedding a sample text
56-
sample_vector = await self._embedding_provider.embed_query("sample text")
57-
vector_size = len(sample_vector)
58-
59-
# Use the vector name as defined in the embedding provider
60-
vector_name = self._embedding_provider.get_vector_name()
61-
await self._client.create_collection(
62-
collection_name=self._collection_name,
63-
vectors_config={
64-
vector_name: models.VectorParams(
65-
size=vector_size,
66-
distance=models.Distance.COSINE,
67-
)
68-
},
69-
)
50+
async def get_collection_names(self) -> list[str]:
51+
"""
52+
Get the names of all collections in the Qdrant server.
53+
:return: A list of collection names.
54+
"""
55+
response = await self._client.get_collections()
56+
return [collection.name for collection in response.collections]
7057

71-
async def store(self, entry: Entry):
58+
async def store(self, entry: Entry, *, collection_name: Optional[str] = None):
7259
"""
7360
Store some information in the Qdrant collection, along with the specified metadata.
7461
:param entry: The entry to store in the Qdrant collection.
62+
:param collection_name: The name of the collection to store the information in, optional. If not provided,
63+
the default collection is used.
7564
"""
76-
await self._ensure_collection_exists()
65+
collection_name = collection_name or self._default_collection_name
66+
await self._ensure_collection_exists(collection_name)
7767

7868
# Embed the document
7969
embeddings = await self._embedding_provider.embed_documents([entry.content])
@@ -82,7 +72,7 @@ async def store(self, entry: Entry):
8272
vector_name = self._embedding_provider.get_vector_name()
8373
payload = {"document": entry.content, "metadata": entry.metadata}
8474
await self._client.upsert(
85-
collection_name=self._collection_name,
75+
collection_name=collection_name,
8676
points=[
8777
models.PointStruct(
8878
id=uuid.uuid4().hex,
@@ -92,13 +82,19 @@ async def store(self, entry: Entry):
9282
],
9383
)
9484

95-
async def search(self, query: str) -> list[Entry]:
85+
async def search(
86+
self, query: str, *, collection_name: Optional[str] = None, limit: int = 10
87+
) -> list[Entry]:
9688
"""
9789
Find points in the Qdrant collection. If there are no entries found, an empty list is returned.
9890
:param query: The query to use for the search.
91+
:param collection_name: The name of the collection to search in, optional. If not provided,
92+
the default collection is used.
93+
:param limit: The maximum number of entries to return.
9994
:return: A list of entries found.
10095
"""
101-
collection_exists = await self._client.collection_exists(self._collection_name)
96+
collection_name = collection_name or self._default_collection_name
97+
collection_exists = await self._client.collection_exists(collection_name)
10298
if not collection_exists:
10399
return []
104100

@@ -108,9 +104,9 @@ async def search(self, query: str) -> list[Entry]:
108104

109105
# Search in Qdrant
110106
search_results = await self._client.search(
111-
collection_name=self._collection_name,
107+
collection_name=collection_name,
112108
query_vector=models.NamedVector(name=vector_name, vector=query_vector),
113-
limit=10,
109+
limit=limit,
114110
)
115111

116112
return [
@@ -120,3 +116,25 @@ async def search(self, query: str) -> list[Entry]:
120116
)
121117
for result in search_results
122118
]
119+
120+
async def _ensure_collection_exists(self, collection_name: str):
121+
"""
122+
Ensure that the collection exists, creating it if necessary.
123+
:param collection_name: The name of the collection to ensure exists.
124+
"""
125+
collection_exists = await self._client.collection_exists(collection_name)
126+
if not collection_exists:
127+
# Create the collection with the appropriate vector size
128+
vector_size = self._embedding_provider.get_vector_size()
129+
130+
# Use the vector name as defined in the embedding provider
131+
vector_name = self._embedding_provider.get_vector_name()
132+
await self._client.create_collection(
133+
collection_name=collection_name,
134+
vectors_config={
135+
vector_name: models.VectorParams(
136+
size=vector_size,
137+
distance=models.Distance.COSINE,
138+
)
139+
},
140+
)

src/mcp_server_qdrant/server.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
import logging
33
from contextlib import asynccontextmanager
4-
from typing import AsyncIterator, List
4+
from typing import AsyncIterator, List, Optional
55

66
from mcp.server import Server
77
from mcp.server.fastmcp import Context, FastMCP
@@ -75,36 +75,53 @@ async def store(
7575
# If we set it to be optional, some of the MCP clients, like Cursor, cannot
7676
# handle the optional parameter correctly.
7777
metadata: Metadata = None,
78+
collection_name: Optional[str] = None,
7879
) -> str:
7980
"""
8081
Store some information in Qdrant.
8182
:param ctx: The context for the request.
8283
:param information: The information to store.
8384
:param metadata: JSON metadata to store with the information, optional.
85+
:param collection_name: The name of the collection to store the information in, optional. If not provided,
86+
the default collection is used.
8487
:return: A message indicating that the information was stored.
8588
"""
8689
await ctx.debug(f"Storing information {information} in Qdrant")
8790
qdrant_connector: QdrantConnector = ctx.request_context.lifespan_context[
8891
"qdrant_connector"
8992
]
9093
entry = Entry(content=information, metadata=metadata)
91-
await qdrant_connector.store(entry)
94+
await qdrant_connector.store(entry, collection_name=collection_name)
95+
if collection_name:
96+
return f"Remembered: {information} in collection {collection_name}"
9297
return f"Remembered: {information}"
9398

9499

95100
@mcp.tool(name="qdrant-find", description=tool_settings.tool_find_description)
96-
async def find(ctx: Context, query: str) -> List[str]:
101+
async def find(
102+
ctx: Context,
103+
query: str,
104+
collection_name: Optional[str] = None,
105+
limit: int = 10,
106+
) -> List[str]:
97107
"""
98108
Find memories in Qdrant.
99109
:param ctx: The context for the request.
100110
:param query: The query to use for the search.
111+
:param collection_name: The name of the collection to search in, optional. If not provided,
112+
the default collection is used.
113+
:param limit: The maximum number of entries to return, optional. Default is 10.
101114
:return: A list of entries found.
102115
"""
103116
await ctx.debug(f"Finding results for query {query}")
117+
if collection_name:
118+
await ctx.debug(f"Overriding the collection name with {collection_name}")
104119
qdrant_connector: QdrantConnector = ctx.request_context.lifespan_context[
105120
"qdrant_connector"
106121
]
107-
entries = await qdrant_connector.search(query)
122+
entries = await qdrant_connector.search(
123+
query, collection_name=collection_name, limit=limit
124+
)
108125
if not entries:
109126
return [f"No information found for the query '{query}'"]
110127
content = [

tests/test_qdrant_integration.py

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ async def test_ensure_collection_exists(qdrant_connector):
9797
"""Test that the collection is created if it doesn't exist."""
9898
# The collection shouldn't exist yet
9999
assert not await qdrant_connector._client.collection_exists(
100-
qdrant_connector._collection_name
100+
qdrant_connector._default_collection_name
101101
)
102102

103103
# Storing an entry should create the collection
@@ -106,7 +106,7 @@ async def test_ensure_collection_exists(qdrant_connector):
106106

107107
# Now the collection should exist
108108
assert await qdrant_connector._client.collection_exists(
109-
qdrant_connector._collection_name
109+
qdrant_connector._default_collection_name
110110
)
111111

112112

@@ -159,3 +159,80 @@ async def test_entry_without_metadata(qdrant_connector):
159159
assert len(results) == 1
160160
assert results[0].content == "Entry without metadata"
161161
assert results[0].metadata is None
162+
163+
164+
@pytest.mark.asyncio
165+
async def test_custom_collection_store_and_search(qdrant_connector):
166+
"""Test storing and searching in a custom collection."""
167+
# Define a custom collection name
168+
custom_collection = f"custom_collection_{uuid.uuid4().hex}"
169+
170+
# Store a test entry in the custom collection
171+
test_entry = Entry(
172+
content="This is stored in a custom collection",
173+
metadata={"custom": True},
174+
)
175+
await qdrant_connector.store(test_entry, collection_name=custom_collection)
176+
177+
# Search in the custom collection
178+
results = await qdrant_connector.search(
179+
"custom collection", collection_name=custom_collection
180+
)
181+
182+
# Verify results
183+
assert len(results) == 1
184+
assert results[0].content == test_entry.content
185+
assert results[0].metadata == test_entry.metadata
186+
187+
# Verify the entry is not in the default collection
188+
default_results = await qdrant_connector.search("custom collection")
189+
assert len(default_results) == 0
190+
191+
192+
@pytest.mark.asyncio
193+
async def test_multiple_collections(qdrant_connector):
194+
"""Test using multiple collections with the same connector."""
195+
# Define two custom collection names
196+
collection_a = f"collection_a_{uuid.uuid4().hex}"
197+
collection_b = f"collection_b_{uuid.uuid4().hex}"
198+
199+
# Store entries in different collections
200+
entry_a = Entry(
201+
content="This belongs to collection A", metadata={"collection": "A"}
202+
)
203+
entry_b = Entry(
204+
content="This belongs to collection B", metadata={"collection": "B"}
205+
)
206+
entry_default = Entry(content="This belongs to the default collection")
207+
208+
await qdrant_connector.store(entry_a, collection_name=collection_a)
209+
await qdrant_connector.store(entry_b, collection_name=collection_b)
210+
await qdrant_connector.store(entry_default)
211+
212+
# Search in collection A
213+
results_a = await qdrant_connector.search("belongs", collection_name=collection_a)
214+
assert len(results_a) == 1
215+
assert results_a[0].content == entry_a.content
216+
217+
# Search in collection B
218+
results_b = await qdrant_connector.search("belongs", collection_name=collection_b)
219+
assert len(results_b) == 1
220+
assert results_b[0].content == entry_b.content
221+
222+
# Search in default collection
223+
results_default = await qdrant_connector.search("belongs")
224+
assert len(results_default) == 1
225+
assert results_default[0].content == entry_default.content
226+
227+
228+
@pytest.mark.asyncio
229+
async def test_nonexistent_collection_search(qdrant_connector):
230+
"""Test searching in a collection that doesn't exist."""
231+
# Search in a collection that doesn't exist
232+
nonexistent_collection = f"nonexistent_{uuid.uuid4().hex}"
233+
results = await qdrant_connector.search(
234+
"test query", collection_name=nonexistent_collection
235+
)
236+
237+
# Verify results
238+
assert len(results) == 0

uv.lock

Lines changed: 11 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)