Skip to content

Commit cc9c9c2

Browse files
zhanglongbinHaiHui886
authored andcommitted
Fix the bug in #84: The search function for models in the MD2Json tool should display all searched models.
1 parent face599 commit cc9c9c2

13 files changed

Lines changed: 146 additions & 23 deletions

File tree

data_celery/formatify/tasks.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -266,9 +266,9 @@ def format_task(task_id: int, user_name: str, user_token: str):
266266
continue
267267

268268
# Execute conversion
269-
# For PDF to MD conversion, need to pass mineru_api_url parameter
269+
# For PDF to MD conversion, need to pass mineru_api_url and mineru_backend parameters
270270
if format_task.from_data_type == DataFormatTypeEnum.PDF.value and format_task.to_data_type == DataFormatTypeEnum.Markdown.value:
271-
result = convert_func(file_path_full, format_task.task_uid, format_task.mineru_api_url)
271+
result = convert_func(file_path_full, format_task.task_uid, format_task.mineru_api_url, format_task.mineru_backend)
272272
else:
273273
result = convert_func(file_path_full, format_task.task_uid)
274274

@@ -665,7 +665,7 @@ def convert_ppt_to_markdown(file_path: str, task_uid) -> Optional[Dict[str, str]
665665
return None # Not a target file, return None
666666

667667

668-
def convert_pdf_to_markdown(file_path: str, task_uid, mineru_api_url: Optional[str] = None) -> Optional[Dict[str, str]]:
668+
def convert_pdf_to_markdown(file_path: str, task_uid, mineru_api_url: Optional[str] = None, mineru_backend: Optional[str] = None) -> Optional[Dict[str, str]]:
669669
if file_path.lower().endswith('.pdf'):
670670
insert_formatity_task_log_info(task_uid, f'Source file address:{file_path}')
671671
try:
@@ -679,10 +679,14 @@ def convert_pdf_to_markdown(file_path: str, task_uid, mineru_api_url: Optional[s
679679
server_url = mineru_api_url
680680
else:
681681
server_url = os.getenv("MINERU_API_URL", "http://111.4.242.20:30000")
682-
backend = "vlm-http-client"
682+
# MinerU backend: Priority: passed parameter > environment variable > default value
683+
if mineru_backend:
684+
backend = mineru_backend
685+
else:
686+
backend = os.getenv("MINERU_BACKEND", "http-client")
683687

684-
# Record used MinerU API address for debugging
685-
insert_formatity_task_log_info(task_uid, f'Using MinerU API server: {server_url}')
688+
# Record used MinerU API address and backend for debugging
689+
insert_formatity_task_log_info(task_uid, f'Using MinerU API server: {server_url}, backend: {backend}')
686690

687691
pdf_file_name = Path(file_path).stem
688692
temp_output_dir = Path(file_path).parent / f"_temp_pdf_convert_{pdf_file_name}"

data_engine/ops/base_op.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,7 @@ class DataType(Enum):
501501
ClosedUnitInterval = 2
502502
from_2_to_20 = 3
503503
SEARCH_SELECT = "search-select"
504+
SELECT_MODEL = "select-model"
504505

505506
@dataclass
506507
class Sample:

data_engine/tools/preprocess/md_to_jsonl_preprocess.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def export_from_files(self, upload_path: Path):
168168
:param upload_path: path with files to upload
169169
:return: branch name
170170
"""
171-
# 🔍 诊断日志: 检查上传路径
171+
172172
logger.info(f'='*80)
173173
logger.info(f'[EXPORT DEBUG] Starting export_from_files()')
174174
logger.info(f'[EXPORT DEBUG] Upload path type: {type(upload_path)}, value: {upload_path}')
@@ -179,7 +179,7 @@ def export_from_files(self, upload_path: Path):
179179
logger.error(f'[EXPORT DEBUG] Upload path does not exist!')
180180
return None
181181

182-
# 🔍 诊断日志: 列出所有要上传的文件
182+
183183
try:
184184
files_to_upload = os.listdir(upload_path)
185185
logger.info(f'[EXPORT DEBUG] Files in upload directory: {len(files_to_upload)} files')
@@ -204,7 +204,7 @@ def export_from_files(self, upload_path: Path):
204204
logger.info('[EXPORT DEBUG] No repo_id specified, skip upload.')
205205
return 'N/A'
206206

207-
# 🔍 诊断日志: repo信息
207+
208208
logger.info(f'[EXPORT DEBUG] Target repo_id: {self.parent.tool_def.repo_id}')
209209
logger.info(f'[EXPORT DEBUG] User token (first 10 chars): {self.parent.executed_params.user_token[:10] if self.parent.executed_params.user_token else "None"}...')
210210

@@ -216,7 +216,7 @@ def export_from_files(self, upload_path: Path):
216216
output_branch_name = self.parent._get_available_branch(branch)
217217
logger.info(f'[EXPORT DEBUG] Output branch name: {output_branch_name}')
218218

219-
# 🔍 诊断日志: 上传参数
219+
220220
endpoint = get_endpoint(endpoint=GetHubEndpoint())
221221
logger.info(f'[EXPORT DEBUG] Upload parameters:')
222222
logger.info(f'[EXPORT DEBUG] - repo_id: {self.parent.tool_def.repo_id}')
@@ -701,7 +701,7 @@ def init_params(cls, userid: str = None, isadmin: bool = False):
701701
"token": "token",
702702
"sentence": "sentence"
703703
}, "token"),
704-
Param("hf_tokenizer", DataType.SEARCH_SELECT, {
704+
Param("hf_tokenizer", DataType.SELECT_MODEL, {
705705
"EleutherAI/pythia-6.9b-deduped": "EleutherAI/pythia-6.9b-deduped",
706706
"hfl/chinese-bert-wwm-ext": "hfl/chinese-bert-wwm-ext"
707707
}, "EleutherAI/pythia-6.9b-deduped"),

data_server/api/endpoints/formatify.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,24 +21,35 @@
2121
@router.get("/get_mineru_api_url", response_model=dict)
2222
async def get_mineru_api_url():
2323
"""
24-
获取当前配置的 MinerU API 地址
24+
获取当前配置的 MinerU 引擎参数
2525
Returns:
26-
Dict: 包含当前 MinerU API 地址的响应
26+
Dict: 包含当前 MinerU 引擎参数的响应
2727
- mineru_api_url: MinerU API 服务器地址
28-
- source: 配置来源 ("environment" | "default")
28+
- mineru_backend: MinerU 后端类型
29+
- sources: 配置来源字典
30+
- mineru_api_url_source: mineru_api_url 的配置来源 ("environment" | "default")
31+
- mineru_backend_source: mineru_backend 的配置来源 ("environment" | "default")
2932
"""
3033
try:
31-
# Get from environment variable, use default value if not set
34+
# Get mineru_api_url from environment variable, use default value if not set
3235
mineru_api_url = os.getenv("MINERU_API_URL", "http://111.4.242.20:30000")
33-
source = "environment" if os.getenv("MINERU_API_URL") else "default"
36+
mineru_api_url_source = "environment" if os.getenv("MINERU_API_URL") else "default"
37+
38+
# Get mineru_backend from environment variable, use default value if not set
39+
mineru_backend = os.getenv("MINERU_BACKEND", "http-client")
40+
mineru_backend_source = "environment" if os.getenv("MINERU_BACKEND") else "default"
3441

3542
return response_success(data={
3643
"mineru_api_url": mineru_api_url,
37-
"source": source
44+
"mineru_backend": mineru_backend,
45+
"sources": {
46+
"mineru_api_url_source": mineru_api_url_source,
47+
"mineru_backend_source": mineru_backend_source
48+
}
3849
})
3950
except Exception as e:
4051
logger.error(f"Failed to get mineru_api_url: {str(e)}")
41-
return response_fail(msg="获取 MinerU API 地址失败")
52+
return response_fail(msg="获取 MinerU 引擎参数失败")
4253

4354

4455
@router.get("/formatify/get_format_type_list", response_model=dict)

data_server/api/endpoints/model_validator.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from loguru import logger
88
from ...schemas.responses import response_success, response_fail
99
from data_engine.utils.cache_utils import DATA_JUICER_MODELS_CACHE
10+
from typing import Optional
1011

1112
router = APIRouter()
1213

@@ -55,7 +56,7 @@ def check_model_for_md_to_jsonl(
5556

5657
# Get the first matching model information
5758
model_info = data['data'][0]
58-
model_name_from_api = model_info.get('name', model_name) # Get name field from API response
59+
model_name_from_api = model_info.get('path', model_name) # Get path field from API response
5960
repository = model_info.get('repository', {})
6061
http_clone_url = repository.get('http_clone_url')
6162

@@ -219,3 +220,87 @@ def check_model_for_md_to_jsonl(
219220
except Exception as e:
220221
logger.error(f'Error checking model: {str(e)}')
221222
return response_fail(msg=f'检查模型时发生错误: {str(e)}')
223+
224+
225+
@router.get("/list-models", summary="获取模型列表")
226+
def list_models(
227+
page: int = Query(1, description="页码,从1开始"),
228+
per_page: int = Query(16, description="每页数量,默认16条"),
229+
search: Optional[str] = Query(None, description="搜索关键词"),
230+
sort: str = Query("trending", description="排序方式")
231+
):
232+
"""
233+
Fetch model list from OpenCSG Hub and return basic information.
234+
235+
Returns:
236+
- path: Model path
237+
- updated_at: Last update time
238+
- first_tag: show_name of the first tag with category == "task"
239+
- downloads: Download count
240+
- description: Model description
241+
242+
:param page: Page number
243+
:param per_page: Items per page, default 16
244+
:param search: Search keyword (optional)
245+
:param sort: Sort method, default trending
246+
:return: Model list
247+
"""
248+
try:
249+
# Get CSGHUB_ENDPOINT from environment
250+
csghub_endpoint = os.getenv('CSGHUB_ENDPOINT', 'https://hub.opencsg.com')
251+
api_url = f'{csghub_endpoint}/api/v1/models'
252+
253+
# Set request parameters
254+
params = {
255+
'page': page,
256+
'per': per_page,
257+
'search': search or '',
258+
'sort': sort,
259+
'source': ''
260+
}
261+
262+
logger.info(f'Fetching models from {api_url} with params: {params}')
263+
264+
# Send request
265+
response = requests.get(api_url, params=params, timeout=30)
266+
response.raise_for_status()
267+
data = response.json()
268+
269+
# Extract required fields
270+
models_data = []
271+
if data.get('data'):
272+
for model in data['data']:
273+
# Find the first tag with category == "task" and get its show_name
274+
first_tag_show_name = ''
275+
if model.get('tags'):
276+
for tag in model['tags']:
277+
if tag.get('category') == 'task':
278+
first_tag_show_name = tag.get('show_name', '')
279+
break
280+
281+
model_info = {
282+
'path': model.get('path', ''),
283+
'updated_at': model.get('updated_at', ''),
284+
'first_tag': first_tag_show_name,
285+
'downloads': model.get('downloads', 0),
286+
'description': model.get('description', '')
287+
}
288+
models_data.append(model_info)
289+
290+
# Build result
291+
result = {
292+
'models': models_data,
293+
'total': data.get('total', 0),
294+
'page': page,
295+
'per_page': per_page
296+
}
297+
298+
logger.info(f'Successfully fetched {len(models_data)} models')
299+
return response_success(data=result, msg='获取模型列表成功')
300+
301+
except requests.RequestException as e:
302+
logger.error(f'Failed to request OpenCSG API: {str(e)}')
303+
return response_fail(msg=f'无法连接到OpenCSG Hub API: {str(e)}')
304+
except Exception as e:
305+
logger.error(f'Error fetching models: {str(e)}')
306+
return response_fail(msg=f'获取模型列表时发生错误: {str(e)}')

data_server/database/session.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,21 @@ def add_mineru_api_url_column():
170170
logger.info("Column 'mineru_api_url' added successfully to data_format_tasks table")
171171

172172

173+
def add_mineru_backend_column():
174+
"""Add mineru_backend column to data_format_tasks table"""
175+
with get_sync_session() as session:
176+
with session.begin():
177+
result = session.execute(text("""
178+
SELECT column_name
179+
FROM information_schema.columns
180+
WHERE table_name = 'data_format_tasks' AND column_name = 'mineru_backend';
181+
"""))
182+
183+
if not result.fetchone():
184+
session.execute(text("ALTER TABLE data_format_tasks ADD COLUMN mineru_backend VARCHAR(100);"))
185+
logger.info("Column 'mineru_backend' added successfully to data_format_tasks table")
186+
187+
173188
_initialized = False
174189
from data_server.database.bean.work import Worker
175190
from data_server.job.JobModels import Job
@@ -202,6 +217,7 @@ def create_tables():
202217

203218
add_first_op_column()
204219
add_mineru_api_url_column()
220+
add_mineru_backend_column()
205221
def is_table_initialized(table_name: str) -> bool:
206222
"""
207223
Check if a specific table contains any data.

data_server/formatify/FormatifyManager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def create_formatify_task(db_session: Session, dataFormatTask: DataFormatTaskReq
2929
to_csg_hub_repo_id=dataFormatTask.to_csg_hub_repo_id,
3030
to_data_type=dataFormatTask.to_data_type,
3131
mineru_api_url=dataFormatTask.mineru_api_url,
32+
mineru_backend=dataFormatTask.mineru_backend,
3233
task_uid=task_uid,
3334
task_status=DataFormatTaskStatusEnum.WAITING.value,
3435
owner_id=user_id)
@@ -75,7 +76,7 @@ def update_formatify_task(db_session: Session, formatify_id: int, dataFormatTask
7576
'name', 'des', 'from_csg_hub_dataset_name', 'from_csg_hub_dataset_id',
7677
'from_csg_hub_dataset_branch', 'from_data_type', 'to_csg_hub_dataset_name',
7778
'to_csg_hub_dataset_id', 'to_csg_hub_dataset_default_branch', 'to_data_type',
78-
'mineru_api_url'
79+
'mineru_api_url', 'mineru_backend'
7980
}
8081
for field in updatable_fields:
8182
value = getattr(dataFormatTaskRequest, field, None)

data_server/formatify/FormatifyModels.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class DataFormatTask(Base):
5454
task_status = Column(Integer, nullable=False, comment="任务状态 DataFormatTaskStatusEnum 枚举")
5555
owner_id = Column(Integer, comment="所属用户")
5656
mineru_api_url = Column(String(500), comment="MinerU API 地址")
57+
mineru_backend = Column(String(100), comment="MinerU 后端类型")
5758
start_run_at = Column(DateTime, comment='运行开始时间')
5859
end_run_at = Column(DateTime, comment='运行结束时间')
5960
created_at = Column(DateTime, default=datetime.datetime.now, comment='任务创建时间')
@@ -79,6 +80,7 @@ def to_dict(self):
7980
"task_status": self.task_status,
8081
"owner_id": self.owner_id,
8182
"mineru_api_url": self.mineru_api_url,
83+
"mineru_backend": self.mineru_backend,
8284
"start_run_at": self.start_run_at,
8385
"end_run_at": self.end_run_at,
8486
"created_at": self.created_at.strftime("%Y-%m-%d %H:%M:%S") if self.created_at else None,

data_server/formatify/schemas.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ class DataFormatTaskRequest(BaseModel):
1616
to_csg_hub_repo_id: Optional[str] = None
1717
to_data_type: Optional[int] = None
1818
mineru_api_url: Optional[str] = None
19+
mineru_backend: Optional[str] = None

data_server/logic/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ class Sample(BaseModelExtended):
8181
class Param(BaseModelExtended):
8282
name: str
8383
type: Optional[Union[Literal["STRING"], Literal["INTEGER"], Literal["FLOAT"],
84-
Literal["BOOLEAN"], Literal["DICTIONARY"], Literal["TUPLE"], Literal["LIST"], Literal["PositiveFloat"], Literal["ClosedUnitInterval"], Literal["from_2_to_20"], Literal["search-select"]]] = None
84+
Literal["BOOLEAN"], Literal["DICTIONARY"], Literal["TUPLE"], Literal["LIST"], Literal["PositiveFloat"], Literal["ClosedUnitInterval"], Literal["from_2_to_20"], Literal["search-select"], Literal["select-model"]]] = None
8585
option_values: Optional[list[Option]] = None
8686
value: Optional[Any] = None
8787
tempVal: Optional[Any] = None

0 commit comments

Comments
 (0)