test_embedder.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. from unittest.mock import MagicMock
  2. import pytest
  3. from chromadb.api.types import Documents, Embeddings
  4. from embedchain.config.embedder.base import BaseEmbedderConfig
  5. from embedchain.embedder.base import BaseEmbedder
  6. @pytest.fixture
  7. def base_embedder():
  8. return BaseEmbedder()
  9. def test_initialization(base_embedder):
  10. assert isinstance(base_embedder.config, BaseEmbedderConfig)
  11. # not initialized
  12. assert not hasattr(base_embedder, "embedding_fn")
  13. assert not hasattr(base_embedder, "vector_dimension")
  14. def test_set_embedding_fn(base_embedder):
  15. def embedding_function(texts: Documents) -> Embeddings:
  16. return [f"Embedding for {text}" for text in texts]
  17. base_embedder.set_embedding_fn(embedding_function)
  18. assert hasattr(base_embedder, "embedding_fn")
  19. assert callable(base_embedder.embedding_fn)
  20. embeddings = base_embedder.embedding_fn(["text1", "text2"])
  21. assert embeddings == ["Embedding for text1", "Embedding for text2"]
  22. def test_set_embedding_fn_when_not_a_function(base_embedder):
  23. with pytest.raises(ValueError):
  24. base_embedder.set_embedding_fn(None)
  25. def test_set_vector_dimension(base_embedder):
  26. base_embedder.set_vector_dimension(256)
  27. assert hasattr(base_embedder, "vector_dimension")
  28. assert base_embedder.vector_dimension == 256
  29. def test_set_vector_dimension_type_error(base_embedder):
  30. with pytest.raises(TypeError):
  31. base_embedder.set_vector_dimension(None)
  32. def test_langchain_default_concept():
  33. embeddings = MagicMock()
  34. embeddings.embed_documents.return_value = ["Embedding1", "Embedding2"]
  35. embed_function = BaseEmbedder._langchain_default_concept(embeddings)
  36. result = embed_function(["text1", "text2"])
  37. assert result == ["Embedding1", "Embedding2"]
  38. def test_embedder_with_config():
  39. embedder = BaseEmbedder(BaseEmbedderConfig())
  40. assert isinstance(embedder.config, BaseEmbedderConfig)