Skip to content

Commit 5564ad1

Browse files
author
zhanglongbin
committed
Fix the bug of dataflow with ID #43
1 parent 02d53ce commit 5564ad1

4 files changed

Lines changed: 233 additions & 111 deletions

File tree

data_celery/datasource/hive/tasks.py

Lines changed: 58 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
ensure_directory_exists_remove, get_datasource_csg_hub_server_dir)
1313
from data_celery.mongo_tools.tools import insert_datasource_run_task_log_info,insert_datasource_run_task_log_error
1414
from data_server.datasource.services.datasource import get_datasource_connector
15+
from data_server.datasource.schemas import DataSourceCreate
1516
from data_engine.exporter.load import load_exporter
1617
from pathlib import Path
1718
import pandas as pd
@@ -103,13 +104,32 @@ def collection_hive_task(task_uid: str,user_name: str,user_token: str):
103104
max_line = extra_config["max_line_json"]
104105
if use_type == "sql":
105106
if use_sql:
106-
connector = get_datasource_connector(collection_task.datasource)
107-
if not connector.test_connection():
107+
try:
108+
# 将数据库对象转换为 DataSourceCreate 对象
109+
datasource_create = DataSourceCreate(
110+
name=collection_task.datasource.name,
111+
des=collection_task.datasource.des,
112+
source_type=collection_task.datasource.source_type,
113+
host=collection_task.datasource.host,
114+
port=collection_task.datasource.port,
115+
username=collection_task.datasource.username,
116+
password=collection_task.datasource.password,
117+
database=collection_task.datasource.database,
118+
auth_type=collection_task.datasource.auth_type
119+
)
120+
connector = get_datasource_connector(datasource_create)
121+
test_result = connector.test_connection()
122+
if not test_result or not test_result.get("success", False):
123+
collection_task.task_status = DataSourceTaskStatusEnum.ERROR.value
124+
error_msg = test_result.get("message", "Connection failed") if test_result else "Connection test returned None"
125+
insert_datasource_run_task_log_error(task_uid, f"Task with UID {task_uid} failed to connect to the database: {error_msg}")
126+
return False
127+
get_table_dataset_by_sql(connector, task_uid, use_sql, db_session, collection_task,
128+
datasource_temp_parquet_dir, max_line=max_line)
129+
except Exception as e:
108130
collection_task.task_status = DataSourceTaskStatusEnum.ERROR.value
109-
insert_datasource_run_task_log_error(task_uid, f"Task with UID {task_uid} failed to connect to the database.")
131+
insert_datasource_run_task_log_error(task_uid, f"Error occurred while executing the task: {str(e)}")
110132
return False
111-
get_table_dataset_by_sql(connector, task_uid, use_sql, db_session, collection_task,
112-
datasource_temp_parquet_dir, max_line=max_line)
113133
upload_path = datasource_temp_parquet_dir.join('run_sql')
114134
upload_to_csg_hub_server(csg_hub_dataset_id,
115135
csg_hub_dataset_name,
@@ -125,14 +145,34 @@ def collection_hive_task(task_uid: str,user_name: str,user_token: str):
125145
source = hive_config["source"]
126146
total_count = 0
127147
records_count = 0
128-
connector = get_datasource_connector(collection_task.datasource)
129-
if not connector.test_connection():
148+
try:
149+
# 将数据库对象转换为 DataSourceCreate 对象
150+
datasource_create = DataSourceCreate(
151+
name=collection_task.datasource.name,
152+
des=collection_task.datasource.des,
153+
source_type=collection_task.datasource.source_type,
154+
host=collection_task.datasource.host,
155+
port=collection_task.datasource.port,
156+
username=collection_task.datasource.username,
157+
password=collection_task.datasource.password,
158+
database=collection_task.datasource.database,
159+
auth_type=collection_task.datasource.auth_type
160+
)
161+
connector = get_datasource_connector(datasource_create)
162+
test_result = connector.test_connection()
163+
if not test_result or not test_result.get("success", False):
164+
collection_task.task_status = DataSourceTaskStatusEnum.ERROR.value
165+
error_msg = test_result.get("message", "Connection failed") if test_result else "Connection test returned None"
166+
insert_datasource_run_task_log_error(task_uid, f"Task with UID {task_uid} failed to connect to the database: {error_msg}")
167+
return False
168+
for table_name in source.keys():
169+
table_total = connector.get_table_total_count_hive(table_name)
170+
total_count += table_total
171+
except Exception as e:
130172
collection_task.task_status = DataSourceTaskStatusEnum.ERROR.value
131-
insert_datasource_run_task_log_error(task_uid, f"Task with UID {task_uid} failed to connect to the database.")
173+
insert_datasource_run_task_log_error(task_uid, f"Error occurred while executing the task: {str(e)}")
132174
return False
133-
for table_name in source.keys():
134-
table_total = connector.get_table_total_count_hive(table_name)
135-
total_count += table_total
175+
136176
collection_task.total_count = total_count
137177
collection_task.records_count = records_count
138178
db_session.commit()
@@ -165,8 +205,14 @@ def collection_hive_task(task_uid: str,user_name: str,user_token: str):
165205
except Exception as e:
166206
if collection_task:
167207
collection_task.task_status = DataSourceTaskStatusEnum.ERROR.value
208+
error_type = type(e).__name__
209+
error_msg = str(e)
210+
error_traceback = traceback.format_exc()
211+
logger.error(f"Task {task_uid} error: {error_type}: {error_msg}")
212+
logger.error(f"Full traceback:\n{error_traceback}")
168213
traceback.print_exc()
169-
insert_datasource_run_task_log_error(task_uid, f"Error occurred while executing the task: {e}")
214+
insert_datasource_run_task_log_error(task_uid, f"Error occurred while executing the task: {error_type}: {error_msg}")
215+
insert_datasource_run_task_log_error(task_uid, f"Traceback: {error_traceback}")
170216
return False
171217
finally:
172218
if collection_task:

data_server/api/endpoints/datasource.py

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import datetime
2+
import asyncio
23

34
from fastapi import FastAPI, APIRouter, HTTPException, status, Header, Depends,Body
45
from sqlalchemy import func
@@ -93,9 +94,25 @@ async def create_datasource(datasource: DataSourceCreate, db: Session = Depends(
9394
if datasource.source_type not in [item.value for item in DataSourceTypeEnum]:
9495
return response_fail(msg="不支持的数据源类型")
9596
# user_id = 54
97+
98+
# 处理分支信息:修正前端传递的分支信息
99+
# 前端可能将用户填写的分支名错误地放在了 csg_hub_dataset_name 中
100+
if datasource.extra_config is None:
101+
datasource.extra_config = {}
102+
103+
current_branch = datasource.extra_config.get("csg_hub_dataset_default_branch", "")
104+
dataset_name = datasource.extra_config.get("csg_hub_dataset_name", "")
105+
dataset_id = datasource.extra_config.get("csg_hub_dataset_id", "")
106+
107+
# 如果用户选择了数据流向,且分支是 main,但 dataset_name 有值,使用 dataset_name 作为分支
108+
if dataset_id and current_branch == "main" and dataset_name and dataset_name != "main" and dataset_name.strip():
109+
datasource.extra_config["csg_hub_dataset_default_branch"] = dataset_name
96110

97111
connector = get_datasource_connector(datasource)
98-
if not connector.test_connection():
112+
# Run the synchronized test_connection method in the thread pool
113+
loop = asyncio.get_event_loop()
114+
test_result = await loop.run_in_executor(None, connector.test_connection)
115+
if not test_result.get('success', False):
99116
datasource.source_status = DataSourceStatusEnum.INACTIVE.value
100117
else:
101118
if datasource.is_run:
@@ -104,7 +121,7 @@ async def create_datasource(datasource: DataSourceCreate, db: Session = Depends(
104121
datasource.source_status = DataSourceStatusEnum.WAITING.value
105122
if not user_id:
106123
return response_fail(msg="用户ID不能为空")
107-
data_source_id = create_data_source(connector.test_connection(), db, datasource, int(user_id), user_name,
124+
data_source_id = create_data_source(test_result, db, datasource, int(user_id), user_name,
108125
user_token)
109126
return response_success(data=data_source_id)
110127
except Exception as e:
@@ -145,7 +162,10 @@ async def test_datasource_connection(datasource: DataSourceCreate):
145162

146163
try:
147164
connector = get_datasource_connector(datasource)
148-
return response_success(data=connector.test_connection())
165+
# Run the synchronized test_connection method in the thread pool to avoid blocking the event loop
166+
loop = asyncio.get_event_loop()
167+
result = await loop.run_in_executor(None, connector.test_connection)
168+
return response_success(data=result)
149169
except Exception as e:
150170
logger.error(f"test_datasource_connection: {str(e)}")
151171
return response_fail(msg=f"测试连接失败:{str(e)}")
@@ -154,6 +174,16 @@ async def test_datasource_connection(datasource: DataSourceCreate):
154174
@router.put("/datasource/edit/{datasource_id}", response_model=dict)
155175
async def update_datasource(datasource_id: int, datasource: DataSourceUpdate, db: Session = Depends(get_sync_session)):
156176
try:
177+
# 处理分支信息:修正前端传递的分支信息(与创建接口相同的逻辑)
178+
if datasource.extra_config is not None:
179+
current_branch = datasource.extra_config.get("csg_hub_dataset_default_branch", "")
180+
dataset_name = datasource.extra_config.get("csg_hub_dataset_name", "")
181+
dataset_id = datasource.extra_config.get("csg_hub_dataset_id", "")
182+
183+
# 如果用户选择了数据流向,且分支是 main,但 dataset_name 有值,使用 dataset_name 作为分支
184+
if dataset_id and current_branch == "main" and dataset_name and dataset_name != "main" and dataset_name.strip():
185+
datasource.extra_config["csg_hub_dataset_default_branch"] = dataset_name
186+
157187
data_source = update_data_source(db, datasource_id, datasource)
158188
if not data_source:
159189
return response_fail(msg="更新失败")
@@ -213,12 +243,14 @@ async def get_datasource_tables(datasource: DataSourceCreate):
213243
try:
214244
# if datasource.source_type == DataSourceTypeEnum.MONGODB.value:
215245

216-
217246
connector = get_datasource_connector(datasource)
218-
if not connector.test_connection():
247+
# test_the_connection_in_the_thread_pool
248+
loop = asyncio.get_event_loop()
249+
test_result = await loop.run_in_executor(None, connector.test_connection)
250+
if not test_result.get('success', False):
219251
return response_fail(msg="数据源连接失败")
220252

221-
tables = connector.get_tables()
253+
tables = await loop.run_in_executor(None, connector.get_tables)
222254
return response_success(data=tables)
223255
except Exception as e:
224256
logger.error(f"获取表列表失败: {str(e)}")
@@ -233,17 +265,16 @@ async def get_datasource_table_columns(datasource: DataSourceCreate, table_name:
233265
return response_fail(msg="MongoDB不支持获取表和字段列表")
234266

235267
connector = get_datasource_connector(datasource)
236-
if not connector.test_connection():
268+
loop = asyncio.get_event_loop()
269+
test_result = await loop.run_in_executor(None, connector.test_connection)
270+
if not test_result.get('success', False):
237271
return response_fail(msg="数据源连接失败")
238272

239-
columns = connector.get_table_columns(table_name)
273+
columns = await loop.run_in_executor(None, connector.get_table_columns, table_name)
240274
return response_success(data=columns)
241275
except Exception as e:
242276
logger.error(f"获取表字段失败: {str(e)}")
243277
return response_fail(msg=f"获取表字段失败: {str(e)}")
244-
except Exception as e:
245-
logger.error(f"获取表字段失败: {str(e)}")
246-
return response_fail(msg=f"获取字段失败: {str(e)}")
247278

248279

249280
@router.get("/datasource/info", response_model=dict)
@@ -266,8 +297,10 @@ async def get_datasource_tables_and_columns(datasource: DataSourceCreate):
266297
return response_fail(msg="MongoDB不支持获取表和字段列表")
267298

268299
connector = get_datasource_connector(datasource)
269-
if not connector.test_connection():
270-
return response_fail(msg="数据源连接失败")
300+
test_result = connector.test_connection()
301+
if not test_result or not test_result.get("success", False):
302+
error_msg = test_result.get("message", "Connection failed") if test_result else "Connection test returned None"
303+
return response_fail(msg=f"数据源连接失败: {error_msg}")
271304

272305
tables_and_columns = connector.get_tables_and_columns()
273306
return response_success(data=tables_and_columns)

data_server/datasource/DatasourceManager.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def create_data_source(is_connection: bool, db_session: Session, datasource: Dat
3636
int: ID of the created data source
3737
"""
3838
# create db model
39+
extra_config = datasource.extra_config or {}
3940
data_source_db = DataSource(name=datasource.name,
4041
des=datasource.des,
4142
source_type=datasource.source_type,
@@ -44,8 +45,9 @@ def create_data_source(is_connection: bool, db_session: Session, datasource: Dat
4445
username=datasource.username,
4546
password=datasource.password,
4647
database=datasource.database,
48+
auth_type=datasource.auth_type,
4749
task_run_time=datasource.task_run_time,
48-
extra_config=json.dumps(datasource.extra_config, ensure_ascii=False, indent=4))
50+
extra_config=json.dumps(extra_config, ensure_ascii=False, indent=4))
4951
data_source_db.source_status = datasource.source_status
5052
data_source_db.owner_id = user_id
5153
db_session.add(data_source_db)
@@ -140,6 +142,8 @@ def update_data_source(db_session: Session, data_source_id: int, update_data: Da
140142
data_source.password = update_data.password
141143
if update_data.database is not None:
142144
data_source.database = update_data.database
145+
if update_data.auth_type is not None:
146+
data_source.auth_type = update_data.auth_type
143147
if update_data.extra_config is not None:
144148
data_source.extra_config = json.dumps(update_data.extra_config, ensure_ascii=False, indent=4)
145149
db_session.commit()

0 commit comments

Comments
 (0)