test_lancedb.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. import os
  2. import shutil
  3. import pytest
  4. from embedchain import App
  5. from embedchain.config import AppConfig
  6. from embedchain.config.vectordb.lancedb import LanceDBConfig
  7. from embedchain.vectordb.lancedb import LanceDB
  8. os.environ["OPENAI_API_KEY"] = "test-api-key"
  9. @pytest.fixture
  10. def lancedb():
  11. return LanceDB(config=LanceDBConfig(dir="test-db", collection_name="test-coll"))
  12. @pytest.fixture
  13. def app_with_settings():
  14. lancedb_config = LanceDBConfig(allow_reset=True, dir="test-db-reset")
  15. lancedb = LanceDB(config=lancedb_config)
  16. app_config = AppConfig(collect_metrics=False)
  17. return App(config=app_config, db=lancedb)
  18. @pytest.fixture(scope="session", autouse=True)
  19. def cleanup_db():
  20. yield
  21. try:
  22. shutil.rmtree("test-db.lance")
  23. shutil.rmtree("test-db-reset.lance")
  24. except OSError as e:
  25. print("Error: %s - %s." % (e.filename, e.strerror))
  26. def test_lancedb_duplicates_throw_warning(caplog):
  27. db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
  28. app = App(config=AppConfig(collect_metrics=False), db=db)
  29. app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
  30. app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
  31. assert "Insert of existing doc ID: 0" not in caplog.text
  32. assert "Add of existing doc ID: 0" not in caplog.text
  33. app.db.reset()
  34. def test_lancedb_duplicates_collections_no_warning(caplog):
  35. db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
  36. app = App(config=AppConfig(collect_metrics=False), db=db)
  37. app.set_collection_name("test_collection_1")
  38. app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
  39. app.set_collection_name("test_collection_2")
  40. app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
  41. assert "Insert of existing doc ID: 0" not in caplog.text
  42. assert "Add of existing doc ID: 0" not in caplog.text
  43. app.db.reset()
  44. app.set_collection_name("test_collection_1")
  45. app.db.reset()
  46. def test_lancedb_collection_init_with_default_collection():
  47. db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
  48. app = App(config=AppConfig(collect_metrics=False), db=db)
  49. assert app.db.collection.name == "embedchain_store"
  50. def test_lancedb_collection_init_with_custom_collection():
  51. db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
  52. app = App(config=AppConfig(collect_metrics=False), db=db)
  53. app.set_collection_name(name="test_collection")
  54. assert app.db.collection.name == "test_collection"
  55. def test_lancedb_collection_set_collection_name():
  56. db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
  57. app = App(config=AppConfig(collect_metrics=False), db=db)
  58. app.set_collection_name("test_collection")
  59. assert app.db.collection.name == "test_collection"
  60. def test_lancedb_collection_changes_encapsulated():
  61. db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
  62. app = App(config=AppConfig(collect_metrics=False), db=db)
  63. app.set_collection_name("test_collection_1")
  64. assert app.db.count() == 0
  65. app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
  66. assert app.db.count() == 1
  67. app.set_collection_name("test_collection_2")
  68. assert app.db.count() == 0
  69. app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
  70. app.set_collection_name("test_collection_1")
  71. assert app.db.count() == 1
  72. app.db.reset()
  73. app.set_collection_name("test_collection_2")
  74. app.db.reset()
  75. def test_lancedb_collection_collections_are_persistent():
  76. db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
  77. app = App(config=AppConfig(collect_metrics=False), db=db)
  78. app.set_collection_name("test_collection_1")
  79. app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
  80. del app
  81. db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
  82. app = App(config=AppConfig(collect_metrics=False), db=db)
  83. app.set_collection_name("test_collection_1")
  84. assert app.db.count() == 1
  85. app.db.reset()
  86. def test_lancedb_collection_parallel_collections():
  87. db1 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db", collection_name="test_collection_1"))
  88. app1 = App(
  89. config=AppConfig(collect_metrics=False),
  90. db=db1,
  91. )
  92. db2 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db", collection_name="test_collection_2"))
  93. app2 = App(
  94. config=AppConfig(collect_metrics=False),
  95. db=db2,
  96. )
  97. # cleanup if any previous tests failed or were interrupted
  98. app1.db.reset()
  99. app2.db.reset()
  100. app1.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
  101. assert app1.db.count() == 1
  102. assert app2.db.count() == 0
  103. app1.db.add(ids=["1", "2"], documents=["doc1", "doc2"], metadatas=["test", "test"])
  104. app2.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
  105. app1.set_collection_name("test_collection_2")
  106. assert app1.db.count() == 1
  107. app2.set_collection_name("test_collection_1")
  108. assert app2.db.count() == 3
  109. # cleanup
  110. app1.db.reset()
  111. app2.db.reset()
  112. def test_lancedb_collection_ids_share_collections():
  113. db1 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
  114. app1 = App(config=AppConfig(collect_metrics=False), db=db1)
  115. app1.set_collection_name("one_collection")
  116. db2 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
  117. app2 = App(config=AppConfig(collect_metrics=False), db=db2)
  118. app2.set_collection_name("one_collection")
  119. # cleanup
  120. app1.db.reset()
  121. app2.db.reset()
  122. app1.db.add(ids=["0", "1"], documents=["doc1", "doc2"], metadatas=["test", "test"])
  123. app2.db.add(ids=["2"], documents=["doc3"], metadatas=["test"])
  124. assert app1.db.count() == 2
  125. assert app2.db.count() == 3
  126. # cleanup
  127. app1.db.reset()
  128. app2.db.reset()
  129. def test_lancedb_collection_reset():
  130. db1 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
  131. app1 = App(config=AppConfig(collect_metrics=False), db=db1)
  132. app1.set_collection_name("one_collection")
  133. db2 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
  134. app2 = App(config=AppConfig(collect_metrics=False), db=db2)
  135. app2.set_collection_name("two_collection")
  136. db3 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
  137. app3 = App(config=AppConfig(collect_metrics=False), db=db3)
  138. app3.set_collection_name("three_collection")
  139. db4 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
  140. app4 = App(config=AppConfig(collect_metrics=False), db=db4)
  141. app4.set_collection_name("four_collection")
  142. # cleanup if any previous tests failed or were interrupted
  143. app1.db.reset()
  144. app2.db.reset()
  145. app3.db.reset()
  146. app4.db.reset()
  147. app1.db.add(ids=["1"], documents=["doc1"], metadatas=["test"])
  148. app2.db.add(ids=["2"], documents=["doc2"], metadatas=["test"])
  149. app3.db.add(ids=["3"], documents=["doc3"], metadatas=["test"])
  150. app4.db.add(ids=["4"], documents=["doc4"], metadatas=["test"])
  151. app1.db.reset()
  152. assert app1.db.count() == 0
  153. assert app2.db.count() == 1
  154. assert app3.db.count() == 1
  155. assert app4.db.count() == 1
  156. # cleanup
  157. app2.db.reset()
  158. app3.db.reset()
  159. app4.db.reset()
  160. def generate_embeddings(dummy_embed, embed_size):
  161. generated_embedding = []
  162. for i in range(embed_size):
  163. generated_embedding.append(dummy_embed)
  164. return generated_embedding