diff --git a/pandajedi/jedibrokerage/AtlasBrokerUtils.py b/pandajedi/jedibrokerage/AtlasBrokerUtils.py index 8256c0278..897f20d07 100644 --- a/pandajedi/jedibrokerage/AtlasBrokerUtils.py +++ b/pandajedi/jedibrokerage/AtlasBrokerUtils.py @@ -17,6 +17,7 @@ from pandaserver.dataservice import DataServiceUtils from pandaserver.dataservice.DataServiceUtils import select_scope from pandaserver.taskbuffer import JobUtils, ProcessGroups, SiteSpec +from pandaserver.taskbuffer.DdmSpec import DOWNTIME_STATUSES # get nuclei where data is available @@ -1094,10 +1095,7 @@ def check( # check vendor if host_gpu_spec["vendor"] != "*": - if not wn_gpus or not any( - g.get("vendor") and re.match(host_gpu_spec["vendor"], g["vendor"], re.IGNORECASE) - for g in wn_gpus - ): + if not wn_gpus or not any(g.get("vendor") and re.match(host_gpu_spec["vendor"], g["vendor"], re.IGNORECASE) for g in wn_gpus): continue # check model (include or exclude pattern) @@ -1110,19 +1108,13 @@ def check( model_excl = False if not wn_gpus: continue - matches = any( - g.get("model") and re.match(model_pattern, g["model"], re.IGNORECASE) - for g in wn_gpus - ) + matches = any(g.get("model") and re.match(model_pattern, g["model"], re.IGNORECASE) for g in wn_gpus) if matches == model_excl: continue # check VRAM (in MB); supports operators: ==, >=, <=, >, <, != (e.g. ">=40960") if "vram" in host_gpu_spec: - if not wn_gpus or not any( - g.get("vram") and compare_version_string(str(g["vram"]), host_gpu_spec["vram"]) - for g in wn_gpus - ): + if not wn_gpus or not any(g.get("vram") and compare_version_string(str(g["vram"]), host_gpu_spec["vram"]) for g in wn_gpus): continue # check GPU microarchitecture generation (e.g. Ampere, Hopper, Ada Lovelace) @@ -1136,16 +1128,14 @@ def check( # check minimum CUDA version if "version" in host_gpu_spec: if not wn_gpus or not any( - g.get("framework_version") and compare_version_string(g["framework_version"], host_gpu_spec["version"]) - for g in wn_gpus + g.get("framework_version") and compare_version_string(g["framework_version"], host_gpu_spec["version"]) for g in wn_gpus ): continue # check minimum GPU driver version (kernel driver, e.g. 575.57.08) if "driver_version" in host_gpu_spec: if not wn_gpus or not any( - g.get("driver_version") and compare_version_string(g["driver_version"], host_gpu_spec["driver_version"]) - for g in wn_gpus + g.get("driver_version") and compare_version_string(g["driver_version"], host_gpu_spec["driver_version"]) for g in wn_gpus ): continue go_ahead = True @@ -1271,12 +1261,12 @@ def check_endpoints_with_blacklist( tmp_read_input_over_lan = tmp_input_endpoint["detailed_status"].get("read_lan") tmp_receive_input_over_wan = tmp_input_endpoint["detailed_status"].get("write_wan") # can read input from local - if tmp_read_input_over_lan not in ["OFF", "TEST"]: + if tmp_read_input_over_lan not in DOWNTIME_STATUSES: read_input_over_lan = True # can receive input from remote to local if tmp_site_name not in sites_in_nucleus: # satellite sites - if tmp_receive_input_over_wan not in ["OFF", "TEST"]: + if tmp_receive_input_over_wan not in DOWNTIME_STATUSES: receive_input_over_wan = True else: # NA for nucleus sites @@ -1287,12 +1277,12 @@ def check_endpoints_with_blacklist( tmp_write_output_over_lan = tmp_output_endpoint["detailed_status"].get("write_lan") tmp_send_output_over_wan = tmp_output_endpoint["detailed_status"].get("read_wan") # can write output to local - if tmp_write_output_over_lan not in ["OFF", "TEST"]: + if tmp_write_output_over_lan not in DOWNTIME_STATUSES: write_output_over_lan = True # can send output from local to remote if tmp_site_name not in sites_in_nucleus: # satellite sites - if tmp_send_output_over_wan not in ["OFF", "TEST"]: + if tmp_send_output_over_wan not in DOWNTIME_STATUSES: send_output_over_wan = True else: # NA for nucleus sites diff --git a/pandajedi/jedibrokerage/AtlasProdJobBroker.py b/pandajedi/jedibrokerage/AtlasProdJobBroker.py index be59b992b..701b22900 100644 --- a/pandajedi/jedibrokerage/AtlasProdJobBroker.py +++ b/pandajedi/jedibrokerage/AtlasProdJobBroker.py @@ -12,6 +12,7 @@ from pandaserver.dataservice.DataServiceUtils import select_scope from pandaserver.srvcore import CoreUtils from pandaserver.taskbuffer import EventServiceUtils, JobUtils +from pandaserver.taskbuffer.DdmSpec import DOWNTIME_STATUSES from . import AtlasBrokerUtils from .JobBrokerBase import JobBrokerBase @@ -262,7 +263,7 @@ def doBrokerage(self, taskSpec, cloudName, inputChunk, taskParamMap, hintForTB=F if not default_endpoint_out: continue receive_output_over_wan = default_endpoint_out["detailed_status"].get("write_wan") - if receive_output_over_wan in ["OFF", "TEST"]: + if receive_output_over_wan in DOWNTIME_STATUSES: nucleus_with_storages_unwritable_over_wan[tmp_name] = receive_output_over_wan else: diff --git a/pandajedi/jedibrokerage/AtlasProdTaskBroker.py b/pandajedi/jedibrokerage/AtlasProdTaskBroker.py index 791a6346c..26aa38c89 100644 --- a/pandajedi/jedibrokerage/AtlasProdTaskBroker.py +++ b/pandajedi/jedibrokerage/AtlasProdTaskBroker.py @@ -18,6 +18,7 @@ ) from pandajedi.jedirefine import RefinerUtils from pandaserver.dataservice import DataServiceUtils +from pandaserver.taskbuffer.DdmSpec import DOWNTIME_STATUSES from . import AtlasBrokerUtils from .AtlasProdJobBroker import AtlasProdJobBroker @@ -345,14 +346,14 @@ def runImpl(self): break # check blacklist read_wan_status = tmpEP["detailed_status"].get("read_wan") - if read_wan_status in ["OFF", "TEST"]: + if read_wan_status in DOWNTIME_STATUSES: tmpLog.info( f" skip nucleus={tmpNucleus} since {tmp_ddm_endpoint_name} has read_wan={read_wan_status} criteria=-source_blacklist" ) to_skip = True break write_wan_status = tmpEP["detailed_status"].get("write_wan") - if write_wan_status in ["OFF", "TEST"]: + if write_wan_status in DOWNTIME_STATUSES: tmpLog.info( f" skip nucleus={tmpNucleus} since {tmp_ddm_endpoint_name} has write_wan={write_wan_status} criteria=-destination_blacklist" ) diff --git a/pandajedi/jedicore/JediTaskBuffer.py b/pandajedi/jedicore/JediTaskBuffer.py index 1ac8bb9cd..514d2b0fe 100644 --- a/pandajedi/jedicore/JediTaskBuffer.py +++ b/pandajedi/jedicore/JediTaskBuffer.py @@ -74,6 +74,7 @@ def insertFilesForDataset_JEDI( order_by, maxFileRecords, skip_short_output, + lfn_constituent_map=None, ): with self.proxyPool.get() as proxy: return proxy.insertFilesForDataset_JEDI( @@ -111,6 +112,7 @@ def insertFilesForDataset_JEDI( order_by, maxFileRecords, skip_short_output, + lfn_constituent_map=lfn_constituent_map, ) # get files from the JEDI contents table with jediTaskID and/or datasetID @@ -409,6 +411,7 @@ def registerTaskInOneShot_JEDI( unmergeDatasetSpecMap, uniqueTaskName, oldTaskStatus, + in_content_dataset_spec_list, ): with self.proxyPool.get() as proxy: return proxy.registerTaskInOneShot_JEDI( @@ -424,6 +427,7 @@ def registerTaskInOneShot_JEDI( unmergeDatasetSpecMap, uniqueTaskName, oldTaskStatus, + in_content_dataset_spec_list, ) # set tasks to be assigned @@ -590,9 +594,9 @@ def retryTask_JEDI( ) # append input datasets for incremental execution - def appendDatasets_JEDI(self, jediTaskID, inMasterDatasetSpecList, inSecDatasetSpecList): + def appendDatasets_JEDI(self, jediTaskID, inMasterDatasetSpecList, inSecDatasetSpecList, in_content_dataset_specs): with self.proxyPool.get() as proxy: - return proxy.appendDatasets_JEDI(jediTaskID, inMasterDatasetSpecList, inSecDatasetSpecList) + return proxy.appendDatasets_JEDI(jediTaskID, inMasterDatasetSpecList, inSecDatasetSpecList, in_content_dataset_specs) # record retry history def recordRetryHistory_JEDI(self, jediTaskID, oldNewPandaIDs, relationType): diff --git a/pandajedi/jedidog/AtlasAnalWatchDog.py b/pandajedi/jedidog/AtlasAnalWatchDog.py index 026eaf150..8b48202c7 100644 --- a/pandajedi/jedidog/AtlasAnalWatchDog.py +++ b/pandajedi/jedidog/AtlasAnalWatchDog.py @@ -8,6 +8,7 @@ from pandajedi.jedicore.MsgWrapper import MsgWrapper from pandaserver.dataservice.activator import Activator +from pandaserver.taskbuffer.JediDatasetSpec import JediDatasetSpec from .TypicalWatchDogBase import TypicalWatchDogBase @@ -752,7 +753,7 @@ def do_periodic_action(self): {"type=.+": {"lifetime": lifetime}, "(SCRATCH|USER)DISK": {"lifetime": lifetime}}, ) # get input datasets - _, tmp_datasets = self.taskBufferIF.getDatasetsWithJediTaskID_JEDI(task_id, ["input"]) + _, tmp_datasets = self.taskBufferIF.getDatasetsWithJediTaskID_JEDI(task_id, ["input", JediDatasetSpec.get_constituent_input_type()]) for dataset_spec in tmp_datasets: # get locations rses = self.taskBufferIF.get_dataset_locality(task_id, dataset_spec.datasetID) @@ -768,6 +769,7 @@ def do_periodic_action(self): f"reset frozen time for taskID={task_id} since all locations {rses} of input dataset {dataset_spec.datasetName} are in downtime" ) self.taskBufferIF.reset_frozen_time_for_task(task_id) + break except Exception as e: tmp_log.error(f"failed with {str(e)}{traceback.format_exc()}") diff --git a/pandajedi/jediorder/ContentsFeeder.py b/pandajedi/jediorder/ContentsFeeder.py index c50d5f7d3..bb9216020 100644 --- a/pandajedi/jediorder/ContentsFeeder.py +++ b/pandajedi/jediorder/ContentsFeeder.py @@ -13,6 +13,7 @@ from pandajedi.jedicore.MsgWrapper import MsgWrapper from pandajedi.jedicore.ThreadUtils import ListWithLock, ThreadPool, WorkerThread from pandajedi.jedirefine import RefinerUtils +from pandaserver.taskbuffer.JediDatasetSpec import JediDatasetSpec from .JediKnight import JediKnight @@ -129,6 +130,12 @@ def feed_contents_to_tasks(self, task_ds_list, real_run=True): if not tmpStat or taskSpec is None: self.logger.debug(f"failed to get taskSpec for jediTaskID={jediTaskID}") continue + # get constituent datasets grouped by their master input datasetID + constituent_by_master = {} + _, c_datasets = self.taskBufferIF.getDatasetsWithJediTaskID_JEDI(jediTaskID, [JediDatasetSpec.get_constituent_input_type()]) + if c_datasets: + for c_ds in c_datasets: + constituent_by_master.setdefault(c_ds.masterID, []).append((c_ds.datasetID, c_ds.datasetName)) # make logger try: @@ -531,6 +538,16 @@ def feed_contents_to_tasks(self, task_ds_list, real_run=True): orderBy = taskSpec.order_input_by() else: orderBy = None + # build LFN -> constituent datasetID map for this input dataset + lfn_constituent_map = {} + if datasetSpec.datasetID in constituent_by_master: + for c_id, c_name in constituent_by_master[datasetSpec.datasetID]: + try: + c_files = ddmIF.getFilesInDataset(c_name) + for f_data in c_files.values(): + lfn_constituent_map[str(f_data["lfn"])] = c_id + except Exception: + tmpLog.warning(f"failed to get files for constituent dataset {c_name}") # feed files to the contents table tmpLog.debug("update contents") res_dict = self.taskBufferIF.insertFilesForDataset_JEDI( @@ -568,6 +585,7 @@ def feed_contents_to_tasks(self, task_ds_list, real_run=True): orderBy, maxFileRecords, skip_short_output, + lfn_constituent_map=lfn_constituent_map, ) retDB = res_dict["ret_val"] missingFileList = res_dict["missingFileList"] diff --git a/pandajedi/jediorder/TaskRefiner.py b/pandajedi/jediorder/TaskRefiner.py index d3b89bb56..6b17ebfbb 100644 --- a/pandajedi/jediorder/TaskRefiner.py +++ b/pandajedi/jediorder/TaskRefiner.py @@ -564,6 +564,7 @@ def runImpl(self): impl.unmergeDatasetSpecMap, uniqueTaskName, taskStatus, + impl.in_content_dataset_specs, ) if not tmpStat: tmpErrStr = "failed to register the task to JEDI in a single shot" @@ -592,7 +593,9 @@ def runImpl(self): # update task with new params self.taskBufferIF.updateTask_JEDI(impl.taskSpec, {"jediTaskID": impl.taskSpec.jediTaskID}, oldStatus=[taskStatus]) # appending for incremental execution - tmpStat = self.taskBufferIF.appendDatasets_JEDI(jediTaskID, impl.inMasterDatasetSpec, impl.inSecDatasetSpecList) + tmpStat = self.taskBufferIF.appendDatasets_JEDI( + jediTaskID, impl.inMasterDatasetSpec, impl.inSecDatasetSpecList, impl.in_content_dataset_specs + ) if not tmpStat: tmpLog.error("failed to append datasets for incexec") except Exception: diff --git a/pandajedi/jedirefine/TaskRefinerBase.py b/pandajedi/jedirefine/TaskRefinerBase.py index 805a0a893..239bc4b8d 100644 --- a/pandajedi/jedirefine/TaskRefinerBase.py +++ b/pandajedi/jedirefine/TaskRefinerBase.py @@ -51,6 +51,7 @@ def initializeRefiner(self, tmpLog): self.unmergeDatasetSpecMap = {} self.oldTaskStatus = None self.unknownDatasetList = [] + self.in_content_dataset_specs = [] # set jobParamsTemplate def setJobParamsTemplate(self, jobParamsTemplate): @@ -607,6 +608,7 @@ def doBasicRefine(self, taskParamMap): if datasetName == "": continue # expand + constituent_dataset_names_not_used_as_input = [] if datasetSpec.isPseudo() or datasetSpec.type in ["random_seed"] or datasetName == "DBR_LATEST": # pseudo input tmpDatasetNameList = [datasetName] @@ -617,54 +619,64 @@ def doBasicRefine(self, taskParamMap): if not tmpIF: tmpDatasetNameList = [] else: + # get datasets in dataset container + dataset_names_in_container = tmpIF.expandContainer(datasetName) + # sort datasets to process online complete replicas first + tmp_ok_list = [] + tmp_ng_list = [] + for tmp_dataset_name_in_container in dataset_names_in_container: + # skip the check if enough datasets are OK + if len(tmp_ok_list) > 10: + is_ok = True + else: + # check if complete replica is available at online endpoint + is_ok = False + tmp_dict = tmpIF.listDatasetReplicas(tmp_dataset_name_in_container) + for tmp_endpoint, tmp_data_list in tmp_dict.items(): + tmp_data = tmp_data_list[0] + if ( + tmp_data["total"] + and tmp_data["total"] == tmp_data["found"] + and self.siteMapper.is_readable_remotely(tmp_endpoint) + ): + is_ok = True + break + if is_ok: + tmp_ok_list.append(tmp_dataset_name_in_container) + else: + tmp_ng_list.append(tmp_dataset_name_in_container) + dataset_names_in_container = tmp_ok_list + tmp_ng_list if "expand" in tmpItem and tmpItem["expand"] is True: # expand dataset container - tmpDatasetNameList = tmpIF.expandContainer(datasetName) - # sort datasets to process online complete replicas first - tmp_ok_list = [] - tmp_ng_list = [] - for tmp_dataset_name in tmpDatasetNameList: - # skip the check if enough datasets are OK - if len(tmp_ok_list) > 10: - is_ok = True - else: - # check if complete replica is available at online endpoint - is_ok = False - tmp_dict = tmpIF.listDatasetReplicas(tmp_dataset_name) - for tmp_endpoint, tmp_data_list in tmp_dict.items(): - tmp_data = tmp_data_list[0] - if ( - tmp_data["total"] - and tmp_data["total"] == tmp_data["found"] - and self.siteMapper.is_readable_remotely(tmp_endpoint) - ): - is_ok = True - break - if is_ok: - tmp_ok_list.append(tmp_dataset_name) - else: - tmp_ng_list.append(tmp_dataset_name) - tmpDatasetNameList = tmp_ok_list + tmp_ng_list + tmpDatasetNameList = dataset_names_in_container else: # normal dataset name tmpDatasetNameList = tmpIF.listDatasets(datasetName) + constituent_dataset_names_not_used_as_input = [i for i in dataset_names_in_container if i not in tmpDatasetNameList] i_element = 0 for elementDatasetName in tmpDatasetNameList: - if "expandedList" in tmpItem: - if elementDatasetName not in tmpItem["expandedList"]: - tmpItem["expandedList"].append(elementDatasetName) - inDatasetSpec = copy.copy(datasetSpec) - inDatasetSpec.datasetName = elementDatasetName - if nIn > 0 or not self.taskSpec.is_hpo_workflow(): - inDatasetSpec.containerName = datasetName + if elementDatasetName not in tmpItem["expandedList"]: + tmpItem["expandedList"].append(elementDatasetName) + inDatasetSpec = copy.copy(datasetSpec) + inDatasetSpec.datasetName = elementDatasetName + if nIn > 0 or not self.taskSpec.is_hpo_workflow(): + inDatasetSpec.containerName = datasetName + # add remaining constituent datasets if they are not used as master input + if nIn == 0: + for tmp_dataset_name_in_container in constituent_dataset_names_not_used_as_input: + if tmp_dataset_name_in_container not in [ds.datasetName for ds in self.in_content_dataset_specs]: + tmp_dataset_spec = copy.copy(inDatasetSpec) + tmp_dataset_spec.type = JediDatasetSpec.get_constituent_input_type() + tmp_dataset_spec.datasetName = tmp_dataset_name_in_container + self.in_content_dataset_specs.append(tmp_dataset_spec) + else: + if self.taskSpec.is_work_segmented(): + inDatasetSpec.containerName = "{}/{}".format( + taskParamMap["segmentSpecs"][i_element]["name"], taskParamMap["segmentSpecs"][i_element]["id"] + ) else: - if self.taskSpec.is_work_segmented(): - inDatasetSpec.containerName = "{}/{}".format( - taskParamMap["segmentSpecs"][i_element]["name"], taskParamMap["segmentSpecs"][i_element]["id"] - ) - else: - inDatasetSpec.containerName = "None/None" - inDatasetSpecList.append(inDatasetSpec) + inDatasetSpec.containerName = "None/None" + inDatasetSpecList.append(inDatasetSpec) i_element += 1 # empty input if inDatasetSpecList == [] and self.oldTaskStatus != "rerefine": @@ -785,6 +797,7 @@ def doBasicRefine(self, taskParamMap): else: # append self.unmergeDatasetSpecMap[datasetSpec.outputMapKey()] = umDatasetSpec + self.tmpLog.debug(f"input constituent datasets: {len(self.in_content_dataset_specs)}") # set attributes for merging if "mergeOutput" in taskParamMap and taskParamMap["mergeOutput"] is True: self.setSplitRule(None, 1, JediTaskSpec.splitRuleToken["mergeOutput"]) diff --git a/pandaserver/brokerage/SiteMapper.py b/pandaserver/brokerage/SiteMapper.py index 913c84a76..7ca45da96 100644 --- a/pandaserver/brokerage/SiteMapper.py +++ b/pandaserver/brokerage/SiteMapper.py @@ -6,6 +6,7 @@ from pandaserver.config import panda_config from pandaserver.dataservice.DataServiceUtils import select_scope +from pandaserver.taskbuffer.DdmSpec import DOWNTIME_STATUSES from pandaserver.taskbuffer.NucleusSpec import NucleusSpec from pandaserver.taskbuffer.SiteSpec import SiteSpec @@ -363,7 +364,7 @@ def is_readable_remotely(self, endpoint_name: str) -> bool: bool: True if the endpoint is readable, False otherwise """ endpoints_with_read_wan_status = self.endpoint_detailed_status_summary.get("read_wan", {}) - bad_endpoints = endpoints_with_read_wan_status.get("OFF", []) + endpoints_with_read_wan_status.get("TEST", []) + bad_endpoints = [ep for status in DOWNTIME_STATUSES for ep in endpoints_with_read_wan_status.get(status, [])] return endpoint_name not in bad_endpoints def is_readable_locally(self, endpoint_name: str) -> bool: @@ -374,7 +375,7 @@ def is_readable_locally(self, endpoint_name: str) -> bool: bool: True if the endpoint is readable, False otherwise """ endpoints_with_read_lan_status = self.endpoint_detailed_status_summary.get("read_lan", {}) - bad_endpoints = endpoints_with_read_lan_status.get("OFF", []) + endpoints_with_read_lan_status.get("TEST", []) + bad_endpoints = [ep for status in DOWNTIME_STATUSES for ep in endpoints_with_read_lan_status.get(status, [])] return endpoint_name not in bad_endpoints def make_endpoint_to_sites_map(self) -> None: diff --git a/pandaserver/taskbuffer/DdmSpec.py b/pandaserver/taskbuffer/DdmSpec.py index 4f38e2ec3..f70011f4d 100644 --- a/pandaserver/taskbuffer/DdmSpec.py +++ b/pandaserver/taskbuffer/DdmSpec.py @@ -5,6 +5,9 @@ import re +# DDM endpoint activity statuses that indicate downtime (activity disabled) +DOWNTIME_STATUSES = ("OFF", "TEST") + class DdmSpec(object): # constructor diff --git a/pandaserver/taskbuffer/JediDatasetSpec.py b/pandaserver/taskbuffer/JediDatasetSpec.py index 77f795831..fbbad5e32 100644 --- a/pandaserver/taskbuffer/JediDatasetSpec.py +++ b/pandaserver/taskbuffer/JediDatasetSpec.py @@ -234,6 +234,12 @@ def getUnknownInputType(cls): getUnknownInputType = classmethod(getUnknownInputType) + # get type of constituent input + def get_constituent_input_type(cls): + return "in_constituent" + + get_constituent_input_type = classmethod(get_constituent_input_type) + # check if JEDI needs to keep track of file usage def toKeepTrack(self): if self.isNoSplit() and self.isRepeated(): diff --git a/pandaserver/taskbuffer/JediFileSpec.py b/pandaserver/taskbuffer/JediFileSpec.py index 4a3425be7..45548dedc 100644 --- a/pandaserver/taskbuffer/JediFileSpec.py +++ b/pandaserver/taskbuffer/JediFileSpec.py @@ -40,6 +40,7 @@ class JediFileSpec(object): "ramCount", "is_waiting", "proc_status", + "constituent_id", ) # attributes which have 0 by default _zeroAttrs = ("fsize", "attemptNr", "failedAttempt", "ramCount") diff --git a/pandaserver/taskbuffer/db_proxy_mods/misc_standalone_module.py b/pandaserver/taskbuffer/db_proxy_mods/misc_standalone_module.py index 40e34f816..3f7d5f296 100644 --- a/pandaserver/taskbuffer/db_proxy_mods/misc_standalone_module.py +++ b/pandaserver/taskbuffer/db_proxy_mods/misc_standalone_module.py @@ -3582,7 +3582,7 @@ def deleteOutdatedDatasetLocality_JEDI(self, before_timestamp): return retVal # append input datasets for incremental execution - def appendDatasets_JEDI(self, jediTaskID, inMasterDatasetSpecList, inSecDatasetSpecList): + def appendDatasets_JEDI(self, jediTaskID, inMasterDatasetSpecList, inSecDatasetSpecList, in_content_dataset_specs): comment = " /* JediDBProxy.appendDatasets_JEDI */" tmpLog = self.create_tagged_logger(comment, f"jediTaskID={jediTaskID}") tmpLog.debug("start") @@ -3616,16 +3616,17 @@ def appendDatasets_JEDI(self, jediTaskID, inMasterDatasetSpecList, inSecDatasetS # get existing input datasets varMap = {} varMap[":jediTaskID"] = jediTaskID - sqlDS = "SELECT datasetName,datasetID,status,nFilesTobeUsed,nFilesUsed,masterID " - sqlDS += f"FROM {panda_config.schemaJEDI}.JEDI_Datasets " - sqlDS += f"WHERE jediTaskID=:jediTaskID AND type IN ({INPUT_TYPES_var_str}) " + sql_get_constituent = "SELECT datasetName,datasetID,status,nFilesTobeUsed,nFilesUsed,masterID,containerName " + sql_get_constituent += f"FROM {panda_config.schemaJEDI}.JEDI_Datasets " + sql_get_constituent += f"WHERE jediTaskID=:jediTaskID AND type IN ({INPUT_TYPES_var_str}) " varMap.update(INPUT_TYPES_var_map) - self.cur.execute(sqlDS + comment, varMap) + self.cur.execute(sql_get_constituent + comment, varMap) resDS = self.cur.fetchall() # check if existing datasets are available, and update status if necessary sql_ex = f"UPDATE {panda_config.schemaJEDI}.JEDI_Datasets SET status=:status WHERE jediTaskID=:jediTaskID AND datasetID=:datasetID " existingDatasets = {} - for datasetName, dataset_id, datasetStatus, nFilesTobeUsed, nFilesUsed, masterID in resDS: + in_container_name_id_map = {} + for datasetName, dataset_id, datasetStatus, nFilesTobeUsed, nFilesUsed, masterID, containerName in resDS: # only master datasets with remaining files try: if masterID is None and (nFilesTobeUsed - nFilesUsed > 0 or datasetStatus in JediDatasetSpec.statusToUpdateContents()): @@ -3650,6 +3651,8 @@ def appendDatasets_JEDI(self, jediTaskID, inMasterDatasetSpecList, inSecDatasetS except Exception: pass existingDatasets[datasetName] = datasetStatus + if containerName and containerName not in in_container_name_id_map: + in_container_name_id_map[containerName] = dataset_id # insert datasets sqlID = f"INSERT INTO {panda_config.schemaJEDI}.JEDI_Datasets ({JediDatasetSpec.columnNames()}) " sqlID += JediDatasetSpec.bindValuesExpression() @@ -3672,6 +3675,8 @@ def appendDatasets_JEDI(self, jediTaskID, inMasterDatasetSpecList, inSecDatasetS datasetID = int(val) masterID = datasetID datasetSpec.datasetID = datasetID + if datasetSpec.containerName and datasetSpec.containerName not in in_container_name_id_map: + in_container_name_id_map[datasetSpec.containerName] = datasetID # insert secondary datasets for datasetSpec in inSecDatasetSpecList: datasetSpec.creationTime = timeNow @@ -3685,6 +3690,29 @@ def appendDatasets_JEDI(self, jediTaskID, inMasterDatasetSpecList, inSecDatasetS datasetID = int(val) datasetSpec.datasetID = datasetID goDefined = True + # get existing input constituent datasets + varMap = {} + varMap[":jediTaskID"] = jediTaskID + varMap[":type_input_constituent"] = JediDatasetSpec.get_constituent_input_type() + sql_get_constituent = "SELECT datasetName " + sql_get_constituent += f"FROM {panda_config.schemaJEDI}.JEDI_Datasets " + sql_get_constituent += f"WHERE jediTaskID=:jediTaskID AND type=:type_input_constituent " + self.cur.execute(sql_get_constituent + comment, varMap) + res_ds = self.cur.fetchall() + existing_constituent_datasets = [i[0] for i in res_ds] + # insert constituent datasets + for datasetSpec in in_content_dataset_specs: + if datasetSpec.datasetName in existing_constituent_datasets: + continue + if datasetSpec.containerName and datasetSpec.containerName in in_container_name_id_map: + datasetSpec.masterID = in_container_name_id_map[datasetSpec.containerName] + datasetSpec.creationTime = timeNow + datasetSpec.modificationTime = timeNow + varMap = datasetSpec.valuesMap(useSeq=True) + varMap[":newDatasetID"] = self.cur.var(varNUMBER) + # insert dataset + tmpLog.debug(f"append constituent {datasetSpec.datasetName}") + self.cur.execute(sqlID + comment, varMap) # update task deft_staus = None sqlUT = f"UPDATE {panda_config.schemaJEDI}.JEDI_Tasks " diff --git a/pandaserver/taskbuffer/db_proxy_mods/task_complex_module.py b/pandaserver/taskbuffer/db_proxy_mods/task_complex_module.py index 42606ab5c..d76680e27 100644 --- a/pandaserver/taskbuffer/db_proxy_mods/task_complex_module.py +++ b/pandaserver/taskbuffer/db_proxy_mods/task_complex_module.py @@ -1,5 +1,6 @@ import copy import datetime +import json import math import os import random @@ -24,6 +25,7 @@ from pandaserver.taskbuffer.db_proxy_mods.metrics_module import get_metrics_module from pandaserver.taskbuffer.db_proxy_mods.task_event_module import get_task_event_module from pandaserver.taskbuffer.db_proxy_mods.task_utils_module import get_task_utils_module +from pandaserver.taskbuffer.DdmSpec import DOWNTIME_STATUSES from pandaserver.taskbuffer.InputChunk import InputChunk from pandaserver.taskbuffer.JediDatasetSpec import ( INPUT_TYPES_var_map, @@ -207,6 +209,7 @@ def insertFilesForDataset_JEDI( order_by, maxFileRecords, skip_short_output, + lfn_constituent_map=None, ): comment = " /* JediDBProxy.insertFilesForDataset_JEDI */" tmpLog = self.create_tagged_logger(comment, f"jediTaskID={datasetSpec.jediTaskID} datasetID={datasetSpec.datasetID}") @@ -442,6 +445,8 @@ def insertFilesForDataset_JEDI( fileSpec.maxAttempt = maxAttempt fileSpec.maxFailure = maxFailure fileSpec.ramCount = ramCount + if lfn_constituent_map: + fileSpec.constituent_id = lfn_constituent_map.get(str(fileVal["lfn"])) tmpNumEvents = None if "events" in fileVal: try: @@ -2640,10 +2645,63 @@ def _read_unprocessed_inputs( :param ds_with_fake_co_jumbo: Set of datasets with fake co-jumbo. :return: Tuple of (input chunks, the typical number of files per job). """ + # find constituent datasets that are available only at RSEs in downtime, to skip their files + skip_constituent_ids = set() + # get RSEs of constituent datasets + sql_constituent_rses = ( + f"SELECT d.datasetID, l.rse FROM {panda_config.schemaJEDI}.JEDI_Datasets d, {panda_config.schemaJEDI}.JEDI_Dataset_Locality l " + "WHERE d.jediTaskID=:jediTaskID AND d.masterID=:masterID AND d.type=:type " + "AND l.jediTaskID=d.jediTaskID AND l.datasetID=d.datasetID " + ) + var_map_constituent_rses = { + ":jediTaskID": jedi_task_id, + ":masterID": primary_dataset_id, + ":type": JediDatasetSpec.get_constituent_input_type(), + } + self.cur.execute(sql_constituent_rses + comment, var_map_constituent_rses) + constituent_rses_map = {} # datasetID -> set of RSEs + for tmp_const_id, tmp_rse in self.cur.fetchall(): + constituent_rses_map.setdefault(tmp_const_id, set()).add(tmp_rse) + if constituent_rses_map: + # check read_lan status of each RSE in ddm_endpoint + sql_ddm_status = f"SELECT detailed_status FROM {panda_config.schemaPANDA}.ddm_endpoint WHERE ddm_endpoint_name=:rse " + rse_in_downtime = {} # rse -> bool + for tmp_rse in {rse for rses in constituent_rses_map.values() for rse in rses}: + self.cur.execute(sql_ddm_status + comment, {":rse": tmp_rse}) + res_ddm = self.cur.fetchone() + in_downtime = False + if res_ddm and res_ddm[0]: + try: + detailed_status = res_ddm[0] + if not isinstance(detailed_status, dict): + detailed_status = json.loads(detailed_status) + if detailed_status.get("read_lan") in DOWNTIME_STATUSES: + in_downtime = True + except Exception: + pass + rse_in_downtime[tmp_rse] = in_downtime + # a constituent dataset is in downtime when all of its RSEs are in downtime + for tmp_const_id, tmp_rses in constituent_rses_map.items(): + if tmp_rses and all(rse_in_downtime.get(rse, False) for rse in tmp_rses): + skip_constituent_ids.add(tmp_const_id) + if skip_constituent_ids: + tmp_log.debug(f"jediTaskID={jedi_task_id} skipping files of constituent datasets in downtime: {sorted(skip_constituent_ids)}") + # build a filter to skip files belonging to downtime constituent datasets + constituent_filter = "" + constituent_var_map = {} + if skip_constituent_ids: + bind_keys = [] + for i, tmp_const_id in enumerate(skip_constituent_ids): + key = f":constituent_{i}" + bind_keys.append(key) + constituent_var_map[key] = tmp_const_id + constituent_filter = f"AND (constituent_id IS NULL OR constituent_id NOT IN ({','.join(bind_keys)})) " + # sql to read files sql_read_files = f"SELECT * FROM (SELECT {JediFileSpec.columnNames()} " sql_read_files += f"FROM {panda_config.schemaJEDI}.JEDI_Dataset_Contents WHERE " sql_read_files += "jediTaskID=:jediTaskID AND datasetID=:datasetID " + sql_read_files += constituent_filter if not simulation_with_file_stat: sql_read_files += "AND status=:status AND (maxAttempt IS NULL OR attemptNr