Skip to content

Commit 6413cb0

Browse files
author
zhanglongbin
committed
Fixing the issue where operators in the tool cannot be used
1 parent 473b4f5 commit 6413cb0

13 files changed

Lines changed: 541 additions & 34 deletions

File tree

2.83 KB
Loading
3.2 KB
Loading
2.81 KB
Loading

data_celery/datasource/mongo/tasks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def collection_mongo_task(task_uid: str,user_name: str,user_token: str):
5252
collection_task.task_status = DataSourceTaskStatusEnum.ERROR.value
5353
insert_datasource_run_task_log_error(task_uid, f"Task with UID {task_uid} has no associated datasource.")
5454
return False
55-
if collection_task.datasource.source_type != DataSourceTypeEnum.MYSQL.value:
55+
if collection_task.datasource.source_type != DataSourceTypeEnum.MONGODB.value:
5656
collection_task.task_status = DataSourceTaskStatusEnum.ERROR.value
5757
insert_datasource_run_task_log_error(task_uid, f"Task with UID {task_uid} is not a MySQL task.")
5858
return False

data_engine/ops/edu/encode_and_get_nearest.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,18 @@
66
from pydantic import Field
77
import requests
88

9-
from ..base_op import OPERATORS, Sample, Selector
9+
from ..base_op import OPERATORS, Sample, Selector,Param,DataType
1010

1111

1212
OP_NAME = 'encode_and_get_nearest_mapper'
1313
# 编码为嵌入向量
14-
def get_embeddings(texts: List[str], url: str = "https://ev19h0o3sv7k.space.opencsg.com/embed"):
14+
def get_embeddings(texts: List[str], model_url):
1515
"""
1616
Call API service to get text embeddings
1717
1818
Args:
1919
texts (List[str]): List of texts to encode
20-
url (str): API address, defaults to hardcoded address
20+
model_url (str): API address, defaults to hardcoded address
2121
2222
Returns:
2323
List[List[float]]: List of embedding vectors
@@ -32,14 +32,14 @@ def get_embeddings(texts: List[str], url: str = "https://ev19h0o3sv7k.space.open
3232
"normalize": True
3333
}
3434
try:
35-
response = requests.post(url, json=payload)
35+
response = requests.post(model_url, json=payload)
3636
response.raise_for_status() # Raise exception for HTTP errors
3737
embeddings = response.json() # List of embeddings
3838
except requests.RequestException as e:
3939
raise requests.RequestException(f"Error calling API: {e}")
4040
return embeddings
4141

42-
def encode_texts(texts: List[str], url: str = "https://ev19h0o3sv7k.space.opencsg.com/embed") -> List[List[float]]:
42+
def encode_texts(texts: List[str], model_url) -> List[List[float]]:
4343
"""
4444
Encode multiple texts into embedding vectors
4545
@@ -50,7 +50,7 @@ def encode_texts(texts: List[str], url: str = "https://ev19h0o3sv7k.space.opencs
5050
Returns:
5151
List[List[float]]: List of embedding vectors
5252
"""
53-
return get_embeddings(texts, url=url)
53+
return get_embeddings(texts, model_url=model_url)
5454

5555

5656
class FaissNearestNeighbour:
@@ -158,6 +158,7 @@ class EncodeAndGetNearestSelector(Selector):
158158
"""Encode texts and find nearest neighbours using Faiss."""
159159

160160
def __init__(self,
161+
model_url: str = "https://ev19h0o3sv7k.space.opencsg.com/embed",
161162
*args,
162163
**kwargs):
163164
"""
@@ -168,6 +169,7 @@ def __init__(self,
168169
"""
169170
super().__init__(*args, **kwargs)
170171
self.first_prompt = []
172+
self.model_url = model_url
171173

172174
def process(self, dataset):
173175
if len(dataset) <= 0:
@@ -176,7 +178,7 @@ def process(self, dataset):
176178

177179
first_prompt_list = dataset["first_prompt"].tolist()
178180

179-
embeddings = encode_texts(first_prompt_list)
181+
embeddings = encode_texts(first_prompt_list,self.model_url)
180182
dataset['embedding'] = embeddings
181183

182184
nearest_neighbour = FaissNearestNeighbour()
@@ -202,3 +204,10 @@ def sample(cls):
202204
"如['What is artificial intelligence?', 'How does machine learning work?']",
203205
after="数据集增加了embedding、nn_indices和nn_scores字段,包含文本的向量表示和最近邻信息"
204206
)
207+
208+
@classmethod
209+
@property
210+
def init_params(cls):
211+
return [
212+
Param("model_url", DataType.STRING, {}, "https://ev19h0o3sv7k.space.opencsg.com/embed"),
213+
]

data_engine/ops/edu/pipeline_magpie_zh.py

Lines changed: 447 additions & 0 deletions
Large diffs are not rendered by default.

data_engine/ops/filter/annotate_edu_train_bert_scorer.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,18 @@
2121
class AnnotateEduTrainBertScorer(Filter):
2222
def __init__(self,
2323
auth_token: DataType.STRING = "",
24+
model_url: DataType.STRING = "https://esupw2o6m6f4.space.opencsg.com/rerank",
2425
*args,
2526
**kwargs):
2627
super().__init__(*args, **kwargs)
2728
self.auth_token = auth_token
28-
29+
self.model_url = model_url
2930

3031
def compute_stats(self, sample, context=False):
3132
score_field = f"{self.text_key}_score"
3233
content = sample[self.text_key]
3334
sample[score_field] = 0
3435

35-
url = "https://esupw2o6m6f4.space.opencsg.com/rerank"
3636
# auth_token = "9acc3ea387b5479607bdeb5386af6e3483fbf070"
3737
data = {
3838
"query": "What is Deep Learning?",
@@ -44,19 +44,19 @@ def compute_stats(self, sample, context=False):
4444
"truncate": False,
4545
"truncation_direction": "right"
4646
}
47-
score = self.get_score_from_model(url,self.auth_token, data)
47+
score = self.get_score_from_model(self.model_url,self.auth_token, data)
4848
if score is not None:
4949
sample[score_field] = score
5050
return sample
5151

52-
def get_score_from_model(self,url, auth_token, data):
52+
def get_score_from_model(self,model_url, auth_token, data):
5353

5454
headers = {
5555
'Content-Type': 'application/json',
5656
'Authorization': f'Bearer {auth_token}'
5757
}
5858

59-
response = requests.post(url, json=data, headers=headers)
59+
response = requests.post(model_url, json=data, headers=headers)
6060

6161
if response.status_code == 200:
6262
try:
@@ -86,4 +86,5 @@ def description(cls):
8686
def init_params(cls):
8787
return [
8888
Param("auth_token", DataType.STRING, {}, ""),
89+
Param("model_url", DataType.STRING, {}, "https://esupw2o6m6f4.space.opencsg.com/rerank"),
8990
]

data_engine/ops/mapper/text_make_cosmopedia.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@
22
# https://github.com/yuyijiong/fineweb-edu-chinese/
33
# --------------------------------------------------------
44

5-
from ..base_op import OPERATORS, Mapper, Sample
5+
from ..base_op import OPERATORS, Mapper, Sample,Param,DataType
66
from ..common import chat_with_model
77

88
OP_NAME = 'make_cosmopedia_mapper'
99

10-
1110
@OPERATORS.register_module(OP_NAME)
1211
class MakeCosmopediaMapper(Mapper):
1312
"""Mapper to generate synthetic tutorial data from seed text samples."""
@@ -66,3 +65,21 @@ def sample(cls):
6665
'Training your dog to sit is one of the most fundamental commands...'
6766
)
6867

68+
@classmethod
69+
@property
70+
def init_params(cls):
71+
return [
72+
Param("web_text_max_len", DataType.STRING, {}, 800),
73+
Param("model_url", DataType.STRING, {}, "https://euqnoct5ophc.space.opencsg.com/v1/chat/completions"),
74+
Param("model", DataType.STRING, {}, "THUDM/LongWriter-glm4-9b"),
75+
Param("auth_token", DataType.STRING, {}, "9acc3ea387b5479607bdeb5386af6e3483fbf070"),
76+
Param("content", DataType.STRING, {}, '''网页摘录:“{web_text}”。
77+
以 WikiHow 的风格写一篇长而非常详细的教程,教程与此网页摘录有相关性。
78+
教程中需要包括对每个步骤的深入解释以及它如何帮助实现预期结果。你可以自由补充其他相关知识。
79+
确保清晰性和实用性,让读者能够轻松遵循教程完成任务。内容中不应包含广告或涉及隐私的信息。
80+
不要使用图像。请直接开始撰写教程。
81+
''')
82+
]
83+
84+
85+

data_engine/tools/base_tool.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def __init__(self, tool_defination: Tool_def, params: ExecutedParams):
6969
)
7070

7171
# whether the model can be accelerated using cuda
72+
7273
_accelerator = self.tool_def.accelerator if self.tool_def.accelerator else None
7374
if _accelerator is not None:
7475
self.accelerator = _accelerator
@@ -91,7 +92,7 @@ def run(self):
9192
# 0. ingest data
9293
self.tool_def.dataset_path = self.ingester.ingest()
9394
logger.info(f'Data ingested from {self.tool_def.dataset_path}')
94-
95+
print('_accelerator', 100 * '*5')
9596
# 1. data process
9697
with TRACE_HELPER_TOOL.trace_block(
9798
"run",
@@ -103,9 +104,11 @@ def run(self):
103104
"operation_name": self._name,
104105
}
105106
):
107+
106108
logger.info('Processing tool...')
107109
tstart = time.time()
108110
target_path: Path = self.process()
111+
print('_accelerator', 100 * '-5')
109112
tend = time.time()
110113
logger.info(f'Tool are done in {tend - tstart:.3f}s.')
111114

data_server/api/endpoints/job.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ async def read_log(id: int,
123123
session: Session = Depends(get_sync_session)):
124124
try:
125125
log = retreive_log(job_id=id, user_id=user_id,
126-
session=session, isadmin=isadmin, )
126+
session=session, isadmin=isadmin )
127127
if not log:
128128
raise HTTPException(
129129
status_code=status.HTTP_404_NOT_FOUND,
@@ -214,11 +214,26 @@ async def read_task_resource_info(id: int,
214214

215215
@router.post("", response_model=responses.JobCreate, description="Create the dataflow job")
216216
def create_job(
217+
217218
config: Union[Tool],
219+
220+
# config: Union[Recipe, Tool],
221+
# config: Union[Tool,Recipe],
222+
# config: Union[Tool],
223+
218224
user_id: Annotated[str | None, Header(alias="user_id")] = None,
219225
user_name: Annotated[str | None, Header(alias="user_name")] = None,
220226
user_token: Annotated[str | None, Header(alias="user_token")] = None
221227
):
228+
# print(user_id)
229+
# print(user_name)
230+
# print(user_token)
231+
# print(config)
232+
if isinstance(config, Recipe):
233+
print("匹配到 Recipe 类")
234+
# 处理 Recipe 逻辑(如解析 process 字段)
235+
elif isinstance(config, Tool):
236+
print("匹配到 Tool 类")
222237
try:
223238
result = create_new_job(
224239
job_cfg=config, user_id=user_id, user_name=user_name, user_token=user_token)

0 commit comments

Comments
 (0)