test_apps.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. import os
  2. import pytest
  3. import yaml
  4. from embedchain import App, CustomApp, Llama2App, OpenSourceApp
  5. from embedchain.config import AddConfig, AppConfig, BaseEmbedderConfig, BaseLlmConfig, ChromaDbConfig
  6. from embedchain.embedder.base import BaseEmbedder
  7. from embedchain.llm.base import BaseLlm
  8. from embedchain.vectordb.base import BaseVectorDB, BaseVectorDbConfig
  9. from embedchain.vectordb.chroma import ChromaDB
  10. @pytest.fixture
  11. def app():
  12. os.environ["OPENAI_API_KEY"] = "test_api_key"
  13. return App()
  14. @pytest.fixture
  15. def custom_app():
  16. os.environ["OPENAI_API_KEY"] = "test_api_key"
  17. return CustomApp()
  18. @pytest.fixture
  19. def opensource_app():
  20. os.environ["OPENAI_API_KEY"] = "test_api_key"
  21. return OpenSourceApp()
  22. @pytest.fixture
  23. def llama2_app():
  24. os.environ["OPENAI_API_KEY"] = "test_api_key"
  25. os.environ["REPLICATE_API_TOKEN"] = "-"
  26. return Llama2App()
  27. def test_app(app):
  28. assert isinstance(app.llm, BaseLlm)
  29. assert isinstance(app.db, BaseVectorDB)
  30. assert isinstance(app.embedder, BaseEmbedder)
  31. def test_custom_app(custom_app):
  32. assert isinstance(custom_app.llm, BaseLlm)
  33. assert isinstance(custom_app.db, BaseVectorDB)
  34. assert isinstance(custom_app.embedder, BaseEmbedder)
  35. def test_opensource_app(opensource_app):
  36. assert isinstance(opensource_app.llm, BaseLlm)
  37. assert isinstance(opensource_app.db, BaseVectorDB)
  38. assert isinstance(opensource_app.embedder, BaseEmbedder)
  39. def test_llama2_app(llama2_app):
  40. assert isinstance(llama2_app.llm, BaseLlm)
  41. assert isinstance(llama2_app.db, BaseVectorDB)
  42. assert isinstance(llama2_app.embedder, BaseEmbedder)
  43. class TestConfigForAppComponents:
  44. def test_constructor_config(self):
  45. collection_name = "my-test-collection"
  46. app = App(db_config=ChromaDbConfig(collection_name=collection_name))
  47. assert app.db.config.collection_name == collection_name
  48. def test_component_config(self):
  49. collection_name = "my-test-collection"
  50. database = ChromaDB(config=ChromaDbConfig(collection_name=collection_name))
  51. app = App(db=database)
  52. assert app.db.config.collection_name == collection_name
  53. def test_different_configs_are_proper_instances(self):
  54. app_config = AppConfig()
  55. wrong_config = AddConfig()
  56. with pytest.raises(TypeError):
  57. App(config=wrong_config)
  58. assert isinstance(app_config, AppConfig)
  59. llm_config = BaseLlmConfig()
  60. wrong_llm_config = "wrong_llm_config"
  61. with pytest.raises(TypeError):
  62. App(llm_config=wrong_llm_config)
  63. assert isinstance(llm_config, BaseLlmConfig)
  64. db_config = BaseVectorDbConfig()
  65. wrong_db_config = "wrong_db_config"
  66. with pytest.raises(TypeError):
  67. App(db_config=wrong_db_config)
  68. assert isinstance(db_config, BaseVectorDbConfig)
  69. embedder_config = BaseEmbedderConfig()
  70. wrong_embedder_config = "wrong_embedder_config"
  71. with pytest.raises(TypeError):
  72. App(embedder_config=wrong_embedder_config)
  73. assert isinstance(embedder_config, BaseEmbedderConfig)
  74. class TestAppFromConfig:
  75. def load_config_data(self, yaml_path):
  76. with open(yaml_path, "r") as file:
  77. return yaml.safe_load(file)
  78. def test_from_chroma_config(self):
  79. yaml_path = "configs/chroma.yaml"
  80. config_data = self.load_config_data(yaml_path)
  81. app = App.from_config(yaml_path)
  82. # Check if the App instance and its components were created correctly
  83. assert isinstance(app, App)
  84. # Validate the AppConfig values
  85. assert app.config.id == config_data["app"]["config"]["id"]
  86. assert app.config.collection_name == config_data["app"]["config"]["collection_name"]
  87. # Even though not present in the config, the default value is used
  88. assert app.config.collect_metrics is True
  89. # Validate the LLM config values
  90. llm_config = config_data["llm"]["config"]
  91. assert app.llm.config.temperature == llm_config["temperature"]
  92. assert app.llm.config.max_tokens == llm_config["max_tokens"]
  93. assert app.llm.config.top_p == llm_config["top_p"]
  94. assert app.llm.config.stream == llm_config["stream"]
  95. # Validate the VectorDB config values
  96. db_config = config_data["vectordb"]["config"]
  97. assert app.db.config.collection_name == db_config["collection_name"]
  98. assert app.db.config.dir == db_config["dir"]
  99. assert app.db.config.allow_reset == db_config["allow_reset"]
  100. # Validate the Embedder config values
  101. embedder_config = config_data["embedder"]["config"]
  102. assert app.embedder.config.model == embedder_config["model"]
  103. assert app.embedder.config.deployment_name == embedder_config["deployment_name"]
  104. def test_from_opensource_config(self):
  105. yaml_path = "configs/opensource.yaml"
  106. config_data = self.load_config_data(yaml_path)
  107. app = App.from_config(yaml_path)
  108. # Check if the App instance and its components were created correctly
  109. assert isinstance(app, App)
  110. # Validate the AppConfig values
  111. assert app.config.id == config_data["app"]["config"]["id"]
  112. assert app.config.collection_name == config_data["app"]["config"]["collection_name"]
  113. assert app.config.collect_metrics == config_data["app"]["config"]["collect_metrics"]
  114. # Validate the LLM config values
  115. llm_config = config_data["llm"]["config"]
  116. assert app.llm.config.temperature == llm_config["temperature"]
  117. assert app.llm.config.max_tokens == llm_config["max_tokens"]
  118. assert app.llm.config.top_p == llm_config["top_p"]
  119. assert app.llm.config.stream == llm_config["stream"]
  120. # Validate the VectorDB config values
  121. db_config = config_data["vectordb"]["config"]
  122. assert app.db.config.collection_name == db_config["collection_name"]
  123. assert app.db.config.dir == db_config["dir"]
  124. assert app.db.config.allow_reset == db_config["allow_reset"]
  125. # Validate the Embedder config values
  126. embedder_config = config_data["embedder"]["config"]
  127. assert app.embedder.config.model == embedder_config["model"]
  128. assert app.embedder.config.deployment_name == embedder_config["deployment_name"]