test_lancedb.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. import os
  2. import shutil
  3. import pytest
  4. from embedchain import App
  5. from embedchain.config import AppConfig
  6. from embedchain.config.vector_db.lancedb import LanceDBConfig
  7. from embedchain.vectordb.lancedb import LanceDB
  8. os.environ["OPENAI_API_KEY"] = "test-api-key"
  9. @pytest.fixture
  10. def lancedb():
  11. return LanceDB(config=LanceDBConfig(dir="test-db", collection_name="test-coll"))
  12. @pytest.fixture
  13. def app_with_settings():
  14. lancedb_config = LanceDBConfig(allow_reset=True, dir="test-db-reset")
  15. lancedb = LanceDB(config=lancedb_config)
  16. app_config = AppConfig(collect_metrics=False)
  17. return App(config=app_config, db=lancedb)
  18. @pytest.fixture(scope="session", autouse=True)
  19. def cleanup_db():
  20. yield
  21. try:
  22. shutil.rmtree("test-db.lance")
  23. shutil.rmtree("test-db-reset.lance")
  24. except OSError as e:
  25. print("Error: %s - %s." % (e.filename, e.strerror))
  26. def test_lancedb_duplicates_throw_warning(caplog):
  27. db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
  28. app = App(config=AppConfig(collect_metrics=False), db=db)
  29. app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
  30. app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
  31. assert "Insert of existing doc ID: 0" not in caplog.text
  32. assert "Add of existing doc ID: 0" not in caplog.text
  33. app.db.reset()
  34. def test_lancedb_duplicates_collections_no_warning(caplog):
  35. db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
  36. app = App(config=AppConfig(collect_metrics=False), db=db)
  37. app.set_collection_name("test_collection_1")
  38. app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
  39. app.set_collection_name("test_collection_2")
  40. app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
  41. assert "Insert of existing doc ID: 0" not in caplog.text
  42. assert "Add of existing doc ID: 0" not in caplog.text
  43. app.db.reset()
  44. app.set_collection_name("test_collection_1")
  45. app.db.reset()
  46. def test_lancedb_collection_init_with_default_collection():
  47. db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
  48. app = App(config=AppConfig(collect_metrics=False), db=db)
  49. assert app.db.collection.name == "embedchain_store"
  50. def test_lancedb_collection_init_with_custom_collection():
  51. db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
  52. app = App(config=AppConfig(collect_metrics=False), db=db)
  53. app.set_collection_name(name="test_collection")
  54. assert app.db.collection.name == "test_collection"
  55. def test_lancedb_collection_set_collection_name():
  56. db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
  57. app = App(config=AppConfig(collect_metrics=False), db=db)
  58. app.set_collection_name("test_collection")
  59. assert app.db.collection.name == "test_collection"
  60. def test_lancedb_collection_changes_encapsulated():
  61. db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
  62. app = App(config=AppConfig(collect_metrics=False), db=db)
  63. app.set_collection_name("test_collection_1")
  64. assert app.db.count() == 0
  65. app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
  66. assert app.db.count() == 1
  67. app.set_collection_name("test_collection_2")
  68. assert app.db.count() == 0
  69. app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
  70. app.set_collection_name("test_collection_1")
  71. assert app.db.count() == 1
  72. app.db.reset()
  73. app.set_collection_name("test_collection_2")
  74. app.db.reset()
  75. def test_lancedb_collection_collections_are_persistent():
  76. db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
  77. app = App(config=AppConfig(collect_metrics=False), db=db)
  78. app.set_collection_name("test_collection_1")
  79. app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
  80. del app
  81. db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
  82. app = App(config=AppConfig(collect_metrics=False), db=db)
  83. app.set_collection_name("test_collection_1")
  84. assert app.db.count() == 1
  85. app.db.reset()
  86. def test_lancedb_collection_parallel_collections():
  87. db1 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db", collection_name="test_collection_1"))
  88. app1 = App(
  89. config=AppConfig(collect_metrics=False),
  90. db=db1,
  91. )
  92. db2 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db", collection_name="test_collection_2"))
  93. app2 = App(
  94. config=AppConfig(collect_metrics=False),
  95. db=db2,
  96. )
  97. # cleanup if any previous tests failed or were interrupted
  98. app1.db.reset()
  99. app2.db.reset()
  100. app1.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
  101. assert app1.db.count() == 1
  102. assert app2.db.count() == 0
  103. app1.db.add(ids=["1", "2"], documents=["doc1", "doc2"], metadatas=["test", "test"])
  104. app2.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
  105. app1.set_collection_name("test_collection_2")
  106. assert app1.db.count() == 1
  107. app2.set_collection_name("test_collection_1")
  108. assert app2.db.count() == 3
  109. # cleanup
  110. app1.db.reset()
  111. app2.db.reset()
  112. def test_lancedb_collection_ids_share_collections():
  113. db1 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
  114. app1 = App(config=AppConfig(collect_metrics=False), db=db1)
  115. app1.set_collection_name("one_collection")
  116. db2 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
  117. app2 = App(config=AppConfig(collect_metrics=False), db=db2)
  118. app2.set_collection_name("one_collection")
  119. # cleanup
  120. app1.db.reset()
  121. app2.db.reset()
  122. app1.db.add(ids=["0", "1"], documents=["doc1", "doc2"], metadatas=["test", "test"])
  123. app2.db.add(ids=["2"], documents=["doc3"], metadatas=["test"])
  124. assert app1.db.count() == 2
  125. assert app2.db.count() == 3
  126. # cleanup
  127. app1.db.reset()
  128. app2.db.reset()
  129. def test_lancedb_collection_reset():
  130. db1 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
  131. app1 = App(config=AppConfig(collect_metrics=False), db=db1)
  132. app1.set_collection_name("one_collection")
  133. db2 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
  134. app2 = App(config=AppConfig(collect_metrics=False), db=db2)
  135. app2.set_collection_name("two_collection")
  136. db3 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
  137. app3 = App(config=AppConfig(collect_metrics=False), db=db3)
  138. app3.set_collection_name("three_collection")
  139. db4 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
  140. app4 = App(config=AppConfig(collect_metrics=False), db=db4)
  141. app4.set_collection_name("four_collection")
  142. # cleanup if any previous tests failed or were interrupted
  143. app1.db.reset()
  144. app2.db.reset()
  145. app3.db.reset()
  146. app4.db.reset()
  147. app1.db.add(ids=["1"], documents=["doc1"], metadatas=["test"])
  148. app2.db.add(ids=["2"], documents=["doc2"], metadatas=["test"])
  149. app3.db.add(ids=["3"], documents=["doc3"], metadatas=["test"])
  150. app4.db.add(ids=["4"], documents=["doc4"], metadatas=["test"])
  151. app1.db.reset()
  152. assert app1.db.count() == 0
  153. assert app2.db.count() == 1
  154. assert app3.db.count() == 1
  155. assert app4.db.count() == 1
  156. # cleanup
  157. app2.db.reset()
  158. app3.db.reset()
  159. app4.db.reset()
  160. def generate_embeddings(dummy_embed, embed_size):
  161. generated_embedding = []
  162. for i in range(embed_size):
  163. generated_embedding.append(dummy_embed)
  164. return generated_embedding