123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657 |
- import pytest
- from unittest.mock import MagicMock
- from embedchain.embedder.base import BaseEmbedder
- from embedchain.config.embedder.base import BaseEmbedderConfig
- from chromadb.api.types import Documents, Embeddings
- @pytest.fixture
- def base_embedder():
- return BaseEmbedder()
- def test_initialization(base_embedder):
- assert isinstance(base_embedder.config, BaseEmbedderConfig)
- # not initialized
- assert not hasattr(base_embedder, "embedding_fn")
- assert not hasattr(base_embedder, "vector_dimension")
- def test_set_embedding_fn(base_embedder):
- def embedding_function(texts: Documents) -> Embeddings:
- return [f"Embedding for {text}" for text in texts]
- base_embedder.set_embedding_fn(embedding_function)
- assert hasattr(base_embedder, "embedding_fn")
- assert callable(base_embedder.embedding_fn)
- embeddings = base_embedder.embedding_fn(["text1", "text2"])
- assert embeddings == ["Embedding for text1", "Embedding for text2"]
- def test_set_embedding_fn_when_not_a_function(base_embedder):
- with pytest.raises(ValueError):
- base_embedder.set_embedding_fn(None)
- def test_set_vector_dimension(base_embedder):
- base_embedder.set_vector_dimension(256)
- assert hasattr(base_embedder, "vector_dimension")
- assert base_embedder.vector_dimension == 256
- def test_set_vector_dimension_type_error(base_embedder):
- with pytest.raises(TypeError):
- base_embedder.set_vector_dimension(None)
- def test_langchain_default_concept():
- embeddings = MagicMock()
- embeddings.embed_documents.return_value = ["Embedding1", "Embedding2"]
- embed_function = BaseEmbedder._langchain_default_concept(embeddings)
- result = embed_function(["text1", "text2"])
- assert result == ["Embedding1", "Embedding2"]
- def test_embedder_with_config():
- embedder = BaseEmbedder(BaseEmbedderConfig())
- assert isinstance(embedder.config, BaseEmbedderConfig)
|