|
@@ -3,6 +3,7 @@ package cn.tycoding.langchat.core.service.impl;
|
|
|
import cn.tycoding.langchat.common.dto.ChatReq;
|
|
|
import cn.tycoding.langchat.common.dto.ImageR;
|
|
|
import cn.tycoding.langchat.core.provider.ModelProvider;
|
|
|
+import cn.tycoding.langchat.core.provider.SearchProvider;
|
|
|
import cn.tycoding.langchat.core.service.Assistant;
|
|
|
import cn.tycoding.langchat.core.service.LangChatService;
|
|
|
import dev.langchain4j.data.image.Image;
|
|
@@ -19,7 +20,6 @@ import dev.langchain4j.rag.query.router.DefaultQueryRouter;
|
|
|
import dev.langchain4j.rag.query.router.QueryRouter;
|
|
|
import dev.langchain4j.service.AiServices;
|
|
|
import dev.langchain4j.service.TokenStream;
|
|
|
-import dev.langchain4j.web.search.google.customsearch.GoogleCustomWebSearchEngine;
|
|
|
import lombok.AllArgsConstructor;
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
import org.springframework.stereotype.Service;
|
|
@@ -34,7 +34,7 @@ import org.springframework.stereotype.Service;
|
|
|
public class LangChatServiceImpl implements LangChatService {
|
|
|
|
|
|
private final ModelProvider provider;
|
|
|
- private final GoogleCustomWebSearchEngine googleCustomWebSearchEngine;
|
|
|
+ private final SearchProvider searchProvider;
|
|
|
|
|
|
@Override
|
|
|
public TokenStream chat(ChatReq req) {
|
|
@@ -42,10 +42,13 @@ public class LangChatServiceImpl implements LangChatService {
|
|
|
|
|
|
Assistant assistant;
|
|
|
if (req.getIsGoogleSearch()) {
|
|
|
- ContentRetriever webSearchContentRetriever = WebSearchContentRetriever.builder()
|
|
|
- .webSearchEngine(googleCustomWebSearchEngine)
|
|
|
- .maxResults(3)
|
|
|
- .build();
|
|
|
+ ContentRetriever webSearchContentRetriever;
|
|
|
+ if (searchProvider.get() == null) {
|
|
|
+ webSearchContentRetriever = WebSearchContentRetriever.builder().maxResults(3).build();
|
|
|
+ } else {
|
|
|
+ webSearchContentRetriever = WebSearchContentRetriever.builder().maxResults(3).webSearchEngine(searchProvider.get()).build();
|
|
|
+ }
|
|
|
+
|
|
|
QueryRouter queryRouter = new DefaultQueryRouter(webSearchContentRetriever);
|
|
|
RetrievalAugmentor retrievalAugmentor = DefaultRetrievalAugmentor.builder()
|
|
|
.queryRouter(queryRouter)
|