test_app.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. import os
  2. import pytest
  3. import yaml
  4. from embedchain import App
  5. from embedchain.config import ChromaDbConfig
  6. from embedchain.embedder.base import BaseEmbedder
  7. from embedchain.llm.base import BaseLlm
  8. from embedchain.vectordb.base import BaseVectorDB
  9. from embedchain.vectordb.chroma import ChromaDB
  10. @pytest.fixture
  11. def app():
  12. os.environ["OPENAI_API_KEY"] = "test_api_key"
  13. os.environ["OPENAI_API_BASE"] = "test_api_base"
  14. return App()
  15. def test_app(app):
  16. assert isinstance(app.llm, BaseLlm)
  17. assert isinstance(app.db, BaseVectorDB)
  18. assert isinstance(app.embedding_model, BaseEmbedder)
  19. class TestConfigForAppComponents:
  20. def test_constructor_config(self):
  21. collection_name = "my-test-collection"
  22. db = ChromaDB(config=ChromaDbConfig(collection_name=collection_name))
  23. app = App(db=db)
  24. assert app.db.config.collection_name == collection_name
  25. def test_component_config(self):
  26. collection_name = "my-test-collection"
  27. database = ChromaDB(config=ChromaDbConfig(collection_name=collection_name))
  28. app = App(db=database)
  29. assert app.db.config.collection_name == collection_name
  30. class TestAppFromConfig:
  31. def load_config_data(self, yaml_path):
  32. with open(yaml_path, "r") as file:
  33. return yaml.safe_load(file)
  34. def test_from_chroma_config(self, mocker):
  35. mocker.patch("embedchain.vectordb.chroma.chromadb.Client")
  36. yaml_path = "configs/chroma.yaml"
  37. config_data = self.load_config_data(yaml_path)
  38. app = App.from_config(config_path=yaml_path)
  39. # Check if the App instance and its components were created correctly
  40. assert isinstance(app, App)
  41. # Validate the AppConfig values
  42. assert app.config.id == config_data["app"]["config"]["id"]
  43. # Even though not present in the config, the default value is used
  44. assert app.config.collect_metrics is True
  45. # Validate the LLM config values
  46. llm_config = config_data["llm"]["config"]
  47. assert app.llm.config.temperature == llm_config["temperature"]
  48. assert app.llm.config.max_tokens == llm_config["max_tokens"]
  49. assert app.llm.config.top_p == llm_config["top_p"]
  50. assert app.llm.config.stream == llm_config["stream"]
  51. # Validate the VectorDB config values
  52. db_config = config_data["vectordb"]["config"]
  53. assert app.db.config.collection_name == db_config["collection_name"]
  54. assert app.db.config.dir == db_config["dir"]
  55. assert app.db.config.allow_reset == db_config["allow_reset"]
  56. # Validate the Embedder config values
  57. embedder_config = config_data["embedder"]["config"]
  58. assert app.embedding_model.config.model == embedder_config["model"]
  59. assert app.embedding_model.config.deployment_name == embedder_config.get("deployment_name")
  60. def test_from_opensource_config(self, mocker):
  61. mocker.patch("embedchain.vectordb.chroma.chromadb.Client")
  62. yaml_path = "configs/opensource.yaml"
  63. config_data = self.load_config_data(yaml_path)
  64. app = App.from_config(yaml_path)
  65. # Check if the App instance and its components were created correctly
  66. assert isinstance(app, App)
  67. # Validate the AppConfig values
  68. assert app.config.id == config_data["app"]["config"]["id"]
  69. assert app.config.collect_metrics == config_data["app"]["config"]["collect_metrics"]
  70. # Validate the LLM config values
  71. llm_config = config_data["llm"]["config"]
  72. assert app.llm.config.model == llm_config["model"]
  73. assert app.llm.config.temperature == llm_config["temperature"]
  74. assert app.llm.config.max_tokens == llm_config["max_tokens"]
  75. assert app.llm.config.top_p == llm_config["top_p"]
  76. assert app.llm.config.stream == llm_config["stream"]
  77. # Validate the VectorDB config values
  78. db_config = config_data["vectordb"]["config"]
  79. assert app.db.config.collection_name == db_config["collection_name"]
  80. assert app.db.config.dir == db_config["dir"]
  81. assert app.db.config.allow_reset == db_config["allow_reset"]
  82. # Validate the Embedder config values
  83. embedder_config = config_data["embedder"]["config"]
  84. assert app.embedding_model.config.deployment_name == embedder_config["deployment_name"]