|
@@ -2,10 +2,10 @@
|
|
# @Author: privacy
|
|
# @Author: privacy
|
|
# @Date: 2024-06-27 09:33:01
|
|
# @Date: 2024-06-27 09:33:01
|
|
# @Last Modified by: privacy
|
|
# @Last Modified by: privacy
|
|
-# @Last Modified time: 2024-09-05 10:38:48
|
|
|
|
|
|
+# @Last Modified time: 2024-09-06 14:12:50
|
|
import os
|
|
import os
|
|
os.environ['TRANSFORMERS_OFFLINE'] = '1'
|
|
os.environ['TRANSFORMERS_OFFLINE'] = '1'
|
|
-from typing import List
|
|
|
|
|
|
+from typing import List, Union
|
|
|
|
|
|
import torch
|
|
import torch
|
|
import numpy as np
|
|
import numpy as np
|
|
@@ -14,7 +14,6 @@ from sklearn.metrics.pairwise import cosine_similarity
|
|
from transformers import AutoTokenizer, AutoModel
|
|
from transformers import AutoTokenizer, AutoModel
|
|
|
|
|
|
|
|
|
|
-
|
|
|
|
class Matcher:
|
|
class Matcher:
|
|
def __init__(self):
|
|
def __init__(self):
|
|
# Load model directly
|
|
# Load model directly
|
|
@@ -22,7 +21,7 @@ class Matcher:
|
|
self.tokenizer = AutoTokenizer.from_pretrained("GanymedeNil/text2vec-base-chinese")
|
|
self.tokenizer = AutoTokenizer.from_pretrained("GanymedeNil/text2vec-base-chinese")
|
|
self.model = AutoModel.from_pretrained("GanymedeNil/text2vec-base-chinese")
|
|
self.model = AutoModel.from_pretrained("GanymedeNil/text2vec-base-chinese")
|
|
|
|
|
|
- def TopK1(self, title: str, keywords: list, query_embedding, option_embeddings: list) -> pd.Series:
|
|
|
|
|
|
+ def TopK1(self, title: str, keywords: list, query_embedding: np.ndarray, option_embeddings: List[np.ndarray]) -> pd.Series:
|
|
"""
|
|
"""
|
|
获取相似度最高的向量
|
|
获取相似度最高的向量
|
|
Args:
|
|
Args:
|
|
@@ -71,6 +70,32 @@ class Matcher:
|
|
text_embeddings.append(np.mean(output.last_hidden_state.mean(dim=1).numpy(), axis=0))
|
|
text_embeddings.append(np.mean(output.last_hidden_state.mean(dim=1).numpy(), axis=0))
|
|
return text_embeddings
|
|
return text_embeddings
|
|
|
|
|
|
|
|
+ @classmethod
|
|
|
|
+ def mean_pooling(cls, token_embeddings: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
|
|
|
+ """
|
|
|
|
+ Args:
|
|
|
|
+ token_embeddings: First element of model_output contains all token embeddings
|
|
|
|
+ """
|
|
|
|
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
|
|
|
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
|
|
|
+
|
|
|
|
+ def sentence_embeddings(self, sentence: Union[str, List[str]]) -> torch.Tensor:
|
|
|
|
+ encoded_input = self.tokenizer(sentence, padding=True, truncation=True, return_tensors='pt')
|
|
|
|
+ with torch.no_grad():
|
|
|
|
+ model_output = self.model(**encoded_input)
|
|
|
|
+ return self.mean_pooling(model_output[0], encoded_input['attention_mask'])
|
|
|
|
+
|
|
|
|
+ def similarities(self, sentence: Union[str, List[str]], query: str, topk: int = 1) -> pd.DataFrame:
|
|
|
|
+ sentence_matrix = self.sentence_embeddings(sentence)
|
|
|
|
+ query_vector = self.sentence_embeddings(query)
|
|
|
|
+ cosine_similarities = cosine_similarity(query_vector, sentence_matrix)
|
|
|
|
+ similarity_df = pd.DataFrame(cosine_similarities[0], columns=['similarity'])
|
|
|
|
+ return similarity_df
|
|
|
|
+ # df_with_similarity = pd.concat([sentence, similarity_df], axis=1).sort_values(by='similarity', ascending=False)
|
|
|
|
+ # threshold = 0.7
|
|
|
|
+ # result = df_with_similarity[df_with_similarity['similarity'] > threshold]
|
|
|
|
+ # return result.head(topk)
|
|
|
|
+
|
|
|
|
|
|
if __name__ == '__main__':
|
|
if __name__ == '__main__':
|
|
matcher = Matcher()
|
|
matcher = Matcher()
|