nvidia.py 992 B

1234567891011121314151617181920212223242526
  1. import logging
  2. import os
  3. from typing import Optional
  4. from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings
  5. from embedchain.config import BaseEmbedderConfig
  6. from embedchain.embedder.base import BaseEmbedder
  7. from embedchain.models import VectorDimensions
  8. class NvidiaEmbedder(BaseEmbedder):
  9. def __init__(self, config: Optional[BaseEmbedderConfig] = None):
  10. if "NVIDIA_API_KEY" not in os.environ:
  11. raise ValueError("NVIDIA_API_KEY environment variable must be set")
  12. super().__init__(config=config)
  13. model = self.config.model or "nvolveqa_40k"
  14. logging.info(f"Using NVIDIA embedding model: {model}")
  15. embedder = NVIDIAEmbeddings(model=model)
  16. embedding_fn = BaseEmbedder._langchain_default_concept(embedder)
  17. self.set_embedding_fn(embedding_fn=embedding_fn)
  18. vector_dimension = self.config.vector_dimension or VectorDimensions.NVIDIA_AI.value
  19. self.set_vector_dimension(vector_dimension=vector_dimension)