test_embedder.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. import pytest
  2. from chromadb.api.types import Documents, Embeddings
  3. from embedchain.config.embedder.base import BaseEmbedderConfig
  4. from embedchain.embedder.base import BaseEmbedder
  5. @pytest.fixture
  6. def base_embedder():
  7. return BaseEmbedder()
  8. def test_initialization(base_embedder):
  9. assert isinstance(base_embedder.config, BaseEmbedderConfig)
  10. # not initialized
  11. assert not hasattr(base_embedder, "embedding_fn")
  12. assert not hasattr(base_embedder, "vector_dimension")
  13. def test_set_embedding_fn(base_embedder):
  14. def embedding_function(texts: Documents) -> Embeddings:
  15. return [f"Embedding for {text}" for text in texts]
  16. base_embedder.set_embedding_fn(embedding_function)
  17. assert hasattr(base_embedder, "embedding_fn")
  18. assert callable(base_embedder.embedding_fn)
  19. embeddings = base_embedder.embedding_fn(["text1", "text2"])
  20. assert embeddings == ["Embedding for text1", "Embedding for text2"]
  21. def test_set_embedding_fn_when_not_a_function(base_embedder):
  22. with pytest.raises(ValueError):
  23. base_embedder.set_embedding_fn(None)
  24. def test_set_vector_dimension(base_embedder):
  25. base_embedder.set_vector_dimension(256)
  26. assert hasattr(base_embedder, "vector_dimension")
  27. assert base_embedder.vector_dimension == 256
  28. def test_set_vector_dimension_type_error(base_embedder):
  29. with pytest.raises(TypeError):
  30. base_embedder.set_vector_dimension(None)
  31. def test_embedder_with_config():
  32. embedder = BaseEmbedder(BaseEmbedderConfig())
  33. assert isinstance(embedder.config, BaseEmbedderConfig)