Selaa lähdekoodia

fix embedding model

tycoding 1 vuosi sitten
vanhempi
commit
b7928a5395

+ 0 - 5
langchat-core/pom.xml

@@ -85,11 +85,6 @@
             <artifactId>langchain4j-pgvector</artifactId>
             <version>${langchain4j.version}</version>
         </dependency>
-        <dependency>
-            <groupId>dev.langchain4j</groupId>
-            <artifactId>langchain4j-embeddings-all-minilm-l6-v2</artifactId>
-            <version>${langchain4j.version}</version>
-        </dependency>
         <dependency>
             <groupId>dev.langchain4j</groupId>
             <artifactId>langchain4j-document-parser-apache-tika</artifactId>

+ 12 - 8
langchat-core/src/main/java/cn/tycoding/langchat/core/provider/EmbedProvider.java

@@ -1,6 +1,5 @@
 package cn.tycoding.langchat.core.provider;
 
-import dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel;
 import dev.langchain4j.model.embedding.EmbeddingModel;
 import lombok.AllArgsConstructor;
 import org.springframework.context.ApplicationContext;
@@ -16,14 +15,19 @@ public class EmbedProvider {
 
     private final ApplicationContext context;
 
-    public EmbeddingModel embed(String model) {
-        if (context.containsBean(model)) {
-            return (EmbeddingModel) context.getBean(model);
+    public EmbeddingModel embed() {
+        if (context.containsBean("OpenAiEmbeddingModel")) {
+            return (EmbeddingModel) context.getBean("OpenAiEmbeddingModel");
+        }
+        if (context.containsBean("AzureOpenAiEmbeddingModel")) {
+            return (EmbeddingModel) context.getBean("AzureOpenAiEmbeddingModel");
+        }
+        if (context.containsBean("QianfanEmbeddingModel")) {
+            return (EmbeddingModel) context.getBean("QianfanEmbeddingModel");
+        }
+        if (context.containsBean("QwenEmbeddingModel")) {
+            return (EmbeddingModel) context.getBean("QwenEmbeddingModel");
         }
         throw new RuntimeException("No matching embedding model information found, please check the model configuration.");
     }
-
-    public EmbeddingModel embed() {
-        return new AllMiniLmL6V2EmbeddingModel();
-    }
 }

+ 24 - 3
langchat-core/src/main/java/cn/tycoding/langchat/core/provider/ProviderInitialize.java

@@ -1,18 +1,20 @@
 package cn.tycoding.langchat.core.provider;
 
 import cn.hutool.core.util.StrUtil;
+import cn.tycoding.langchat.aigc.component.ProviderEnum;
 import cn.tycoding.langchat.aigc.entity.AigcModel;
 import cn.tycoding.langchat.aigc.service.AigcModelService;
 import cn.tycoding.langchat.common.component.SpringContextHolder;
-import cn.tycoding.langchat.aigc.component.ProviderEnum;
 import dev.langchain4j.model.azure.AzureOpenAiEmbeddingModel;
 import dev.langchain4j.model.azure.AzureOpenAiImageModel;
 import dev.langchain4j.model.azure.AzureOpenAiStreamingChatModel;
+import dev.langchain4j.model.dashscope.QwenEmbeddingModel;
 import dev.langchain4j.model.dashscope.QwenStreamingChatModel;
 import dev.langchain4j.model.ollama.OllamaStreamingChatModel;
 import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
 import dev.langchain4j.model.openai.OpenAiImageModel;
 import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
+import dev.langchain4j.model.qianfan.QianfanEmbeddingModel;
 import dev.langchain4j.model.qianfan.QianfanStreamingChatModel;
 import dev.langchain4j.model.vertexai.VertexAiGeminiStreamingChatModel;
 import dev.langchain4j.model.zhipu.ZhipuAiStreamingChatModel;
@@ -193,7 +195,7 @@ public class ProviderInitialize implements ApplicationContextAware {
                             .modelName(model.getModel())
                             .dimensions(model.getDimensions())
                             .build();
-                    contextHolder.registerBean(model.getId(), build);
+                    contextHolder.registerBean("OpenAiEmbeddingModel", build);
                 }
 
                 if (ProviderEnum.AZURE_OPENAI.getModel().equals(model.getModelType())) {
@@ -202,7 +204,26 @@ public class ProviderInitialize implements ApplicationContextAware {
                             .apiKey(model.getApiKey())
                             .deploymentName(model.getBaseUrl())
                             .build();
-                    contextHolder.registerBean(model.getId(), build);
+                    contextHolder.registerBean("AzureOpenAiEmbeddingModel", build);
+                }
+
+                if (ProviderEnum.BAIDU.getModel().equals(model.getModelType())) {
+                    QianfanEmbeddingModel build = QianfanEmbeddingModel
+                            .builder()
+                            .apiKey(model.getApiKey())
+                            .modelName(model.getModel())
+                            .secretKey(model.getSecretKey())
+                            .build();
+                    contextHolder.registerBean("QianfanEmbeddingModel", build);
+                }
+
+                if (ProviderEnum.ALIBABA.getModel().equals(model.getModelType())) {
+                    QwenEmbeddingModel build = QwenEmbeddingModel
+                            .builder()
+                            .apiKey(model.getApiKey())
+                            .modelName(model.getModel())
+                            .build();
+                    contextHolder.registerBean("QwenEmbeddingModel", build);
                 }
             }
         });

+ 2 - 2
langchat-core/src/main/java/cn/tycoding/langchat/core/service/impl/LangDocServiceImpl.java

@@ -61,7 +61,7 @@ public class LangDocServiceImpl implements LangDocService {
     public EmbeddingR embeddingText(ChatReq req) {
         TextSegment segment = TextSegment.from(req.getMessage(),
                 metadata(KNOWLEDGE, req.getKnowledgeId()).put(FILENAME, req.getDocsName()));
-        EmbeddingModel embeddingModel = embedProvider.embed(req.getModel());
+        EmbeddingModel embeddingModel = embedProvider.embed();
         Embedding embedding = embeddingModel.embed(segment).content();
 
         String id = embeddingStore.add(embedding, segment);
@@ -70,7 +70,7 @@ public class LangDocServiceImpl implements LangDocService {
 
     @Override
     public List<EmbeddingR> embeddingDocs(ChatReq req) {
-        EmbeddingModel model = embedProvider.embed(req.getModel());
+        EmbeddingModel model = embedProvider.embed();
 
         Document document = FileSystemDocumentLoader.loadDocument(req.getPath(), new ApacheTikaDocumentParser());
         document.metadata().put(KNOWLEDGE, req.getKnowledgeId()).put(FILENAME, req.getDocsName());

+ 20 - 2
langchat-ui/src/views/aigc/model/data.ts

@@ -1,3 +1,13 @@
+export enum EmbeddingProviderEnum {
+  OPENAI = 'openai',
+  AZURE_OPENAI = 'azure-openai',
+  GOOGLE = 'google',
+  OLLAMA = 'ollama',
+  BAIDU = 'baidu',
+  ALIBABA = 'alibaba',
+  ZHIPU = 'zhipu',
+}
+
 export enum ProviderEnum {
   OPENAI = 'openai',
   AZURE_OPENAI = 'azure-openai',
@@ -111,9 +121,17 @@ export const LLMProviders: any[] = [
     models: ['dall-e-2', 'dall-e-3'],
   },
   {
-    model: 'embedding',
+    model: ProviderEnum.EMBEDDING,
     name: 'Embedding',
-    models: ['text-embedding-3-small', 'text-embedding-3-large', 'text-embedding-ada-002'],
+    models: [
+      'text-embedding-3-small',
+      'text-embedding-3-large',
+      'text-embedding-ada-002',
+      'embedding-v1',
+      'bge_large_zh',
+      'bge_large_en',
+      'tao_8k',
+    ],
   },
   // {
   //   model: 'web-search',

+ 0 - 1
langchat-ui/src/views/aigc/model/edit.vue

@@ -5,7 +5,6 @@
   import { isNullOrWhitespace } from '@/utils/is';
   import { add, update } from '@/api/aigc/model';
   import { useMessage } from 'naive-ui';
-  import { getColumns } from '@/views/aigc/model/coumns';
 
   const props = defineProps<{
     provider: string;

+ 9 - 11
langchat-ui/src/views/aigc/model/schemas.ts

@@ -1,5 +1,5 @@
 import { FormSchema } from '@/components/Form';
-import { LLMProviders, ProviderEnum } from '@/views/aigc/model/data';
+import { EmbeddingProviderEnum, LLMProviders, ProviderEnum } from '@/views/aigc/model/data';
 
 const baseHeadSchemas: FormSchema[] = [
   {
@@ -336,16 +336,12 @@ export const embeddingSchemas: FormSchema[] = [
     component: 'NSelect',
     rules: [{ required: true, message: '请选择模型类型', trigger: ['blur'] }],
     componentProps: {
-      options: [
-        {
-          label: 'openai',
-          value: 'openai',
-        },
-        {
-          label: 'azure-openai',
-          value: 'azure-openai',
-        },
-      ],
+      options: Object.values(EmbeddingProviderEnum).map((i) => {
+        return {
+          label: i,
+          value: i,
+        };
+      }),
     },
   },
   {
@@ -362,6 +358,8 @@ export const embeddingSchemas: FormSchema[] = [
     component: 'NSelect',
     rules: [{ required: true, message: '请选择模型', trigger: ['blur'] }],
     componentProps: {
+      filterable: true,
+      tag: true,
       options: getModels(ProviderEnum.EMBEDDING),
     },
   },