123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215 |
- import os
- import shutil
- import pytest
- from embedchain import App
- from embedchain.config import AppConfig
- from embedchain.config.vector_db.lancedb import LanceDBConfig
- from embedchain.vectordb.lancedb import LanceDB
- os.environ["OPENAI_API_KEY"] = "test-api-key"
- @pytest.fixture
- def lancedb():
- return LanceDB(config=LanceDBConfig(dir="test-db", collection_name="test-coll"))
- @pytest.fixture
- def app_with_settings():
- lancedb_config = LanceDBConfig(allow_reset=True, dir="test-db-reset")
- lancedb = LanceDB(config=lancedb_config)
- app_config = AppConfig(collect_metrics=False)
- return App(config=app_config, db=lancedb)
- @pytest.fixture(scope="session", autouse=True)
- def cleanup_db():
- yield
- try:
- shutil.rmtree("test-db.lance")
- shutil.rmtree("test-db-reset.lance")
- except OSError as e:
- print("Error: %s - %s." % (e.filename, e.strerror))
- def test_lancedb_duplicates_throw_warning(caplog):
- db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
- app = App(config=AppConfig(collect_metrics=False), db=db)
- app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
- app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
- assert "Insert of existing doc ID: 0" not in caplog.text
- assert "Add of existing doc ID: 0" not in caplog.text
- app.db.reset()
- def test_lancedb_duplicates_collections_no_warning(caplog):
- db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
- app = App(config=AppConfig(collect_metrics=False), db=db)
- app.set_collection_name("test_collection_1")
- app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
- app.set_collection_name("test_collection_2")
- app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
- assert "Insert of existing doc ID: 0" not in caplog.text
- assert "Add of existing doc ID: 0" not in caplog.text
- app.db.reset()
- app.set_collection_name("test_collection_1")
- app.db.reset()
- def test_lancedb_collection_init_with_default_collection():
- db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
- app = App(config=AppConfig(collect_metrics=False), db=db)
- assert app.db.collection.name == "embedchain_store"
- def test_lancedb_collection_init_with_custom_collection():
- db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
- app = App(config=AppConfig(collect_metrics=False), db=db)
- app.set_collection_name(name="test_collection")
- assert app.db.collection.name == "test_collection"
- def test_lancedb_collection_set_collection_name():
- db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
- app = App(config=AppConfig(collect_metrics=False), db=db)
- app.set_collection_name("test_collection")
- assert app.db.collection.name == "test_collection"
- def test_lancedb_collection_changes_encapsulated():
- db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
- app = App(config=AppConfig(collect_metrics=False), db=db)
- app.set_collection_name("test_collection_1")
- assert app.db.count() == 0
- app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
- assert app.db.count() == 1
- app.set_collection_name("test_collection_2")
- assert app.db.count() == 0
- app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
- app.set_collection_name("test_collection_1")
- assert app.db.count() == 1
- app.db.reset()
- app.set_collection_name("test_collection_2")
- app.db.reset()
- def test_lancedb_collection_collections_are_persistent():
- db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
- app = App(config=AppConfig(collect_metrics=False), db=db)
- app.set_collection_name("test_collection_1")
- app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
- del app
- db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
- app = App(config=AppConfig(collect_metrics=False), db=db)
- app.set_collection_name("test_collection_1")
- assert app.db.count() == 1
- app.db.reset()
- def test_lancedb_collection_parallel_collections():
- db1 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db", collection_name="test_collection_1"))
- app1 = App(
- config=AppConfig(collect_metrics=False),
- db=db1,
- )
- db2 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db", collection_name="test_collection_2"))
- app2 = App(
- config=AppConfig(collect_metrics=False),
- db=db2,
- )
- # cleanup if any previous tests failed or were interrupted
- app1.db.reset()
- app2.db.reset()
- app1.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
- assert app1.db.count() == 1
- assert app2.db.count() == 0
- app1.db.add(ids=["1", "2"], documents=["doc1", "doc2"], metadatas=["test", "test"])
- app2.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
- app1.set_collection_name("test_collection_2")
- assert app1.db.count() == 1
- app2.set_collection_name("test_collection_1")
- assert app2.db.count() == 3
- # cleanup
- app1.db.reset()
- app2.db.reset()
- def test_lancedb_collection_ids_share_collections():
- db1 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
- app1 = App(config=AppConfig(collect_metrics=False), db=db1)
- app1.set_collection_name("one_collection")
- db2 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
- app2 = App(config=AppConfig(collect_metrics=False), db=db2)
- app2.set_collection_name("one_collection")
- # cleanup
- app1.db.reset()
- app2.db.reset()
- app1.db.add(ids=["0", "1"], documents=["doc1", "doc2"], metadatas=["test", "test"])
- app2.db.add(ids=["2"], documents=["doc3"], metadatas=["test"])
- assert app1.db.count() == 2
- assert app2.db.count() == 3
- # cleanup
- app1.db.reset()
- app2.db.reset()
- def test_lancedb_collection_reset():
- db1 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
- app1 = App(config=AppConfig(collect_metrics=False), db=db1)
- app1.set_collection_name("one_collection")
- db2 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
- app2 = App(config=AppConfig(collect_metrics=False), db=db2)
- app2.set_collection_name("two_collection")
- db3 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
- app3 = App(config=AppConfig(collect_metrics=False), db=db3)
- app3.set_collection_name("three_collection")
- db4 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
- app4 = App(config=AppConfig(collect_metrics=False), db=db4)
- app4.set_collection_name("four_collection")
- # cleanup if any previous tests failed or were interrupted
- app1.db.reset()
- app2.db.reset()
- app3.db.reset()
- app4.db.reset()
- app1.db.add(ids=["1"], documents=["doc1"], metadatas=["test"])
- app2.db.add(ids=["2"], documents=["doc2"], metadatas=["test"])
- app3.db.add(ids=["3"], documents=["doc3"], metadatas=["test"])
- app4.db.add(ids=["4"], documents=["doc4"], metadatas=["test"])
- app1.db.reset()
- assert app1.db.count() == 0
- assert app2.db.count() == 1
- assert app3.db.count() == 1
- assert app4.db.count() == 1
- # cleanup
- app2.db.reset()
- app3.db.reset()
- app4.db.reset()
- def generate_embeddings(dummy_embed, embed_size):
- generated_embedding = []
- for i in range(embed_size):
- generated_embedding.append(dummy_embed)
- return generated_embedding
|