Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 10 additions & 20 deletions pandajedi/jedibrokerage/AtlasBrokerUtils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from pandaserver.dataservice import DataServiceUtils
from pandaserver.dataservice.DataServiceUtils import select_scope
from pandaserver.taskbuffer import JobUtils, ProcessGroups, SiteSpec
from pandaserver.taskbuffer.DdmSpec import DOWNTIME_STATUSES


# get nuclei where data is available
Expand Down Expand Up @@ -1094,10 +1095,7 @@ def check(

# check vendor
if host_gpu_spec["vendor"] != "*":
if not wn_gpus or not any(
g.get("vendor") and re.match(host_gpu_spec["vendor"], g["vendor"], re.IGNORECASE)
for g in wn_gpus
):
if not wn_gpus or not any(g.get("vendor") and re.match(host_gpu_spec["vendor"], g["vendor"], re.IGNORECASE) for g in wn_gpus):
continue

# check model (include or exclude pattern)
Expand All @@ -1110,19 +1108,13 @@ def check(
model_excl = False
if not wn_gpus:
continue
matches = any(
g.get("model") and re.match(model_pattern, g["model"], re.IGNORECASE)
for g in wn_gpus
)
matches = any(g.get("model") and re.match(model_pattern, g["model"], re.IGNORECASE) for g in wn_gpus)
if matches == model_excl:
continue

# check VRAM (in MB); supports operators: ==, >=, <=, >, <, != (e.g. ">=40960")
if "vram" in host_gpu_spec:
if not wn_gpus or not any(
g.get("vram") and compare_version_string(str(g["vram"]), host_gpu_spec["vram"])
for g in wn_gpus
):
if not wn_gpus or not any(g.get("vram") and compare_version_string(str(g["vram"]), host_gpu_spec["vram"]) for g in wn_gpus):
continue

# check GPU microarchitecture generation (e.g. Ampere, Hopper, Ada Lovelace)
Expand All @@ -1136,16 +1128,14 @@ def check(
# check minimum CUDA version
if "version" in host_gpu_spec:
if not wn_gpus or not any(
g.get("framework_version") and compare_version_string(g["framework_version"], host_gpu_spec["version"])
for g in wn_gpus
g.get("framework_version") and compare_version_string(g["framework_version"], host_gpu_spec["version"]) for g in wn_gpus
):
continue

# check minimum GPU driver version (kernel driver, e.g. 575.57.08)
if "driver_version" in host_gpu_spec:
if not wn_gpus or not any(
g.get("driver_version") and compare_version_string(g["driver_version"], host_gpu_spec["driver_version"])
for g in wn_gpus
g.get("driver_version") and compare_version_string(g["driver_version"], host_gpu_spec["driver_version"]) for g in wn_gpus
):
continue
go_ahead = True
Expand Down Expand Up @@ -1271,12 +1261,12 @@ def check_endpoints_with_blacklist(
tmp_read_input_over_lan = tmp_input_endpoint["detailed_status"].get("read_lan")
tmp_receive_input_over_wan = tmp_input_endpoint["detailed_status"].get("write_wan")
# can read input from local
if tmp_read_input_over_lan not in ["OFF", "TEST"]:
if tmp_read_input_over_lan not in DOWNTIME_STATUSES:
read_input_over_lan = True
# can receive input from remote to local
if tmp_site_name not in sites_in_nucleus:
# satellite sites
if tmp_receive_input_over_wan not in ["OFF", "TEST"]:
if tmp_receive_input_over_wan not in DOWNTIME_STATUSES:
receive_input_over_wan = True
else:
# NA for nucleus sites
Expand All @@ -1287,12 +1277,12 @@ def check_endpoints_with_blacklist(
tmp_write_output_over_lan = tmp_output_endpoint["detailed_status"].get("write_lan")
tmp_send_output_over_wan = tmp_output_endpoint["detailed_status"].get("read_wan")
# can write output to local
if tmp_write_output_over_lan not in ["OFF", "TEST"]:
if tmp_write_output_over_lan not in DOWNTIME_STATUSES:
write_output_over_lan = True
# can send output from local to remote
if tmp_site_name not in sites_in_nucleus:
# satellite sites
if tmp_send_output_over_wan not in ["OFF", "TEST"]:
if tmp_send_output_over_wan not in DOWNTIME_STATUSES:
send_output_over_wan = True
else:
# NA for nucleus sites
Expand Down
3 changes: 2 additions & 1 deletion pandajedi/jedibrokerage/AtlasProdJobBroker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from pandaserver.dataservice.DataServiceUtils import select_scope
from pandaserver.srvcore import CoreUtils
from pandaserver.taskbuffer import EventServiceUtils, JobUtils
from pandaserver.taskbuffer.DdmSpec import DOWNTIME_STATUSES

from . import AtlasBrokerUtils
from .JobBrokerBase import JobBrokerBase
Expand Down Expand Up @@ -262,7 +263,7 @@ def doBrokerage(self, taskSpec, cloudName, inputChunk, taskParamMap, hintForTB=F
if not default_endpoint_out:
continue
receive_output_over_wan = default_endpoint_out["detailed_status"].get("write_wan")
if receive_output_over_wan in ["OFF", "TEST"]:
if receive_output_over_wan in DOWNTIME_STATUSES:
nucleus_with_storages_unwritable_over_wan[tmp_name] = receive_output_over_wan

else:
Expand Down
5 changes: 3 additions & 2 deletions pandajedi/jedibrokerage/AtlasProdTaskBroker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from pandajedi.jedirefine import RefinerUtils
from pandaserver.dataservice import DataServiceUtils
from pandaserver.taskbuffer.DdmSpec import DOWNTIME_STATUSES

from . import AtlasBrokerUtils
from .AtlasProdJobBroker import AtlasProdJobBroker
Expand Down Expand Up @@ -345,14 +346,14 @@ def runImpl(self):
break
# check blacklist
read_wan_status = tmpEP["detailed_status"].get("read_wan")
if read_wan_status in ["OFF", "TEST"]:
if read_wan_status in DOWNTIME_STATUSES:
tmpLog.info(
f" skip nucleus={tmpNucleus} since {tmp_ddm_endpoint_name} has read_wan={read_wan_status} criteria=-source_blacklist"
)
to_skip = True
break
write_wan_status = tmpEP["detailed_status"].get("write_wan")
if write_wan_status in ["OFF", "TEST"]:
if write_wan_status in DOWNTIME_STATUSES:
tmpLog.info(
f" skip nucleus={tmpNucleus} since {tmp_ddm_endpoint_name} has write_wan={write_wan_status} criteria=-destination_blacklist"
)
Expand Down
8 changes: 6 additions & 2 deletions pandajedi/jedicore/JediTaskBuffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def insertFilesForDataset_JEDI(
order_by,
maxFileRecords,
skip_short_output,
lfn_constituent_map=None,
):
with self.proxyPool.get() as proxy:
return proxy.insertFilesForDataset_JEDI(
Expand Down Expand Up @@ -111,6 +112,7 @@ def insertFilesForDataset_JEDI(
order_by,
maxFileRecords,
skip_short_output,
lfn_constituent_map=lfn_constituent_map,
)

# get files from the JEDI contents table with jediTaskID and/or datasetID
Expand Down Expand Up @@ -409,6 +411,7 @@ def registerTaskInOneShot_JEDI(
unmergeDatasetSpecMap,
uniqueTaskName,
oldTaskStatus,
in_content_dataset_spec_list,
):
with self.proxyPool.get() as proxy:
return proxy.registerTaskInOneShot_JEDI(
Expand All @@ -424,6 +427,7 @@ def registerTaskInOneShot_JEDI(
unmergeDatasetSpecMap,
uniqueTaskName,
oldTaskStatus,
in_content_dataset_spec_list,
)

# set tasks to be assigned
Expand Down Expand Up @@ -590,9 +594,9 @@ def retryTask_JEDI(
)

# append input datasets for incremental execution
def appendDatasets_JEDI(self, jediTaskID, inMasterDatasetSpecList, inSecDatasetSpecList):
def appendDatasets_JEDI(self, jediTaskID, inMasterDatasetSpecList, inSecDatasetSpecList, in_content_dataset_specs):
with self.proxyPool.get() as proxy:
return proxy.appendDatasets_JEDI(jediTaskID, inMasterDatasetSpecList, inSecDatasetSpecList)
return proxy.appendDatasets_JEDI(jediTaskID, inMasterDatasetSpecList, inSecDatasetSpecList, in_content_dataset_specs)

# record retry history
def recordRetryHistory_JEDI(self, jediTaskID, oldNewPandaIDs, relationType):
Expand Down
4 changes: 3 additions & 1 deletion pandajedi/jedidog/AtlasAnalWatchDog.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from pandajedi.jedicore.MsgWrapper import MsgWrapper
from pandaserver.dataservice.activator import Activator
from pandaserver.taskbuffer.JediDatasetSpec import JediDatasetSpec

from .TypicalWatchDogBase import TypicalWatchDogBase

Expand Down Expand Up @@ -752,7 +753,7 @@ def do_periodic_action(self):
{"type=.+": {"lifetime": lifetime}, "(SCRATCH|USER)DISK": {"lifetime": lifetime}},
)
# get input datasets
_, tmp_datasets = self.taskBufferIF.getDatasetsWithJediTaskID_JEDI(task_id, ["input"])
_, tmp_datasets = self.taskBufferIF.getDatasetsWithJediTaskID_JEDI(task_id, ["input", JediDatasetSpec.get_constituent_input_type()])
for dataset_spec in tmp_datasets:
# get locations
rses = self.taskBufferIF.get_dataset_locality(task_id, dataset_spec.datasetID)
Expand All @@ -768,6 +769,7 @@ def do_periodic_action(self):
f"reset frozen time for taskID={task_id} since all locations {rses} of input dataset {dataset_spec.datasetName} are in downtime"
)
self.taskBufferIF.reset_frozen_time_for_task(task_id)
break

except Exception as e:
tmp_log.error(f"failed with {str(e)}{traceback.format_exc()}")
18 changes: 18 additions & 0 deletions pandajedi/jediorder/ContentsFeeder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pandajedi.jedicore.MsgWrapper import MsgWrapper
from pandajedi.jedicore.ThreadUtils import ListWithLock, ThreadPool, WorkerThread
from pandajedi.jedirefine import RefinerUtils
from pandaserver.taskbuffer.JediDatasetSpec import JediDatasetSpec

from .JediKnight import JediKnight

Expand Down Expand Up @@ -129,6 +130,12 @@ def feed_contents_to_tasks(self, task_ds_list, real_run=True):
if not tmpStat or taskSpec is None:
self.logger.debug(f"failed to get taskSpec for jediTaskID={jediTaskID}")
continue
# get constituent datasets grouped by their master input datasetID
constituent_by_master = {}
_, c_datasets = self.taskBufferIF.getDatasetsWithJediTaskID_JEDI(jediTaskID, [JediDatasetSpec.get_constituent_input_type()])
if c_datasets:
for c_ds in c_datasets:
constituent_by_master.setdefault(c_ds.masterID, []).append((c_ds.datasetID, c_ds.datasetName))

# make logger
try:
Expand Down Expand Up @@ -531,6 +538,16 @@ def feed_contents_to_tasks(self, task_ds_list, real_run=True):
orderBy = taskSpec.order_input_by()
else:
orderBy = None
# build LFN -> constituent datasetID map for this input dataset
lfn_constituent_map = {}
if datasetSpec.datasetID in constituent_by_master:
for c_id, c_name in constituent_by_master[datasetSpec.datasetID]:
try:
c_files = ddmIF.getFilesInDataset(c_name)
for f_data in c_files.values():
lfn_constituent_map[str(f_data["lfn"])] = c_id
except Exception:
tmpLog.warning(f"failed to get files for constituent dataset {c_name}")
# feed files to the contents table
tmpLog.debug("update contents")
res_dict = self.taskBufferIF.insertFilesForDataset_JEDI(
Expand Down Expand Up @@ -568,6 +585,7 @@ def feed_contents_to_tasks(self, task_ds_list, real_run=True):
orderBy,
maxFileRecords,
skip_short_output,
lfn_constituent_map=lfn_constituent_map,
)
retDB = res_dict["ret_val"]
missingFileList = res_dict["missingFileList"]
Expand Down
5 changes: 4 additions & 1 deletion pandajedi/jediorder/TaskRefiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,7 @@ def runImpl(self):
impl.unmergeDatasetSpecMap,
uniqueTaskName,
taskStatus,
impl.in_content_dataset_specs,
)
if not tmpStat:
tmpErrStr = "failed to register the task to JEDI in a single shot"
Expand Down Expand Up @@ -592,7 +593,9 @@ def runImpl(self):
# update task with new params
self.taskBufferIF.updateTask_JEDI(impl.taskSpec, {"jediTaskID": impl.taskSpec.jediTaskID}, oldStatus=[taskStatus])
# appending for incremental execution
tmpStat = self.taskBufferIF.appendDatasets_JEDI(jediTaskID, impl.inMasterDatasetSpec, impl.inSecDatasetSpecList)
tmpStat = self.taskBufferIF.appendDatasets_JEDI(
jediTaskID, impl.inMasterDatasetSpec, impl.inSecDatasetSpecList, impl.in_content_dataset_specs
)
if not tmpStat:
tmpLog.error("failed to append datasets for incexec")
except Exception:
Expand Down
Loading
Loading