|
- import os
- import shutil
- import pytest
- from embedchain import App
- from embedchain.config import AppConfig
- from embedchain.config.vector_db.lancedb import LanceDBConfig
- from embedchain.vectordb.lancedb import LanceDB
- os.environ["OPENAI_API_KEY"] = "test-api-key"
- @pytest.fixture
- def lancedb():
- return LanceDB(config=LanceDBConfig(dir="test-db", collection_name="test-coll"))
- @pytest.fixture
- def app_with_settings():
- lancedb_config = LanceDBConfig(allow_reset=True, dir="test-db-reset")
- lancedb = LanceDB(config=lancedb_config)
- app_config = AppConfig(collect_metrics=False)
- return App(config=app_config, db=lancedb)
- @pytest.fixture(scope="session", autouse=True)
- def cleanup_db():
- yield
- try:
- shutil.rmtree("test-db.lance")
- shutil.rmtree("test-db-reset.lance")
- except OSError as e:
- print("Error: %s - %s." % (e.filename, e.strerror))
- def test_lancedb_duplicates_throw_warning(caplog):
- db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
- app = App(config=AppConfig(collect_metrics=False), db=db)
- app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
- app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
- assert "Insert of existing doc ID: 0" not in caplog.text
- assert "Add of existing doc ID: 0" not in caplog.text
- app.db.reset()
- def test_lancedb_duplicates_collections_no_warning(caplog):
- db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
- app = App(config=AppConfig(collect_metrics=False), db=db)
- app.set_collection_name("test_collection_1")
- app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
- app.set_collection_name("test_collection_2")
- app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
- assert "Insert of existing doc ID: 0" not in caplog.text
- assert "Add of existing doc ID: 0" not in caplog.text
- app.db.reset()
- app.set_collection_name("test_collection_1")
- app.db.reset()
- def test_lancedb_collection_init_with_default_collection():
- db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
- app = App(config=AppConfig(collect_metrics=False), db=db)
- assert app.db.collection.name == "embedchain_store"
- def test_lancedb_collection_init_with_custom_collection():
- db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
- app = App(config=AppConfig(collect_metrics=False), db=db)
- app.set_collection_name(name="test_collection")
- assert app.db.collection.name == "test_collection"
- def test_lancedb_collection_set_collection_name():
- db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
- app = App(config=AppConfig(collect_metrics=False), db=db)
- app.set_collection_name("test_collection")
- assert app.db.collection.name == "test_collection"
- def test_lancedb_collection_changes_encapsulated():
- db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
- app = App(config=AppConfig(collect_metrics=False), db=db)
- app.set_collection_name("test_collection_1")
- assert app.db.count() == 0
- app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
- assert app.db.count() == 1
- app.set_collection_name("test_collection_2")
- assert app.db.count() == 0
- app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
- app.set_collection_name("test_collection_1")
- assert app.db.count() == 1
- app.db.reset()
- app.set_collection_name("test_collection_2")
- app.db.reset()
- def test_lancedb_collection_collections_are_persistent():
- db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
- app = App(config=AppConfig(collect_metrics=False), db=db)
- app.set_collection_name("test_collection_1")
- app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
- del app
- db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
- app = App(config=AppConfig(collect_metrics=False), db=db)
- app.set_collection_name("test_collection_1")
- assert app.db.count() == 1
- app.db.reset()
- def test_lancedb_collection_parallel_collections():
- db1 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db", collection_name="test_collection_1"))
- app1 = App(
- config=AppConfig(collect_metrics=False),
- db=db1,
- )
- db2 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db", collection_name="test_collection_2"))
- app2 = App(
- config=AppConfig(collect_metrics=False),
- db=db2,
- )
- # cleanup if any previous tests failed or were interrupted
- app1.db.reset()
- app2.db.reset()
- app1.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
- assert app1.db.count() == 1
- assert app2.db.count() == 0
- app1.db.add(ids=["1", "2"], documents=["doc1", "doc2"], metadatas=["test", "test"])
- app2.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
- app1.set_collection_name("test_collection_2")
- assert app1.db.count() == 1
- app2.set_collection_name("test_collection_1")
- assert app2.db.count() == 3
- # cleanup
- app1.db.reset()
- app2.db.reset()
- def test_lancedb_collection_ids_share_collections():
- db1 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
- app1 = App(config=AppConfig(collect_metrics=False), db=db1)
- app1.set_collection_name("one_collection")
- db2 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
- app2 = App(config=AppConfig(collect_metrics=False), db=db2)
- app2.set_collection_name("one_collection")
- # cleanup
- app1.db.reset()
- app2.db.reset()
- app1.db.add(ids=["0", "1"], documents=["doc1", "doc2"], metadatas=["test", "test"])
- app2.db.add(ids=["2"], documents=["doc3"], metadatas=["test"])
- assert app1.db.count() == 2
- assert app2.db.count() == 3
- # cleanup
- app1.db.reset()
- app2.db.reset()
- def test_lancedb_collection_reset():
- db1 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
- app1 = App(config=AppConfig(collect_metrics=False), db=db1)
- app1.set_collection_name("one_collection")
- db2 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
- app2 = App(config=AppConfig(collect_metrics=False), db=db2)
- app2.set_collection_name("two_collection")
- db3 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
- app3 = App(config=AppConfig(collect_metrics=False), db=db3)
- app3.set_collection_name("three_collection")
- db4 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
- app4 = App(config=AppConfig(collect_metrics=False), db=db4)
- app4.set_collection_name("four_collection")
- # cleanup if any previous tests failed or were interrupted
- app1.db.reset()
- app2.db.reset()
- app3.db.reset()
- app4.db.reset()
- app1.db.add(ids=["1"], documents=["doc1"], metadatas=["test"])
- app2.db.add(ids=["2"], documents=["doc2"], metadatas=["test"])
- app3.db.add(ids=["3"], documents=["doc3"], metadatas=["test"])
- app4.db.add(ids=["4"], documents=["doc4"], metadatas=["test"])
- app1.db.reset()
- assert app1.db.count() == 0
- assert app2.db.count() == 1
- assert app3.db.count() == 1
- assert app4.db.count() == 1
- # cleanup
- app2.db.reset()
- app3.db.reset()
- app4.db.reset()
- def generate_embeddings(dummy_embed, embed_size):
- generated_embedding = []
- for i in range(embed_size):
- generated_embedding.append(dummy_embed)
- return generated_embedding
|