66from pydantic import Field
77import requests
88
9- from ..base_op import OPERATORS , Sample , Selector
9+ from ..base_op import OPERATORS , Sample , Selector , Param , DataType
1010
1111
1212OP_NAME = 'encode_and_get_nearest_mapper'
1313# 编码为嵌入向量
14- def get_embeddings (texts : List [str ], url : str = "https://ev19h0o3sv7k.space.opencsg.com/embed" ):
14+ def get_embeddings (texts : List [str ], model_url ):
1515 """
1616 Call API service to get text embeddings
1717
1818 Args:
1919 texts (List[str]): List of texts to encode
20- url (str): API address, defaults to hardcoded address
20+ model_url (str): API address, defaults to hardcoded address
2121
2222 Returns:
2323 List[List[float]]: List of embedding vectors
@@ -32,14 +32,14 @@ def get_embeddings(texts: List[str], url: str = "https://ev19h0o3sv7k.space.open
3232 "normalize" : True
3333 }
3434 try :
35- response = requests .post (url , json = payload )
35+ response = requests .post (model_url , json = payload )
3636 response .raise_for_status () # Raise exception for HTTP errors
3737 embeddings = response .json () # List of embeddings
3838 except requests .RequestException as e :
3939 raise requests .RequestException (f"Error calling API: { e } " )
4040 return embeddings
4141
42- def encode_texts (texts : List [str ], url : str = "https://ev19h0o3sv7k.space.opencsg.com/embed" ) -> List [List [float ]]:
42+ def encode_texts (texts : List [str ], model_url ) -> List [List [float ]]:
4343 """
4444 Encode multiple texts into embedding vectors
4545
@@ -50,7 +50,7 @@ def encode_texts(texts: List[str], url: str = "https://ev19h0o3sv7k.space.opencs
5050 Returns:
5151 List[List[float]]: List of embedding vectors
5252 """
53- return get_embeddings (texts , url = url )
53+ return get_embeddings (texts , model_url = model_url )
5454
5555
5656class FaissNearestNeighbour :
@@ -158,6 +158,7 @@ class EncodeAndGetNearestSelector(Selector):
158158 """Encode texts and find nearest neighbours using Faiss."""
159159
160160 def __init__ (self ,
161+ model_url : str = "https://ev19h0o3sv7k.space.opencsg.com/embed" ,
161162 * args ,
162163 ** kwargs ):
163164 """
@@ -168,6 +169,7 @@ def __init__(self,
168169 """
169170 super ().__init__ (* args , ** kwargs )
170171 self .first_prompt = []
172+ self .model_url = model_url
171173
172174 def process (self , dataset ):
173175 if len (dataset ) <= 0 :
@@ -176,7 +178,7 @@ def process(self, dataset):
176178
177179 first_prompt_list = dataset ["first_prompt" ].tolist ()
178180
179- embeddings = encode_texts (first_prompt_list )
181+ embeddings = encode_texts (first_prompt_list , self . model_url )
180182 dataset ['embedding' ] = embeddings
181183
182184 nearest_neighbour = FaissNearestNeighbour ()
@@ -202,3 +204,10 @@ def sample(cls):
202204 "如['What is artificial intelligence?', 'How does machine learning work?']" ,
203205 after = "数据集增加了embedding、nn_indices和nn_scores字段,包含文本的向量表示和最近邻信息"
204206 )
207+
208+ @classmethod
209+ @property
210+ def init_params (cls ):
211+ return [
212+ Param ("model_url" , DataType .STRING , {}, "https://ev19h0o3sv7k.space.opencsg.com/embed" ),
213+ ]
0 commit comments