|
@@ -31,7 +31,7 @@ class TestEsDB(unittest.TestCase):
|
|
|
# Create some dummy data.
|
|
|
embeddings = [[1, 2, 3], [4, 5, 6]]
|
|
|
documents = ["This is a document.", "This is another document."]
|
|
|
- metadatas = [{}, {}]
|
|
|
+ metadatas = [{"url": "url_1", "doc_id": "doc_id_1"}, {"url": "url_2", "doc_id": "doc_id_2"}]
|
|
|
ids = ["doc_1", "doc_2"]
|
|
|
|
|
|
# Add the data to the database.
|
|
@@ -40,8 +40,17 @@ class TestEsDB(unittest.TestCase):
|
|
|
search_response = {
|
|
|
"hits": {
|
|
|
"hits": [
|
|
|
- {"_source": {"text": "This is a document."}, "_score": 0.9},
|
|
|
- {"_source": {"text": "This is another document."}, "_score": 0.8},
|
|
|
+ {
|
|
|
+ "_source": {"text": "This is a document.", "metadata": {"url": "url_1", "doc_id": "doc_id_1"}},
|
|
|
+ "_score": 0.9,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "_source": {
|
|
|
+ "text": "This is another document.",
|
|
|
+ "metadata": {"url": "url_2", "doc_id": "doc_id_2"},
|
|
|
+ },
|
|
|
+ "_score": 0.8,
|
|
|
+ },
|
|
|
]
|
|
|
}
|
|
|
}
|
|
@@ -54,7 +63,9 @@ class TestEsDB(unittest.TestCase):
|
|
|
results = self.db.query(query, n_results=2, where={}, skip_embedding=False)
|
|
|
|
|
|
# Assert that the results are correct.
|
|
|
- self.assertEqual(results, ["This is a document.", "This is another document."])
|
|
|
+ self.assertEqual(
|
|
|
+ results, [("This is a document.", "url_1", "doc_id_1"), ("This is another document.", "url_2", "doc_id_2")]
|
|
|
+ )
|
|
|
|
|
|
@patch("embedchain.vectordb.elasticsearch.Elasticsearch")
|
|
|
def test_query_with_skip_embedding(self, mock_client):
|
|
@@ -68,7 +79,7 @@ class TestEsDB(unittest.TestCase):
|
|
|
# Create some dummy data.
|
|
|
embeddings = [[1, 2, 3], [4, 5, 6]]
|
|
|
documents = ["This is a document.", "This is another document."]
|
|
|
- metadatas = [{}, {}]
|
|
|
+ metadatas = [{"url": "url_1", "doc_id": "doc_id_1"}, {"url": "url_2", "doc_id": "doc_id_2"}]
|
|
|
ids = ["doc_1", "doc_2"]
|
|
|
|
|
|
# Add the data to the database.
|
|
@@ -77,8 +88,17 @@ class TestEsDB(unittest.TestCase):
|
|
|
search_response = {
|
|
|
"hits": {
|
|
|
"hits": [
|
|
|
- {"_source": {"text": "This is a document."}, "_score": 0.9},
|
|
|
- {"_source": {"text": "This is another document."}, "_score": 0.8},
|
|
|
+ {
|
|
|
+ "_source": {"text": "This is a document.", "metadata": {"url": "url_1", "doc_id": "doc_id_1"}},
|
|
|
+ "_score": 0.9,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "_source": {
|
|
|
+ "text": "This is another document.",
|
|
|
+ "metadata": {"url": "url_2", "doc_id": "doc_id_2"},
|
|
|
+ },
|
|
|
+ "_score": 0.8,
|
|
|
+ },
|
|
|
]
|
|
|
}
|
|
|
}
|
|
@@ -91,7 +111,9 @@ class TestEsDB(unittest.TestCase):
|
|
|
results = self.db.query(query, n_results=2, where={}, skip_embedding=True)
|
|
|
|
|
|
# Assert that the results are correct.
|
|
|
- self.assertEqual(results, ["This is a document.", "This is another document."])
|
|
|
+ self.assertEqual(
|
|
|
+ results, [("This is a document.", "url_1", "doc_id_1"), ("This is another document.", "url_2", "doc_id_2")]
|
|
|
+ )
|
|
|
|
|
|
def test_init_without_url(self):
|
|
|
# Make sure it's not loaded from env
|