test_apps.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. import os
  2. import pytest
  3. import yaml
  4. from embedchain import App
  5. from embedchain.config import (AddConfig, AppConfig, BaseEmbedderConfig,
  6. BaseLlmConfig, ChromaDbConfig)
  7. from embedchain.embedder.base import BaseEmbedder
  8. from embedchain.llm.base import BaseLlm
  9. from embedchain.vectordb.base import BaseVectorDB, BaseVectorDbConfig
  10. from embedchain.vectordb.chroma import ChromaDB
  11. @pytest.fixture
  12. def app():
  13. os.environ["OPENAI_API_KEY"] = "test_api_key"
  14. return App()
  15. def test_app(app):
  16. assert isinstance(app.llm, BaseLlm)
  17. assert isinstance(app.db, BaseVectorDB)
  18. assert isinstance(app.embedder, BaseEmbedder)
  19. class TestConfigForAppComponents:
  20. def test_constructor_config(self):
  21. collection_name = "my-test-collection"
  22. app = App(db_config=ChromaDbConfig(collection_name=collection_name))
  23. assert app.db.config.collection_name == collection_name
  24. def test_component_config(self):
  25. collection_name = "my-test-collection"
  26. database = ChromaDB(config=ChromaDbConfig(collection_name=collection_name))
  27. app = App(db=database)
  28. assert app.db.config.collection_name == collection_name
  29. def test_different_configs_are_proper_instances(self):
  30. app_config = AppConfig()
  31. wrong_config = AddConfig()
  32. with pytest.raises(TypeError):
  33. App(config=wrong_config)
  34. assert isinstance(app_config, AppConfig)
  35. llm_config = BaseLlmConfig()
  36. wrong_llm_config = "wrong_llm_config"
  37. with pytest.raises(TypeError):
  38. App(llm_config=wrong_llm_config)
  39. assert isinstance(llm_config, BaseLlmConfig)
  40. db_config = BaseVectorDbConfig()
  41. wrong_db_config = "wrong_db_config"
  42. with pytest.raises(TypeError):
  43. App(db_config=wrong_db_config)
  44. assert isinstance(db_config, BaseVectorDbConfig)
  45. embedder_config = BaseEmbedderConfig()
  46. wrong_embedder_config = "wrong_embedder_config"
  47. with pytest.raises(TypeError):
  48. App(embedder_config=wrong_embedder_config)
  49. assert isinstance(embedder_config, BaseEmbedderConfig)
  50. def test_components_raises_type_error_if_not_proper_instances(self):
  51. wrong_llm = "wrong_llm"
  52. with pytest.raises(TypeError):
  53. App(llm=wrong_llm)
  54. wrong_db = "wrong_db"
  55. with pytest.raises(TypeError):
  56. App(db=wrong_db)
  57. wrong_embedder = "wrong_embedder"
  58. with pytest.raises(TypeError):
  59. App(embedder=wrong_embedder)
  60. class TestAppFromConfig:
  61. def load_config_data(self, yaml_path):
  62. with open(yaml_path, "r") as file:
  63. return yaml.safe_load(file)
  64. def test_from_chroma_config(self, mocker):
  65. mocker.patch("embedchain.vectordb.chroma.chromadb.Client")
  66. yaml_path = "configs/chroma.yaml"
  67. config_data = self.load_config_data(yaml_path)
  68. app = App.from_config(yaml_path)
  69. # Check if the App instance and its components were created correctly
  70. assert isinstance(app, App)
  71. # Validate the AppConfig values
  72. assert app.config.id == config_data["app"]["config"]["id"]
  73. assert app.config.collection_name == config_data["app"]["config"]["collection_name"]
  74. # Even though not present in the config, the default value is used
  75. assert app.config.collect_metrics is True
  76. # Validate the LLM config values
  77. llm_config = config_data["llm"]["config"]
  78. assert app.llm.config.temperature == llm_config["temperature"]
  79. assert app.llm.config.max_tokens == llm_config["max_tokens"]
  80. assert app.llm.config.top_p == llm_config["top_p"]
  81. assert app.llm.config.stream == llm_config["stream"]
  82. # Validate the VectorDB config values
  83. db_config = config_data["vectordb"]["config"]
  84. assert app.db.config.collection_name == db_config["collection_name"]
  85. assert app.db.config.dir == db_config["dir"]
  86. assert app.db.config.allow_reset == db_config["allow_reset"]
  87. # Validate the Embedder config values
  88. embedder_config = config_data["embedder"]["config"]
  89. assert app.embedder.config.model == embedder_config["model"]
  90. assert app.embedder.config.deployment_name == embedder_config["deployment_name"]
  91. def test_from_opensource_config(self, mocker):
  92. mocker.patch("embedchain.vectordb.chroma.chromadb.Client")
  93. yaml_path = "configs/opensource.yaml"
  94. config_data = self.load_config_data(yaml_path)
  95. app = App.from_config(yaml_path)
  96. # Check if the App instance and its components were created correctly
  97. assert isinstance(app, App)
  98. # Validate the AppConfig values
  99. assert app.config.id == config_data["app"]["config"]["id"]
  100. assert app.config.collection_name == config_data["app"]["config"]["collection_name"]
  101. assert app.config.collect_metrics == config_data["app"]["config"]["collect_metrics"]
  102. # Validate the LLM config values
  103. llm_config = config_data["llm"]["config"]
  104. assert app.llm.config.model == llm_config["model"]
  105. assert app.llm.config.temperature == llm_config["temperature"]
  106. assert app.llm.config.max_tokens == llm_config["max_tokens"]
  107. assert app.llm.config.top_p == llm_config["top_p"]
  108. assert app.llm.config.stream == llm_config["stream"]
  109. # Validate the VectorDB config values
  110. db_config = config_data["vectordb"]["config"]
  111. assert app.db.config.collection_name == db_config["collection_name"]
  112. assert app.db.config.dir == db_config["dir"]
  113. assert app.db.config.allow_reset == db_config["allow_reset"]
  114. # Validate the Embedder config values
  115. embedder_config = config_data["embedder"]["config"]
  116. assert app.embedder.config.deployment_name == embedder_config["deployment_name"]