Skip to content

Commit d713f1a

Browse files
author
zhanglongbin
committed
1.add a new operator
2.modify the image loading method 3.Modify the loading method of operator permissions
1 parent d0977d7 commit d713f1a

16 files changed

Lines changed: 127 additions & 130 deletions

File tree

data_agents/utils/tools/load_samples.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def parse(path: str):
5757
with open(readme_path, encoding='utf-8') as stream:
5858
content = stream.read()
5959
else:
60-
with open(readme_path) as stream: # 在 Linux 和 macOS 下不显式指定编码
60+
with open(readme_path) as stream:
6161
content = stream.read()
6262
plan.readme = content
6363

data_engine/exporter/csghub_exporter.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,8 @@
44
from pycsghub.cmd.repo_types import RepoType
55
from pycsghub.upload_large_folder.main import upload_large_folder_internal
66

7-
#from data_celery.mongo_tools.tools import insert_pipline_job_run_task_log_info
87
from data_engine.exporter.base_exporter import Exporter
98
import os
10-
import uuid
119
import re
1210
from loguru import logger
1311
from pycsghub.repository import Repository
@@ -161,7 +159,7 @@ def _export_common(self):
161159
)
162160
r.upload()
163161
logger.info(f'Done push {self.upload_path} to repo: {self.repo_id} with branch: {self.output_branch_name}')
164-
# insert_pipline_job_run_task_log_info(job_uid, f'Done push {self.upload_path} to repo: {self.repo_id} with branch: {self.output_branch_name}')
162+
#insert_pipline_job_run_task_log_info(job_uid, f'Done push {self.upload_path} to repo: {self.repo_id} with branch: {self.output_branch_name}')
165163
if os.path.exists(self.repo_work_dir):
166164
logger.info(f'Remove {self.repo_work_dir}')
167165
shutil.rmtree(self.repo_work_dir)

data_engine/ops/mapper/extract_qa_mapper.py

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -50,39 +50,7 @@ def __init__(self,
5050
sampling_params: Dict = {'temperature': 0.3},
5151
*args,
5252
**kwargs):
53-
"""
54-
Initialization method.
55-
:param hf_model: Hugginface model id.
56-
:param trust_remote_code: passed to transformers
57-
:param pattern: regular expression pattern to search for within text.
58-
:param qa_format: Output format of question and answer pair.
59-
:param enable_vllm: Whether to use vllm for inference acceleration.
60-
:param tensor_parallel_size: It is only valid when enable_vllm is True.
61-
The number of GPUs to use for distributed execution with tensor
62-
parallelism.
63-
:param max_model_len: It is only valid when enable_vllm is True.
64-
Model context length. If unspecified, will be automatically
65-
derived from the model config.
66-
:param max_num_seqs: It is only valid when enable_vllm is True.
67-
Maximum number of sequences to be processed in a single iteration.
68-
:param sampling_params: Sampling parameters for text generation.
69-
e.g {'temperature': 0.9, 'top_p': 0.95}
70-
:param args: extra args
71-
:param kwargs: extra args
72-
73-
The default data format parsed by this interface is as follows:
74-
Model Input:
75-
蒙古国的首都是乌兰巴托(Ulaanbaatar)
76-
冰岛的首都是雷克雅未克(Reykjavik)
77-
Model Output:
78-
蒙古国的首都是乌兰巴托(Ulaanbaatar)
79-
冰岛的首都是雷克雅未克(Reykjavik)
80-
Human: 请问蒙古国的首都是哪里?
81-
Assistant: 你好,根据提供的信息,蒙古国的首都是乌兰巴托(Ulaanbaatar)。
82-
Human: 冰岛的首都是哪里呢?
83-
Assistant: 冰岛的首都是雷克雅未克(Reykjavik)。
84-
...
85-
"""
53+
8654

8755
super().__init__(*args, **kwargs)
8856
self.num_proc = 1

data_engine/ops/mapper/nlpcda_zh_mapper.py

Lines changed: 1 addition & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -30,45 +30,7 @@ def __init__(self,
3030
replace_equivalent_num: bool = False,
3131
*args,
3232
**kwargs):
33-
"""
34-
Initialization method. All augmentation methods use default parameters
35-
in default. We recommend you to only use 1-3 augmentation methods at a
36-
time. Otherwise, the semantics of samples might be changed
37-
significantly. **Notice**: some augmentation method might not work for
38-
some special texts, so there might be no augmented texts generated.
39-
40-
:param sequential: whether combine all augmentation methods to a
41-
sequence. If it's True, a sample will be augmented by all opened
42-
augmentation methods sequentially. If it's False, each opened
43-
augmentation method would generate its augmented samples
44-
independently.
45-
:param aug_num: number of augmented samples to be generated. If
46-
`sequential` is True, there will be total aug_num augmented samples
47-
generated. If it's False, there will be (aug_num *
48-
#opened_aug_method) augmented samples generated.
49-
:param keep_original_sample: whether to keep the original sample. If
50-
it's set to False, there will be only generated texts in the final
51-
datasets and the original texts will be removed. It's True in
52-
default.
53-
:param replace_similar_word: whether to open the augmentation method of
54-
replacing random words with their similar words in the original
55-
texts. e.g. "这里一共有5种不同的数据增强方法" --> "这边一共有5种不同的数据增强方法"
56-
:param replace_homophone_char: whether to open the augmentation method
57-
of replacing random characters with their homophones in the
58-
original texts. e.g. "这里一共有5种不同的数据增强方法" --> "这里一共有5种不同的濖据增强方法"
59-
:param delete_random_char: whether to open the augmentation method of
60-
deleting random characters from the original texts. e.g.
61-
"这里一共有5种不同的数据增强方法" --> "这里一共有5种不同的数据增强"
62-
:param swap_random_char: whether to open the augmentation method of
63-
swapping random contiguous characters in the original texts. e.g.
64-
"这里一共有5种不同的数据增强方法" --> "这里一共有5种不同的数据强增方法"
65-
:param replace_equivalent_num: whether to open the augmentation method
66-
of replacing random numbers with their equivalent representations
67-
in the original texts. **Notice**: Only for numbers for now. e.g.
68-
"这里一共有5种不同的数据增强方法" --> "这里一共有伍种不同的数据增强方法"
69-
:param args: extra args
70-
:param kwargs: extra args
71-
"""
33+
7234
super().__init__(*args, **kwargs)
7335

7436
self.aug_num = aug_num

data_engine/utils/constant.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ class StatsKeysConstant(object):
133133
word_rep_ratio = 'word_rep_ratio'
134134
bloom = 'bloom'
135135
high_score = 'high_score'
136+
embedding = 'embedding'
136137

137138
# image
138139
aspect_ratios = 'aspect_ratios'

data_server/api/api_router.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,6 @@
2929
api_router.include_router(operator_permission.router, prefix="/operator_permission", tags=["算子权限相关接口"])
3030

3131
api_router.include_router(op_pic_upload.op_pic_router, prefix="/internal_api", tags=["文件上传接口"])
32+
api_router.include_router(op_pic_upload.image_getter_router, tags=["文件获取接口"])
3233

3334
api_router.include_router(algo_templates.router, prefix="/algo_templates", tags=["算法模板相关接口"])

data_server/api/endpoints/algo_templates.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class AlgoTemplateListResponse(BaseModel):
3131
page_size: int = Field(..., description="每页数量")
3232

3333

34-
@router.get("/", response_model=dict, summary="获取算法模板列表")
34+
@router.get("", response_model=dict, summary="获取算法模板列表")
3535
async def get_algo_templates(
3636
user_id: str = Header(..., alias="user_id", description="用户ID"),
3737
page: int = Query(1, ge=1, description="页码"),
@@ -101,7 +101,7 @@ async def get_algo_template_by_id(
101101
db.close()
102102

103103

104-
@router.post("/", response_model=dict, summary="创建新的算法模板")
104+
@router.post("", response_model=dict, summary="创建新的算法模板")
105105
async def create_algo_template(
106106
template_data: AlgoTemplateCreate,
107107
user_id: str = Header(..., alias="user_id", description="用户ID"),

data_server/api/endpoints/op_pic_upload.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
from fastapi import APIRouter, UploadFile, File, HTTPException, status, Request
22
from typing import Dict, Any
33
import os
4+
import base64
5+
from pathlib import Path
46
from loguru import logger
57
from data_server.utils.file_storage import file_storage_manager
68
from data_server.schemas.responses import response_success, response_fail
9+
from data_celery.utils import get_project_root
710

811

912
op_pic_router = APIRouter()
13+
image_getter_router = APIRouter()
1014

1115

1216
@op_pic_router.post("/internal_api/upload", summary="上传operator图片")
@@ -78,3 +82,32 @@ async def delete_uploaded_file_by_name(filename: str) -> Dict[str, Any]:
7882
except Exception as e:
7983
logger.error(f"删除文件失败: {str(e)}")
8084
return response_fail(msg=f"删除文件失败: {str(e)}")
85+
86+
87+
@image_getter_router.get("/real_static_files/{category}/{filename}", summary="obtain_the_base64_encoding_of_the_image")
88+
async def get_image_base64(category: str, filename: str):
89+
try:
90+
project_root = get_project_root()
91+
image_path = Path(project_root) / 'attach' / category / filename
92+
93+
if not image_path.exists() or not image_path.is_file():
94+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Image not found")
95+
96+
with open(image_path, "rb") as image_file:
97+
encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
98+
99+
file_extension = filename.split('.')[-1].lower()
100+
mime_type = f"image/{file_extension}"
101+
if file_extension == 'svg':
102+
mime_type = "image/svg+xml"
103+
104+
base64_image = encoded_string
105+
106+
return response_success(data={base64_image})
107+
108+
except HTTPException as http_exc:
109+
logger.warning(f"failed-to-obtain-the-picture: {http_exc.detail}")
110+
raise http_exc
111+
except Exception as e:
112+
logger.error(f"base64_encoding_failed: {str(e)}")
113+
return response_fail(msg=f"failed_to_obtain_the_picture: {str(e)}")

data_server/api/endpoints/operator.py

Lines changed: 65 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1-
from fastapi import FastAPI, APIRouter, Depends, HTTPException, Query, Path
1+
from fastapi import FastAPI, APIRouter, Depends, HTTPException, Query, Path,Header
22
from sqlalchemy.orm import Session
33
from typing import List, Dict, Any, Optional, Annotated
4+
import base64
5+
from pathlib import Path
6+
from data_celery.utils import get_project_root
47

58
from data_server.database.session import get_sync_session
69

@@ -21,7 +24,7 @@
2124

2225

2326

24-
@router.post("/", summary="create_operator")
27+
@router.post("", summary="create_operator")
2528
def create_operator_api(
2629
operator_data: OperatorCreateRequest,
2730
db: Session = Depends(get_sync_session)
@@ -36,7 +39,7 @@ def create_operator_api(
3639
db.close()
3740

3841

39-
@router.get("/", summary="GET_LIST_OF_OPERATORS")
42+
@router.get("", summary="GET_LIST_OF_OPERATORS")
4043
def read_operators_api(
4144
skip: int = 0,
4245
limit: int = 100,
@@ -45,7 +48,27 @@ def read_operators_api(
4548

4649
try:
4750
operators = get_operators(db, skip, limit)
48-
return response_success(data=operators, msg="获取算子列表成功")
51+
operators_data = []
52+
project_root = get_project_root()
53+
for op in operators:
54+
op_dict = op.__dict__
55+
pic_base64 = None
56+
mime_type = None
57+
if op.icon:
58+
try:
59+
filename = Path(op.icon).name
60+
image_path = project_root / 'attach' / 'operator' / filename
61+
if image_path.exists() and image_path.is_file():
62+
with open(image_path, "rb") as image_file:
63+
encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
64+
pic_base64 = encoded_string
65+
except Exception:
66+
# Ignore errors for individual images
67+
pass
68+
op_dict['pic_base64'] = pic_base64
69+
operators_data.append(op_dict)
70+
71+
return response_success(data=operators_data, msg="获取算子列表成功")
4972
except Exception as e:
5073
return response_fail(msg=f"获取算子列表失败: {str(e)}")
5174
finally:
@@ -126,7 +149,7 @@ def get_operator_config_select_option_by_id_api(
126149
db.close()
127150

128151

129-
@router.post("/config_select_options/", summary="添加下拉框选项")
152+
@router.post("/config_select_options", summary="添加下拉框选项")
130153
def create_operator_config_select_option_api(
131154
option: OperatorConfigSelectOptionsCreate,
132155
db: Session = Depends(get_sync_session)
@@ -148,14 +171,30 @@ def get_operators_grouped_by_type_api(
148171

149172
try:
150173
grouped_operators = get_operators_grouped_by_type(db)
174+
project_root = get_project_root()
175+
for group in grouped_operators:
176+
for op in group['list']:
177+
pic_base64 = None
178+
icon = op.get('icon')
179+
if icon:
180+
try:
181+
filename = Path(icon).name
182+
image_path = project_root / 'attach' / 'operator' / filename
183+
if image_path.exists() and image_path.is_file():
184+
with open(image_path, "rb") as image_file:
185+
encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
186+
pic_base64 = encoded_string
187+
except Exception:
188+
pass
189+
op['pic_base64'] = pic_base64
151190
return response_success(data=grouped_operators, msg="获取分组算子列表成功")
152191
except Exception as e:
153192
return response_fail(msg=f"获取分组算子列表失败: {str(e)}")
154193
finally:
155194
db.close()
156195

157196
# find_operator_by_uuid_orgs
158-
@router.get("/types/grouped-by-condition/", summary="根据算子分类和权限返回算子数据")
197+
@router.get("/types/grouped-by-condition", summary="根据算子分类和权限返回算子数据")
159198
def get_operators_grouped_by_condition_api(
160199
payload: Dict = Depends(get_validated_token_payload),
161200
db: Session = Depends(get_sync_session),
@@ -174,8 +213,28 @@ def get_operators_grouped_by_condition_api(
174213
return response_fail("Token中缺少用户信息 (uuid)")
175214

176215
grouped_operators = get_operators_grouped_by_condition(db, user_id, paths)
216+
project_root = get_project_root()
217+
for group in grouped_operators:
218+
for op in group['list']:
219+
pic_base64 = None
220+
icon = op.get('icon')
221+
if icon:
222+
try:
223+
filename = Path(icon).name
224+
image_path = project_root / 'attach' / 'operator' / filename
225+
if image_path.exists() and image_path.is_file():
226+
with open(image_path, "rb") as image_file:
227+
encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
228+
pic_base64 = encoded_string
229+
except Exception:
230+
pass
231+
op['pic_base64'] = pic_base64
177232
return response_success(data=grouped_operators, msg="获取分组算子列表成功")
178233
except Exception as e:
179234
return response_fail(msg=f"获取分组算子列表失败: {str(e)}")
180235
finally:
181236
db.close()
237+
238+
@router.get("/isAdmin/torf")
239+
def get_isAdmin_true_or_false(isadmin: str = Header(..., alias="isadmin", description="是否管理员")):
240+
return response_success(data={"isadmin":isadmin})

data_server/api/endpoints/operator_permission.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616

1717

18-
@router.post("/", summary="创建算子权限")
18+
@router.post("", summary="创建算子权限")
1919
def create_permission_api(request_data: OperatorPermissionCreateRequest, db: Session = Depends(get_sync_session)):
2020

2121
try:
@@ -102,7 +102,7 @@ def create_permission_api(request_data: OperatorPermissionCreateRequest, db: Ses
102102

103103

104104

105-
@router.get("/", summary="获取权限列表")
105+
@router.get("", summary="获取权限列表")
106106
def read_permissions_api(skip: int = 0, limit: int = 100, db: Session = Depends(get_sync_session)):
107107

108108
try:

0 commit comments

Comments
 (0)