nvidia.py 1.0 KB

12345678910111213141516171819202122232425262728
  1. import logging
  2. import os
  3. from typing import Optional
  4. from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings
  5. from embedchain.config import BaseEmbedderConfig
  6. from embedchain.embedder.base import BaseEmbedder
  7. from embedchain.models import VectorDimensions
  8. logger = logging.getLogger(__name__)
  9. class NvidiaEmbedder(BaseEmbedder):
  10. def __init__(self, config: Optional[BaseEmbedderConfig] = None):
  11. if "NVIDIA_API_KEY" not in os.environ:
  12. raise ValueError("NVIDIA_API_KEY environment variable must be set")
  13. super().__init__(config=config)
  14. model = self.config.model or "nvolveqa_40k"
  15. logger.info(f"Using NVIDIA embedding model: {model}")
  16. embedder = NVIDIAEmbeddings(model=model)
  17. embedding_fn = BaseEmbedder._langchain_default_concept(embedder)
  18. self.set_embedding_fn(embedding_fn=embedding_fn)
  19. vector_dimension = self.config.vector_dimension or VectorDimensions.NVIDIA_AI.value
  20. self.set_vector_dimension(vector_dimension=vector_dimension)